Medal-S-V1.0 / data /resampling_test.py
spc819's picture
Upload 69 files
7f3dfd7 verified
from typing import Union, Tuple, List
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
import time
from copy import deepcopy
from default_resampling import determine_do_sep_z_and_axis
import psutil
import nibabel as nib
import os
from pathlib import Path
ANISO_THRESHOLD = 3
def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]:
"""Compute new shape based on spacing ratios."""
current_shape = np.array(current_shape)
current_spacing = np.array(current_spacing)
target_spacing = np.array(target_spacing)
return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)]
def optimized_3d_resample(
data: Union[torch.Tensor, np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
target_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
device: torch.device = torch.device('cpu'),
num_threads: int = 8,
chunk_size: int = 64,
force_separate_z: Union[bool, None] = None,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
preserve_range: bool = True
) -> Union[torch.Tensor, np.ndarray]:
"""
Optimized 3D image resampling with adaptive interpolation and chunked processing.
Args:
data: Input 3D volume [C, D, H, W] or [D, H, W]
current_spacing: Current voxel spacing (z, y, x)
target_spacing: Target voxel spacing (z, y, x)
is_seg: Whether the input is a segmentation mask
device: Torch device for computation
num_threads: Number of threads for CPU operations
chunk_size: Size of chunks for large volume processing
force_separate_z: Force separate z resampling
separate_z_anisotropy_threshold: Threshold for anisotropic resampling
preserve_range: Preserve original value range for non-segmentation data
Returns:
Resampled 3D volume
"""
print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}")
input_was_numpy = isinstance(data, np.ndarray)
if input_was_numpy:
data = torch.from_numpy(data).to(device)
else:
data = data.to(device)
print(f"Input converted to tensor on {device}, shape: {data.shape}")
if data.ndim == 3:
data = data.unsqueeze(0)
assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)"
new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing)
print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}")
if all(i == j for i, j in zip(new_shape, data.shape[1:])):
print("No resampling needed, shapes identical.")
return data.cpu().numpy() if input_was_numpy else data
mode = 'nearest' if is_seg else 'trilinear'
aniso_axis_mode = 'nearest-exact' if is_seg else 'linear'
print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}")
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing,
target_spacing, separate_z_anisotropy_threshold)
print(f"Do separate Z: {do_separate_z}, Axis: {axis}")
if preserve_range and not is_seg:
v_min, v_max = data.min(), data.max()
print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}")
torch.set_num_threads(num_threads)
print(f"Set number of threads to {num_threads}")
start_time = time.time()
if do_separate_z:
tmp = "xyz"
axis_letter = tmp[axis]
others_int = [i for i in range(3) if i != axis]
others = [tmp[i] for i in others_int]
print(f"Separate Z resampling along axis {axis_letter}, others: {others}")
tmp_new_shape = [new_shape[i] for i in others_int]
print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}")
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
print(f"Rearranged data shape: {data.shape}")
data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg)
print(f"After first pass resampling, shape: {data.shape}")
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
**{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
print(f"Rearranged back to shape: {data.shape}")
data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg)
print(f"After second pass resampling, final shape: {data.shape}")
else:
print(f"Direct resampling to shape: {new_shape}")
data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg)
print(f"After direct resampling, final shape: {data.shape}")
resample_time = time.time() - start_time
print(f"Resampling completed in {resample_time:.3f}s")
if is_seg:
unique_values = torch.unique(data)
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
data = data.round().to(result_dtype)
print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}")
if preserve_range and not is_seg:
data = torch.clamp(data, v_min, v_max)
print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}")
output = data.cpu().numpy() if input_was_numpy else data
print(f"Output shape: {output.shape}, type: {type(output)}")
return output
def _chunked_resample(
volume: torch.Tensor,
target_shape: Tuple[int, ...],
mode: str,
chunk_size: int,
device: torch.device,
is_seg: bool
) -> torch.Tensor:
"""Chunked resampling for large volumes with adaptive chunk sizing."""
print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}")
C, D, H, W = volume.shape
tD, tH, tW = target_shape
# Adaptive chunk size based on available memory
if device.type == 'cpu':
available_memory = psutil.virtual_memory().available / 1024**2 # in MB
else:
total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2 # in MB
allocated_memory = torch.cuda.memory_allocated(device) / 1024**2
available_memory = total_memory - allocated_memory
mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel()
target_voxel_count = C * tD * tH * tW
chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3
adaptive_chunk_size = max(
32,
min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3)))
)
# Early return for small volumes
if D * H * W <= 128**3:
with torch.cuda.amp.autocast(enabled=not is_seg):
start_time = time.time()
# Cast to float for interpolation if is_seg and mode is nearest
input_tensor = volume.float() if is_seg and mode == 'nearest' else volume
result = F.interpolate(
input_tensor.unsqueeze(0),
size=target_shape,
mode=mode,
align_corners=False if mode != 'nearest' else None
).squeeze(0)
# Convert back to original dtype for segmentation
if is_seg:
result = result.round().to(volume.dtype)
# print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}")
return result
result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype)
out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W)))
for c in range(C):
for z in range(0, tD, out_chunk_size):
z_end = min(z + out_chunk_size, tD)
for y in range(0, tH, out_chunk_size):
y_end = min(y + out_chunk_size, tH)
for x in range(0, tW, out_chunk_size):
x_end = min(x + out_chunk_size, tW)
in_z = max(0, int(z * D / tD) - 1)
in_z_end = min(D, int(z_end * D / tD) + 2)
in_y = max(0, int(y * H / tH) - 1)
in_y_end = min(H, int(y_end * H / tH) + 2)
in_x = max(0, int(x * W / tW) - 1)
in_x_end = min(W, int(x_end * W / tW) + 2)
chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end]
chunk_target = (z_end - z, y_end - y, x_end - x)
with torch.cuda.amp.autocast(enabled=not is_seg):
start_time = time.time()
# Cast to float for interpolation if is_seg and mode is nearest
input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk
resampled_chunk = F.interpolate(
input_chunk.unsqueeze(0),
size=chunk_target,
mode=mode,
align_corners=False if mode != 'nearest' else None
).squeeze(0)
# Convert back to original dtype for segmentation
if is_seg:
resampled_chunk = resampled_chunk.round().to(volume.dtype)
# print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}")
result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk
del chunk, resampled_chunk
if device.type == 'cuda':
torch.cuda.empty_cache()
return result
def resample_torch_simple(
data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
is_seg: bool = False,
num_threads: int = 4,
device: torch.device = torch.device('cpu'),
memefficient_seg_resampling: bool = False,
mode: str = 'linear'
) -> Union[torch.Tensor, np.ndarray]:
if mode == 'linear':
torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear'
else:
torch_mode = mode
if isinstance(new_shape, np.ndarray):
new_shape = [int(i) for i in new_shape]
if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
return data
n_threads = torch.get_num_threads()
torch.set_num_threads(num_threads)
new_shape = tuple(new_shape)
with torch.no_grad():
input_was_numpy = isinstance(data, np.ndarray)
if input_was_numpy:
data = torch.from_numpy(data).to(device)
else:
orig_device = deepcopy(data.device)
data = data.to(device)
if is_seg:
unique_values = torch.unique(data)
result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
if not memefficient_seg_resampling:
result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
device=device)
scale_factor = 1000
done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
for i, u in enumerate(unique_values):
result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
antialias=False)[0]
mask = result_tmp[i] > (0.7 * scale_factor)
result[mask] = u.item()
done_mask |= mask
if not torch.all(done_mask):
result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
else:
for i, u in enumerate(unique_values):
if u == 0:
continue
result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u
else:
result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]
if input_was_numpy:
result = result.cpu().numpy()
else:
result = result.to(orig_device)
torch.set_num_threads(n_threads)
return result
def resample_torch_fornnunet(
data: Union[torch.Tensor, np.ndarray],
new_shape: Union[Tuple[int, ...], List[int], np.ndarray],
current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],
is_seg: bool = False,
num_threads: int = 4,
device: torch.device = torch.device('cpu'),
memefficient_seg_resampling: bool = False,
force_separate_z: Union[bool, None] = None,
separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,
mode: str = 'linear',
aniso_axis_mode: str = 'nearest-exact'
) -> Union[torch.Tensor, np.ndarray]:
assert data.ndim == 4, "data must be c, x, y, z"
new_shape = [int(i) for i in new_shape]
orig_shape = data.shape
do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
separate_z_anisotropy_threshold)
if do_separate_z:
was_numpy = isinstance(data, np.ndarray)
if was_numpy:
data = torch.from_numpy(data)
if isinstance(axis, list):
axis = axis[0]
tmp = "xyz"
axis_letter = tmp[axis]
others_int = [i for i in range(3) if i != axis]
others = [tmp[i] for i in others_int]
data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
tmp_new_shape = [new_shape[i] for i in others_int]
data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
**{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
if was_numpy:
data = data.numpy()
return data
else:
return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)
def dice_score(pred: np.ndarray, true: np.ndarray) -> float:
"""Compute Dice score for segmentation masks."""
pred = pred.flatten()
true = true.flatten()
intersection = np.sum(pred * true)
return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8)
# Placeholder for compute_new_shape if not provided
def compute_new_shape(original_shape, current_spacing, target_spacing):
"""
Compute the new shape based on the spacing ratio.
original_shape: (z, y, x)
current_spacing: (z, y, x)
target_spacing: (z, y, x)
"""
zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)]
new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)]
return tuple(new_shape)
# Function to save as NIfTI
def save_nii(array, spacing, output_path, is_seg=False):
"""
Save numpy array as NIfTI file with specified spacing.
is_seg: If True, convert to int32 for segmentation masks.
"""
# Convert torch tensor to numpy if necessary
if isinstance(array, torch.Tensor):
array = array.cpu().numpy()
# Convert data type for NIfTI compatibility
if is_seg:
array = array.astype(np.int32) # Convert segmentation to int32
else:
array = array.astype(np.float32) # Ensure image is float32
# Transpose to (X, Y, Z, C) for NIfTI
if array.ndim == 4:
array = array.transpose(2, 3, 1, 0) # From (C, Z, Y, X) to (X, Y, Z, C)
else:
array = array.transpose(2, 3, 1) # From (Z, Y, X) to (X, Y, Z)
# Create NIfTI image with affine based on spacing
affine = np.diag(list(spacing) + [1.0])
nii_img = nib.Nifti1Image(array, affine=affine)
nib.save(nii_img, output_path)
print(f"Saved: {output_path}")
# Main resampling function
def main():
torch.set_num_threads(4)
device = torch.device('cuda') #torch.device('cpu') # Force CPU as per provided code
print(f"\nRunning tests on device: {device}")
# Define paths
npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz"
gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz"
output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample"
# Ensure output directory exists
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Load input data
data = np.load(npz_file_path, allow_pickle=True)
img_array = data['imgs'] # Shape: (C, Z, Y, X) or (Z, Y, X)
img_spacing = data['spacing'] # (z, y, x)
img_spacing = [1.0, 1.0, 1.0] # Override as per provided code
gt_data = np.load(gt_path, allow_pickle=True)
gt_array = gt_data['gts'] # Shape: (C, Z, Y, X) or (Z, Y, X)
# Convert data types to PyTorch-compatible types
img_array = img_array.astype(np.float32) # Convert image to float32
gt_array = gt_array.astype(np.int32) # Convert segmentation mask to int32
# Ensure img_array and gt_array have channel dimension
if img_array.ndim == 3:
img_array = img_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
if gt_array.ndim == 3:
gt_array = gt_array[np.newaxis, ...] # Add channel dimension: (1, Z, Y, X)
# Define target spacings to test
target_spacings = [
(1.2, 1.2, 1.2),
(1.5, 1.5, 1.5),
(2.0, 2.0, 2.0),
]
# Original shape and spacing
original_shape = img_array.shape[1:] # (Z, Y, X)
current_spacing = img_spacing
print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}")
for target_spacing in target_spacings:
print(f"\n=== Resampling to Target Spacing: {target_spacing} ===")
# Compute new shape
new_shape = compute_new_shape(original_shape, current_spacing, target_spacing)
print(f"Computed target shape: {new_shape}")
# === Image Resampling ===
print("\nResampling image...")
# Ground truth resampling
print("Computing ground truth with resample_torch_simple...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize() # Ensure GPU operations are complete
gt_img = resample_torch_simple(
img_array,
new_shape=new_shape,
is_seg=False,
num_threads=4,
device=device
)
if device.type == 'cuda':
torch.cuda.synchronize() # Ensure GPU operations are complete
gt_time = time.time() - start_time
output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s")
save_nii(gt_img, target_spacing, output_path, is_seg=False)
# Optimized resampling
print("Running optimized_3d_resample...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize()
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
resampled_img_opt = optimized_3d_resample(
img_array,
current_spacing,
target_spacing,
is_seg=False,
device=device,
num_threads=4,
chunk_size=64
)
if device.type == 'cuda':
torch.cuda.synchronize()
opt_time = time.time() - start_time
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
opt_mae = np.mean(np.abs(resampled_img_opt - gt_img))
output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, "
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}")
save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False)
# Original resampling
print("Running resample_torch_fornnunet...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize()
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
resampled_img_orig = resample_torch_fornnunet(
img_array,
new_shape,
current_spacing,
target_spacing,
is_seg=False,
num_threads=4,
device=device
)
if device.type == 'cuda':
torch.cuda.synchronize()
orig_time = time.time() - start_time
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
orig_mae = np.mean(np.abs(resampled_img_orig - gt_img))
output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, "
f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}")
save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False)
# === Segmentation Mask Resampling ===
print("\nResampling segmentation mask...")
# Ground truth resampling
print("Computing ground truth with resample_torch_simple...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize()
gt_seg = resample_torch_simple(
gt_array,
new_shape=new_shape,
is_seg=True,
num_threads=4,
device=device
)
if device.type == 'cuda':
torch.cuda.synchronize()
gt_seg_time = time.time() - start_time
output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s")
save_nii(gt_seg, target_spacing, output_path, is_seg=True)
# Optimized resampling
print("Running optimized_3d_resample for segmentation...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize()
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
resampled_seg_opt = optimized_3d_resample(
gt_array,
current_spacing,
target_spacing,
is_seg=True,
device=device,
num_threads=4,
chunk_size=64
)
if device.type == 'cuda':
torch.cuda.synchronize()
opt_seg_time = time.time() - start_time
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
opt_dice = dice_score(resampled_seg_opt, gt_seg)
output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, "
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}")
save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True)
# Original resampling
print("Running resample_torch_fornnunet for segmentation...")
start_time = time.time()
if device.type == 'cuda':
torch.cuda.synchronize()
mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
resampled_seg_orig = resample_torch_fornnunet(
gt_array,
new_shape,
current_spacing,
target_spacing,
is_seg=True,
num_threads=4,
device=device
)
if device.type == 'cuda':
torch.cuda.synchronize()
orig_seg_time = time.time() - start_time
mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
orig_dice = dice_score(resampled_seg_orig, gt_seg)
output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, "
f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}")
save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True)
# Summary
print(f"\n=== Summary for Target Spacing: {target_spacing} ===")
print("Image Resampling Metrics:")
print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}")
print(f"Original - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}")
print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%")
print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%")
print("Segmentation Mask Resampling Metrics:")
print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}")
print(f"Original - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}")
print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%")
print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%")
if __name__ == '__main__':
main()