from typing import Union, Tuple, List import numpy as np import torch import torch.nn.functional as F from einops import rearrange import time from copy import deepcopy from default_resampling import determine_do_sep_z_and_axis import psutil import nibabel as nib import os from pathlib import Path ANISO_THRESHOLD = 3 def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]: """Compute new shape based on spacing ratios.""" current_shape = np.array(current_shape) current_spacing = np.array(current_spacing) target_spacing = np.array(target_spacing) return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)] def optimized_3d_resample( data: Union[torch.Tensor, np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], target_spacing: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, device: torch.device = torch.device('cpu'), num_threads: int = 8, chunk_size: int = 64, force_separate_z: Union[bool, None] = None, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, preserve_range: bool = True ) -> Union[torch.Tensor, np.ndarray]: """ Optimized 3D image resampling with adaptive interpolation and chunked processing. Args: data: Input 3D volume [C, D, H, W] or [D, H, W] current_spacing: Current voxel spacing (z, y, x) target_spacing: Target voxel spacing (z, y, x) is_seg: Whether the input is a segmentation mask device: Torch device for computation num_threads: Number of threads for CPU operations chunk_size: Size of chunks for large volume processing force_separate_z: Force separate z resampling separate_z_anisotropy_threshold: Threshold for anisotropic resampling preserve_range: Preserve original value range for non-segmentation data Returns: Resampled 3D volume """ print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}") input_was_numpy = isinstance(data, np.ndarray) if input_was_numpy: data = torch.from_numpy(data).to(device) else: data = data.to(device) print(f"Input converted to tensor on {device}, shape: {data.shape}") if data.ndim == 3: data = data.unsqueeze(0) assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)" new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing) print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}") if all(i == j for i, j in zip(new_shape, data.shape[1:])): print("No resampling needed, shapes identical.") return data.cpu().numpy() if input_was_numpy else data mode = 'nearest' if is_seg else 'trilinear' aniso_axis_mode = 'nearest-exact' if is_seg else 'linear' print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}") do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, target_spacing, separate_z_anisotropy_threshold) print(f"Do separate Z: {do_separate_z}, Axis: {axis}") if preserve_range and not is_seg: v_min, v_max = data.min(), data.max() print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}") torch.set_num_threads(num_threads) print(f"Set number of threads to {num_threads}") start_time = time.time() if do_separate_z: tmp = "xyz" axis_letter = tmp[axis] others_int = [i for i in range(3) if i != axis] others = [tmp[i] for i in others_int] print(f"Separate Z resampling along axis {axis_letter}, others: {others}") tmp_new_shape = [new_shape[i] for i in others_int] print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}") data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") print(f"Rearranged data shape: {data.shape}") data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg) print(f"After first pass resampling, shape: {data.shape}") data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", **{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]}) print(f"Rearranged back to shape: {data.shape}") data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg) print(f"After second pass resampling, final shape: {data.shape}") else: print(f"Direct resampling to shape: {new_shape}") data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg) print(f"After direct resampling, final shape: {data.shape}") resample_time = time.time() - start_time print(f"Resampling completed in {resample_time:.3f}s") if is_seg: unique_values = torch.unique(data) result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 data = data.round().to(result_dtype) print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}") if preserve_range and not is_seg: data = torch.clamp(data, v_min, v_max) print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}") output = data.cpu().numpy() if input_was_numpy else data print(f"Output shape: {output.shape}, type: {type(output)}") return output def _chunked_resample( volume: torch.Tensor, target_shape: Tuple[int, ...], mode: str, chunk_size: int, device: torch.device, is_seg: bool ) -> torch.Tensor: """Chunked resampling for large volumes with adaptive chunk sizing.""" print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}") C, D, H, W = volume.shape tD, tH, tW = target_shape # Adaptive chunk size based on available memory if device.type == 'cpu': available_memory = psutil.virtual_memory().available / 1024**2 # in MB else: total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2 # in MB allocated_memory = torch.cuda.memory_allocated(device) / 1024**2 available_memory = total_memory - allocated_memory mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel() target_voxel_count = C * tD * tH * tW chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3 adaptive_chunk_size = max( 32, min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3))) ) # Early return for small volumes if D * H * W <= 128**3: with torch.cuda.amp.autocast(enabled=not is_seg): start_time = time.time() # Cast to float for interpolation if is_seg and mode is nearest input_tensor = volume.float() if is_seg and mode == 'nearest' else volume result = F.interpolate( input_tensor.unsqueeze(0), size=target_shape, mode=mode, align_corners=False if mode != 'nearest' else None ).squeeze(0) # Convert back to original dtype for segmentation if is_seg: result = result.round().to(volume.dtype) # print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}") return result result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype) out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W))) for c in range(C): for z in range(0, tD, out_chunk_size): z_end = min(z + out_chunk_size, tD) for y in range(0, tH, out_chunk_size): y_end = min(y + out_chunk_size, tH) for x in range(0, tW, out_chunk_size): x_end = min(x + out_chunk_size, tW) in_z = max(0, int(z * D / tD) - 1) in_z_end = min(D, int(z_end * D / tD) + 2) in_y = max(0, int(y * H / tH) - 1) in_y_end = min(H, int(y_end * H / tH) + 2) in_x = max(0, int(x * W / tW) - 1) in_x_end = min(W, int(x_end * W / tW) + 2) chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end] chunk_target = (z_end - z, y_end - y, x_end - x) with torch.cuda.amp.autocast(enabled=not is_seg): start_time = time.time() # Cast to float for interpolation if is_seg and mode is nearest input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk resampled_chunk = F.interpolate( input_chunk.unsqueeze(0), size=chunk_target, mode=mode, align_corners=False if mode != 'nearest' else None ).squeeze(0) # Convert back to original dtype for segmentation if is_seg: resampled_chunk = resampled_chunk.round().to(volume.dtype) # print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}") result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk del chunk, resampled_chunk if device.type == 'cuda': torch.cuda.empty_cache() return result def resample_torch_simple( data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], is_seg: bool = False, num_threads: int = 4, device: torch.device = torch.device('cpu'), memefficient_seg_resampling: bool = False, mode: str = 'linear' ) -> Union[torch.Tensor, np.ndarray]: if mode == 'linear': torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear' else: torch_mode = mode if isinstance(new_shape, np.ndarray): new_shape = [int(i) for i in new_shape] if all([i == j for i, j in zip(new_shape, data.shape[1:])]): return data n_threads = torch.get_num_threads() torch.set_num_threads(num_threads) new_shape = tuple(new_shape) with torch.no_grad(): input_was_numpy = isinstance(data, np.ndarray) if input_was_numpy: data = torch.from_numpy(data).to(device) else: orig_device = deepcopy(data.device) data = data.to(device) if is_seg: unique_values = torch.unique(data) result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16 result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device) if not memefficient_seg_resampling: result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16, device=device) scale_factor = 1000 done_mask = torch.zeros_like(result, dtype=torch.bool, device=device) for i, u in enumerate(unique_values): result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode, antialias=False)[0] mask = result_tmp[i] > (0.7 * scale_factor) result[mask] = u.item() done_mask |= mask if not torch.all(done_mask): result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype) else: for i, u in enumerate(unique_values): if u == 0: continue result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u else: result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0] if input_was_numpy: result = result.cpu().numpy() else: result = result.to(orig_device) torch.set_num_threads(n_threads) return result def resample_torch_fornnunet( data: Union[torch.Tensor, np.ndarray], new_shape: Union[Tuple[int, ...], List[int], np.ndarray], current_spacing: Union[Tuple[float, ...], List[float], np.ndarray], new_spacing: Union[Tuple[float, ...], List[float], np.ndarray], is_seg: bool = False, num_threads: int = 4, device: torch.device = torch.device('cpu'), memefficient_seg_resampling: bool = False, force_separate_z: Union[bool, None] = None, separate_z_anisotropy_threshold: float = ANISO_THRESHOLD, mode: str = 'linear', aniso_axis_mode: str = 'nearest-exact' ) -> Union[torch.Tensor, np.ndarray]: assert data.ndim == 4, "data must be c, x, y, z" new_shape = [int(i) for i in new_shape] orig_shape = data.shape do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing, separate_z_anisotropy_threshold) if do_separate_z: was_numpy = isinstance(data, np.ndarray) if was_numpy: data = torch.from_numpy(data) if isinstance(axis, list): axis = axis[0] tmp = "xyz" axis_letter = tmp[axis] others_int = [i for i in range(3) if i != axis] others = [tmp[i] for i in others_int] data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}") tmp_new_shape = [new_shape[i] for i in others_int] data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device, memefficient_seg_resampling=memefficient_seg_resampling, mode=mode) data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z", **{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]}) data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device, memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode) if was_numpy: data = data.numpy() return data else: return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling) def dice_score(pred: np.ndarray, true: np.ndarray) -> float: """Compute Dice score for segmentation masks.""" pred = pred.flatten() true = true.flatten() intersection = np.sum(pred * true) return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8) # Placeholder for compute_new_shape if not provided def compute_new_shape(original_shape, current_spacing, target_spacing): """ Compute the new shape based on the spacing ratio. original_shape: (z, y, x) current_spacing: (z, y, x) target_spacing: (z, y, x) """ zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)] new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)] return tuple(new_shape) # Function to save as NIfTI def save_nii(array, spacing, output_path, is_seg=False): """ Save numpy array as NIfTI file with specified spacing. is_seg: If True, convert to int32 for segmentation masks. """ # Convert torch tensor to numpy if necessary if isinstance(array, torch.Tensor): array = array.cpu().numpy() # Convert data type for NIfTI compatibility if is_seg: array = array.astype(np.int32) # Convert segmentation to int32 else: array = array.astype(np.float32) # Ensure image is float32 # Transpose to (X, Y, Z, C) for NIfTI if array.ndim == 4: array = array.transpose(2, 3, 1, 0) # From (C, Z, Y, X) to (X, Y, Z, C) else: array = array.transpose(2, 3, 1) # From (Z, Y, X) to (X, Y, Z) # Create NIfTI image with affine based on spacing affine = np.diag(list(spacing) + [1.0]) nii_img = nib.Nifti1Image(array, affine=affine) nib.save(nii_img, output_path) print(f"Saved: {output_path}") # Main resampling function def main(): torch.set_num_threads(4) device = torch.device('cuda') #torch.device('cpu') # Force CPU as per provided code print(f"\nRunning tests on device: {device}") # Define paths npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz" gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz" output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample" # Ensure output directory exists if not os.path.exists(output_dir): os.makedirs(output_dir) # Load input data data = np.load(npz_file_path, allow_pickle=True) img_array = data['imgs'] # Shape: (C, Z, Y, X) or (Z, Y, X) img_spacing = data['spacing'] # (z, y, x) img_spacing = [1.0, 1.0, 1.0] # Override as per provided code gt_data = np.load(gt_path, allow_pickle=True) gt_array = gt_data['gts'] # Shape: (C, Z, Y, X) or (Z, Y, X) # Convert data types to PyTorch-compatible types img_array = img_array.astype(np.float32) # Convert image to float32 gt_array = gt_array.astype(np.int32) # Convert segmentation mask to int32 # Ensure img_array and gt_array have channel dimension if img_array.ndim == 3: img_array = img_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X) if gt_array.ndim == 3: gt_array = gt_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X) # Define target spacings to test target_spacings = [ (1.2, 1.2, 1.2), (1.5, 1.5, 1.5), (2.0, 2.0, 2.0), ] # Original shape and spacing original_shape = img_array.shape[1:] # (Z, Y, X) current_spacing = img_spacing print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}") for target_spacing in target_spacings: print(f"\n=== Resampling to Target Spacing: {target_spacing} ===") # Compute new shape new_shape = compute_new_shape(original_shape, current_spacing, target_spacing) print(f"Computed target shape: {new_shape}") # === Image Resampling === print("\nResampling image...") # Ground truth resampling print("Computing ground truth with resample_torch_simple...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() # Ensure GPU operations are complete gt_img = resample_torch_simple( img_array, new_shape=new_shape, is_seg=False, num_threads=4, device=device ) if device.type == 'cuda': torch.cuda.synchronize() # Ensure GPU operations are complete gt_time = time.time() - start_time output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s") save_nii(gt_img, target_spacing, output_path, is_seg=False) # Optimized resampling print("Running optimized_3d_resample...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 resampled_img_opt = optimized_3d_resample( img_array, current_spacing, target_spacing, is_seg=False, device=device, num_threads=4, chunk_size=64 ) if device.type == 'cuda': torch.cuda.synchronize() opt_time = time.time() - start_time mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 opt_mae = np.mean(np.abs(resampled_img_opt - gt_img)) output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, " f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}") save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False) # Original resampling print("Running resample_torch_fornnunet...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 resampled_img_orig = resample_torch_fornnunet( img_array, new_shape, current_spacing, target_spacing, is_seg=False, num_threads=4, device=device ) if device.type == 'cuda': torch.cuda.synchronize() orig_time = time.time() - start_time mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 orig_mae = np.mean(np.abs(resampled_img_orig - gt_img)) output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, " f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}") save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False) # === Segmentation Mask Resampling === print("\nResampling segmentation mask...") # Ground truth resampling print("Computing ground truth with resample_torch_simple...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() gt_seg = resample_torch_simple( gt_array, new_shape=new_shape, is_seg=True, num_threads=4, device=device ) if device.type == 'cuda': torch.cuda.synchronize() gt_seg_time = time.time() - start_time output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s") save_nii(gt_seg, target_spacing, output_path, is_seg=True) # Optimized resampling print("Running optimized_3d_resample for segmentation...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 resampled_seg_opt = optimized_3d_resample( gt_array, current_spacing, target_spacing, is_seg=True, device=device, num_threads=4, chunk_size=64 ) if device.type == 'cuda': torch.cuda.synchronize() opt_seg_time = time.time() - start_time mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 opt_dice = dice_score(resampled_seg_opt, gt_seg) output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, " f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}") save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True) # Original resampling print("Running resample_torch_fornnunet for segmentation...") start_time = time.time() if device.type == 'cuda': torch.cuda.synchronize() mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 resampled_seg_orig = resample_torch_fornnunet( gt_array, new_shape, current_spacing, target_spacing, is_seg=True, num_threads=4, device=device ) if device.type == 'cuda': torch.cuda.synchronize() orig_seg_time = time.time() - start_time mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2 orig_dice = dice_score(resampled_seg_orig, gt_seg) output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz") print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, " f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}") save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True) # Summary print(f"\n=== Summary for Target Spacing: {target_spacing} ===") print("Image Resampling Metrics:") print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}") print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}") print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%") print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%") print("Segmentation Mask Resampling Metrics:") print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}") print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}") print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%") print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%") if __name__ == '__main__': main()