|
|
from typing import Union, Tuple, List
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
import time
|
|
|
from copy import deepcopy
|
|
|
from default_resampling import determine_do_sep_z_and_axis
|
|
|
import psutil
|
|
|
import nibabel as nib
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
|
|
|
ANISO_THRESHOLD = 3
|
|
|
|
|
|
def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
|
|
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
|
|
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]:
|
|
|
"""Compute new shape based on spacing ratios."""
|
|
|
current_shape = np.array(current_shape)
|
|
|
current_spacing = np.array(current_spacing)
|
|
|
target_spacing = np.array(target_spacing)
|
|
|
return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)]
|
|
|
|
|
|
def optimized_3d_resample(
|
|
|
data: Union[torch.Tensor, np.ndarray],
|
|
|
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
|
|
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
|
|
is_seg: bool = False,
|
|
|
device: torch.device = torch.device('cpu'),
|
|
|
num_threads: int = 8,
|
|
|
chunk_size: int = 64,
|
|
|
force_separate_z: Union[bool, None] = None,
|
|
|
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
|
|
|
preserve_range: bool = True
|
|
|
) -> Union[torch.Tensor, np.ndarray]:
|
|
|
"""
|
|
|
Optimized 3D image resampling with adaptive interpolation and chunked processing.
|
|
|
|
|
|
Args:
|
|
|
data: Input 3D volume [C, D, H, W] or [D, H, W]
|
|
|
current_spacing: Current voxel spacing (z, y, x)
|
|
|
target_spacing: Target voxel spacing (z, y, x)
|
|
|
is_seg: Whether the input is a segmentation mask
|
|
|
device: Torch device for computation
|
|
|
num_threads: Number of threads for CPU operations
|
|
|
chunk_size: Size of chunks for large volume processing
|
|
|
force_separate_z: Force separate z resampling
|
|
|
separate_z_anisotropy_threshold: Threshold for anisotropic resampling
|
|
|
preserve_range: Preserve original value range for non-segmentation data
|
|
|
|
|
|
Returns:
|
|
|
Resampled 3D volume
|
|
|
"""
|
|
|
print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}")
|
|
|
input_was_numpy = isinstance(data, np.ndarray)
|
|
|
if input_was_numpy:
|
|
|
data = torch.from_numpy(data).to(device)
|
|
|
else:
|
|
|
data = data.to(device)
|
|
|
print(f"Input converted to tensor on {device}, shape: {data.shape}")
|
|
|
|
|
|
if data.ndim == 3:
|
|
|
data = data.unsqueeze(0)
|
|
|
assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)"
|
|
|
|
|
|
new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing)
|
|
|
print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}")
|
|
|
|
|
|
if all(i == j for i, j in zip(new_shape, data.shape[1:])):
|
|
|
print("No resampling needed, shapes identical.")
|
|
|
return data.cpu().numpy() if input_was_numpy else data
|
|
|
|
|
|
mode = 'nearest' if is_seg else 'trilinear'
|
|
|
aniso_axis_mode = 'nearest-exact' if is_seg else 'linear'
|
|
|
print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}")
|
|
|
|
|
|
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing,
|
|
|
target_spacing, separate_z_anisotropy_threshold)
|
|
|
print(f"Do separate Z: {do_separate_z}, Axis: {axis}")
|
|
|
|
|
|
if preserve_range and not is_seg:
|
|
|
v_min, v_max = data.min(), data.max()
|
|
|
print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}")
|
|
|
|
|
|
torch.set_num_threads(num_threads)
|
|
|
print(f"Set number of threads to {num_threads}")
|
|
|
|
|
|
start_time = time.time()
|
|
|
if do_separate_z:
|
|
|
tmp = "xyz"
|
|
|
axis_letter = tmp[axis]
|
|
|
others_int = [i for i in range(3) if i != axis]
|
|
|
others = [tmp[i] for i in others_int]
|
|
|
print(f"Separate Z resampling along axis {axis_letter}, others: {others}")
|
|
|
|
|
|
tmp_new_shape = [new_shape[i] for i in others_int]
|
|
|
print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}")
|
|
|
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
|
|
|
print(f"Rearranged data shape: {data.shape}")
|
|
|
data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg)
|
|
|
print(f"After first pass resampling, shape: {data.shape}")
|
|
|
|
|
|
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
|
|
|
**{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
|
|
|
print(f"Rearranged back to shape: {data.shape}")
|
|
|
data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg)
|
|
|
print(f"After second pass resampling, final shape: {data.shape}")
|
|
|
else:
|
|
|
print(f"Direct resampling to shape: {new_shape}")
|
|
|
data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg)
|
|
|
print(f"After direct resampling, final shape: {data.shape}")
|
|
|
resample_time = time.time() - start_time
|
|
|
print(f"Resampling completed in {resample_time:.3f}s")
|
|
|
|
|
|
if is_seg:
|
|
|
unique_values = torch.unique(data)
|
|
|
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
|
|
|
data = data.round().to(result_dtype)
|
|
|
print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}")
|
|
|
|
|
|
if preserve_range and not is_seg:
|
|
|
data = torch.clamp(data, v_min, v_max)
|
|
|
print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}")
|
|
|
|
|
|
output = data.cpu().numpy() if input_was_numpy else data
|
|
|
print(f"Output shape: {output.shape}, type: {type(output)}")
|
|
|
return output
|
|
|
|
|
|
def _chunked_resample(
|
|
|
volume: torch.Tensor,
|
|
|
target_shape: Tuple[int, ...],
|
|
|
mode: str,
|
|
|
chunk_size: int,
|
|
|
device: torch.device,
|
|
|
is_seg: bool
|
|
|
) -> torch.Tensor:
|
|
|
"""Chunked resampling for large volumes with adaptive chunk sizing."""
|
|
|
print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}")
|
|
|
C, D, H, W = volume.shape
|
|
|
tD, tH, tW = target_shape
|
|
|
|
|
|
|
|
|
if device.type == 'cpu':
|
|
|
available_memory = psutil.virtual_memory().available / 1024**2
|
|
|
else:
|
|
|
total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2
|
|
|
allocated_memory = torch.cuda.memory_allocated(device) / 1024**2
|
|
|
available_memory = total_memory - allocated_memory
|
|
|
|
|
|
mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel()
|
|
|
target_voxel_count = C * tD * tH * tW
|
|
|
chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3
|
|
|
adaptive_chunk_size = max(
|
|
|
32,
|
|
|
min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3)))
|
|
|
)
|
|
|
|
|
|
|
|
|
if D * H * W <= 128**3:
|
|
|
with torch.cuda.amp.autocast(enabled=not is_seg):
|
|
|
start_time = time.time()
|
|
|
|
|
|
input_tensor = volume.float() if is_seg and mode == 'nearest' else volume
|
|
|
result = F.interpolate(
|
|
|
input_tensor.unsqueeze(0),
|
|
|
size=target_shape,
|
|
|
mode=mode,
|
|
|
align_corners=False if mode != 'nearest' else None
|
|
|
).squeeze(0)
|
|
|
|
|
|
if is_seg:
|
|
|
result = result.round().to(volume.dtype)
|
|
|
|
|
|
return result
|
|
|
|
|
|
result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype)
|
|
|
|
|
|
out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W)))
|
|
|
|
|
|
for c in range(C):
|
|
|
for z in range(0, tD, out_chunk_size):
|
|
|
z_end = min(z + out_chunk_size, tD)
|
|
|
for y in range(0, tH, out_chunk_size):
|
|
|
y_end = min(y + out_chunk_size, tH)
|
|
|
for x in range(0, tW, out_chunk_size):
|
|
|
x_end = min(x + out_chunk_size, tW)
|
|
|
|
|
|
in_z = max(0, int(z * D / tD) - 1)
|
|
|
in_z_end = min(D, int(z_end * D / tD) + 2)
|
|
|
in_y = max(0, int(y * H / tH) - 1)
|
|
|
in_y_end = min(H, int(y_end * H / tH) + 2)
|
|
|
in_x = max(0, int(x * W / tW) - 1)
|
|
|
in_x_end = min(W, int(x_end * W / tW) + 2)
|
|
|
|
|
|
chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end]
|
|
|
chunk_target = (z_end - z, y_end - y, x_end - x)
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=not is_seg):
|
|
|
start_time = time.time()
|
|
|
|
|
|
input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk
|
|
|
resampled_chunk = F.interpolate(
|
|
|
input_chunk.unsqueeze(0),
|
|
|
size=chunk_target,
|
|
|
mode=mode,
|
|
|
align_corners=False if mode != 'nearest' else None
|
|
|
).squeeze(0)
|
|
|
|
|
|
if is_seg:
|
|
|
resampled_chunk = resampled_chunk.round().to(volume.dtype)
|
|
|
|
|
|
|
|
|
result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk
|
|
|
del chunk, resampled_chunk
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
return result
|
|
|
|
|
|
def resample_torch_simple(
|
|
|
data: Union[torch.Tensor, np.ndarray],
|
|
|
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
|
|
is_seg: bool = False,
|
|
|
num_threads: int = 4,
|
|
|
device: torch.device = torch.device('cpu'),
|
|
|
memefficient_seg_resampling: bool = False,
|
|
|
mode: str = 'linear'
|
|
|
) -> Union[torch.Tensor, np.ndarray]:
|
|
|
if mode == 'linear':
|
|
|
torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear'
|
|
|
else:
|
|
|
torch_mode = mode
|
|
|
|
|
|
if isinstance(new_shape, np.ndarray):
|
|
|
new_shape = [int(i) for i in new_shape]
|
|
|
|
|
|
if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
|
|
|
return data
|
|
|
|
|
|
n_threads = torch.get_num_threads()
|
|
|
torch.set_num_threads(num_threads)
|
|
|
new_shape = tuple(new_shape)
|
|
|
with torch.no_grad():
|
|
|
input_was_numpy = isinstance(data, np.ndarray)
|
|
|
if input_was_numpy:
|
|
|
data = torch.from_numpy(data).to(device)
|
|
|
else:
|
|
|
orig_device = deepcopy(data.device)
|
|
|
data = data.to(device)
|
|
|
|
|
|
if is_seg:
|
|
|
unique_values = torch.unique(data)
|
|
|
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
|
|
|
result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
|
|
|
if not memefficient_seg_resampling:
|
|
|
result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
|
|
|
device=device)
|
|
|
scale_factor = 1000
|
|
|
done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
|
|
|
for i, u in enumerate(unique_values):
|
|
|
result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
|
|
|
antialias=False)[0]
|
|
|
mask = result_tmp[i] > (0.7 * scale_factor)
|
|
|
result[mask] = u.item()
|
|
|
done_mask |= mask
|
|
|
if not torch.all(done_mask):
|
|
|
result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
|
|
|
else:
|
|
|
for i, u in enumerate(unique_values):
|
|
|
if u == 0:
|
|
|
continue
|
|
|
result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u
|
|
|
else:
|
|
|
result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
|
|
|
|
|
|
if input_was_numpy:
|
|
|
result = result.cpu().numpy()
|
|
|
else:
|
|
|
result = result.to(orig_device)
|
|
|
|
|
|
torch.set_num_threads(n_threads)
|
|
|
return result
|
|
|
|
|
|
def resample_torch_fornnunet(
|
|
|
data: Union[torch.Tensor, np.ndarray],
|
|
|
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
|
|
|
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
|
|
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
|
|
|
is_seg: bool = False,
|
|
|
num_threads: int = 4,
|
|
|
device: torch.device = torch.device('cpu'),
|
|
|
memefficient_seg_resampling: bool = False,
|
|
|
force_separate_z: Union[bool, None] = None,
|
|
|
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
|
|
|
mode: str = 'linear',
|
|
|
aniso_axis_mode: str = 'nearest-exact'
|
|
|
) -> Union[torch.Tensor, np.ndarray]:
|
|
|
assert data.ndim == 4, "data must be c, x, y, z"
|
|
|
new_shape = [int(i) for i in new_shape]
|
|
|
orig_shape = data.shape
|
|
|
|
|
|
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
|
|
|
separate_z_anisotropy_threshold)
|
|
|
|
|
|
if do_separate_z:
|
|
|
was_numpy = isinstance(data, np.ndarray)
|
|
|
if was_numpy:
|
|
|
data = torch.from_numpy(data)
|
|
|
|
|
|
if isinstance(axis, list):
|
|
|
axis = axis[0]
|
|
|
|
|
|
tmp = "xyz"
|
|
|
axis_letter = tmp[axis]
|
|
|
others_int = [i for i in range(3) if i != axis]
|
|
|
others = [tmp[i] for i in others_int]
|
|
|
|
|
|
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
|
|
|
tmp_new_shape = [new_shape[i] for i in others_int]
|
|
|
data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
|
|
memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
|
|
|
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
|
|
|
**{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
|
|
|
data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
|
|
|
memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
|
|
|
if was_numpy:
|
|
|
data = data.numpy()
|
|
|
return data
|
|
|
else:
|
|
|
return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
|
|
|
|
|
|
def dice_score(pred: np.ndarray, true: np.ndarray) -> float:
|
|
|
"""Compute Dice score for segmentation masks."""
|
|
|
pred = pred.flatten()
|
|
|
true = true.flatten()
|
|
|
intersection = np.sum(pred * true)
|
|
|
return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8)
|
|
|
|
|
|
|
|
|
def compute_new_shape(original_shape, current_spacing, target_spacing):
|
|
|
"""
|
|
|
Compute the new shape based on the spacing ratio.
|
|
|
original_shape: (z, y, x)
|
|
|
current_spacing: (z, y, x)
|
|
|
target_spacing: (z, y, x)
|
|
|
"""
|
|
|
zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)]
|
|
|
new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)]
|
|
|
return tuple(new_shape)
|
|
|
|
|
|
|
|
|
def save_nii(array, spacing, output_path, is_seg=False):
|
|
|
"""
|
|
|
Save numpy array as NIfTI file with specified spacing.
|
|
|
is_seg: If True, convert to int32 for segmentation masks.
|
|
|
"""
|
|
|
|
|
|
if isinstance(array, torch.Tensor):
|
|
|
array = array.cpu().numpy()
|
|
|
|
|
|
|
|
|
if is_seg:
|
|
|
array = array.astype(np.int32)
|
|
|
else:
|
|
|
array = array.astype(np.float32)
|
|
|
|
|
|
|
|
|
if array.ndim == 4:
|
|
|
array = array.transpose(2, 3, 1, 0)
|
|
|
else:
|
|
|
array = array.transpose(2, 3, 1)
|
|
|
|
|
|
|
|
|
affine = np.diag(list(spacing) + [1.0])
|
|
|
nii_img = nib.Nifti1Image(array, affine=affine)
|
|
|
nib.save(nii_img, output_path)
|
|
|
print(f"Saved: {output_path}")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
torch.set_num_threads(4)
|
|
|
device = torch.device('cuda')
|
|
|
print(f"\nRunning tests on device: {device}")
|
|
|
|
|
|
|
|
|
npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz"
|
|
|
gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz"
|
|
|
output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample"
|
|
|
|
|
|
|
|
|
if not os.path.exists(output_dir):
|
|
|
os.makedirs(output_dir)
|
|
|
|
|
|
|
|
|
data = np.load(npz_file_path, allow_pickle=True)
|
|
|
img_array = data['imgs']
|
|
|
img_spacing = data['spacing']
|
|
|
img_spacing = [1.0, 1.0, 1.0]
|
|
|
gt_data = np.load(gt_path, allow_pickle=True)
|
|
|
gt_array = gt_data['gts']
|
|
|
|
|
|
|
|
|
img_array = img_array.astype(np.float32)
|
|
|
gt_array = gt_array.astype(np.int32)
|
|
|
|
|
|
|
|
|
if img_array.ndim == 3:
|
|
|
img_array = img_array[np.newaxis, ...]
|
|
|
if gt_array.ndim == 3:
|
|
|
gt_array = gt_array[np.newaxis, ...]
|
|
|
|
|
|
|
|
|
target_spacings = [
|
|
|
(1.2, 1.2, 1.2),
|
|
|
(1.5, 1.5, 1.5),
|
|
|
(2.0, 2.0, 2.0),
|
|
|
]
|
|
|
|
|
|
|
|
|
original_shape = img_array.shape[1:]
|
|
|
current_spacing = img_spacing
|
|
|
print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}")
|
|
|
|
|
|
for target_spacing in target_spacings:
|
|
|
print(f"\n=== Resampling to Target Spacing: {target_spacing} ===")
|
|
|
|
|
|
|
|
|
new_shape = compute_new_shape(original_shape, current_spacing, target_spacing)
|
|
|
print(f"Computed target shape: {new_shape}")
|
|
|
|
|
|
|
|
|
print("\nResampling image...")
|
|
|
|
|
|
|
|
|
print("Computing ground truth with resample_torch_simple...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
gt_img = resample_torch_simple(
|
|
|
img_array,
|
|
|
new_shape=new_shape,
|
|
|
is_seg=False,
|
|
|
num_threads=4,
|
|
|
device=device
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
gt_time = time.time() - start_time
|
|
|
output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s")
|
|
|
save_nii(gt_img, target_spacing, output_path, is_seg=False)
|
|
|
|
|
|
|
|
|
print("Running optimized_3d_resample...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
resampled_img_opt = optimized_3d_resample(
|
|
|
img_array,
|
|
|
current_spacing,
|
|
|
target_spacing,
|
|
|
is_seg=False,
|
|
|
device=device,
|
|
|
num_threads=4,
|
|
|
chunk_size=64
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
opt_time = time.time() - start_time
|
|
|
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
opt_mae = np.mean(np.abs(resampled_img_opt - gt_img))
|
|
|
output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, "
|
|
|
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}")
|
|
|
save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False)
|
|
|
|
|
|
|
|
|
print("Running resample_torch_fornnunet...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
resampled_img_orig = resample_torch_fornnunet(
|
|
|
img_array,
|
|
|
new_shape,
|
|
|
current_spacing,
|
|
|
target_spacing,
|
|
|
is_seg=False,
|
|
|
num_threads=4,
|
|
|
device=device
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
orig_time = time.time() - start_time
|
|
|
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
orig_mae = np.mean(np.abs(resampled_img_orig - gt_img))
|
|
|
output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, "
|
|
|
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}")
|
|
|
save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False)
|
|
|
|
|
|
|
|
|
print("\nResampling segmentation mask...")
|
|
|
|
|
|
|
|
|
print("Computing ground truth with resample_torch_simple...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
gt_seg = resample_torch_simple(
|
|
|
gt_array,
|
|
|
new_shape=new_shape,
|
|
|
is_seg=True,
|
|
|
num_threads=4,
|
|
|
device=device
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
gt_seg_time = time.time() - start_time
|
|
|
output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s")
|
|
|
save_nii(gt_seg, target_spacing, output_path, is_seg=True)
|
|
|
|
|
|
|
|
|
print("Running optimized_3d_resample for segmentation...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
resampled_seg_opt = optimized_3d_resample(
|
|
|
gt_array,
|
|
|
current_spacing,
|
|
|
target_spacing,
|
|
|
is_seg=True,
|
|
|
device=device,
|
|
|
num_threads=4,
|
|
|
chunk_size=64
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
opt_seg_time = time.time() - start_time
|
|
|
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
opt_dice = dice_score(resampled_seg_opt, gt_seg)
|
|
|
output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, "
|
|
|
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}")
|
|
|
save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True)
|
|
|
|
|
|
|
|
|
print("Running resample_torch_fornnunet for segmentation...")
|
|
|
start_time = time.time()
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
resampled_seg_orig = resample_torch_fornnunet(
|
|
|
gt_array,
|
|
|
new_shape,
|
|
|
current_spacing,
|
|
|
target_spacing,
|
|
|
is_seg=True,
|
|
|
num_threads=4,
|
|
|
device=device
|
|
|
)
|
|
|
if device.type == 'cuda':
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
orig_seg_time = time.time() - start_time
|
|
|
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
|
|
|
orig_dice = dice_score(resampled_seg_orig, gt_seg)
|
|
|
output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
|
|
|
print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, "
|
|
|
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}")
|
|
|
save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True)
|
|
|
|
|
|
|
|
|
print(f"\n=== Summary for Target Spacing: {target_spacing} ===")
|
|
|
print("Image Resampling Metrics:")
|
|
|
print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}")
|
|
|
print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}")
|
|
|
print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%")
|
|
|
print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%")
|
|
|
print("Segmentation Mask Resampling Metrics:")
|
|
|
print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}")
|
|
|
print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}")
|
|
|
print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%")
|
|
|
print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main() |