File size: 27,646 Bytes
7f3dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
from typing import Union, Tuple, List
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
import time
from copy import deepcopy
from default_resampling import determine_do_sep_z_and_axis
import psutil
import nibabel as nib
import os
from pathlib import Path

ANISO_THRESHOLD = 3

def compute_new_shape(current_shape: Union[Tuple[int, ...], List[int], np.ndarray],

                     current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],

                     target_spacing: Union[Tuple[float, ...], List[float], np.ndarray]) -> List[int]:
    """Compute new shape based on spacing ratios."""
    current_shape = np.array(current_shape)
    current_spacing = np.array(current_spacing)
    target_spacing = np.array(target_spacing)
    return [int(round(s * (cs / ts))) for s, cs, ts in zip(current_shape, current_spacing, target_spacing)]

def optimized_3d_resample(

    data: Union[torch.Tensor, np.ndarray],

    current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],

    target_spacing: Union[Tuple[float, ...], List[float], np.ndarray],

    is_seg: bool = False,

    device: torch.device = torch.device('cpu'),

    num_threads: int = 8,

    chunk_size: int = 64,

    force_separate_z: Union[bool, None] = None,

    separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,

    preserve_range: bool = True

) -> Union[torch.Tensor, np.ndarray]:
    """

    Optimized 3D image resampling with adaptive interpolation and chunked processing.



    Args:

        data: Input 3D volume [C, D, H, W] or [D, H, W]

        current_spacing: Current voxel spacing (z, y, x)

        target_spacing: Target voxel spacing (z, y, x)

        is_seg: Whether the input is a segmentation mask

        device: Torch device for computation

        num_threads: Number of threads for CPU operations

        chunk_size: Size of chunks for large volume processing

        force_separate_z: Force separate z resampling

        separate_z_anisotropy_threshold: Threshold for anisotropic resampling

        preserve_range: Preserve original value range for non-segmentation data



    Returns:

        Resampled 3D volume

    """
    print(f"\nStarting optimized_3d_resample with input shape: {data.shape}, is_seg: {is_seg}")
    input_was_numpy = isinstance(data, np.ndarray)
    if input_was_numpy:
        data = torch.from_numpy(data).to(device)
    else:
        data = data.to(device)
    print(f"Input converted to tensor on {device}, shape: {data.shape}")

    if data.ndim == 3:
        data = data.unsqueeze(0)
    assert data.ndim == 4, "Data must be 3D or 4D (C, D, H, W)"

    new_shape = compute_new_shape(data.shape[1:], current_spacing, target_spacing)
    print(f"Computed new shape: {new_shape} from current_spacing: {current_spacing}, target_spacing: {target_spacing}")

    if all(i == j for i, j in zip(new_shape, data.shape[1:])):
        print("No resampling needed, shapes identical.")
        return data.cpu().numpy() if input_was_numpy else data

    mode = 'nearest' if is_seg else 'trilinear'
    aniso_axis_mode = 'nearest-exact' if is_seg else 'linear'
    print(f"Interpolation mode: {mode}, Anisotropic axis mode: {aniso_axis_mode}")

    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, 
                                                     target_spacing, separate_z_anisotropy_threshold)
    print(f"Do separate Z: {do_separate_z}, Axis: {axis}")

    if preserve_range and not is_seg:
        v_min, v_max = data.min(), data.max()
        print(f"Preserving range for non-segmentation data: min={v_min.item():.4f}, max={v_max.item():.4f}")

    torch.set_num_threads(num_threads)
    print(f"Set number of threads to {num_threads}")

    start_time = time.time()
    if do_separate_z:
        tmp = "xyz"
        axis_letter = tmp[axis]
        others_int = [i for i in range(3) if i != axis]
        others = [tmp[i] for i in others_int]
        print(f"Separate Z resampling along axis {axis_letter}, others: {others}")

        tmp_new_shape = [new_shape[i] for i in others_int]
        print(f"First pass: Resampling to shape {tmp_new_shape} for axes {others}")
        data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
        print(f"Rearranged data shape: {data.shape}")
        data = _chunked_resample(data, tmp_new_shape, mode, chunk_size, device, is_seg)
        print(f"After first pass resampling, shape: {data.shape}")

        data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
                        **{axis_letter: data.shape[1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
        print(f"Rearranged back to shape: {data.shape}")
        data = _chunked_resample(data, new_shape, aniso_axis_mode, chunk_size, device, is_seg)
        print(f"After second pass resampling, final shape: {data.shape}")
    else:
        print(f"Direct resampling to shape: {new_shape}")
        data = _chunked_resample(data, new_shape, mode, chunk_size, device, is_seg)
        print(f"After direct resampling, final shape: {data.shape}")
    resample_time = time.time() - start_time
    print(f"Resampling completed in {resample_time:.3f}s")

    if is_seg:
        unique_values = torch.unique(data)
        result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
        data = data.round().to(result_dtype)
        print(f"Segmentation data rounded and converted to {result_dtype}, unique values: {unique_values.tolist()}")

    if preserve_range and not is_seg:
        data = torch.clamp(data, v_min, v_max)
        print(f"Clamped data to original range: min={v_min.item():.4f}, max={v_max.item():.4f}")

    output = data.cpu().numpy() if input_was_numpy else data
    print(f"Output shape: {output.shape}, type: {type(output)}")
    return output

def _chunked_resample(

    volume: torch.Tensor,

    target_shape: Tuple[int, ...],

    mode: str,

    chunk_size: int,

    device: torch.device,

    is_seg: bool

) -> torch.Tensor:
    """Chunked resampling for large volumes with adaptive chunk sizing."""
    print(f"\nStarting _chunked_resample with input shape: {volume.shape}, target shape: {target_shape}")
    C, D, H, W = volume.shape
    tD, tH, tW = target_shape

    # Adaptive chunk size based on available memory
    if device.type == 'cpu':
        available_memory = psutil.virtual_memory().available / 1024**2  # in MB
    else:
        total_memory = torch.cuda.get_device_properties(device).total_memory / 1024**2  # in MB
        allocated_memory = torch.cuda.memory_allocated(device) / 1024**2
        available_memory = total_memory - allocated_memory

    mem_per_voxel = volume.element_size() * volume.nelement() / volume.numel()
    target_voxel_count = C * tD * tH * tW
    chunk_mem_ratio = 0.5 if device.type == 'cpu' else 0.3
    adaptive_chunk_size = max(
        32, 
        min(chunk_size, int((available_memory * chunk_mem_ratio / mem_per_voxel / C) ** (1/3)))
    )

    # Early return for small volumes
    if D * H * W <= 128**3:
        with torch.cuda.amp.autocast(enabled=not is_seg):
            start_time = time.time()
            # Cast to float for interpolation if is_seg and mode is nearest
            input_tensor = volume.float() if is_seg and mode == 'nearest' else volume
            result = F.interpolate(
                input_tensor.unsqueeze(0), 
                size=target_shape, 
                mode=mode, 
                align_corners=False if mode != 'nearest' else None
            ).squeeze(0)
            # Convert back to original dtype for segmentation
            if is_seg:
                result = result.round().to(volume.dtype)
            # print(f"Direct interpolation completed in {time.time() - start_time:.3f}s, output shape: {result.shape}")
            return result

    result = torch.zeros((C, tD, tH, tW), device=device, dtype=volume.dtype)

    out_chunk_size = max(1, int(adaptive_chunk_size * min(tD/D, tH/H, tW/W)))

    for c in range(C):
        for z in range(0, tD, out_chunk_size):
            z_end = min(z + out_chunk_size, tD)
            for y in range(0, tH, out_chunk_size):
                y_end = min(y + out_chunk_size, tH)
                for x in range(0, tW, out_chunk_size):
                    x_end = min(x + out_chunk_size, tW)

                    in_z = max(0, int(z * D / tD) - 1)
                    in_z_end = min(D, int(z_end * D / tD) + 2)
                    in_y = max(0, int(y * H / tH) - 1)
                    in_y_end = min(H, int(y_end * H / tH) + 2)
                    in_x = max(0, int(x * W / tW) - 1)
                    in_x_end = min(W, int(x_end * W / tW) + 2)

                    chunk = volume[c:c+1, in_z:in_z_end, in_y:in_y_end, in_x:in_x_end]
                    chunk_target = (z_end - z, y_end - y, x_end - x)

                    with torch.cuda.amp.autocast(enabled=not is_seg):
                        start_time = time.time()
                        # Cast to float for interpolation if is_seg and mode is nearest
                        input_chunk = chunk.float() if is_seg and mode == 'nearest' else chunk
                        resampled_chunk = F.interpolate(
                            input_chunk.unsqueeze(0),
                            size=chunk_target,
                            mode=mode,
                            align_corners=False if mode != 'nearest' else None
                        ).squeeze(0)
                        # Convert back to original dtype for segmentation
                        if is_seg:
                            resampled_chunk = resampled_chunk.round().to(volume.dtype)
                        # print(f"Chunk interpolation completed in {time.time() - start_time:.3f}s, shape: {resampled_chunk.shape}")

                    result[c, z:z_end, y:y_end, x:x_end] = resampled_chunk
                    del chunk, resampled_chunk
                    if device.type == 'cuda':
                        torch.cuda.empty_cache()

    return result

def resample_torch_simple(

    data: Union[torch.Tensor, np.ndarray],

    new_shape: Union[Tuple[int, ...], List[int], np.ndarray],

    is_seg: bool = False,

    num_threads: int = 4,

    device: torch.device = torch.device('cpu'),

    memefficient_seg_resampling: bool = False,

    mode: str = 'linear'

) -> Union[torch.Tensor, np.ndarray]:
    if mode == 'linear':
        torch_mode = 'trilinear' if data.ndim == 4 else 'bilinear'
    else:
        torch_mode = mode

    if isinstance(new_shape, np.ndarray):
        new_shape = [int(i) for i in new_shape]

    if all([i == j for i, j in zip(new_shape, data.shape[1:])]):
        return data

    n_threads = torch.get_num_threads()
    torch.set_num_threads(num_threads)
    new_shape = tuple(new_shape)
    with torch.no_grad():
        input_was_numpy = isinstance(data, np.ndarray)
        if input_was_numpy:
            data = torch.from_numpy(data).to(device)
        else:
            orig_device = deepcopy(data.device)
            data = data.to(device)

        if is_seg:
            unique_values = torch.unique(data)
            result_dtype = torch.int8 if max(unique_values) < 127 else torch.int16
            result = torch.zeros((data.shape[0], *new_shape), dtype=result_dtype, device=device)
            if not memefficient_seg_resampling:
                result_tmp = torch.zeros((len(unique_values), data.shape[0], *new_shape), dtype=torch.float16,
                                       device=device)
                scale_factor = 1000
                done_mask = torch.zeros_like(result, dtype=torch.bool, device=device)
                for i, u in enumerate(unique_values):
                    result_tmp[i] = F.interpolate((data[None] == u).float() * scale_factor, new_shape, mode=torch_mode,
                                                antialias=False)[0]
                    mask = result_tmp[i] > (0.7 * scale_factor)
                    result[mask] = u.item()
                    done_mask |= mask
                if not torch.all(done_mask):
                    result[~done_mask] = unique_values[result_tmp[:, ~done_mask].argmax(0)].to(result_dtype)
            else:
                for i, u in enumerate(unique_values):
                    if u == 0:
                        continue
                    result[F.interpolate((data[None] == u).float(), new_shape, mode=torch_mode, antialias=False)[0] > 0.5] = u
        else:
            result = F.interpolate(data[None].float(), new_shape, mode=torch_mode, antialias=False)[0]

        if input_was_numpy:
            result = result.cpu().numpy()
        else:
            result = result.to(orig_device)

    torch.set_num_threads(n_threads)
    return result

def resample_torch_fornnunet(

    data: Union[torch.Tensor, np.ndarray],

    new_shape: Union[Tuple[int, ...], List[int], np.ndarray],

    current_spacing: Union[Tuple[float, ...], List[float], np.ndarray],

    new_spacing: Union[Tuple[float, ...], List[float], np.ndarray],

    is_seg: bool = False,

    num_threads: int = 4,

    device: torch.device = torch.device('cpu'),

    memefficient_seg_resampling: bool = False,

    force_separate_z: Union[bool, None] = None,

    separate_z_anisotropy_threshold: float = ANISO_THRESHOLD,

    mode: str = 'linear',

    aniso_axis_mode: str = 'nearest-exact'

) -> Union[torch.Tensor, np.ndarray]:
    assert data.ndim == 4, "data must be c, x, y, z"
    new_shape = [int(i) for i in new_shape]
    orig_shape = data.shape

    do_separate_z, axis = determine_do_sep_z_and_axis(force_separate_z, current_spacing, new_spacing,
                                                     separate_z_anisotropy_threshold)

    if do_separate_z:
        was_numpy = isinstance(data, np.ndarray)
        if was_numpy:
            data = torch.from_numpy(data)

        if isinstance(axis, list):
            axis = axis[0]

        tmp = "xyz"
        axis_letter = tmp[axis]
        others_int = [i for i in range(3) if i != axis]
        others = [tmp[i] for i in others_int]

        data = rearrange(data, f"c x y z -> (c {axis_letter}) {others[0]} {others[1]}")
        tmp_new_shape = [new_shape[i] for i in others_int]
        data = resample_torch_simple(data, tmp_new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
                                   memefficient_seg_resampling=memefficient_seg_resampling, mode=mode)
        data = rearrange(data, f"(c {axis_letter}) {others[0]} {others[1]} -> c x y z",
                        **{axis_letter: orig_shape[axis + 1], others[0]: tmp_new_shape[0], others[1]: tmp_new_shape[1]})
        data = resample_torch_simple(data, new_shape, is_seg=is_seg, num_threads=num_threads, device=device,
                                   memefficient_seg_resampling=memefficient_seg_resampling, mode=aniso_axis_mode)
        if was_numpy:
            data = data.numpy()
        return data
    else:
        return resample_torch_simple(data, new_shape, is_seg, num_threads, device, memefficient_seg_resampling)

def dice_score(pred: np.ndarray, true: np.ndarray) -> float:
    """Compute Dice score for segmentation masks."""
    pred = pred.flatten()
    true = true.flatten()
    intersection = np.sum(pred * true)
    return (2. * intersection) / (np.sum(pred) + np.sum(true) + 1e-8)

# Placeholder for compute_new_shape if not provided
def compute_new_shape(original_shape, current_spacing, target_spacing):
    """

    Compute the new shape based on the spacing ratio.

    original_shape: (z, y, x)

    current_spacing: (z, y, x)

    target_spacing: (z, y, x)

    """
    zoom_factors = [c / t for c, t in zip(current_spacing, target_spacing)]
    new_shape = [int(round(s * z)) for s, z in zip(original_shape, zoom_factors)]
    return tuple(new_shape)

# Function to save as NIfTI
def save_nii(array, spacing, output_path, is_seg=False):
    """

    Save numpy array as NIfTI file with specified spacing.

    is_seg: If True, convert to int32 for segmentation masks.

    """
    # Convert torch tensor to numpy if necessary
    if isinstance(array, torch.Tensor):
        array = array.cpu().numpy()
    
    # Convert data type for NIfTI compatibility
    if is_seg:
        array = array.astype(np.int32)  # Convert segmentation to int32
    else:
        array = array.astype(np.float32)  # Ensure image is float32
    
    # Transpose to (X, Y, Z, C) for NIfTI
    if array.ndim == 4:
        array = array.transpose(2, 3, 1, 0)  # From (C, Z, Y, X) to (X, Y, Z, C)
    else:
        array = array.transpose(2, 3, 1)  # From (Z, Y, X) to (X, Y, Z)
    
    # Create NIfTI image with affine based on spacing
    affine = np.diag(list(spacing) + [1.0])
    nii_img = nib.Nifti1Image(array, affine=affine)
    nib.save(nii_img, output_path)
    print(f"Saved: {output_path}")

# Main resampling function
def main():
    torch.set_num_threads(4)
    device = torch.device('cuda') #torch.device('cpu')  # Force CPU as per provided code
    print(f"\nRunning tests on device: {device}")

    # Define paths
    npz_file_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/inputs/Microscopy_cremi_000_sc.npz"
    gt_path = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/gts/Microscopy_cremi_000_sc.npz"
    output_dir = "/media/shipc/hhd_8T/spc/code/CVPR2025_Text_guided_seg_submission/workspace_teamx/outputs_test_resample"

    # Ensure output directory exists
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Load input data
    data = np.load(npz_file_path, allow_pickle=True)
    img_array = data['imgs']  # Shape: (C, Z, Y, X) or (Z, Y, X)
    img_spacing = data['spacing']  # (z, y, x)
    img_spacing = [1.0, 1.0, 1.0]  # Override as per provided code
    gt_data = np.load(gt_path, allow_pickle=True)
    gt_array = gt_data['gts']  # Shape: (C, Z, Y, X) or (Z, Y, X)

    # Convert data types to PyTorch-compatible types
    img_array = img_array.astype(np.float32)  # Convert image to float32
    gt_array = gt_array.astype(np.int32)      # Convert segmentation mask to int32

    # Ensure img_array and gt_array have channel dimension
    if img_array.ndim == 3:
        img_array = img_array[np.newaxis, ...]  # Add channel dimension: (1, Z, Y, X)
    if gt_array.ndim == 3:
        gt_array = gt_array[np.newaxis, ...]  # Add channel dimension: (1, Z, Y, X)

    # Define target spacings to test
    target_spacings = [
        (1.2, 1.2, 1.2),
        (1.5, 1.5, 1.5),
        (2.0, 2.0, 2.0),
    ]

    # Original shape and spacing
    original_shape = img_array.shape[1:]  # (Z, Y, X)
    current_spacing = img_spacing
    print(f"\nOriginal image shape: {original_shape}, Current spacing (z,y,x): {current_spacing}")

    for target_spacing in target_spacings:
        print(f"\n=== Resampling to Target Spacing: {target_spacing} ===")
        
        # Compute new shape
        new_shape = compute_new_shape(original_shape, current_spacing, target_spacing)
        print(f"Computed target shape: {new_shape}")

        # === Image Resampling ===
        print("\nResampling image...")
        
        # Ground truth resampling
        print("Computing ground truth with resample_torch_simple...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()  # Ensure GPU operations are complete
        gt_img = resample_torch_simple(
            img_array,
            new_shape=new_shape,
            is_seg=False,
            num_threads=4,
            device=device
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()  # Ensure GPU operations are complete
        gt_time = time.time() - start_time
        output_path = os.path.join(output_dir, f"img_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Ground truth image shape: {gt_img.shape}, Time: {gt_time:.3f}s")
        save_nii(gt_img, target_spacing, output_path, is_seg=False)

        # Optimized resampling
        print("Running optimized_3d_resample...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()
        mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        resampled_img_opt = optimized_3d_resample(
            img_array,
            current_spacing,
            target_spacing,
            is_seg=False,
            device=device,
            num_threads=4,
            chunk_size=64
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()

        opt_time = time.time() - start_time
        mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        opt_mae = np.mean(np.abs(resampled_img_opt - gt_img))
        output_path = os.path.join(output_dir, f"img_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Optimized image shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, "
              f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {opt_mae:.6f}")
        save_nii(resampled_img_opt, target_spacing, output_path, is_seg=False)

        # Original resampling
        print("Running resample_torch_fornnunet...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()
        mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        resampled_img_orig = resample_torch_fornnunet(
            img_array,
            new_shape,
            current_spacing,
            target_spacing,
            is_seg=False,
            num_threads=4,
            device=device
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()
        orig_time = time.time() - start_time
        mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        orig_mae = np.mean(np.abs(resampled_img_orig - gt_img))
        output_path = os.path.join(output_dir, f"img_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Original image shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, "
              f"Memory used: {mem_after - mem_before:.2f} MB, MAE: {orig_mae:.6f}")
        save_nii(resampled_img_orig, target_spacing, output_path, is_seg=False)

        # === Segmentation Mask Resampling ===
        print("\nResampling segmentation mask...")
        
        # Ground truth resampling
        print("Computing ground truth with resample_torch_simple...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()
        gt_seg = resample_torch_simple(
            gt_array,
            new_shape=new_shape,
            is_seg=True,
            num_threads=4,
            device=device
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()
        gt_seg_time = time.time() - start_time
        output_path = os.path.join(output_dir, f"seg_gt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Ground truth segmentation shape: {gt_seg.shape}, Time: {gt_seg_time:.3f}s")
        save_nii(gt_seg, target_spacing, output_path, is_seg=True)

        # Optimized resampling
        print("Running optimized_3d_resample for segmentation...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()
        mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        resampled_seg_opt = optimized_3d_resample(
            gt_array,
            current_spacing,
            target_spacing,
            is_seg=True,
            device=device,
            num_threads=4,
            chunk_size=64
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()

        opt_seg_time = time.time() - start_time
        mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        opt_dice = dice_score(resampled_seg_opt, gt_seg)
        output_path = os.path.join(output_dir, f"seg_opt_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Optimized segmentation shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, "
              f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {opt_dice:.6f}")
        save_nii(resampled_seg_opt, target_spacing, output_path, is_seg=True)

        # Original resampling
        print("Running resample_torch_fornnunet for segmentation...")
        start_time = time.time()
        if device.type == 'cuda':
            torch.cuda.synchronize()
        mem_before = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        resampled_seg_orig = resample_torch_fornnunet(
            gt_array,
            new_shape,
            current_spacing,
            target_spacing,
            is_seg=True,
            num_threads=4,
            device=device
        )
        if device.type == 'cuda':
            torch.cuda.synchronize()

        orig_seg_time = time.time() - start_time
        mem_after = psutil.virtual_memory().used / 1024**2 if device.type == 'cpu' else torch.cuda.memory_allocated(device) / 1024**2
        orig_dice = dice_score(resampled_seg_orig, gt_seg)
        output_path = os.path.join(output_dir, f"seg_orig_spacing_{target_spacing[0]}_{target_spacing[1]}_{target_spacing[2]}.nii.gz")
        print(f"Original segmentation shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, "
              f"Memory used: {mem_after - mem_before:.2f} MB, Dice: {orig_dice:.6f}")
        save_nii(resampled_seg_orig, target_spacing, output_path, is_seg=True)

        # Summary
        print(f"\n=== Summary for Target Spacing: {target_spacing} ===")
        print("Image Resampling Metrics:")
        print(f"Optimized - Shape: {resampled_img_opt.shape}, Time: {opt_time:.3f}s, MAE: {opt_mae:.6f}")
        print(f"Original  - Shape: {resampled_img_orig.shape}, Time: {orig_time:.3f}s, MAE: {orig_mae:.6f}")
        print(f"Time Improvement: {(orig_time - opt_time) / orig_time * 100:.2f}%")
        print(f"MAE Improvement: {(orig_mae - opt_mae) / orig_mae * 100:.2f}%")
        print("Segmentation Mask Resampling Metrics:")
        print(f"Optimized - Shape: {resampled_seg_opt.shape}, Time: {opt_seg_time:.3f}s, Dice: {opt_dice:.6f}")
        print(f"Original  - Shape: {resampled_seg_orig.shape}, Time: {orig_seg_time:.3f}s, Dice: {orig_dice:.6f}")
        print(f"Time Improvement: {(orig_seg_time - opt_seg_time) / orig_seg_time * 100:.2f}%")
        print(f"Dice Improvement: {(opt_dice - orig_dice) / orig_dice * 100:.2f}%")

if __name__ == '__main__':
    main()