maxmo2009 commited on
Commit
2af0e94
·
verified ·
1 Parent(s): 9f30236

Sync from local: code + epoch-110 checkpoint, clean README

Browse files

Replace existing repo with current local OmniMorph: full source tree (training/inference/registration scripts), Diffusion/OMorpher modules, dataloader mappings (16 datasets), and Models/all_om_net/000110_all_om_net.pth (final checkpoint, 3.0G). README rewritten to remove internal links/credentials. BERT external model and intermediate checkpoints not bundled.

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .gitignore +1 -0
  3. Config/config_om.yaml +15 -20
  4. Config/config_reg_brain.yaml +36 -0
  5. Config/config_reg_hip.yaml +48 -0
  6. Dataloader/dataLoader.py +172 -69
  7. Dataloader/dataloader_utils.py +3 -3
  8. Dataloader/deal_with_json.py +150 -0
  9. Dataloader/embding_gen.py +10 -2
  10. Dataloader/nifty_mappings/AbdomenAtlas_mappings.json +2 -2
  11. Dataloader/nifty_mappings/AbdomenCT1k_mappings.json +2 -2
  12. Dataloader/nifty_mappings/Brats2019_mappings.json +2 -2
  13. Dataloader/nifty_mappings/Brats2020_mappings.json +2 -2
  14. Dataloader/nifty_mappings/Brats2021_mappings.json +2 -2
  15. Dataloader/nifty_mappings/CIA_mappings.json +2 -2
  16. Dataloader/nifty_mappings/Kaggle_osic_mappings.json +0 -0
  17. Dataloader/nifty_mappings/MSD_mappings.json +2 -2
  18. Dataloader/nifty_mappings/MnMs_mappings.json +0 -0
  19. Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json +3 -0
  20. Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json +3 -0
  21. Dataloader/nifty_mappings/OASIS_1_mappings.json +2 -2
  22. Dataloader/nifty_mappings/OASIS_2_mappings.json +2 -2
  23. Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json +2 -2
  24. Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json +2 -2
  25. Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json +2 -2
  26. Diffusion/diffuser-reg.py +541 -0
  27. Diffusion/diffuser.py +45 -20
  28. Diffusion/diffuser_opt.py +357 -0
  29. Diffusion/losses.py +44 -7
  30. Diffusion/losses_opt.py +141 -0
  31. Diffusion/networks.py +328 -17
  32. Diffusion/networks0.py +1195 -0
  33. Diffusion/networks_opt.py +239 -0
  34. Diffusion/safe_conv_transpose.py +401 -0
  35. Models/all_om_net/000110_all_om_net.pth +3 -0
  36. OM_reg.py +10 -18
  37. OM_reg_flexres.py +382 -0
  38. OM_train_2modes-reg.py +517 -0
  39. OM_train_2modes.py +60 -69
  40. OM_train_3modes-XPU.py +957 -0
  41. OM_train_3modes.py +697 -198
  42. OM_train_3modes_cudaonly.py +512 -0
  43. OM_train_3modes_opt.py +513 -0
  44. OM_train_3modes_original.py +585 -0
  45. OMorpher/__init__.py +3 -0
  46. OMorpher/omorpher.py +1058 -0
  47. README.md +129 -80
  48. Scripts/OM_aug_om.py +239 -0
  49. Scripts/OM_reg_flexres_om.py +315 -0
  50. Scripts/OM_reg_pair_ext.py +676 -0
.gitattributes CHANGED
@@ -46,3 +46,5 @@ Dataloader/nifty_mappings/OASIS_2_mappings.json filter=lfs diff=lfs merge=lfs -t
46
  Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
47
  Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
48
  Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
 
 
 
46
  Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
47
  Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
48
  Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
49
+ Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json filter=lfs diff=lfs merge=lfs -text
50
+ Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -15,6 +15,7 @@ External/
15
 
16
  # Logs
17
  Log/
 
18
  swanlog/
19
  train_log.txt
20
  aug_log.txt
 
15
 
16
  # Logs
17
  Log/
18
+ Logs/
19
  swanlog/
20
  train_log.txt
21
  aug_log.txt
Config/config_om.yaml CHANGED
@@ -1,13 +1,13 @@
1
  data_name: all
2
- # net_name: recresacnet
3
- net_name: recmutattnnet
4
- # net_name: recmutattnnet1
5
  # net_name: defrecmutattnnet
6
  ndims: 3
7
  img_size: 128
8
- batchsize: 2
9
  ddf_pad_mode: border
10
- device: cuda
11
  img_pad_mode: zeros
12
  num_input_chn: 1
13
  padding_mode: border
@@ -19,23 +19,21 @@ v_scale: 5.0e-05
19
  epoch: 10000
20
  epoch_per_save: 1
21
  lr: 0.00001
22
- noise_scale: 0.1
23
  # =========================
24
  # AUGMENTATION SETTING
25
  patients_list: []
26
  # model_id_str: '000000'
27
  # model_id_str: '000180' # before registration training
28
- # model_id_str: '000353' # good augmentation results on msd
29
- model_id_str: '000354' #
30
  # model_id_str: '000157'
31
  # model_id_str: '000171'
32
- start_noise_step: 48 # starting from which noise step to add noise
 
33
  noise_step: 1
34
- aug_coe: 64 # how many times each sample will be augmented
35
- # start_noise_step: 56 # starting from which noise step to add noise
36
- # noise_step: 4
37
- # aug_coe: 4 # how many times each sample will be augmented
38
- condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
39
  # aug_img_savepath: Data/Aug_data/totseg/img/
40
  # aug_msk_savepath: Data/Aug_data/totseg/msk/
41
  # aug_ddf_savepath: Data/Aug_data/totseg/ddf/
@@ -45,9 +43,6 @@ condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample
45
  reg_img_savepath: Data/Reg_data/om/img/
46
  reg_msk_savepath: Data/Reg_data/om/msk/
47
  reg_ddf_savepath: Data/Reg_data/om/ddf/
48
- # aug_img_savepath: Data/Aug_data/msd/img/
49
- # aug_msk_savepath: Data/Aug_data/msd/msk/
50
- # aug_ddf_savepath: Data/Aug_data/msd/ddf/
51
- aug_img_savepath: Data/Aug_data/mnms/img/
52
- aug_msk_savepath: Data/Aug_data/mnms/msk/
53
- aug_ddf_savepath: Data/Aug_data/mnms/ddf/
 
1
  data_name: all
2
+ net_name: om_net
3
+ # net_name: recmutattnnet
4
+ # net_name: recmulmodmutattnnet
5
  # net_name: defrecmutattnnet
6
  ndims: 3
7
  img_size: 128
8
+ batchsize: 3
9
  ddf_pad_mode: border
10
+ device: xpu
11
  img_pad_mode: zeros
12
  num_input_chn: 1
13
  padding_mode: border
 
19
  epoch: 10000
20
  epoch_per_save: 1
21
  lr: 0.00001
22
+ noise_scale: 0.05
23
  # =========================
24
  # AUGMENTATION SETTING
25
  patients_list: []
26
  # model_id_str: '000000'
27
  # model_id_str: '000180' # before registration training
28
+ # model_id_str: '000356'
 
29
  # model_id_str: '000157'
30
  # model_id_str: '000171'
31
+ model_id_str: '000009'
32
+ start_noise_step: 75
33
  noise_step: 1
34
+ # aug_coe: 32 # how many times each sample will be augmented
35
+ aug_coe: 1 # how many times each sample will be augmented
36
+ condition_type: 'slice' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
 
 
37
  # aug_img_savepath: Data/Aug_data/totseg/img/
38
  # aug_msk_savepath: Data/Aug_data/totseg/msk/
39
  # aug_ddf_savepath: Data/Aug_data/totseg/ddf/
 
43
  reg_img_savepath: Data/Reg_data/om/img/
44
  reg_msk_savepath: Data/Reg_data/om/msk/
45
  reg_ddf_savepath: Data/Reg_data/om/ddf/
46
+ aug_img_savepath: Data/Aug_data/msd/img/
47
+ aug_msk_savepath: Data/Aug_data/msd/msk/
48
+ aug_ddf_savepath: Data/Aug_data/msd/ddf/
 
 
 
Config/config_reg_brain.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: all
2
+ # net_name: recresacnet
3
+ # net_name: recmutattnnet
4
+ net_name: recmulmodmutattnnet
5
+ # net_name: defrecmutattnnet
6
+ ndims: 3
7
+ img_size: 128
8
+ batchsize: 3
9
+ ddf_pad_mode: border
10
+ device: xpu
11
+ img_pad_mode: zeros
12
+ num_input_chn: 1
13
+ padding_mode: border
14
+ resample_mode: bilinear
15
+ timesteps: 80
16
+ v_scale: 5.0e-05
17
+ # =========================
18
+ # TRAINING SETTING
19
+ epoch: 10000
20
+ epoch_per_save: 1
21
+ lr: 0.00001
22
+ noise_scale: 0.1
23
+ # =========================
24
+ # AUGMENTATION SETTING
25
+ patients_list: []
26
+ model_id_str: '000009'
27
+ start_noise_step: 75
28
+ noise_step: 1
29
+ aug_coe: 1
30
+ condition_type: 'none'
31
+ reg_img_savepath: Data/Reg_data/unpair_brain/img/
32
+ reg_msk_savepath: Data/Reg_data/unpair_brain/msk/
33
+ reg_ddf_savepath: Data/Reg_data/unpair_brain/ddf/
34
+ aug_img_savepath: Data/Aug_data/unpair_brain/img/
35
+ aug_msk_savepath: Data/Aug_data/unpair_brain/msk/
36
+ aug_ddf_savepath: Data/Aug_data/unpair_brain/ddf/
Config/config_reg_hip.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: all
2
+ # net_name: recresacnet
3
+ # net_name: recmutattnnet
4
+ net_name: recmulmodmutattnnet
5
+ # net_name: defrecmutattnnet
6
+ ndims: 3
7
+ img_size: 128
8
+ batchsize: 3
9
+ ddf_pad_mode: border
10
+ device: xpu
11
+ img_pad_mode: zeros
12
+ num_input_chn: 1
13
+ padding_mode: border
14
+ resample_mode: bilinear
15
+ timesteps: 80
16
+ v_scale: 5.0e-05
17
+ # =========================
18
+ # TRAINING SETTING
19
+ epoch: 10000
20
+ epoch_per_save: 1
21
+ lr: 0.00001
22
+ noise_scale: 0.1
23
+ # =========================
24
+ # AUGMENTATION SETTING
25
+ patients_list: []
26
+ # model_id_str: '000000'
27
+ # model_id_str: '000180' # before registration training
28
+ # model_id_str: '000356'
29
+ # model_id_str: '000157'
30
+ # model_id_str: '000171'
31
+ model_id_str: '000009'
32
+ start_noise_step: 75
33
+ noise_step: 1
34
+ # aug_coe: 32 # how many times each sample will be augmented
35
+ aug_coe: 1 # how many times each sample will be augmented
36
+ condition_type: 'none' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
37
+ # aug_img_savepath: Data/Aug_data/totseg/img/
38
+ # aug_msk_savepath: Data/Aug_data/totseg/msk/
39
+ # aug_ddf_savepath: Data/Aug_data/totseg/ddf/
40
+ # aug_img_savepath: Data/Aug_data/om/img/
41
+ # aug_msk_savepath: Data/Aug_data/om/msk/
42
+ # aug_ddf_savepath: Data/Aug_data/om/ddf/
43
+ reg_img_savepath: Data/Reg_data/pair_hip/img/
44
+ reg_msk_savepath: Data/Reg_data/pair_hip/msk/
45
+ reg_ddf_savepath: Data/Reg_data/pair_hip/ddf/
46
+ aug_img_savepath: Data/Aug_data/pair_hip/img/
47
+ aug_msk_savepath: Data/Aug_data/pair_hip/msk/
48
+ aug_ddf_savepath: Data/Aug_data/pair_hip/ddf/
Dataloader/dataLoader.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  from torch.utils.data import Dataset, DataLoader
3
  import json
@@ -5,8 +8,8 @@ import SimpleITK as sitk
5
  import numpy as np
6
  from skimage.transform import rescale, resize, downscale_local_mean
7
  # from torchvision.transforms import v2
8
- import sys
9
- sys.path.append('./')
10
  from Dataloader.dataloader_utils import *
11
  import random
12
 
@@ -18,22 +21,42 @@ import random
18
  # }
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  mapping_files = {
22
- 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
23
- 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
24
- 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
25
- 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
26
- 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
27
- # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
28
- 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
29
- 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
30
- 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
31
- 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
32
- 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
33
- 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
34
- 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
35
- 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
 
 
36
  }
 
 
37
 
38
  CLAMP_RANGE = [-400, 400] # default clamp range for the images
39
 
@@ -74,50 +97,9 @@ def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high')
74
  sample_value = np.random.uniform(low, high=sample_value)
75
  return sample_value
76
 
77
- class DummyOMDataset_indiv(Dataset):
78
- """Dummy dataset that generates random 3D volumes and embeddings for XPU testing."""
79
- def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
80
- self.out_sz = out_sz
81
- self.num_samples = num_samples
82
- self.embd_dim = embd_dim
83
- self.transform = transform
84
-
85
- def __len__(self):
86
- return self.num_samples
87
-
88
- def __getitem__(self, idx):
89
- volume = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
90
- embd = np.random.randn(self.embd_dim).astype(np.float32)
91
- if self.transform is not None:
92
- volume = self.transform(volume)
93
- return volume, embd
94
-
95
-
96
- class DummyOMDataset_pair(Dataset):
97
- """Dummy dataset that generates random paired 3D volumes and embeddings for XPU testing."""
98
- def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
99
- self.out_sz = out_sz
100
- self.num_samples = num_samples
101
- self.embd_dim = embd_dim
102
- self.transform = transform
103
-
104
- def __len__(self):
105
- return self.num_samples
106
-
107
- def __getitem__(self, idx):
108
- volume_A = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
109
- volume_B = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
110
- embd_A = np.random.randn(self.embd_dim).astype(np.float32)
111
- embd_B = np.random.randn(self.embd_dim).astype(np.float32)
112
- if self.transform is not None:
113
- volume_A = self.transform(volume_A)
114
- volume_B = self.transform(volume_B)
115
- return [volume_A, volume_B, embd_A, embd_B]
116
-
117
-
118
  class OminiDataset(object):
119
  """Base class for OmniMorph datasets."""
120
- def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
121
 
122
  # self.mappings = mapping_files
123
  self.ALLdata = self.combine_data(mappings = mapping_files)
@@ -155,10 +137,27 @@ class OminiDataset(object):
155
 
156
  def combine_data(self, mappings = mapping_files):
157
  ALLdata = {}
 
 
158
  for j in mappings.keys():
159
  with open(mappings[j], 'r') as f:
160
  mappings_tmp = json.load(f)
161
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return ALLdata
163
 
164
  def get_3D_volume(self, volume, select_channel = None):
@@ -301,10 +300,27 @@ class OminiDataset_v1(Dataset):
301
 
302
  def combine_data(self):
303
  ALLdata = {}
 
 
304
  for j in self.mappings.keys():
305
  with open(self.mappings[j], 'r') as f:
306
  mappings = json.load(f)
307
- ALLdata.update(mappings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  return ALLdata
309
 
310
  def __len__(self):
@@ -442,10 +458,27 @@ class OMDataset_indiv(Dataset):
442
 
443
  def combine_data(self, mappings = mapping_files):
444
  ALLdata = {}
 
 
445
  for j in mappings.keys():
446
  with open(mappings[j], 'r') as f:
447
  mappings_tmp = json.load(f)
448
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
  return ALLdata
450
 
451
  def __len__(self):
@@ -496,7 +529,7 @@ class OMDataset_indiv(Dataset):
496
  return [volume, embd]
497
 
498
  class OminiDataset_paired(Dataset):
499
- def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.9, ROIs = None, modality = None, reverse_axis_order = False):
500
  # self.mappings = mapping_files
501
  self.ALLdata = self.combine_data(mappings=mapping_files)
502
  self.out_sz = out_sz
@@ -525,10 +558,27 @@ class OminiDataset_paired(Dataset):
525
 
526
  def combine_data(self, mappings = mapping_files):
527
  ALLdata = {}
 
 
528
  for j in mappings.keys():
529
  with open(mappings[j], 'r') as f:
530
  mappings_tmp = json.load(f)
531
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  return ALLdata
533
 
534
  def normalize(self, volume, eps=1e-7):
@@ -747,10 +797,27 @@ class OMDataset_pair(Dataset):
747
 
748
  def combine_data(self, mappings = mapping_files):
749
  ALLdata = {}
 
 
750
  for j in mappings.keys():
751
  with open(mappings[j], 'r') as f:
752
  mappings_tmp = json.load(f)
753
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
754
  return ALLdata
755
 
756
  def normalize(self, volume, eps=1e-7):
@@ -911,8 +978,8 @@ class OMDataset_pair(Dataset):
911
 
912
  paired_key = random.choice(paired_keys)
913
 
914
- print(f"Key: {key}, Paired Key: {paired_key}")
915
- print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
916
 
917
 
918
  volume_B = sitk.ReadImage(paired_key)
@@ -1004,10 +1071,27 @@ class OminiDataset_paired_inf(object):
1004
 
1005
  def combine_data(self, mappings = mapping_files):
1006
  ALLdata = {}
 
 
1007
  for j in mappings.keys():
1008
  with open(mappings[j], 'r') as f:
1009
  mappings_tmp = json.load(f)
1010
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  return ALLdata
1012
 
1013
  def __len__(self):
@@ -1244,10 +1328,27 @@ class OminiDataset_inference_w_all(object):
1244
 
1245
  def combine_data(self, mappings = mapping_files):
1246
  ALLdata = {}
 
 
1247
  for j in mappings.keys():
1248
  with open(mappings[j], 'r') as f:
1249
  mappings_tmp = json.load(f)
1250
- ALLdata.update(mappings_tmp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1251
  return ALLdata
1252
 
1253
  def normalize(self, volume, eps=1e-7):
@@ -1414,6 +1515,7 @@ class OminiDataset_inference_w_all(object):
1414
  # print(f"Label with channels, pad_width_lab: {pad_width_lab}")
1415
  else:
1416
  pad_width_lab = pad_width
 
1417
  label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
1418
  # print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
1419
  label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
@@ -1442,6 +1544,7 @@ class OminiDataset_inference_w_all(object):
1442
  return return_dict
1443
 
1444
 
 
1445
  class OminiDataset_bertembd(OminiDataset):
1446
  def __init__(self,
1447
  out_sz = 128,
@@ -1453,7 +1556,7 @@ class OminiDataset_bertembd(OminiDataset):
1453
  reverse_axis_order = False,
1454
  min_dim = 3,
1455
  mapping_files = mapping_files):
1456
- super().init(out_sz = out_sz,
1457
  transform = transform,
1458
  clamp_range = clamp_range,
1459
  min_crop_ratio = min_crop_ratio,
 
1
+ import os, sys
2
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
3
+
4
  import torch
5
  from torch.utils.data import Dataset, DataLoader
6
  import json
 
8
  import numpy as np
9
  from skimage.transform import rescale, resize, downscale_local_mean
10
  # from torchvision.transforms import v2
11
+ # sys.path.append('./')
12
+ sys.path.append(ROOT_DIR)
13
  from Dataloader.dataloader_utils import *
14
  import random
15
 
 
21
  # }
22
 
23
 
24
+ # mapping_files = {
25
+ # 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
26
+ # 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
27
+ # 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
28
+ # 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
29
+ # 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
30
+ # # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
31
+ # 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
32
+ # 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
33
+ # 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
34
+ # 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
35
+ # 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
36
+ # 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
37
+ # 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
38
+ # 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
39
+ # }
40
  mapping_files = {
41
+ 'MSD': 'nifty_mappings/MSD_mappings.json',
42
+ 'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
43
+ 'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
44
+ 'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
45
+ 'MnMs': 'nifty_mappings/MnMs_mappings.json',
46
+ # 'Brats2019': 'nifty_mappings/Brats2019_mappings.json', # should be commented out after testing
47
+ 'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
48
+ 'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
49
+ 'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
50
+ 'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
51
+ 'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
52
+ 'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
53
+ 'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
54
+ 'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
55
+ 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_KL_mappings.json',
56
+ # 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_WOMAC_mappings.json', # alternative: WOMAC scores instead of KL-grade
57
  }
58
+ for k,v in mapping_files.items():
59
+ mapping_files[k] = os.path.join(ROOT_DIR, v)
60
 
61
  CLAMP_RANGE = [-400, 400] # default clamp range for the images
62
 
 
97
  sample_value = np.random.uniform(low, high=sample_value)
98
  return sample_value
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  class OminiDataset(object):
101
  """Base class for OmniMorph datasets."""
102
+ def __init__(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
103
 
104
  # self.mappings = mapping_files
105
  self.ALLdata = self.combine_data(mappings = mapping_files)
 
137
 
138
  def combine_data(self, mappings = mapping_files):
139
  ALLdata = {}
140
+ total_entries = 0
141
+ total_skipped = 0
142
  for j in mappings.keys():
143
  with open(mappings[j], 'r') as f:
144
  mappings_tmp = json.load(f)
145
+ skipped = 0
146
+ for k, v in mappings_tmp.items():
147
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
148
+ skipped += 1
149
+ continue
150
+ ALLdata[k] = v
151
+ accessible = len(mappings_tmp) - skipped
152
+ total_entries += len(mappings_tmp)
153
+ total_skipped += skipped
154
+ if skipped > 0:
155
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
156
+ if total_skipped > 0:
157
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
158
+ if len(ALLdata) < 1000:
159
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
160
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
161
  return ALLdata
162
 
163
  def get_3D_volume(self, volume, select_channel = None):
 
300
 
301
  def combine_data(self):
302
  ALLdata = {}
303
+ total_entries = 0
304
+ total_skipped = 0
305
  for j in self.mappings.keys():
306
  with open(self.mappings[j], 'r') as f:
307
  mappings = json.load(f)
308
+ skipped = 0
309
+ for k, v in mappings.items():
310
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
311
+ skipped += 1
312
+ continue
313
+ ALLdata[k] = v
314
+ accessible = len(mappings) - skipped
315
+ total_entries += len(mappings)
316
+ total_skipped += skipped
317
+ if skipped > 0:
318
+ print(f" WARNING: {j}: {accessible}/{len(mappings)} accessible ({skipped} missing/empty)")
319
+ if total_skipped > 0:
320
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
321
+ if len(ALLdata) < 1000:
322
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
323
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
324
  return ALLdata
325
 
326
  def __len__(self):
 
458
 
459
  def combine_data(self, mappings = mapping_files):
460
  ALLdata = {}
461
+ total_entries = 0
462
+ total_skipped = 0
463
  for j in mappings.keys():
464
  with open(mappings[j], 'r') as f:
465
  mappings_tmp = json.load(f)
466
+ skipped = 0
467
+ for k, v in mappings_tmp.items():
468
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
469
+ skipped += 1
470
+ continue
471
+ ALLdata[k] = v
472
+ accessible = len(mappings_tmp) - skipped
473
+ total_entries += len(mappings_tmp)
474
+ total_skipped += skipped
475
+ if skipped > 0:
476
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
477
+ if total_skipped > 0:
478
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
479
+ if len(ALLdata) < 1000:
480
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
481
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
482
  return ALLdata
483
 
484
  def __len__(self):
 
529
  return [volume, embd]
530
 
531
  class OminiDataset_paired(Dataset):
532
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False):
533
  # self.mappings = mapping_files
534
  self.ALLdata = self.combine_data(mappings=mapping_files)
535
  self.out_sz = out_sz
 
558
 
559
  def combine_data(self, mappings = mapping_files):
560
  ALLdata = {}
561
+ total_entries = 0
562
+ total_skipped = 0
563
  for j in mappings.keys():
564
  with open(mappings[j], 'r') as f:
565
  mappings_tmp = json.load(f)
566
+ skipped = 0
567
+ for k, v in mappings_tmp.items():
568
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
569
+ skipped += 1
570
+ continue
571
+ ALLdata[k] = v
572
+ accessible = len(mappings_tmp) - skipped
573
+ total_entries += len(mappings_tmp)
574
+ total_skipped += skipped
575
+ if skipped > 0:
576
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
577
+ if total_skipped > 0:
578
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
579
+ if len(ALLdata) < 1000:
580
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
581
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
582
  return ALLdata
583
 
584
  def normalize(self, volume, eps=1e-7):
 
797
 
798
  def combine_data(self, mappings = mapping_files):
799
  ALLdata = {}
800
+ total_entries = 0
801
+ total_skipped = 0
802
  for j in mappings.keys():
803
  with open(mappings[j], 'r') as f:
804
  mappings_tmp = json.load(f)
805
+ skipped = 0
806
+ for k, v in mappings_tmp.items():
807
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
808
+ skipped += 1
809
+ continue
810
+ ALLdata[k] = v
811
+ accessible = len(mappings_tmp) - skipped
812
+ total_entries += len(mappings_tmp)
813
+ total_skipped += skipped
814
+ if skipped > 0:
815
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
816
+ if total_skipped > 0:
817
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
818
+ if len(ALLdata) < 1000:
819
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
820
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
821
  return ALLdata
822
 
823
  def normalize(self, volume, eps=1e-7):
 
978
 
979
  paired_key = random.choice(paired_keys)
980
 
981
+ # print(f"Key: {key}, Paired Key: {paired_key}")
982
+ # print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
983
 
984
 
985
  volume_B = sitk.ReadImage(paired_key)
 
1071
 
1072
  def combine_data(self, mappings = mapping_files):
1073
  ALLdata = {}
1074
+ total_entries = 0
1075
+ total_skipped = 0
1076
  for j in mappings.keys():
1077
  with open(mappings[j], 'r') as f:
1078
  mappings_tmp = json.load(f)
1079
+ skipped = 0
1080
+ for k, v in mappings_tmp.items():
1081
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
1082
+ skipped += 1
1083
+ continue
1084
+ ALLdata[k] = v
1085
+ accessible = len(mappings_tmp) - skipped
1086
+ total_entries += len(mappings_tmp)
1087
+ total_skipped += skipped
1088
+ if skipped > 0:
1089
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
1090
+ if total_skipped > 0:
1091
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
1092
+ if len(ALLdata) < 1000:
1093
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
1094
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
1095
  return ALLdata
1096
 
1097
  def __len__(self):
 
1328
 
1329
  def combine_data(self, mappings = mapping_files):
1330
  ALLdata = {}
1331
+ total_entries = 0
1332
+ total_skipped = 0
1333
  for j in mappings.keys():
1334
  with open(mappings[j], 'r') as f:
1335
  mappings_tmp = json.load(f)
1336
+ skipped = 0
1337
+ for k, v in mappings_tmp.items():
1338
+ if not os.path.exists(k) or os.path.getsize(k) == 0:
1339
+ skipped += 1
1340
+ continue
1341
+ ALLdata[k] = v
1342
+ accessible = len(mappings_tmp) - skipped
1343
+ total_entries += len(mappings_tmp)
1344
+ total_skipped += skipped
1345
+ if skipped > 0:
1346
+ print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
1347
+ if total_skipped > 0:
1348
+ print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
1349
+ if len(ALLdata) < 1000:
1350
+ print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
1351
+ f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
1352
  return ALLdata
1353
 
1354
  def normalize(self, volume, eps=1e-7):
 
1515
  # print(f"Label with channels, pad_width_lab: {pad_width_lab}")
1516
  else:
1517
  pad_width_lab = pad_width
1518
+
1519
  label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
1520
  # print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
1521
  label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
 
1544
  return return_dict
1545
 
1546
 
1547
+
1548
  class OminiDataset_bertembd(OminiDataset):
1549
  def __init__(self,
1550
  out_sz = 128,
 
1556
  reverse_axis_order = False,
1557
  min_dim = 3,
1558
  mapping_files = mapping_files):
1559
+ super().__init__(out_sz = out_sz,
1560
  transform = transform,
1561
  clamp_range = clamp_range,
1562
  min_crop_ratio = min_crop_ratio,
Dataloader/dataloader_utils.py CHANGED
@@ -48,9 +48,9 @@ def get_sizeRange_dict(roi=''):
48
  'abdomen': [240, 1024],
49
  'pelvis': [220, 1024],
50
  'thorax': [220, 1024],
51
- 'arm': [140, 1024],
52
- 'hand': [140, 1024],
53
- 'leg': [160, 1024],
54
  'skeleton': [130, 1024],
55
  }
56
  if roi in sizeRange_dict:
 
48
  'abdomen': [240, 1024],
49
  'pelvis': [220, 1024],
50
  'thorax': [220, 1024],
51
+ 'arm': [100, 1024],
52
+ 'hand': [100, 1024],
53
+ 'leg': [100, 1024],
54
  'skeleton': [130, 1024],
55
  }
56
  if roi in sizeRange_dict:
Dataloader/deal_with_json.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
3
+ sys.path.append(ROOT_DIR)
4
+ import json
5
+
6
+ # CORRECT_DATA_PATH = os.path.join(ROOT_DIR, '../..')
7
+ # CORRECT_DATA_PATH = os.path.join('/hy-tmp')
8
+ CORRECT_DATA_PATH = '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D'
9
+
10
+
11
+ def traverse_and_print(data, path=()):
12
+ for key, value in data.items():
13
+ current_path = path + (key,)
14
+
15
+ if isinstance(key, str) and 'DATASETS' in key:
16
+ print(f"KEY (str): {key}")
17
+
18
+ if isinstance(value, str) and 'DATASETS' in value:
19
+ print(f" VALUE (str): {value}")
20
+ elif isinstance(value, dict):
21
+ traverse_and_print(value, current_path)
22
+
23
+ def traverse_and_check(data, path=()):
24
+ failed_files = []
25
+ for key, value in data.items():
26
+ current_path = path + (key,)
27
+
28
+ if isinstance(key, str) and 'DATASETS_processed' in key:
29
+ if os.path.isfile(key):
30
+ print(f'\rCheck pass: {key}', end='')
31
+ else:
32
+ print(f'\rCheck fail ! : {key}')
33
+ failed_files.append(key)
34
+
35
+ if isinstance(value, str) and 'DATASETS_processed' in value:
36
+ if os.path.isfile(value):
37
+ print(f'\rCheck pass: {value}', end='')
38
+ else:
39
+ print(f'\rCheck fail ! : {value}')
40
+ failed_files.append(value)
41
+ elif isinstance(value, dict):
42
+ traverse_and_check(value, current_path)
43
+
44
+ if failed_files != []:
45
+ print(f'\nCheck finished. Failed files: {failed_files}')
46
+ return False
47
+ else:
48
+ print('\nAll files check passed!')
49
+ return True
50
+
51
+ def traverse_and_revise(data, path=()):
52
+ what_need_change = [
53
+ '/home/jachin/data/Github/data/data_gen_def',
54
+ '/home/data/Github/data/data_gen_def',
55
+ ]
56
+ for key, value in list(data.items()):
57
+ current_path = path + (key,)
58
+
59
+ new_key = key
60
+ if isinstance(key, str) and 'data_gen_def' in key:
61
+ for wnc in what_need_change:
62
+ if wnc in key:
63
+ new_key = key.replace(wnc, CORRECT_DATA_PATH)
64
+
65
+ # change keys
66
+ data[new_key] = data.pop(key)
67
+ value = data[new_key]
68
+ current_path = path + (new_key,)
69
+
70
+ if isinstance(value, str) and 'data_gen_def' in value:
71
+ for wnc in what_need_change:
72
+ if wnc in value:
73
+ data[new_key] = value.replace(wnc, CORRECT_DATA_PATH)
74
+
75
+ elif isinstance(value, dict):
76
+ traverse_and_revise(value, current_path)
77
+
78
+ return data
79
+
80
+ def traverse_and_rename_label(data, old_label, new_label, task_keys=("segmentation", "registration")):
81
+ """Rename a label key inside Label_path -> segmentation/registration for every entry.
82
+
83
+ Example: rename "brain" -> "brain_tumour" to fix the BraTS mislabel.
84
+ """
85
+ count = 0
86
+ for key, value in data.items():
87
+ if not isinstance(value, dict):
88
+ continue
89
+ label_path = value.get("Label_path")
90
+ if isinstance(label_path, dict):
91
+ for tk in task_keys:
92
+ task_dict = label_path.get(tk)
93
+ if isinstance(task_dict, dict) and old_label in task_dict:
94
+ task_dict[new_label] = task_dict.pop(old_label)
95
+ count += 1
96
+ else:
97
+ # recurse into nested dicts
98
+ count += traverse_and_rename_label(value, old_label, new_label, task_keys)
99
+ return count
100
+
101
+
102
+ mapping_files = {
103
+ 'MSD': 'nifty_mappings/MSD_mappings.json',
104
+ 'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
105
+ 'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
106
+ 'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
107
+ 'MnMs': 'nifty_mappings/MnMs_mappings.json',
108
+ 'Brats2019': 'nifty_mappings/Brats2019_mappings.json',
109
+ 'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
110
+ 'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
111
+ 'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
112
+ 'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
113
+ 'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
114
+ 'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
115
+ 'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
116
+ 'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
117
+ }
118
+ for k,v in mapping_files.items():
119
+ mapping_files[k] = os.path.join(ROOT_DIR, v)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ # --- Fix BraTS / MSD mislabel: "brain" -> "brain_tumour" ---
124
+ rename_datasets = ['Brats2019', 'Brats2020', 'Brats2021', 'MSD']
125
+ for ds_name in rename_datasets:
126
+ if ds_name not in mapping_files:
127
+ continue
128
+ v = mapping_files[ds_name]
129
+ with open(v, 'r') as f:
130
+ mappings_tmp = json.load(f)
131
+ n = traverse_and_rename_label(mappings_tmp, 'brain', 'brain_tumour')
132
+ if n > 0:
133
+ with open(v, 'w') as f:
134
+ json.dump(mappings_tmp, f, indent=4)
135
+ print(f'[{ds_name}] Renamed "brain" -> "brain_tumour" in {n} entries, saved to {v}')
136
+ else:
137
+ print(f'[{ds_name}] No "brain" labels found (already renamed?)')
138
+
139
+ # --- Path revision (uncomment to run) ---
140
+ # for k,v in mapping_files.items():
141
+ # with open(v, 'r') as f:
142
+ # mappings_tmp = json.load(f)
143
+ # new_mappings_tmp = traverse_and_revise(mappings_tmp)
144
+ # # traverse_and_print(new_mappings_tmp)
145
+ # # all_good = traverse_and_check(new_mappings_tmp)
146
+ # # save in-place
147
+ # with open(v, 'w') as f:
148
+ # json.dump(new_mappings_tmp, f, indent=4)
149
+ # print(f'Saved revised mapping to {v}')
150
+
Dataloader/embding_gen.py CHANGED
@@ -23,7 +23,9 @@ mapping_files = {
23
  # 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
24
  # 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
25
  # 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
26
- 'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
 
 
27
  # 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
28
  # 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
29
  # 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
@@ -45,6 +47,8 @@ save_paths = {
45
  'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
46
  'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
47
  'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
 
 
48
  }
49
  query = {
50
  'MSD': ['description'],
@@ -61,6 +65,8 @@ query = {
61
  'PSMA-CT':[],
62
  'AbdomenAtlas':[],
63
  'AbdomenCT1k':[],
 
 
64
  }
65
  add_text = {
66
  'MSD': {},
@@ -77,11 +83,13 @@ add_text = {
77
  'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
78
  'AbdomenAtlas':{},
79
  'AbdomenCT1k':{},
 
 
80
  }
81
 
82
 
83
  # bert intialization
84
- model_name = '/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased'
85
  reduce_method = 'mean'
86
  max_words_num = 32 # max number of words in the caption > 2
87
  # max_words_num = 64 # max number of words in the caption > 2
 
23
  # 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
24
  # 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
25
  # 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
26
+ # 'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
27
+ 'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
28
+ 'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
29
  # 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
30
  # 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
31
  # 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
 
47
  'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
48
  'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
49
  'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
50
+ 'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json',
51
+ 'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json',
52
  }
53
  query = {
54
  'MSD': ['description'],
 
65
  'PSMA-CT':[],
66
  'AbdomenAtlas':[],
67
  'AbdomenCT1k':[],
68
+ 'OAI_ZIB_KL': ['Age', 'Gender', 'KL_Grade', 'BMI'],
69
+ 'OAI_ZIB_WOMAC': ['Age', 'Gender', 'WOMAC_Pain', 'WOMAC_ADL', 'WOMAC_Stiffness', 'BMI'],
70
  }
71
  add_text = {
72
  'MSD': {},
 
83
  'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
84
  'AbdomenAtlas':{},
85
  'AbdomenCT1k':{},
86
+ 'OAI_ZIB_KL': {'description': 'right knee osteoarthritis'},
87
+ 'OAI_ZIB_WOMAC': {'description': 'right knee osteoarthritis'},
88
  }
89
 
90
 
91
  # bert intialization
92
+ model_name = '/rds/project/rds-TWhPgQVLKbA/Code/OmniMorph/External/Models/bert_large_uncased'
93
  reduce_method = 'mean'
94
  max_words_num = 32 # max number of words in the caption > 2
95
  # max_words_num = 64 # max number of words in the caption > 2
Dataloader/nifty_mappings/AbdomenAtlas_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:303c3fb7388e7b3b01cb6f494c3ac3f542da98487039e5b2415786ac4af58ba0
3
- size 179457573
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6000e9ba6b4fac278a1288826696ab7d5f77c97929d7e001dfb8938d7d5aa0a8
3
+ size 182087319
Dataloader/nifty_mappings/AbdomenCT1k_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0abaaa1013fdafe3fae6d5544746a66d8b20892ceb3cf9141a125113984e8350
3
- size 37315918
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36ccd80e859aefd7334fb99ebca10601bb39be9e6432a1f59b4e98e9c4069a8
3
+ size 30687976
Dataloader/nifty_mappings/Brats2019_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1c5b80fc861484d36d8d6e0f97c404e2c321ee965cc1556a868205f5937d24fe
3
- size 12126490
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f128806b4673b7e1219990f0e2c5732abd1080fd4de271195fa74538c32ab70
3
+ size 12178080
Dataloader/nifty_mappings/Brats2020_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:de345c6a66a4f33552aacbb961cd034ac488500ff5d48810579055f0543162dc
3
- size 17743015
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90659bf584857b9e543163431e3730c6e6ce229b3386dc8ab13e7411a6b00c78
3
+ size 17815563
Dataloader/nifty_mappings/Brats2021_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4990a7031d6ac91e1c33e6db046dddf234f67dd8edecd07691675945b9d00af5
3
- size 44722001
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c758b9cfb8190f3b77eef03ea93a43f95e2d9e89dae4b08f6ae4dabc65024b97
3
+ size 44888384
Dataloader/nifty_mappings/CIA_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:98cbd21d3d5b7f5fb84091705fbbfcd0f8f26cb26ff4b34ffcf546cf1cedb48a
3
- size 32744567
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1aef79728ee6d2ab15ab7225a52d5e437cd10d33cfdcbb6f4d9c2aee1687d5f3
3
+ size 32803157
Dataloader/nifty_mappings/Kaggle_osic_mappings.json CHANGED
The diff for this file is too large to render. See raw diff
 
Dataloader/nifty_mappings/MSD_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a1ab13c61cd6829f088ee92bff4ce12a0f0e19fc9367682291fbd9717b149e83
3
- size 92620864
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b777fb0d1ab09b22dcb3048b25cf60a31ccc30749888f1f02d7dc4b43715ad6
3
+ size 92732794
Dataloader/nifty_mappings/MnMs_mappings.json CHANGED
The diff for this file is too large to render. See raw diff
 
Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5ab4159932276f0ccd52efe44986ed184b504162f568cec68fc76fa0769efad
3
+ size 18096063
Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dad37ced9f1dbe3819dd6ac0d51b6585c25e641b4d07352d706aaf3ac17c19a
3
+ size 18119154
Dataloader/nifty_mappings/OASIS_1_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:8784bff1bb5c9ba08fccc8ca9776f3f26c9b2993c1c446ef17d5ba1dd2bda490
3
- size 15609846
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a39ccde5fe81bd7b2b5fa1cc64feb7094ff83851bfd40a5287e01d817e45db59
3
+ size 15646470
Dataloader/nifty_mappings/OASIS_2_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4f88910a0846e056b0d4caacd6e6ebfebde52b537828756e217d9a6c6343177c
3
- size 13396017
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7665f7769ef262f1758af1cf42e1610f211c53d35a625a457c5a50bca3841757
3
+ size 13440390
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c3c8729df59b6e9771fa791c5fe1cd7636e83a3c17109613984cdce0d92eefdc
3
- size 11700732
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd252fec7062df77452b0bdeab47013314aba638cf0b0de295bc62748d2cfec
3
+ size 11728536
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:922363b739e1f14243731ea283ee730bc55724a27360d2f28f32b01b23ede5d9
3
- size 48425273
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cab3cbb5a5a651e1c3446079a3c18b944ed1893893ccd25451c110f13eebe4cc
3
+ size 48538337
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c36ba45053fea97244c259af0151ddb02e8281fce8c8f439cc88733bd71d668f
3
- size 67962146
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a922ecc5c136bcc3427f81e970d1cdd02e3b6c61bedc198e99b6fec8c380b4c3
3
+ size 69966911
Diffusion/diffuser-reg.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+ from torch.nn.utils.stateless import functional_call
5
+
6
+ import Diffusion.utils_diff as utils
7
+ from Diffusion.networks import *
8
+ # from networks import *
9
+
10
+ import random
11
+
12
+ EPS = 1e-8
13
+
14
+
15
+
16
+ class DeformDDPM(nn.Module):
17
+ def __init__(
18
+ self,
19
+ network,
20
+ n_steps=50,
21
+ beta_schedule_fn = None,
22
+ device='cpu',
23
+ image_chw=(1, 28, 28),
24
+ batch_size = 1,
25
+ img_pad_mode = "zeros",
26
+ ddf_pad_mode="border",
27
+ padding_mode="border",
28
+ v_scale = 0.008/256,
29
+ resample_mode=None,
30
+ inf_mode = False,
31
+ ):
32
+ super(DeformDDPM, self).__init__()
33
+ self.rec_num=2
34
+ self.ndims=len(image_chw)-1
35
+ self.n_steps = n_steps
36
+ self.v_scale = v_scale
37
+ self.device = device
38
+ self.msk_noise_scale = torch.tensor(0)
39
+ # self.msk_noise_scale = torch.tensor(1)
40
+
41
+ # print('================')
42
+ # print("device:",device)
43
+ # if device == 'cpu':
44
+ # print("num_device: 1")
45
+ # else:
46
+ # print("num_device:", torch.cuda.device_count())
47
+ # print('================')
48
+
49
+ self.num_device = torch.cuda.device_count()
50
+
51
+ self.batch_size = batch_size #//self.num_device
52
+ self.img_pad_mode = img_pad_mode
53
+ self.ddf_pad_mode = ddf_pad_mode
54
+ self.padding_mode = padding_mode
55
+ self.resample_mode = resample_mode
56
+ self.image_chw = image_chw
57
+ self.network = network#.to(self.device)
58
+ self.ddf_stn_full = STN(
59
+ img_sz = self.image_chw[1],
60
+ ndims = self.ndims,
61
+ padding_mode = self.padding_mode,
62
+ device = self.device,
63
+ )
64
+ self._DDF_Encoder_init()
65
+ self.copy_opt = nn.Identity()
66
+ self.inf_mode = inf_mode
67
+ return
68
+
69
+ def get_stn(self):
70
+ return self.img_stn, self.ddf_stn_full
71
+
72
+ def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
73
+ if ctl_sz is None:
74
+ ctl_sz = self.image_chw[1] // ctl_ratio
75
+ self.ctl_sz=ctl_sz
76
+ self.img_sz=self.image_chw[1]
77
+ self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
78
+ self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
79
+ self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
80
+
81
+ def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
82
+ rec_num = 1
83
+ mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
84
+ mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
85
+ # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
86
+ # mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
87
+ # mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
88
+ mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
89
+ mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
90
+ # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
91
+ return rec_num,mul_num_ddf,mul_num_dvf
92
+
93
+ # def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3):
94
+ # # high: tensor of shape (...), low: int or tensor broadcastable to high
95
+ # sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone()
96
+ # for _ in range(order_num):
97
+ # # For each element, sample in [sample_num, high]
98
+ # # torch.randint requires scalar low/high, so we use elementwise sampling
99
+ # rand_shape = high.shape
100
+ # # Clamp sample_num to be <= high
101
+ # sample_num = torch.minimum(sample_num, high)
102
+ # # Generate random numbers for each element
103
+ # rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device)
104
+ # for idx in np.ndindex(rand_shape):
105
+ # l = sample_num[idx].item()
106
+ # h = high[idx].item()
107
+ # if l >= h:
108
+ # rand[idx] = l
109
+ # else:
110
+ # rand[idx] = torch.randint(l, h + 1, (1,), device=high.device)
111
+ # sample_num = rand.to(high.dtype)
112
+ # return sample_num
113
+
114
+ def _get_random_ddf(self,img,t):
115
+ rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
116
+ ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
117
+ warped_img = self.img_stn(img,ddf_forward)
118
+ return warped_img, dvf_forward,ddf_forward
119
+
120
+ def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
121
+ dvf=0
122
+ if self.img_sz is None:
123
+ self.img_sz=max(ctl_szs)
124
+ if 1 in ctl_szs:
125
+ dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
126
+ dvf = dvf + dvf_rot
127
+ for ctl_sz in ctl_szs:
128
+ _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
129
+ # temp>>
130
+ if ctl_sz <= 2:
131
+ _v_scale = _v_scale/2
132
+ # temp<<
133
+ dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
134
+ dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
135
+ dvf=dvf+dvf_comp
136
+ return dvf
137
+
138
+ def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
139
+ sample_value = low
140
+ for _ in range(order_num):
141
+ sample_value = np.random.uniform(low=sample_value, high=high)
142
+ return sample_value
143
+
144
+ def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
145
+ crop_rate=2
146
+ for _ in range(self.ndims+1):
147
+ mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
148
+ # v_scale = v_scale *crop_rate
149
+ ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
150
+ if ddf0 is not None:
151
+ ddf=ddf0
152
+ else:
153
+ ddf = torch.zeros(ctl_ddf_sz) * 0
154
+ dddf = torch.zeros(ctl_ddf_sz) * 0
155
+ scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine
156
+ # scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine
157
+ # scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine
158
+ ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
159
+
160
+ for i in range(rec_num):
161
+ # Randomly select 5 elements from ctl_szs (if there are at least 5)
162
+ if len(ctl_szs_all) > select_num:
163
+ ctl_szs = random.sample(ctl_szs_all, select_num)
164
+ dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
165
+ # if True:
166
+ if noise_ratio==0:
167
+ dvf0=dvf
168
+ else:
169
+ dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
170
+ # print([num.shape for num in mul_num])
171
+ for j in range(torch.max(mul_num[0]).item()):
172
+ flag = [(n>j).int().to(self.device) for n in mul_num]
173
+ ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
174
+ dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
175
+
176
+ ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
177
+ # ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
178
+ if self.ndims==2:
179
+ ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
180
+ else:
181
+ ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
182
+ # if rec_num==1:
183
+ if True:
184
+ dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
185
+ # dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
186
+ if self.ndims == 2:
187
+ dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
188
+ else:
189
+ dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
190
+ return ddf,dddf
191
+ else:
192
+ return ddf
193
+
194
+ def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
195
+ if noise_type == 'gaussian':
196
+ noise_map = torch.randn_like(img) * noise_scale
197
+ elif noise_type == 'uniform':
198
+ noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1
199
+ elif noise_type == 'binary':
200
+ noise_map = torch.bernoulli(torch.rand_like(img))
201
+ else:
202
+ noise_map = torch.zeros_like(img)
203
+ noise_map = noise_map.to(img.device)
204
+ return noise_map
205
+
206
+ def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
207
+ noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
208
+ return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
209
+
210
+ def apply_noise(self, img, noise_map=None, apply_mask=None):
211
+ return img * apply_mask + noise_map * (1-apply_mask)
212
+
213
+ def downsample(self, img, down_ratio_range=[1./32,1]):
214
+ down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
215
+ # print(down_ratio)
216
+ down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
217
+ # print(down_img)
218
+ # return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio)
219
+ return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy
220
+
221
+ def get_slice_mask(self, img, slice_num_range=[0,32]):
222
+ slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
223
+ mask = torch.zeros_like(img)
224
+ sample_ratio = 0
225
+ for i in range(self.ndims):
226
+ if self.inf_mode:
227
+ slice_num = 1 # use max slice num for inference for better performance
228
+ slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
229
+ else:
230
+ slice_num = random.randint(slice_num_range[0], slice_num_range[1])
231
+ slice_idx = random.sample(range(self.image_chw[1]), slice_num)
232
+ transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
233
+ for idx in slice_idx:
234
+ mask[..., idx] = 1
235
+ mask = mask.permute(*transpose_list)
236
+ # sample_ratio += slice_num / self.image_chw[1] / self.ndims
237
+ sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy
238
+
239
+ # print(mask)
240
+ # print("sample_ratio:", sample_ratio)
241
+ return mask, sample_ratio
242
+
243
+ def project(self, img):
244
+ proj_img = torch.zeros_like(img)
245
+ rand_bourn = np.random.randint(0, 2, size=[self.ndims])
246
+ proj_dim_num = np.sum(rand_bourn)
247
+ for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
248
+ if pflag:
249
+ proj_img += torch.mean(img, dim=i, keepdim=True)
250
+ # print("projecting dim:", i)
251
+ return proj_img/(proj_dim_num+EPS), proj_dim_num
252
+
253
+ def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
254
+ # Remove torch.no_grad() since most operations are not differentiable anyway
255
+ proc_img = img.clone().detach()
256
+ if proc_type is None:
257
+ # Heavily bias towards 'uncon' for efficiency
258
+ proc_type = random.choices(
259
+ # ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
260
+ # weights=[1, 1, 1, 1, 1, 1, 3], k=1
261
+ ['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
262
+ weights=[1, 1, 1, 1, 1, 3], k=1
263
+ )[0]
264
+ mask = torch.tensor(1, device=img.device)
265
+ cond_ratio = torch.tensor(1., device=img.device)
266
+ self.msk_noise_scale = torch.tensor(0, device=img.device)
267
+ noise_type = random.choice(['gaussian', 'uniform', 'none'])
268
+ # Precompute noise_map only if needed
269
+ noise_map = None
270
+ if proc_type not in ['none', None, '']:
271
+ if proc_type == 'uncon':
272
+ noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
273
+ proc_img = noise_map
274
+ mask = torch.tensor(0, device=img.device)
275
+ cond_ratio = torch.tensor(0, device=img.device)
276
+ return proc_img, mask, cond_ratio
277
+ if proc_type in ['adding', 'independ', 'slice','slice1']:
278
+ # self.msk_noise_scale = 0
279
+ noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
280
+ if proc_type == 'adding':
281
+ proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
282
+ cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
283
+ elif proc_type == 'independ':
284
+ mask = self.create_noise_map(img, noise_type='binary')
285
+ if self.msk_noise_scale == 0:
286
+ proc_img = img * mask
287
+ else:
288
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
289
+ with torch.no_grad():
290
+ cond_ratio = mask.float().mean()
291
+ elif proc_type == 'downsample':
292
+ # proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
293
+ proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
294
+ cond_ratio = torch.tensor(down_ratio, device=img.device)
295
+ elif proc_type == 'slice' or proc_type == 'slice1':
296
+ if proc_type == 'slice1':
297
+ slice_num_max = 1
298
+ else:
299
+ slice_num_max = random.randint(1, 64)
300
+ slice_num_max = random.randint(1, slice_num_max)
301
+ mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
302
+ if self.msk_noise_scale == 0:
303
+ proc_img = img * mask
304
+ else:
305
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
306
+ cond_ratio = torch.tensor(sample_ratio, device=img.device)
307
+ elif proc_type == 'project':
308
+ proc_img, proj_num = self.project(proc_img)
309
+ cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
310
+ # cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy
311
+ return proc_img, mask, cond_ratio
312
+
313
+ def diffuse(self, x_0, t):
314
+ t=torch.tensor(t)
315
+ # img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t)
316
+ # return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn
317
+ return self._get_random_ddf(img = x_0, t = t)
318
+
319
+
320
+ def recover(self, x, y, t,rec_num=2, text=None):
321
+ if isinstance(t, list):
322
+ t=[torch.tensor(t0) for t0 in t]
323
+ t=[t0.to(x.device) for t0 in t]
324
+ else:
325
+ t=torch.tensor(t)
326
+ t.to(x.device)
327
+ if rec_num is None:
328
+ rec_num = self.rec_num
329
+ return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
330
+
331
+ def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
332
+ """
333
+ use detach to recover:
334
+ - but not include no_grad
335
+ """
336
+ if isinstance(t, list):
337
+ t = [torch.tensor(t0, device=x.device) for t0 in t]
338
+ else:
339
+ t = torch.tensor(t, device=x.device)
340
+
341
+ if rec_num is None:
342
+ rec_num = self.rec_num
343
+
344
+ # params = {k: v.detach() for k, v in self.network.named_parameters()}
345
+ # buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer
346
+ # # functional_call require position args,here kwargs doesnot work, so:
347
+ # def _forward(module, kw):
348
+ # return module(**kw)
349
+ # # functional_call(module, ...) can only pass args/kwargs to module.forward
350
+ # # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs)
351
+ # return functional_call(
352
+ # self.network,
353
+ # (params, buffers),
354
+ # args=(),
355
+ # kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
356
+ # )
357
+
358
+ # 1) param detached
359
+ params = {k: v.detach() for k, v in self.network.named_parameters()}
360
+ # 2) buffers keeps unchanged
361
+ buffers = dict(self.network.named_buffers())
362
+
363
+ # 3) old version of PyTorch doesnot support passing params and buffers together
364
+ params_and_buffers = {}
365
+ params_and_buffers.update(params)
366
+ params_and_buffers.update(buffers)
367
+ return functional_call(
368
+ self.network,
369
+ params_and_buffers,
370
+ (),
371
+ kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
372
+ )
373
+
374
+
375
+ def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
376
+ if mask is None:
377
+ mask = 1
378
+ # org_imgs=self.copy_opt(x0)
379
+ if cond_imgs is None:
380
+ cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
381
+ noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
382
+ if isinstance(self.network,DefRec_MutAttnNet):
383
+ t = [t] * 1
384
+ return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
385
+
386
+ def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
387
+ if T is not None:
388
+ return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
389
+ else:
390
+ return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
391
+ # if mask is None:
392
+ # mask = 1
393
+ # cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs)
394
+ # noisy_imgs, dvf_I, _ = self.diffuse(x0, t)
395
+ # if isinstance(self.network, DefRec_MutAttnNet):
396
+ # t = [t] * 1
397
+ # return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I
398
+
399
+ def diff_recover(self,
400
+ img_org,
401
+ msk_org=None,
402
+ T=[None,None],
403
+ ddf_rand=None,
404
+ v_scale = None,
405
+ t_save=None,
406
+ cond_imgs=None,
407
+ proc_type=None,
408
+ text=None,
409
+ ):
410
+ if cond_imgs is None:
411
+ cond_imgs = img_org.clone().detach()
412
+ # if proc_type is not None:
413
+ cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
414
+ if ddf_rand is None:
415
+ if v_scale is not None:
416
+ self.v_scale=v_scale
417
+ self._DDF_Encoder_init()
418
+ if T[0] is None or T[0] == 0:
419
+ img_diff = img_org.clone().detach()
420
+ ddf_rand = torch.zeros_like(img_diff)
421
+ else:
422
+ img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
423
+ else:
424
+ img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
425
+ ddf_comp = ddf_rand.clone().detach()
426
+ img_rec = img_diff.clone().detach()
427
+ if msk_org is not None:
428
+ msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
429
+ else:
430
+ msk_diff = None
431
+ msk_rec = msk_diff.clone().detach() if msk_org is not None else None
432
+ img_save=[]
433
+ msk_save=[]
434
+
435
+ if isinstance(self.network,DefRec_MutAttnNet):
436
+ # Denosing image via list of t
437
+ t_list = list(range(T[1]-1, -1, -1))
438
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
439
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
440
+ img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
441
+ if msk_org is not None:
442
+ msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
443
+ else:
444
+ # Denosing image
445
+ if isinstance(T[-1], int):
446
+ time_steps = range(T[-1] - 1, -1, -1)
447
+ trainable_iterations =[]
448
+ else:
449
+ time_steps = T[-1]
450
+
451
+ # # Randomly select k iterations to make their parameters trainable
452
+ # win_len = 2 # Number of iterations to make trainable
453
+ # if len(time_steps) <= win_len:
454
+ # win_start = 0
455
+ # else:
456
+ # win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
457
+ # win_end = win_start + win_len - 1
458
+
459
+ k=2
460
+ # trainable_iterations = time_steps[win_start: win_start + win_len]
461
+ # trainable_iterations = random.sample(time_steps, k)
462
+ trainable_iterations = time_steps[-1:-k-1:-1]
463
+ # print(time_steps)
464
+ # print("trainable_iterations:", trainable_iterations)
465
+ for i in time_steps:
466
+ t = torch.tensor(np.array([i])).to(self.device)
467
+
468
+ if i in trainable_iterations:
469
+ # Make parameters trainable for this iteration
470
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
471
+ else:
472
+ # Freeze parameters for this iteration using torch.no_grad()
473
+ with torch.no_grad():
474
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
475
+ # for idx, i in enumerate(time_steps):
476
+ # t = torch.tensor(np.array([i])).to(self.device)
477
+ # if idx < win_start:
478
+ # # just no_grad
479
+ # with torch.no_grad():
480
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
481
+ # elif win_start <= idx <= win_end:
482
+ # # normal update
483
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
484
+ # else:
485
+ # # freeze params but keep grad for input
486
+ # pre_dvf_I = self.recover_frozen_params_but_grad_input(
487
+ # x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text
488
+ # )
489
+
490
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
491
+ # Apply to image
492
+ img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
493
+ if msk_org is not None:
494
+ msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
495
+ if t_save is not None:
496
+ if i in t_save:
497
+ img_save.append(img_rec)
498
+ if msk_org is not None:
499
+ msk_save.append(msk_rec)
500
+
501
+ # for i in time_steps:
502
+ # t = torch.tensor(np.array([i])).to(self.device)
503
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None)
504
+ # ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
505
+ # # apply to image
506
+ # img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
507
+ # if msk_org is not None:
508
+ # msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp)
509
+ # if t_save is not None:
510
+ # if i in t_save:
511
+ # img_save.append(img_rec)
512
+ # if msk_org is not None:
513
+ # msk_save.append(msk_rec)
514
+ # print(torch.max(torch.abs(ddf_comp)))
515
+ # print(torch.max(torch.abs(ddf_rand)))
516
+
517
+ return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
518
+
519
+ if __name__ == "__main__":
520
+ H, W = 8, 8
521
+ deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
522
+ # img = torch.zeros([1, 1, H, W])
523
+ img = torch.randn([1, 1, H, W])
524
+ t = 1
525
+ rec_num = 2
526
+ # proc_type = 'adding'
527
+ # proc_type = 'independ'
528
+ # proc_type = 'downsample'
529
+ proc_type = 'slice'
530
+ # proc_type = 'project'
531
+ # proc_type = 'none'
532
+ print(img)
533
+ cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
534
+ print(cond_imgs)
535
+ # img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type)
536
+ # print(img_rec.shape, dvf_I.shape)
537
+
538
+ # proc_type = 'adding'
539
+ # ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)
540
+
541
+
Diffusion/diffuser.py CHANGED
@@ -27,6 +27,7 @@ class DeformDDPM(nn.Module):
27
  padding_mode="border",
28
  v_scale = 0.008/256,
29
  resample_mode=None,
 
30
  ):
31
  super(DeformDDPM, self).__init__()
32
  self.rec_num=2
@@ -35,6 +36,7 @@ class DeformDDPM(nn.Module):
35
  self.v_scale = v_scale
36
  self.device = device
37
  self.msk_noise_scale = torch.tensor(0)
 
38
 
39
  # print('================')
40
  # print("device:",device)
@@ -61,6 +63,7 @@ class DeformDDPM(nn.Module):
61
  )
62
  self._DDF_Encoder_init()
63
  self.copy_opt = nn.Identity()
 
64
  return
65
 
66
  def get_stn(self):
@@ -78,7 +81,8 @@ class DeformDDPM(nn.Module):
78
  def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
79
  rec_num = 1
80
  mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
81
- mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
 
82
  # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
83
  # mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
84
  # mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
@@ -110,7 +114,7 @@ class DeformDDPM(nn.Module):
110
 
111
  def _get_random_ddf(self,img,t):
112
  rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
113
- ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
114
  warped_img = self.img_stn(img,ddf_forward)
115
  return warped_img, dvf_forward,ddf_forward
116
 
@@ -122,8 +126,10 @@ class DeformDDPM(nn.Module):
122
  dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
123
  dvf = dvf + dvf_rot
124
  for ctl_sz in ctl_szs:
125
- _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
126
  # temp>>
 
 
127
  if ctl_sz <= 2:
128
  _v_scale = _v_scale/2
129
  # temp<<
@@ -138,7 +144,7 @@ class DeformDDPM(nn.Module):
138
  sample_value = np.random.uniform(low=sample_value, high=high)
139
  return sample_value
140
 
141
- def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
142
  crop_rate=2
143
  for _ in range(self.ndims+1):
144
  mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
@@ -188,11 +194,11 @@ class DeformDDPM(nn.Module):
188
  else:
189
  return ddf
190
 
191
- def create_noise_map(self, img, noise_type='gaussian', noise_ratio=0.2):
192
  if noise_type == 'gaussian':
193
- noise_map = torch.randn_like(img) * noise_ratio
194
  elif noise_type == 'uniform':
195
- noise_map = torch.rand_like(img) # 0-1
196
  elif noise_type == 'binary':
197
  noise_map = torch.bernoulli(torch.rand_like(img))
198
  else:
@@ -220,8 +226,18 @@ class DeformDDPM(nn.Module):
220
  mask = torch.zeros_like(img)
221
  sample_ratio = 0
222
  for i in range(self.ndims):
223
- slice_num = random.randint(slice_num_range[0], slice_num_range[1])
224
- slice_idx = random.sample(range(self.image_chw[1]), slice_num)
 
 
 
 
 
 
 
 
 
 
225
  transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
226
  for idx in slice_idx:
227
  mask[..., idx] = 1
@@ -243,7 +259,7 @@ class DeformDDPM(nn.Module):
243
  # print("projecting dim:", i)
244
  return proj_img/(proj_dim_num+EPS), proj_dim_num
245
 
246
- def proc_cond_img(self, img, proc_type=None):
247
  # Remove torch.no_grad() since most operations are not differentiable anyway
248
  proc_img = img.clone().detach()
249
  if proc_type is None:
@@ -251,7 +267,7 @@ class DeformDDPM(nn.Module):
251
  proc_type = random.choices(
252
  # ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
253
  # weights=[1, 1, 1, 1, 1, 1, 3], k=1
254
- ['adding', 'independ', 'downsample', 'slice', 'none', 'uncon'],
255
  weights=[1, 1, 1, 1, 1, 3], k=1
256
  )[0]
257
  mask = torch.tensor(1, device=img.device)
@@ -262,14 +278,14 @@ class DeformDDPM(nn.Module):
262
  noise_map = None
263
  if proc_type not in ['none', None, '']:
264
  if proc_type == 'uncon':
265
- noise_map = self.create_noise_map(img, noise_type=noise_type)
266
  proc_img = noise_map
267
  mask = torch.tensor(0, device=img.device)
268
  cond_ratio = torch.tensor(0, device=img.device)
269
  return proc_img, mask, cond_ratio
270
- if proc_type in ['adding', 'independ', 'slice']:
271
  # self.msk_noise_scale = 0
272
- noise_map = self.create_noise_map(img, noise_type=noise_type)
273
  if proc_type == 'adding':
274
  proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
275
  cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
@@ -285,9 +301,12 @@ class DeformDDPM(nn.Module):
285
  # proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
286
  proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
287
  cond_ratio = torch.tensor(down_ratio, device=img.device)
288
- elif proc_type == 'slice':
289
- slice_num_max = random.randint(1, 64)
290
- slice_num_max = random.randint(1, slice_num_max)
 
 
 
291
  mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
292
  if self.msk_noise_scale == 0:
293
  proc_img = img * mask
@@ -373,8 +392,14 @@ class DeformDDPM(nn.Module):
373
  t = [t] * 1
374
  return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
375
 
376
- def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
377
- if T is not None:
 
 
 
 
 
 
378
  return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
379
  else:
380
  return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
@@ -446,7 +471,7 @@ class DeformDDPM(nn.Module):
446
  # win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
447
  # win_end = win_start + win_len - 1
448
 
449
- k=2
450
  # trainable_iterations = time_steps[win_start: win_start + win_len]
451
  # trainable_iterations = random.sample(time_steps, k)
452
  trainable_iterations = time_steps[-1:-k-1:-1]
 
27
  padding_mode="border",
28
  v_scale = 0.008/256,
29
  resample_mode=None,
30
+ inf_mode = False,
31
  ):
32
  super(DeformDDPM, self).__init__()
33
  self.rec_num=2
 
36
  self.v_scale = v_scale
37
  self.device = device
38
  self.msk_noise_scale = torch.tensor(0)
39
+ # self.msk_noise_scale = torch.tensor(1)
40
 
41
  # print('================')
42
  # print("device:",device)
 
63
  )
64
  self._DDF_Encoder_init()
65
  self.copy_opt = nn.Identity()
66
+ self.inf_mode = inf_mode
67
  return
68
 
69
  def get_stn(self):
 
81
  def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
82
  rec_num = 1
83
  mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
84
+ # mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
85
+ mul_num_dvf = torch.floor_divide(torch.pow(t,0.75), divide_num).int() # raise the power number to increase the dvf ratio, which can help the training of ddf_stn_rec and make the model more robust to large deformation
86
  # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
87
  # mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
88
  # mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
 
114
 
115
  def _get_random_ddf(self,img,t):
116
  rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
117
+ ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf],select_num=random.choice([1, 2, 3, 3, 4, 4]))
118
  warped_img = self.img_stn(img,ddf_forward)
119
  return warped_img, dvf_forward,ddf_forward
120
 
 
126
  dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
127
  dvf = dvf + dvf_rot
128
  for ctl_sz in ctl_szs:
129
+ _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=0., order_num=random.choice([1, 2])) if rand_v_scale else v_scale
130
  # temp>>
131
+ if ctl_sz <= 4:
132
+ _v_scale = _v_scale/2
133
  if ctl_sz <= 2:
134
  _v_scale = _v_scale/2
135
  # temp<<
 
144
  sample_value = np.random.uniform(low=sample_value, high=high)
145
  return sample_value
146
 
147
+ def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=3, flip_ratio=0.5):
148
  crop_rate=2
149
  for _ in range(self.ndims+1):
150
  mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
 
194
  else:
195
  return ddf
196
 
197
+ def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
198
  if noise_type == 'gaussian':
199
+ noise_map = torch.randn_like(img) * noise_scale
200
  elif noise_type == 'uniform':
201
+ noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1
202
  elif noise_type == 'binary':
203
  noise_map = torch.bernoulli(torch.rand_like(img))
204
  else:
 
226
  mask = torch.zeros_like(img)
227
  sample_ratio = 0
228
  for i in range(self.ndims):
229
+ if self.inf_mode:
230
+ if i== 0:
231
+ slice_num = 1 # use max slice num for inference for better performance
232
+ slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
233
+ else:
234
+ slice_num = 0
235
+ slice_idx = []
236
+ # slice_num = 1 # use max slice num for inference for better performance
237
+ # slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
238
+ else:
239
+ slice_num = random.randint(slice_num_range[0], slice_num_range[1])
240
+ slice_idx = random.sample(range(self.image_chw[1]), slice_num)
241
  transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
242
  for idx in slice_idx:
243
  mask[..., idx] = 1
 
259
  # print("projecting dim:", i)
260
  return proj_img/(proj_dim_num+EPS), proj_dim_num
261
 
262
+ def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
263
  # Remove torch.no_grad() since most operations are not differentiable anyway
264
  proc_img = img.clone().detach()
265
  if proc_type is None:
 
267
  proc_type = random.choices(
268
  # ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
269
  # weights=[1, 1, 1, 1, 1, 1, 3], k=1
270
+ ['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
271
  weights=[1, 1, 1, 1, 1, 3], k=1
272
  )[0]
273
  mask = torch.tensor(1, device=img.device)
 
278
  noise_map = None
279
  if proc_type not in ['none', None, '']:
280
  if proc_type == 'uncon':
281
+ noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
282
  proc_img = noise_map
283
  mask = torch.tensor(0, device=img.device)
284
  cond_ratio = torch.tensor(0, device=img.device)
285
  return proc_img, mask, cond_ratio
286
+ if proc_type in ['adding', 'independ', 'slice','slice1']:
287
  # self.msk_noise_scale = 0
288
+ noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
289
  if proc_type == 'adding':
290
  proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
291
  cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
 
301
  # proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
302
  proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
303
  cond_ratio = torch.tensor(down_ratio, device=img.device)
304
+ elif proc_type == 'slice' or proc_type == 'slice1':
305
+ if proc_type == 'slice1':
306
+ slice_num_max = 1
307
+ else:
308
+ slice_num_max = random.randint(1, 64)
309
+ slice_num_max = random.randint(1, slice_num_max)
310
  mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
311
  if self.msk_noise_scale == 0:
312
  proc_img = img * mask
 
392
  t = [t] * 1
393
  return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
394
 
395
+ def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, output_embedding=False, **kwargs):
396
+ if output_embedding:
397
+ # Direct network forward for contrastive embedding (no diffusion).
398
+ # Returns img_embd so DDP's prepare_for_backward traces the correct subgraph
399
+ # (encoder + mid + attn + img2txt only, no decoder).
400
+ self.network(x=img_org, y=cond_imgs, t=T, text=kwargs.get('text'), rec_num=1)
401
+ return self.network.img_embd
402
+ elif T is not None:
403
  return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
404
  else:
405
  return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
 
471
  # win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
472
  # win_end = win_start + win_len - 1
473
 
474
+ k = 1 if len(time_steps) > 16 else 2
475
  # trainable_iterations = time_steps[win_start: win_start + win_len]
476
  # trainable_iterations = random.sample(time_steps, k)
477
  trainable_iterations = time_steps[-1:-k-1:-1]
Diffusion/diffuser_opt.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ diffuser_opt.py — Optimized DeformDDPM subclass.
3
+
4
+ Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods
5
+ that benefit from optimization.
6
+
7
+ Key optimizations:
8
+ 1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop,
9
+ pre-compute timestep tensors, use torch.no_grad() for frozen steps
10
+ 2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition
11
+ instead of O(n), crop-first upsampling (4x faster), on-device tensors.
12
+ 3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight)
13
+ 4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device))
14
+ 5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor()
15
+ 6. _multiscale_dvf_generate(): generate random tensors on device to avoid
16
+ CPU→GPU transfer of 3D volumes.
17
+ """
18
+
19
+ from torch import nn
20
+ import torch
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+ import random
24
+ import math
25
+
26
+ import Diffusion.utils_diff as utils
27
+ from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM
28
+ from Diffusion.networks import *
29
+ from Diffusion.networks_opt import OptSTN
30
+
31
+ EPS = 1e-8
32
+
33
+
34
+ class DeformDDPM(_BaseDeformDDPM):
35
+ """Drop-in replacement for DeformDDPM with speed optimizations."""
36
+
37
+ # ------------------------------------------------------------------
38
+ # Optimization 4: use OptSTN (register_buffer, no per-call .to())
39
+ # ------------------------------------------------------------------
40
+ def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
41
+ if ctl_sz is None:
42
+ ctl_sz = self.image_chw[1] // ctl_ratio
43
+ self.ctl_sz = ctl_sz
44
+ self.img_sz = self.image_chw[1]
45
+ # OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz
46
+ self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device,
47
+ padding_mode=self.ddf_pad_mode)
48
+ self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
49
+ padding_mode=self.img_pad_mode, resample_mode=self.resample_mode)
50
+ self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
51
+ padding_mode=self.img_pad_mode, resample_mode='nearest')
52
+
53
+ def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu',
54
+ image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros",
55
+ ddf_pad_mode="border", padding_mode="border",
56
+ v_scale=0.008/256, resample_mode=None, inf_mode=False):
57
+ # Call parent __init__ — it creates STN instances
58
+ super().__init__(
59
+ network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn,
60
+ device=device, image_chw=image_chw, batch_size=batch_size,
61
+ img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode,
62
+ padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode,
63
+ inf_mode=inf_mode,
64
+ )
65
+ # OPT: replace ddf_stn_full with OptSTN too
66
+ self.ddf_stn_full = OptSTN(
67
+ img_sz=self.image_chw[1], ndims=self.ndims,
68
+ padding_mode=self.padding_mode, device=self.device,
69
+ )
70
+
71
+ # ------------------------------------------------------------------
72
+ # Optimization 5: fix recover() t tensor bug + avoid redundant copies
73
+ # ------------------------------------------------------------------
74
+ def recover(self, x, y, t, rec_num=2, text=None):
75
+ # OPT: don't recreate t if already a tensor on the right device
76
+ if isinstance(t, list):
77
+ t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device)
78
+ for t0 in t]
79
+ t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t]
80
+ elif isinstance(t, torch.Tensor):
81
+ # OPT: skip torch.tensor() copy — just ensure correct device
82
+ if t.device != x.device:
83
+ t = t.to(x.device)
84
+ else:
85
+ t = torch.tensor(t, device=x.device)
86
+ if rec_num is None:
87
+ rec_num = self.rec_num
88
+ return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
89
+
90
+ # ------------------------------------------------------------------
91
+ # Optimization 2: scaling-and-squaring + crop-first upsample
92
+ # ------------------------------------------------------------------
93
+ def _compose_n_times(self, dvf, n):
94
+ """Compute n-fold self-composition of dvf using scaling-and-squaring.
95
+
96
+ Uses binary decomposition: O(log n) STN calls instead of O(n).
97
+ E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls).
98
+
99
+ The result is the same deformation (n-fold composition) but computed
100
+ via a different sequence of grid_sample interpolations, so there are
101
+ small numerical differences (~1e-2 to 1e-1) vs iterative composition.
102
+ This is acceptable because DDF generation is stochastic augmentation.
103
+ """
104
+ if n <= 0:
105
+ return torch.zeros_like(dvf)
106
+ result = None
107
+ current = dvf # current = dvf^(2^i), starts as dvf^1
108
+ while n > 0:
109
+ if n & 1: # bit is set → accumulate this power
110
+ if result is None:
111
+ result = current.clone()
112
+ else:
113
+ # result = current ∘ result (apply result first, then current)
114
+ result = result + self.ddf_stn_rec(current, result)
115
+ n >>= 1
116
+ if n > 0:
117
+ # Square: current = current ∘ current
118
+ current = current + self.ddf_stn_rec(current, current)
119
+ return result
120
+
121
+ def _crop_upsample(self, field):
122
+ """Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop.
123
+
124
+ Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then
125
+ cropping to img_sz (128³), we crop the control-point field first
126
+ (to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster
127
+ and bit-identical because trilinear interpolation is local.
128
+ """
129
+ crop_rate = 2
130
+ upscale = self.img_sz * crop_rate // self.ctl_sz # e.g. 8
131
+ margin = 2 # voxels of margin for interpolation boundary
132
+ lo = self.ctl_sz // 4 - margin # e.g. 6
133
+ hi = self.ctl_sz * 3 // 4 + margin # e.g. 26
134
+ crop_sz = hi - lo # e.g. 20
135
+ up_sz = crop_sz * upscale # e.g. 160
136
+ pad = (up_sz - self.img_sz) // 2 # e.g. 16
137
+
138
+ mode = 'bilinear' if self.ndims == 2 else 'trilinear'
139
+ if self.ndims == 2:
140
+ field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz
141
+ field_up = F.interpolate(field_crop, up_sz, mode=mode)
142
+ return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz]
143
+ else:
144
+ field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz
145
+ field_up = F.interpolate(field_crop, up_sz, mode=mode)
146
+ return field_up[..., pad:pad + self.img_sz,
147
+ pad:pad + self.img_sz,
148
+ pad:pad + self.img_sz]
149
+
150
+ def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])],
151
+ ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5):
152
+ for _ in range(self.ndims + 1):
153
+ mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
154
+ ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
155
+ if ddf0 is not None:
156
+ ddf = ddf0
157
+ else:
158
+ ddf = torch.zeros(ctl_ddf_sz, device=self.device)
159
+ dddf = torch.zeros(ctl_ddf_sz, device=self.device)
160
+ scale_num = min(8, int(math.log2(self.ctl_sz)))
161
+ ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
162
+
163
+ for i in range(rec_num):
164
+ if len(ctl_szs_all) > select_num:
165
+ ctl_szs = random.sample(ctl_szs_all, select_num)
166
+ else:
167
+ ctl_szs = ctl_szs_all
168
+ dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs)
169
+ if noise_ratio == 0:
170
+ dvf0 = dvf
171
+ else:
172
+ dvf0 = dvf + self.ddf_stn_rec(
173
+ self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False),
174
+ dvf)
175
+
176
+ mul_num_ddf_val = int(torch.max(mul_num[0]).item())
177
+ mul_num_dvf_val = int(torch.max(mul_num[1]).item())
178
+
179
+ # OPT: scaling-and-squaring — O(log n) STN calls instead of O(n)
180
+ # For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195.
181
+ ddf = self._compose_n_times(dvf0, mul_num_ddf_val)
182
+ dddf = self._compose_n_times(dvf, mul_num_dvf_val)
183
+
184
+ # OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical)
185
+ ddf = self._crop_upsample(ddf)
186
+ dddf = self._crop_upsample(dddf)
187
+ return ddf, dddf
188
+
189
+ # ------------------------------------------------------------------
190
+ # Optimization 6: generate DVF on device to avoid CPU→GPU transfer
191
+ # ------------------------------------------------------------------
192
+ def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True):
193
+ dvf = 0
194
+ if self.img_sz is None:
195
+ self.img_sz = max(ctl_szs)
196
+ if 1 in ctl_szs:
197
+ dvf_rot = utils.random_ddf(
198
+ batch_size=self.batch_size, ndims=self.ndims,
199
+ img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90)
200
+ dvf = dvf + dvf_rot
201
+ for ctl_sz in ctl_szs:
202
+ _v_scale = self._sample_random_uniform_multi_order(
203
+ high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale
204
+ if ctl_sz <= 2:
205
+ _v_scale = _v_scale / 2
206
+ # OPT: generate random tensor directly on device
207
+ dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims,
208
+ device=self.device) * _v_scale
209
+ dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims,
210
+ align_corners=False,
211
+ mode='bilinear' if self.ndims == 2 else 'trilinear')
212
+ dvf = dvf + dvf_comp
213
+ return dvf
214
+
215
+ # ------------------------------------------------------------------
216
+ # Optimization 3: skip clone for 'uncon' (most common conditioning type)
217
+ # ------------------------------------------------------------------
218
+ def proc_cond_img(self, img, proc_type=None, noise_scale=0.1):
219
+ if proc_type is None:
220
+ proc_type = random.choices(
221
+ ['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'],
222
+ weights=[1, 1, 1, 1, 1, 3], k=1
223
+ )[0]
224
+ mask = torch.tensor(1, device=img.device)
225
+ cond_ratio = torch.tensor(1., device=img.device)
226
+ self.msk_noise_scale = torch.tensor(0, device=img.device)
227
+ noise_type = random.choice(['gaussian', 'uniform', 'none'])
228
+
229
+ if proc_type not in ['none', None, '']:
230
+ # OPT: handle 'uncon' before cloning — no need to clone img
231
+ if proc_type == 'uncon':
232
+ noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
233
+ proc_img = noise_map
234
+ mask = torch.tensor(0, device=img.device)
235
+ cond_ratio = torch.tensor(0, device=img.device)
236
+ return proc_img, mask, cond_ratio
237
+
238
+ # Only clone when we actually need the image data
239
+ proc_img = img.clone().detach()
240
+ noise_map = None
241
+ if proc_type in ['adding', 'independ', 'slice', 'slice1']:
242
+ noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
243
+ if proc_type == 'adding':
244
+ proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
245
+ cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
246
+ elif proc_type == 'independ':
247
+ mask = self.create_noise_map(img, noise_type='binary')
248
+ if self.msk_noise_scale == 0:
249
+ proc_img = img * mask
250
+ else:
251
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
252
+ with torch.no_grad():
253
+ cond_ratio = mask.float().mean()
254
+ elif proc_type == 'downsample':
255
+ proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1])
256
+ cond_ratio = torch.tensor(down_ratio, device=img.device)
257
+ elif proc_type == 'slice' or proc_type == 'slice1':
258
+ if proc_type == 'slice1':
259
+ slice_num_max = 1
260
+ else:
261
+ slice_num_max = random.randint(1, 64)
262
+ slice_num_max = random.randint(1, slice_num_max)
263
+ mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
264
+ if self.msk_noise_scale == 0:
265
+ proc_img = img * mask
266
+ else:
267
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
268
+ cond_ratio = torch.tensor(sample_ratio, device=img.device)
269
+ elif proc_type == 'project':
270
+ proc_img, proj_num = self.project(proc_img)
271
+ cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
272
+ return proc_img, mask, cond_ratio
273
+ else:
274
+ # 'none' type — still need clone
275
+ proc_img = img.clone().detach()
276
+ return proc_img, mask, cond_ratio
277
+
278
+ # ------------------------------------------------------------------
279
+ # Optimization 1: hoist clone, pre-compute timestep tensors,
280
+ # use inference_mode for frozen iterations
281
+ # ------------------------------------------------------------------
282
+ def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None,
283
+ v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None):
284
+ if cond_imgs is None:
285
+ cond_imgs = img_org.clone().detach()
286
+ cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type)
287
+ if ddf_rand is None:
288
+ if v_scale is not None:
289
+ self.v_scale = v_scale
290
+ self._DDF_Encoder_init()
291
+ if T[0] is None or T[0] == 0:
292
+ img_diff = img_org.clone().detach()
293
+ ddf_rand = torch.zeros_like(img_diff)
294
+ else:
295
+ img_diff, _, ddf_rand = self._get_random_ddf(
296
+ img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
297
+ else:
298
+ img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
299
+ ddf_comp = ddf_rand.clone().detach()
300
+ img_rec = img_diff.clone().detach()
301
+ if msk_org is not None:
302
+ msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
303
+ else:
304
+ msk_diff = None
305
+ msk_rec = msk_diff.clone().detach() if msk_org is not None else None
306
+ img_save = []
307
+ msk_save = []
308
+
309
+ # OPT: hoist clone().detach() outside the loop — grid_sample is read-only
310
+ img_org_ref = img_org.clone().detach()
311
+ msk_org_ref = msk_org.clone().detach() if msk_org is not None else None
312
+
313
+ if isinstance(self.network, DefRec_MutAttnNet):
314
+ t_list = list(range(T[1] - 1, -1, -1))
315
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text)
316
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
317
+ img_rec = self.img_stn(img_org_ref, ddf_comp)
318
+ if msk_org is not None:
319
+ msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
320
+ else:
321
+ if isinstance(T[-1], int):
322
+ time_steps = range(T[-1] - 1, -1, -1)
323
+ trainable_iterations = []
324
+ else:
325
+ time_steps = T[-1]
326
+ k = 2
327
+ trainable_iterations = time_steps[-1:-k - 1:-1]
328
+
329
+ # OPT: pre-compute trainable index threshold — avoid unhashable list issue
330
+ t_save_set = set(t_save) if t_save is not None else None
331
+ num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps)
332
+ trainable_start_idx = num_time_steps - len(trainable_iterations)
333
+
334
+ for step_idx, i in enumerate(time_steps):
335
+ # OPT: create tensor directly on device, no numpy intermediate
336
+ t = torch.tensor([i], device=self.device)
337
+
338
+ if step_idx >= trainable_start_idx:
339
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
340
+ else:
341
+ # OPT: no_grad for frozen iterations (inference_mode not safe here
342
+ # because ddf_comp is composed across frozen+trainable iterations)
343
+ with torch.no_grad():
344
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
345
+
346
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
347
+ # OPT: use pre-cloned reference instead of cloning each iteration
348
+ img_rec = self.img_stn(img_org_ref, ddf_comp)
349
+ if msk_org is not None:
350
+ msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
351
+ if t_save_set is not None:
352
+ if i in t_save_set:
353
+ img_save.append(img_rec)
354
+ if msk_org is not None:
355
+ msk_save.append(msk_rec)
356
+
357
+ return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]
Diffusion/losses.py CHANGED
@@ -21,7 +21,7 @@ class LMSE(torch.nn.Module):
21
  Labeled Mean Square Error (LMSE)
22
  """
23
 
24
- def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
25
  super(LMSE, self).__init__()
26
  self.eps = eps
27
  self.relate_eps = relate_eps
@@ -72,7 +72,7 @@ class LNCC(torch.nn.Module):
72
  Local (over window) normalized cross-correlation (LNCC)
73
  """
74
 
75
- def __init__(self, win=None, num_ch=1, eps=1e-6, central=True, smooth=True):
76
  super(LNCC, self).__init__()
77
  self.scale = 2e0
78
  self.win = win
@@ -84,11 +84,11 @@ class LNCC(torch.nn.Module):
84
 
85
  # Set window size
86
  if self.win is None:
87
- self.win = [9] * self.ndims
88
  self.padding = [(w-1) // 2 for w in self.win]
89
 
90
  if smooth:
91
- self.kernels = self._build_kernel(std=0.45)
92
  self.sum_filt = self._build_kernel(std=0.0)
93
 
94
  def _build_kernel(self, std=0.0):
@@ -153,7 +153,7 @@ class LNCC(torch.nn.Module):
153
  J_var = J2_sum
154
 
155
  # cc = (cross * cross) / (I_var * J_var + self.eps)
156
- cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
157
  if label is not None:
158
  label = label.float()
159
  cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
@@ -164,6 +164,43 @@ class LNCC(torch.nn.Module):
164
  return -self.lncc(I*self.scale, J*self.scale, label=label)
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  class NCC(torch.nn.Module):
169
  # def __init__(self, eps_scale=10e-7,img_sz=256):
@@ -236,7 +273,7 @@ class Grad(torch.nn.Module):
236
  N-D gradient loss
237
  """
238
 
239
- def __init__(self, penalty=['l1'],ndims=2, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
240
  super(Grad, self).__init__()
241
  self.penalty = penalty
242
  self.eps = eps
@@ -521,7 +558,7 @@ if __name__ == "__main__":
521
  img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
522
  # img3d_t = img3d.clone().detach()
523
  # img3d_t = torch.zeros_like(img3d)
524
- translation = 2
525
  start = 0
526
  end = 32
527
  # img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
 
21
  Labeled Mean Square Error (LMSE)
22
  """
23
 
24
+ def __init__(self, eps=1e-7, relate_eps=1e-1, win=None, smooth=False):
25
  super(LMSE, self).__init__()
26
  self.eps = eps
27
  self.relate_eps = relate_eps
 
72
  Local (over window) normalized cross-correlation (LNCC)
73
  """
74
 
75
+ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
76
  super(LNCC, self).__init__()
77
  self.scale = 2e0
78
  self.win = win
 
84
 
85
  # Set window size
86
  if self.win is None:
87
+ self.win = [11] * self.ndims
88
  self.padding = [(w-1) // 2 for w in self.win]
89
 
90
  if smooth:
91
+ self.kernels = self._build_kernel(std=0.5)
92
  self.sum_filt = self._build_kernel(std=0.0)
93
 
94
  def _build_kernel(self, std=0.0):
 
153
  J_var = J2_sum
154
 
155
  # cc = (cross * cross) / (I_var * J_var + self.eps)
156
+ cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps) # eps must be large enough to avoid numerical unstability
157
  if label is not None:
158
  label = label.float()
159
  cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
 
164
  return -self.lncc(I*self.scale, J*self.scale, label=label)
165
 
166
 
167
+ class MSLNCC(LNCC):
168
+ """
169
+ Multi-Scale Local Normalized Cross-Correlation (MSLNCC)
170
+ Computes LNCC at multiple scales and combines with weighted sum.
171
+ Images are downsampled via average pooling, labels via max pooling.
172
+ """
173
+
174
+ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
175
+ scale_ratios=[1, 0.5, 0.25], scale_weights=[0.25, 0.5, 0.75]):
176
+ super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
177
+ central=central, smooth=smooth)
178
+ if win is None:
179
+ win = [9] * self.ndims
180
+ self.scale_ratios = scale_ratios
181
+ self.scale_weights = scale_weights
182
+
183
+ def _downsample(self, I, J, label, ratio):
184
+ """Downsample images via average pooling, labels via max pooling."""
185
+ if ratio >= 1.0:
186
+ return I, J, label
187
+ factor = int(1.0 / ratio)
188
+ I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
189
+ J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
190
+ label_down = None
191
+ if label is not None:
192
+ label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
193
+ return I_down, J_down, label_down
194
+
195
+ def forward(self, I, J, label=None):
196
+ total_loss = 0.0
197
+ total_weight = 0.0
198
+ for ratio, weight in zip(self.scale_ratios, self.scale_weights):
199
+ I_s, J_s, label_s = self._downsample(I, J, label, ratio)
200
+ total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
201
+ total_weight += weight
202
+ return -total_loss / total_weight
203
+
204
 
205
  class NCC(torch.nn.Module):
206
  # def __init__(self, eps_scale=10e-7,img_sz=256):
 
273
  N-D gradient loss
274
  """
275
 
276
+ def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=1e4, apear_scale=8, dist=1, sign=1,waive_thresh=10**-4):
277
  super(Grad, self).__init__()
278
  self.penalty = penalty
279
  self.eps = eps
 
558
  img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
559
  # img3d_t = img3d.clone().detach()
560
  # img3d_t = torch.zeros_like(img3d)
561
+ translation = 16
562
  start = 0
563
  end = 32
564
  # img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
Diffusion/losses_opt.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ losses_opt.py — Optimized loss functions.
3
+
4
+ Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use
5
+ register_buffer for convolution kernels (auto device transfer, no
6
+ per-call .to(device) overhead).
7
+
8
+ All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported
9
+ unchanged.
10
+ """
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ # Re-export unchanged classes
17
+ from Diffusion.losses import (
18
+ LMSE,
19
+ NCC,
20
+ MRSE,
21
+ RMSE,
22
+ Grad,
23
+ avg_std_skew_kurt,
24
+ grad_std,
25
+ avg_std,
26
+ EPS,
27
+ eps_scale,
28
+ )
29
+
30
+
31
+ class LNCC(torch.nn.Module):
32
+ """
33
+ Local (over window) normalized cross-correlation (LNCC).
34
+ Optimized: kernels stored as registered buffers for automatic device transfer.
35
+ """
36
+
37
+ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
38
+ super(LNCC, self).__init__()
39
+ self.scale = 2e0
40
+ self.win = win
41
+ self.eps = eps
42
+ self.central = central
43
+ self.ndims = 3
44
+ self.strides = [1] * (self.ndims + 2)
45
+ self.smooth = smooth
46
+
47
+ if self.win is None:
48
+ self.win = [11] * self.ndims
49
+ self.padding = [(w - 1) // 2 for w in self.win]
50
+
51
+ if smooth:
52
+ self.tail = None # will be set in _build_kernel
53
+ kernels = self._build_kernel(std=0.5)
54
+ self.register_buffer('kernels', kernels) # OPT: auto device transfer
55
+ self.register_buffer('sum_filt', self._build_kernel(std=0.0)) # OPT: auto device transfer
56
+
57
+ def _build_kernel(self, std=0.0):
58
+ if std == 0.0:
59
+ return torch.ones([1, 1, *self.win]) / np.prod(self.win)
60
+ else:
61
+ self.tail = int(np.ceil(std)) * 2
62
+ k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
63
+ kernel = k / torch.sum(k)
64
+ kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
65
+ return kernel.unsqueeze(0).unsqueeze(0)
66
+
67
+ def lncc(self, I, J, label=None):
68
+ # OPT: no .to(I.device) needed — buffers auto-transfer with module.to()
69
+
70
+ if self.smooth:
71
+ I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
72
+ J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
73
+
74
+ I2 = I * I
75
+ J2 = J * J
76
+ IJ = I * J
77
+
78
+ if self.central:
79
+ I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
80
+ J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
81
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
82
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
83
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
84
+
85
+ cross = IJ_sum - (I_sum * J_sum)
86
+ I_var = I2_sum - (I_sum * I_sum)
87
+ J_var = J2_sum - (J_sum * J_sum)
88
+ else:
89
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
90
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
91
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
92
+
93
+ cross = IJ_sum
94
+ I_var = I2_sum
95
+ J_var = J2_sum
96
+
97
+ cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
98
+ if label is not None:
99
+ label = label.float()
100
+ cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
101
+
102
+ return torch.mean(cc)
103
+
104
+ def forward(self, I, J, label=None):
105
+ return -self.lncc(I * self.scale, J * self.scale, label=label)
106
+
107
+
108
+ class MSLNCC(LNCC):
109
+ """
110
+ Multi-Scale Local Normalized Cross-Correlation (MSLNCC).
111
+ Optimized: inherits buffer-based kernels from LNCC.
112
+ """
113
+
114
+ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
115
+ scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
116
+ super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
117
+ central=central, smooth=smooth)
118
+ if win is None:
119
+ win = [9] * self.ndims
120
+ self.scale_ratios = scale_ratios
121
+ self.scale_weights = scale_weights
122
+
123
+ def _downsample(self, I, J, label, ratio):
124
+ if ratio >= 1.0:
125
+ return I, J, label
126
+ factor = int(1.0 / ratio)
127
+ I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
128
+ J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
129
+ label_down = None
130
+ if label is not None:
131
+ label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
132
+ return I_down, J_down, label_down
133
+
134
+ def forward(self, I, J, label=None):
135
+ total_loss = 0.0
136
+ total_weight = 0.0
137
+ for ratio, weight in zip(self.scale_ratios, self.scale_weights):
138
+ I_s, J_s, label_s = self._downsample(I, J, label, ratio)
139
+ total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
140
+ total_weight += weight
141
+ return -total_loss / total_weight
Diffusion/networks.py CHANGED
@@ -1,8 +1,28 @@
1
  from torch import nn
2
  import torch
3
  import torch.nn.functional as F
 
4
  import numpy as np
5
  import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def get_net(name="recresnet"):
8
  name = name.lower()
@@ -16,8 +36,10 @@ def get_net(name="recresnet"):
16
  net = RecMutAttnNet1
17
  elif name == "defrecmutattnnet":
18
  net = DefRec_MutAttnNet
19
- elif name == "recmutattnnet_contrastive":
20
- net = RecMutAttnNet_contrastive
 
 
21
  else:
22
  net = None
23
  return net
@@ -440,6 +462,7 @@ class DefRec_MutAttnNet(nn.Module):
440
  nn.Linear(dim_out, dim_out)
441
  )
442
 
 
443
  class RecMutAttnNet1(nn.Module):
444
  def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
445
  super(RecMutAttnNet1, self).__init__()
@@ -749,6 +772,8 @@ class RecMutAttnNet(nn.Module):
749
  else:
750
  ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
751
  img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
 
 
752
 
753
  return ddf
754
 
@@ -759,9 +784,9 @@ class RecMutAttnNet(nn.Module):
759
  nn.Linear(dim_out, dim_out)
760
  )
761
 
762
- class RecMutAttnNet_contrastive(nn.Module):
763
  def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
764
- super(RecMutAttnNet_contrastive, self).__init__()
765
 
766
  # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
767
  self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
@@ -785,16 +810,21 @@ class RecMutAttnNet_contrastive(nn.Module):
785
  self.block_down = nn.ModuleList()
786
  self.block_up = nn.ModuleList()
787
  if self.conditional_input:
 
 
788
  self.block_down_cond = nn.ModuleList()
789
  self.fuse_conv0 = nn.ModuleList()
790
  self.fuse_conv1 = nn.ModuleList()
791
- self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
 
792
  Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
793
  self.global_maxpool = Global_Maxpool(1)
794
  self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
795
  self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
796
  self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
797
- self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
 
 
798
  self.img_res = [res]*self.dimension
799
  self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
800
  [1, self.dimension]+list(self.img_res))
@@ -811,6 +841,11 @@ class RecMutAttnNet_contrastive(nn.Module):
811
  AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
812
  ))
813
  if self.conditional_input:
 
 
 
 
 
814
  self.block_down_cond.append(nn.Sequential(
815
  AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
816
  AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
@@ -829,12 +864,14 @@ class RecMutAttnNet_contrastive(nn.Module):
829
  ))
830
 
831
  # Bottleneck
 
832
  self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
833
  self.b_mid = nn.Sequential(
834
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
835
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
836
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
837
  )
 
838
 
839
  self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
840
 
@@ -860,6 +897,7 @@ class RecMutAttnNet_contrastive(nn.Module):
860
  self.max_sz = [img_sz[0]] * self.dimension
861
  ts_emb_shape=[n,-1]+[1]*self.dimension
862
 
 
863
  self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
864
  if list(img_sz) != self.img_res:
865
  # print ("Reinitialize the ref_grid to match the model's input image size.")
@@ -870,6 +908,13 @@ class RecMutAttnNet_contrastive(nn.Module):
870
 
871
  img = x
872
  t = self.time_embed(t)
 
 
 
 
 
 
 
873
 
874
  for rec_id in range(rec_num):
875
  if self.conditional_input:
@@ -879,7 +924,7 @@ class RecMutAttnNet_contrastive(nn.Module):
879
  for i in range(self.hier_num):
880
  out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
881
  if self.conditional_input:
882
- tgt = self.block_down_cond[i](tgt)
883
  out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
884
  tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
885
  enc_list.append(out)
@@ -893,19 +938,24 @@ class RecMutAttnNet_contrastive(nn.Module):
893
  # out += self.attn_layer(out, tgt, tgt)[0]
894
  out_shape = out.shape
895
  tgt_shape = tgt.shape
896
- # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
897
- tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
898
- out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
 
899
  out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
 
900
  out = out + out_attn
 
 
901
 
902
  if self.conditional_input:
903
- if text is None:
904
- text = self.text
905
- text = text.to(self.device)
906
- text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
907
- img_embd = self.global_maxpool(self.img2txt(out)).view(n, -1) # [B, 1024]
908
- out_txt = self.img2txt(out) + text
 
909
  out_txt = self.txt_proc(out_txt)
910
  out_txt = self.txt2img(out_txt)
911
  out = out + out_txt
@@ -922,8 +972,264 @@ class RecMutAttnNet_contrastive(nn.Module):
922
  else:
923
  ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
924
  img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
925
 
926
- return ddf, img_embd
927
 
928
  def _make_te(self, dim_in, dim_out):
929
  return nn.Sequential(
@@ -931,6 +1237,8 @@ class RecMutAttnNet_contrastive(nn.Module):
931
  nn.ReLU(),
932
  nn.Linear(dim_out, dim_out)
933
  )
 
 
934
  # class RecMutAttnNet(nn.Module):
935
  # def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
936
  # super(RecMutAttnNet, self).__init__()
@@ -1085,6 +1393,8 @@ def composite(ddfs,stn=None):
1085
  comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
1086
  return comp_ddf
1087
 
 
 
1088
  class STN(nn.Module):
1089
  def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
1090
  super(STN, self).__init__()
@@ -1148,6 +1458,7 @@ class STN(nn.Module):
1148
  resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
1149
  return resampled_x
1150
 
 
1151
  if __name__ == '__main__':
1152
  ndims = 3
1153
  res = 128
 
1
  from torch import nn
2
  import torch
3
  import torch.nn.functional as F
4
+ from torch.utils.checkpoint import checkpoint as grad_checkpoint
5
  import numpy as np
6
  import math
7
+ from Diffusion.safe_conv_transpose import SafeConvTranspose3d
8
+
9
+ class UpsampleConv(nn.Module):
10
+ """Drop-in replacement for ConvTranspose3d/2d that avoids the XPU memory leak.
11
+ ConvTranspose3d backward leaks ~0.33 GiB/step on Intel XPU (oneDNN bug).
12
+ This uses F.interpolate (zero leak) + Conv (negligible leak) instead.
13
+ Also avoids checkerboard artifacts common with transposed convolutions.
14
+ """
15
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, ndims=3):
16
+ super().__init__()
17
+ self.scale_factor = stride
18
+ self.mode = 'trilinear' if ndims == 3 else 'bilinear'
19
+ Conv = getattr(nn, f'Conv{ndims}d')
20
+ self.conv = Conv(in_channels, out_channels, 3, 1, 1)
21
+
22
+ def forward(self, x):
23
+ x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
24
+ return self.conv(x)
25
+
26
 
27
  def get_net(name="recresnet"):
28
  name = name.lower()
 
36
  net = RecMutAttnNet1
37
  elif name == "defrecmutattnnet":
38
  net = DefRec_MutAttnNet
39
+ elif name == "recmulmodmutattnnet":
40
+ net = RecMulModMutAttnNet
41
+ elif name == "om_net":
42
+ net = OM_net
43
  else:
44
  net = None
45
  return net
 
462
  nn.Linear(dim_out, dim_out)
463
  )
464
 
465
+
466
  class RecMutAttnNet1(nn.Module):
467
  def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
468
  super(RecMutAttnNet1, self).__init__()
 
772
  else:
773
  ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
774
  img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
775
+
776
+ # print(torch.max(torch.abs(ddf)))
777
 
778
  return ddf
779
 
 
784
  nn.Linear(dim_out, dim_out)
785
  )
786
 
787
+ class RecMulModMutAttnNet(nn.Module):
788
  def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
789
+ super(RecMulModMutAttnNet, self).__init__()
790
 
791
  # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
792
  self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
 
810
  self.block_down = nn.ModuleList()
811
  self.block_up = nn.ModuleList()
812
  if self.conditional_input:
813
+ # self.gate_img = nn.ModuleList()
814
+ self.txt_layers = nn.ModuleList()
815
  self.block_down_cond = nn.ModuleList()
816
  self.fuse_conv0 = nn.ModuleList()
817
  self.fuse_conv1 = nn.ModuleList()
818
+ self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
819
+ self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
820
  Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
821
  self.global_maxpool = Global_Maxpool(1)
822
  self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
823
  self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
824
  self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
825
+ # self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
826
+ self.text = torch.zeros(1, self.text_feat_chn)
827
+
828
  self.img_res = [res]*self.dimension
829
  self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
830
  [1, self.dimension]+list(self.img_res))
 
841
  AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
842
  ))
843
  if self.conditional_input:
844
+ # self.gate_img.append(nn.Sequential(
845
+ # nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0),
846
+ # nn.Sigmoid()
847
+ # ))
848
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
849
  self.block_down_cond.append(nn.Sequential(
850
  AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
851
  AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
 
864
  ))
865
 
866
  # Bottleneck
867
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
868
  self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
869
  self.b_mid = nn.Sequential(
870
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
871
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
872
  AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
873
  )
874
+ self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
875
 
876
  self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
877
 
 
897
  self.max_sz = [img_sz[0]] * self.dimension
898
  ts_emb_shape=[n,-1]+[1]*self.dimension
899
 
900
+
901
  self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
902
  if list(img_sz) != self.img_res:
903
  # print ("Reinitialize the ref_grid to match the model's input image size.")
 
908
 
909
  img = x
910
  t = self.time_embed(t)
911
+ if text is None:
912
+ text = self.text
913
+ # print(text.shape)
914
+ text = text.to(self.device)
915
+ txt_shape = [1,-1]+[1]*self.dimension
916
+ else:
917
+ txt_shape = [n,-1]+[1]*self.dimension
918
 
919
  for rec_id in range(rec_num):
920
  if self.conditional_input:
 
924
  for i in range(self.hier_num):
925
  out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
926
  if self.conditional_input:
927
+ tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
928
  out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
929
  tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
930
  enc_list.append(out)
 
938
  # out += self.attn_layer(out, tgt, tgt)[0]
939
  out_shape = out.shape
940
  tgt_shape = tgt.shape
941
+ out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
942
+ tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
943
+ out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
944
+ tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
945
  out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
946
+ tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) # (H*W, N, C) -> (N, C, H, W)
947
  out = out + out_attn
948
+ tgt = tgt + tgt_attn
949
+ out = self.fuse(torch.cat([out, tgt], dim=1))
950
 
951
  if self.conditional_input:
952
+
953
+ # text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
954
+
955
+ # out_txt = self.img2txt(out) + text.reshape(txt_shape)
956
+ img_txt_feat = self.img2txt(out)
957
+ self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
958
+ out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
959
  out_txt = self.txt_proc(out_txt)
960
  out_txt = self.txt2img(out_txt)
961
  out = out + out_txt
 
972
  else:
973
  ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
974
  img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
975
+
976
+ # print(torch.max(torch.abs(ddf)))
977
+
978
+ return ddf
979
+
980
+ def _make_te(self, dim_in, dim_out):
981
+ return nn.Sequential(
982
+ nn.Linear(dim_in, dim_out),
983
+ nn.ReLU(),
984
+ nn.Linear(dim_out, dim_out)
985
+ )
986
+
987
+
988
+ class OM_net(nn.Module):
989
+ """
990
+ Extended RecMulModMutAttnNet with gated attention mechanisms:
991
+ 1. Text Gate (bottleneck): sigmoid weight w_txt to interpolate between
992
+ text-enhanced features and raw image features. Learns to suppress
993
+ text branch when text embedding is zeros (no text provided).
994
+ 2. Target Gate (each encoder level): per-voxel spatial gate using
995
+ residual AtrousBlock to identify condition vs. noise voxels in the
996
+ target/condition image path, weighting the fuse_conv1 output.
997
+
998
+ Supports gradient checkpointing via `use_checkpoint` flag to reduce
999
+ peak activation memory (trades compute for memory).
1000
+ """
1001
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0,
1002
+ conditional_input=True, text_feat_chn=1024, num_heads=4,
1003
+ use_conv_transpose=False):
1004
+ super(OM_net, self).__init__()
1005
+ self.use_checkpoint = False # Set True to enable gradient checkpointing
1006
+ self.use_conv_transpose = use_conv_transpose
1007
+
1008
+ self.feat_channels = [num_input_chn, 12, 32, 64, 128, 512]
1009
+ self.conditional_input = conditional_input
1010
+ self.num_heads = num_heads
1011
+ self.text_feat_chn = text_feat_chn
1012
+
1013
+ self.dimension = ndims
1014
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
1015
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
1016
+
1017
+ # Sinusoidal embedding
1018
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
1019
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
1020
+ self.time_embed.requires_grad_(False)
1021
+ self.hier_num = len(self.feat_channels) - 1
1022
+ self.down_layers = nn.ModuleList()
1023
+ self.up_layers = nn.ModuleList()
1024
+ self.ted_layers = nn.ModuleList()
1025
+ self.teu_layers = nn.ModuleList()
1026
+ self.block_down = nn.ModuleList()
1027
+ self.block_up = nn.ModuleList()
1028
+ if self.conditional_input:
1029
+ self.txt_layers = nn.ModuleList()
1030
+ self.block_down_cond = nn.ModuleList()
1031
+ self.fuse_conv0 = nn.ModuleList()
1032
+ self.fuse_conv1 = nn.ModuleList()
1033
+ self.tgt_gate = nn.ModuleList() # Target gate per encoder level
1034
+ self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
1035
+ self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
1036
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
1037
+ self.global_maxpool = Global_Maxpool(1)
1038
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
1039
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
1040
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
1041
+ self.text = torch.zeros(1, self.text_feat_chn)
1042
+
1043
+ # Text Gate: text-only MLP → sigmoid weight (computed before rec loop)
1044
+ self.text_gate = nn.Sequential(
1045
+ nn.Linear(self.text_feat_chn, self.text_feat_chn // 4),
1046
+ nn.ReLU(),
1047
+ nn.Linear(self.text_feat_chn // 4, 1),
1048
+ nn.Sigmoid()
1049
+ )
1050
+
1051
+ self.img_res = [res]*self.dimension
1052
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
1053
+ [1, self.dimension]+list(self.img_res))
1054
+
1055
+ for i in range(1, self.hier_num + 1):
1056
+ j=-i
1057
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
1058
+ self.up_layers.append(SafeConvTranspose3d(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
1059
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
1060
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
1061
+ self.block_down.append(nn.Sequential(
1062
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
1063
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
1064
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
1065
+ ))
1066
+ if self.conditional_input:
1067
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
1068
+ self.block_down_cond.append(nn.Sequential(
1069
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
1070
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
1071
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
1072
+ ))
1073
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
1074
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
1075
+ # Target Gate: residual AtrousBlock → 2-channel softmax (condition vs noise)
1076
+ self.tgt_gate.append(nn.Sequential(
1077
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims,
1078
+ self.feat_channels[i], self.feat_channels[i], ndims=ndims, atrous_rates=[1, 3]),
1079
+ self.Conv(self.feat_channels[i], 2, 1, 1, 0)
1080
+ ))
1081
+ if i==self.hier_num:
1082
+ k=j
1083
+ else:
1084
+ k=j-1
1085
+ self.block_up.append(nn.Sequential(
1086
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
1087
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
1088
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
1089
+ ))
1090
+
1091
+ # Bottleneck
1092
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
1093
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
1094
+ self.b_mid = nn.Sequential(
1095
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
1096
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
1097
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
1098
+ )
1099
+ self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
1100
+
1101
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
1102
+
1103
+ # Initialize target gates toward pass-through (condition confidence high)
1104
+ self._init_tgt_gates()
1105
+
1106
+ def _init_tgt_gates(self):
1107
+ """Bias target gates so condition channel starts moderately high (~0.73).
1108
+ Milder than [2,-2] to ensure both cond*tgt and (1-cond)*out halves of
1109
+ fuse_conv1 input have enough signal for healthy early gradient flow."""
1110
+ for gate_seq in self.tgt_gate:
1111
+ final_conv = gate_seq[-1] # the Conv that outputs 2 channels
1112
+ with torch.no_grad():
1113
+ final_conv.bias.data[0] = 1.0 # condition channel → softmax ~0.73
1114
+ final_conv.bias.data[1] = -1.0 # noise channel → softmax ~0.27
1115
+
1116
+ def _encoder_level(self, i, out, tgt, t, ts_emb_shape, text, txt_shape, w_txt):
1117
+ """Single encoder level — extracted for gradient checkpointing."""
1118
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
1119
+ if self.conditional_input and tgt is not None:
1120
+ tgt = self.block_down_cond[i](tgt) + w_txt * self.txt_layers[i](text).reshape(txt_shape)
1121
+ gate_logits = self.tgt_gate[i](tgt)
1122
+ cond_confidence = F.softmax(gate_logits, dim=1)[:, 0:1]
1123
+ tgt = self.fuse_conv1[i](torch.cat([cond_confidence*tgt, (1-cond_confidence)*out], axis=1))
1124
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
1125
+ return out, tgt
1126
+
1127
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
1128
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1129
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
1130
+ zip(sample_coords, max_sz)], 1)
1131
+
1132
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
1133
+ ref = self.ref_grid if ref is None else ref
1134
+ img_sz = self.max_sz if img_sz is None else img_sz
1135
+ resample_mode = 'bilinear'
1136
+
1137
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
1138
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
1139
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
1140
+ align_corners=True)
1141
+
1142
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
1143
+ self.device = x.device
1144
+ img_sz = x.size()[2:]
1145
+ n = x.size()[0]
1146
+ self.max_sz = [img_sz[0]] * self.dimension
1147
+ ts_emb_shape=[n,-1]+[1]*self.dimension
1148
+
1149
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
1150
+ if list(img_sz) != self.img_res:
1151
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
1152
+ [1, self.dimension]+list(img_sz))
1153
+ self.ref_grid = self.ref_grid.to(self.device)
1154
+
1155
+ img = x
1156
+ t = self.time_embed(t)
1157
+ if text is None:
1158
+ text = self.text
1159
+ text = text.to(self.device)
1160
+ txt_shape = [1,-1]+[1]*self.dimension
1161
+ else:
1162
+ txt_shape = [n,-1]+[1]*self.dimension
1163
+
1164
+ # Text Gate: compute w_txt from text embedding alone before rec loop
1165
+ txt_vec = text.view(text.size(0), -1) # [1, 1024] or [n, 1024]
1166
+ if txt_vec.size(0) == 1 and n > 1:
1167
+ txt_vec = txt_vec.expand(n, -1)
1168
+ w_txt = self.text_gate(txt_vec) # [B, 1]
1169
+ w_txt = w_txt.view([w_txt.size(0), 1] + [1] * self.dimension)
1170
+
1171
+ for rec_id in range(rec_num):
1172
+ if self.conditional_input:
1173
+ tgt = y
1174
+ enc_list = []
1175
+ out = img
1176
+ for i in range(self.hier_num):
1177
+ # Gradient checkpointing on early encoder levels (large feature maps)
1178
+ # to reduce peak activation memory. Levels 0-2 have 128^3, 64^3, 32^3 maps.
1179
+ if self.use_checkpoint and self.training and i < 3:
1180
+ out, tgt = grad_checkpoint(
1181
+ self._encoder_level, i, out, tgt if self.conditional_input else None,
1182
+ t, ts_emb_shape, text, txt_shape, w_txt,
1183
+ use_reentrant=False,
1184
+ )
1185
+ else:
1186
+ out, tgt = self._encoder_level(
1187
+ i, out, tgt if self.conditional_input else None,
1188
+ t, ts_emb_shape, text, txt_shape, w_txt,
1189
+ )
1190
+ enc_list.append(out)
1191
+ out = self.down_layers[i](out)
1192
+ if self.conditional_input:
1193
+ tgt = self.down_layers[i](tgt)
1194
+
1195
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
1196
+ if self.conditional_input:
1197
+ out_shape = out.shape
1198
+ tgt_shape = tgt.shape
1199
+ out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
1200
+ tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
1201
+ out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
1202
+ tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
1203
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
1204
+ tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
1205
+ out = out + out_attn
1206
+ tgt = tgt + tgt_attn
1207
+ out = self.fuse(torch.cat([out, tgt], dim=1))
1208
+
1209
+ if self.conditional_input:
1210
+ img_txt_feat = self.img2txt(out)
1211
+ self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
1212
+ out_txt = self.txt_layers[-1](text).reshape(txt_shape) - img_txt_feat
1213
+ out_txt = self.txt_proc(out_txt)
1214
+ out_txt = self.txt2img(out_txt)
1215
+
1216
+ # Text Gate: w_txt precomputed from text embedding alone
1217
+ out = (1 - w_txt) * out + w_txt * out_txt
1218
+
1219
+ for i in range(self.hier_num):
1220
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
1221
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
1222
+
1223
+ out = self.conv_out(out)/128
1224
+
1225
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
1226
+ if rec_id == 0:
1227
+ ddf = ddf_one
1228
+ else:
1229
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
1230
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
1231
 
1232
+ return ddf
1233
 
1234
  def _make_te(self, dim_in, dim_out):
1235
  return nn.Sequential(
 
1237
  nn.ReLU(),
1238
  nn.Linear(dim_out, dim_out)
1239
  )
1240
+
1241
+
1242
  # class RecMutAttnNet(nn.Module):
1243
  # def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
1244
  # super(RecMutAttnNet, self).__init__()
 
1393
  comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
1394
  return comp_ddf
1395
 
1396
+
1397
+
1398
  class STN(nn.Module):
1399
  def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
1400
  super(STN, self).__init__()
 
1458
  resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
1459
  return resampled_x
1460
 
1461
+
1462
  if __name__ == '__main__':
1463
  ndims = 3
1464
  res = 128
Diffusion/networks0.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+
7
+ def get_net(name="recresnet"):
8
+ name = name.lower()
9
+ if name == "recresacnet":
10
+ net = RecResACNet
11
+ elif name == "recmutattnnet":
12
+ net = RecMutAttnNet
13
+ elif name == "recmutattnnet0":
14
+ net = RecMutAttnNet0
15
+ elif name == "recmutattnnet1":
16
+ net = RecMutAttnNet1
17
+ elif name == "defrecmutattnnet":
18
+ net = DefRec_MutAttnNet
19
+ elif name == "recmulmodmutattnnet":
20
+ net = RecMulModMutAttnNet
21
+ else:
22
+ net = None
23
+ return net
24
+
25
+
26
+
27
+ def sinusoidal_embedding(n, d):
28
+ # Returns the standard positional embedding
29
+ embedding = torch.zeros(n, d)
30
+ wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
31
+ wk = wk.reshape((1, d))
32
+ t = torch.arange(n).reshape((n, 1))
33
+ embedding[:,::2] = torch.sin(t * wk[:,::2])
34
+ embedding[:,1::2] = torch.cos(t * wk[:,::2])
35
+ return embedding
36
+
37
+ class AtrousBlock(nn.Module):
38
+ def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
39
+ super(AtrousBlock, self).__init__()
40
+ # if 0 not in shape:
41
+ if normalize:
42
+ # print(shape)
43
+ # self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024
44
+ norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024
45
+ self.ln = norm(out_c,affine=True)
46
+ else:
47
+ self.ln = nn.Identity()
48
+ Conv=getattr(nn,'Conv%dd' % ndims)
49
+ if in_c!=out_c:
50
+ self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None
51
+ else:
52
+ self.conv0 = None
53
+ self.convs = nn.ModuleList([
54
+ Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
55
+ if ar>0 else Conv(out_c, out_c, 1, 1, 0)
56
+ for ar in atrous_rates
57
+ ])
58
+ # self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding)
59
+ # self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding)
60
+ self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
61
+ # self.activation = nn.ReLU() if activation is None else activation
62
+ # self.activation = nn.ReLU()
63
+ self.normalize = normalize
64
+
65
+ def forward(self, x):
66
+ if self.conv0 is not None:
67
+ x = self.conv0(x) #if self.conv0 is not None else x
68
+ x = self.ln(x) if self.normalize else x # jzheng 15/03/2024
69
+ out=nn.Identity()(x)
70
+ for conv in self.convs:
71
+ out = self.activation(out)
72
+ out = conv(out)
73
+ return self.activation(out+x)
74
+
75
+ # ==============================================
76
+ # Unconditional Network
77
+ # ==============================================
78
+
79
+ class RecResACNet(nn.Module):
80
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
81
+ super(RecResACNet, self).__init__()
82
+
83
+ self.dimension = ndims
84
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
85
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
86
+
87
+ # Sinusoidal embedding
88
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
89
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
90
+ self.time_embed.requires_grad_(False)
91
+
92
+ # First half
93
+ self.te1 = self._make_te(time_emb_dim, 1)
94
+ self.b1 = nn.Sequential(
95
+ AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
96
+ AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
97
+ AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
98
+
99
+ )
100
+ self.down1 = self.Conv(10, 10, 4, 2, 1)
101
+
102
+ self.te2 = self._make_te(time_emb_dim, 10)
103
+ self.b2 = nn.Sequential(
104
+ AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
105
+ AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
106
+ AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
107
+ )
108
+ self.down2 = self.Conv(20, 20, 4, 2, 1)
109
+
110
+ self.te3 = self._make_te(time_emb_dim, 20)
111
+ self.b3 = nn.Sequential(
112
+ AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
113
+ AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
114
+ AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
115
+ )
116
+ self.down3 = self.Conv(40, 40, 4, 2, 1)
117
+
118
+ # Bottleneck
119
+ self.te_mid = self._make_te(time_emb_dim, 40)
120
+ self.b_mid = nn.Sequential(
121
+ AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
122
+ AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
123
+ AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
124
+ )
125
+
126
+ # Second half
127
+ self.up1 = self.ConvT(40, 40, 4, 2, 1)
128
+
129
+ self.te4 = self._make_te(time_emb_dim, 80)
130
+ self.b4 = nn.Sequential(
131
+ AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
132
+ AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
133
+ AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
134
+ )
135
+
136
+ self.up2 = self.ConvT(20, 20, 4, 2, 1)
137
+ self.te5 = self._make_te(time_emb_dim, 40)
138
+ self.b5 = nn.Sequential(
139
+ AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
140
+ AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
141
+ AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
142
+ )
143
+
144
+ self.up3 = self.ConvT(10, 10, 4, 2, 1)
145
+ self.te_out = self._make_te(time_emb_dim, 20)
146
+ self.b_out = nn.Sequential(
147
+ AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
148
+ AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
149
+ AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
150
+ )
151
+
152
+ self.conv_out = self.Conv(10, ndims, 3, 1, 1)
153
+
154
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
155
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
156
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
157
+ zip(sample_coords, max_sz)], 1)
158
+
159
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
160
+ ref = self.ref_grid if ref is None else ref
161
+ img_sz = self.max_sz if img_sz is None else img_sz
162
+ # resample_mode = 'bicubic'
163
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
164
+ # padding_mode = "border"
165
+
166
+ if True:
167
+ # return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
168
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
169
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
170
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
171
+ align_corners=True)
172
+
173
+ def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
174
+ #
175
+ self.device = x.device
176
+ # [h, w] = x.size()[2:]
177
+ img_sz = x.size()[2:]
178
+ n = x.size()[0]
179
+ self.max_sz = [img_sz[0]] * self.dimension
180
+ ts_emb_shape=[n,-1]+[1]*self.dimension
181
+ # [h,w]=img_sz
182
+ # self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2])
183
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
184
+ # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0),
185
+ # [1, 2, h, w]).to(self.device)
186
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
187
+ [1, self.dimension]+list(img_sz)).to(self.device)
188
+ img = x
189
+
190
+ # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
191
+ t = self.time_embed(t)
192
+
193
+ for rec_id in range(rec_num):
194
+ out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28)
195
+ out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14)
196
+ out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7)
197
+
198
+ out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3)
199
+
200
+ out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
201
+ out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7)
202
+
203
+ out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
204
+ out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14)
205
+
206
+ out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
207
+ out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28)
208
+
209
+ out = self.conv_out(out)
210
+
211
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
212
+ if rec_id == 0:
213
+ ddf = ddf_one
214
+ else:
215
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
216
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
217
+
218
+ return ddf
219
+
220
+ def _make_te(self, dim_in, dim_out):
221
+ # make time embedding
222
+
223
+ return nn.Sequential(
224
+ nn.Linear(dim_in, dim_out),
225
+ # nn.SiLU(),
226
+ nn.ReLU(),
227
+ nn.Linear(dim_out, dim_out)
228
+ )
229
+
230
+ # ==============================================
231
+ # Conditional Network
232
+ # ==============================================
233
+
234
+ class cross_attn(nn.Module):
235
+ def __init__(self, q, k, v, ndims=2):
236
+ self.q = q
237
+ self.k = k
238
+ self.v = v
239
+ self.ndims = ndims
240
+ self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
241
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
242
+ self.softmax = nn.Softmax(dim=-1)
243
+ self.gamma = nn.Parameter(torch.zeros(1))
244
+
245
+ def forward(self, x, y):
246
+ q = self.q(x)
247
+ k = self.k(y)
248
+ v = self.v(y)
249
+ attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
250
+ out = torch.matmul(attn, v)
251
+ return out
252
+
253
+ class DefRec_MutAttnNet(nn.Module):
254
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
255
+ super(DefRec_MutAttnNet, self).__init__()
256
+
257
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
258
+ # self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
259
+ self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
260
+ self.conditional_input = conditional_input
261
+ self.num_heads = num_heads
262
+ self.text_feat_chn = text_feat_chn
263
+
264
+ self.dimension = ndims
265
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
266
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
267
+ self.copy = nn.Identity()
268
+ # Sinusoidal embedding
269
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
270
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
271
+ self.time_embed.requires_grad_(False)
272
+ self.hier_num = len(self.feat_channels) - 1
273
+ self.down_layers = nn.ModuleList()
274
+ self.up_layers = nn.ModuleList()
275
+ self.ted_layers = nn.ModuleList()
276
+ self.teu_layers = nn.ModuleList()
277
+ self.block_down = nn.ModuleList()
278
+ self.block_up = nn.ModuleList()
279
+ if self.conditional_input:
280
+ self.block_down_cond = nn.ModuleList()
281
+ self.fuse_conv0 = nn.ModuleList()
282
+ # self.fuse_conv1 = nn.ModuleList()
283
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
284
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
285
+ self.global_maxpool = Global_Maxpool(1)
286
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
287
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
288
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
289
+ self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
290
+ self.img_res = [res]*self.dimension
291
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
292
+ [1, self.dimension]+list(self.img_res))
293
+
294
+ for i in range(1, self.hier_num + 1):
295
+ j=-i
296
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
297
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
298
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
299
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
300
+ self.block_down.append(nn.Sequential(
301
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
302
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
303
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
304
+ ))
305
+ if self.conditional_input:
306
+ self.block_down_cond.append(nn.Sequential(
307
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
308
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
309
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
310
+ ))
311
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
312
+ # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
313
+ if i==self.hier_num:
314
+ k=j
315
+ else:
316
+ k=j-1
317
+ self.block_up.append(nn.Sequential(
318
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
319
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
320
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
321
+ ))
322
+
323
+ # Bottleneck
324
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
325
+ self.b_mid = nn.Sequential(
326
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
327
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
328
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
329
+ )
330
+
331
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
332
+
333
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
334
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
335
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
336
+ zip(sample_coords, max_sz)], 1)
337
+
338
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
339
+ ref = self.ref_grid if ref is None else ref
340
+ img_sz = self.max_sz if img_sz is None else img_sz
341
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
342
+
343
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
344
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
345
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
346
+ align_corners=True)
347
+
348
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
349
+ self.device = x.device
350
+ img_sz = x.size()[2:]
351
+ n = x.size()[0]
352
+ self.max_sz = [img_sz[0]] * self.dimension
353
+ ts_emb_shape=[n,-1]+[1]*self.dimension
354
+
355
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
356
+ if list(img_sz) != self.img_res:
357
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
358
+ # print(img_sz, self.img_res)
359
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
360
+ [1, self.dimension]+list(img_sz))
361
+ self.ref_grid = self.ref_grid.to(self.device)
362
+
363
+ img = x
364
+ if self.conditional_input:
365
+ tgt = y
366
+ # encode the conditional input
367
+ tgt_down_list = []
368
+ for i in range(self.hier_num):
369
+ # out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
370
+ if self.conditional_input:
371
+ tgt = self.block_down_cond[i](tgt)
372
+ tgt_down_list.append(self.copy(tgt))
373
+ tgt = self.down_layers[i](tgt)
374
+ tgt_mid = self.copy(tgt)
375
+ tgt_shape = tgt_mid.shape
376
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
377
+ tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
378
+
379
+ t = [t0.to(self.device) for t0 in t]
380
+ t = [t0 for _ in range(rec_num) for t0 in t]
381
+ for rec_id,time in enumerate(t):
382
+ t_emb = self.time_embed(time)
383
+
384
+ # for rec_id in range(rec_num):
385
+ # if self.conditional_input:
386
+ # tgt = y
387
+ enc_list = []
388
+ out = img
389
+ for i in range(self.hier_num):
390
+ out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
391
+ if self.conditional_input:
392
+ # tgt = self.block_down_cond[i](tgt)
393
+ out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
394
+ # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
395
+ enc_list.append(out)
396
+ out = self.down_layers[i](out)
397
+ # if self.conditional_input:
398
+ # tgt = self.down_layers[i](tgt)
399
+
400
+
401
+ out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
402
+ if self.conditional_input:
403
+ # out += self.attn_layer(out, tgt, tgt)[0]
404
+ out_shape = out.shape
405
+ # tgt_shape = tgt.shape
406
+ # # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
407
+ # tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
408
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
409
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
410
+ out = out + out_attn
411
+
412
+ if self.conditional_input:
413
+ if text is None:
414
+ text = self.text
415
+ text = text.to(self.device)
416
+ out_txt = self.img2txt(out) + text
417
+ out_txt = self.txt_proc(out_txt)
418
+ out_txt = self.txt2img(out_txt)
419
+ out = out + out_txt
420
+
421
+ for i in range(self.hier_num):
422
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
423
+ out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
424
+
425
+ out = self.conv_out(out)/128
426
+
427
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
428
+ if rec_id == 0:
429
+ ddf = ddf_one
430
+ else:
431
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
432
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
433
+
434
+ return ddf
435
+
436
+ def _make_te(self, dim_in, dim_out):
437
+ return nn.Sequential(
438
+ nn.Linear(dim_in, dim_out),
439
+ nn.ReLU(),
440
+ nn.Linear(dim_out, dim_out)
441
+ )
442
+
443
+
444
+ class RecMutAttnNet1(nn.Module):
445
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
446
+ super(RecMutAttnNet1, self).__init__()
447
+
448
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
449
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
450
+ self.conditional_input = conditional_input
451
+ self.num_heads = num_heads
452
+ self.text_feat_chn = text_feat_chn
453
+
454
+ self.dimension = ndims
455
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
456
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
457
+
458
+ # Sinusoidal embedding
459
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
460
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
461
+ self.time_embed.requires_grad_(False)
462
+ self.hier_num = len(self.feat_channels) - 1
463
+ self.down_layers = nn.ModuleList()
464
+ self.up_layers = nn.ModuleList()
465
+ self.ted_layers = nn.ModuleList()
466
+ self.teu_layers = nn.ModuleList()
467
+ self.block_down = nn.ModuleList()
468
+ if self.conditional_input:
469
+ self.block_down_cond = nn.ModuleList()
470
+ self.fuse_conv0 = nn.ModuleList()
471
+ self.fuse_conv1 = nn.ModuleList()
472
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
473
+
474
+ self.block_up = nn.ModuleList()
475
+
476
+ for i in range(1, self.hier_num + 1):
477
+ j=-i
478
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
479
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
480
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
481
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
482
+ self.block_down.append(nn.Sequential(
483
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
484
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
485
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
486
+ ))
487
+ if self.conditional_input:
488
+ self.block_down_cond.append(nn.Sequential(
489
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
490
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
491
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
492
+ ))
493
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
494
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
495
+ if i==self.hier_num:
496
+ k=j
497
+ else:
498
+ k=j-1
499
+ self.block_up.append(nn.Sequential(
500
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
501
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
502
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
503
+ ))
504
+
505
+ # Bottleneck
506
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
507
+ self.b_mid = nn.Sequential(
508
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
509
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
510
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
511
+ )
512
+
513
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
514
+
515
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
516
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
517
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
518
+ zip(sample_coords, max_sz)], 1)
519
+
520
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
521
+ ref = self.ref_grid if ref is None else ref
522
+ img_sz = self.max_sz if img_sz is None else img_sz
523
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
524
+
525
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
526
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
527
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
528
+ align_corners=True)
529
+
530
+ def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
531
+ self.device = x.device
532
+ img_sz = x.size()[2:]
533
+ n = x.size()[0]
534
+ self.max_sz = [img_sz[0]] * self.dimension
535
+ ts_emb_shape=[n,-1]+[1]*self.dimension
536
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
537
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
538
+ [1, self.dimension]+list(img_sz)).to(self.device)
539
+ img = x
540
+ t = self.time_embed(t)
541
+
542
+ for rec_id in range(rec_num):
543
+ if self.conditional_input:
544
+ tgt = y
545
+ enc_list = []
546
+ out = img
547
+ for i in range(self.hier_num):
548
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
549
+ if self.conditional_input:
550
+ tgt = self.block_down_cond[i](tgt)
551
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
552
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
553
+ enc_list.append(out)
554
+ out = self.down_layers[i](out)
555
+ if self.conditional_input:
556
+ tgt = self.down_layers[i](tgt)
557
+
558
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
559
+ if self.conditional_input:
560
+ # out += self.attn_layer(out, tgt, tgt)[0]
561
+ out_shape = out.shape
562
+ tgt_shape = tgt.shape
563
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
564
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
565
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
566
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
567
+ out = out + out_attn
568
+
569
+ for i in range(self.hier_num):
570
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
571
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
572
+
573
+ out = self.conv_out(out)/128
574
+
575
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
576
+ if rec_id == 0:
577
+ ddf = ddf_one
578
+ else:
579
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
580
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
581
+
582
+ return ddf
583
+
584
+ def _make_te(self, dim_in, dim_out):
585
+ return nn.Sequential(
586
+ nn.Linear(dim_in, dim_out),
587
+ nn.ReLU(),
588
+ nn.Linear(dim_out, dim_out)
589
+ )
590
+
591
+ class RecMutAttnNet(nn.Module):
592
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
593
+ super(RecMutAttnNet, self).__init__()
594
+
595
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
596
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
597
+ self.conditional_input = conditional_input
598
+ self.num_heads = num_heads
599
+ self.text_feat_chn = text_feat_chn
600
+
601
+ self.dimension = ndims
602
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
603
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
604
+
605
+ # Sinusoidal embedding
606
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
607
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
608
+ self.time_embed.requires_grad_(False)
609
+ self.hier_num = len(self.feat_channels) - 1
610
+ self.down_layers = nn.ModuleList()
611
+ self.up_layers = nn.ModuleList()
612
+ self.ted_layers = nn.ModuleList()
613
+ self.teu_layers = nn.ModuleList()
614
+ self.block_down = nn.ModuleList()
615
+ self.block_up = nn.ModuleList()
616
+ if self.conditional_input:
617
+ self.block_down_cond = nn.ModuleList()
618
+ self.fuse_conv0 = nn.ModuleList()
619
+ self.fuse_conv1 = nn.ModuleList()
620
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
621
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
622
+ self.global_maxpool = Global_Maxpool(1)
623
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
624
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
625
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
626
+ self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
627
+ self.img_res = [res]*self.dimension
628
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
629
+ [1, self.dimension]+list(self.img_res))
630
+
631
+ for i in range(1, self.hier_num + 1):
632
+ j=-i
633
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
634
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
635
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
636
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
637
+ self.block_down.append(nn.Sequential(
638
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
639
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
640
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
641
+ ))
642
+ if self.conditional_input:
643
+ self.block_down_cond.append(nn.Sequential(
644
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
645
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
646
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
647
+ ))
648
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
649
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
650
+ if i==self.hier_num:
651
+ k=j
652
+ else:
653
+ k=j-1
654
+ self.block_up.append(nn.Sequential(
655
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
656
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
657
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
658
+ ))
659
+
660
+ # Bottleneck
661
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
662
+ self.b_mid = nn.Sequential(
663
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
664
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
665
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
666
+ )
667
+
668
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
669
+
670
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
671
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
672
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
673
+ zip(sample_coords, max_sz)], 1)
674
+
675
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
676
+ ref = self.ref_grid if ref is None else ref
677
+ img_sz = self.max_sz if img_sz is None else img_sz
678
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
679
+
680
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
681
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
682
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
683
+ align_corners=True)
684
+
685
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
686
+ self.device = x.device
687
+ img_sz = x.size()[2:]
688
+ n = x.size()[0]
689
+ self.max_sz = [img_sz[0]] * self.dimension
690
+ ts_emb_shape=[n,-1]+[1]*self.dimension
691
+
692
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
693
+ if list(img_sz) != self.img_res:
694
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
695
+ # print(img_sz, self.img_res)
696
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
697
+ [1, self.dimension]+list(img_sz))
698
+ self.ref_grid = self.ref_grid.to(self.device)
699
+
700
+ img = x
701
+ t = self.time_embed(t)
702
+
703
+ for rec_id in range(rec_num):
704
+ if self.conditional_input:
705
+ tgt = y
706
+ enc_list = []
707
+ out = img
708
+ for i in range(self.hier_num):
709
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
710
+ if self.conditional_input:
711
+ tgt = self.block_down_cond[i](tgt)
712
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
713
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
714
+ enc_list.append(out)
715
+ out = self.down_layers[i](out)
716
+ if self.conditional_input:
717
+ tgt = self.down_layers[i](tgt)
718
+
719
+
720
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
721
+ if self.conditional_input:
722
+ # out += self.attn_layer(out, tgt, tgt)[0]
723
+ out_shape = out.shape
724
+ tgt_shape = tgt.shape
725
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
726
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
727
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
728
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
729
+ out = out + out_attn
730
+
731
+ if self.conditional_input:
732
+ if text is None:
733
+ text = self.text
734
+ text = text.to(self.device)
735
+ text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
736
+ out_txt = self.img2txt(out) + text
737
+ out_txt = self.txt_proc(out_txt)
738
+ out_txt = self.txt2img(out_txt)
739
+ out = out + out_txt
740
+
741
+ for i in range(self.hier_num):
742
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
743
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
744
+
745
+ out = self.conv_out(out)/128
746
+
747
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
748
+ if rec_id == 0:
749
+ ddf = ddf_one
750
+ else:
751
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
752
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
753
+
754
+ # print(torch.max(torch.abs(ddf)))
755
+
756
+ return ddf
757
+
758
+ def _make_te(self, dim_in, dim_out):
759
+ return nn.Sequential(
760
+ nn.Linear(dim_in, dim_out),
761
+ nn.ReLU(),
762
+ nn.Linear(dim_out, dim_out)
763
+ )
764
+
765
+ class RecMulModMutAttnNet(nn.Module):
766
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
767
+ super(RecMulModMutAttnNet, self).__init__()
768
+
769
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
770
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
771
+ self.conditional_input = conditional_input
772
+ self.num_heads = num_heads
773
+ self.text_feat_chn = text_feat_chn
774
+
775
+ self.dimension = ndims
776
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
777
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
778
+
779
+ # Sinusoidal embedding
780
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
781
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
782
+ self.time_embed.requires_grad_(False)
783
+ self.hier_num = len(self.feat_channels) - 1
784
+ self.down_layers = nn.ModuleList()
785
+ self.up_layers = nn.ModuleList()
786
+ self.ted_layers = nn.ModuleList()
787
+ self.teu_layers = nn.ModuleList()
788
+ self.block_down = nn.ModuleList()
789
+ self.block_up = nn.ModuleList()
790
+ if self.conditional_input:
791
+ # self.gate_img = nn.ModuleList()
792
+ self.txt_layers = nn.ModuleList()
793
+ self.block_down_cond = nn.ModuleList()
794
+ self.fuse_conv0 = nn.ModuleList()
795
+ self.fuse_conv1 = nn.ModuleList()
796
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
797
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
798
+ self.global_maxpool = Global_Maxpool(1)
799
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
800
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
801
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
802
+ # self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
803
+ self.text = torch.zeros(1, self.text_feat_chn)
804
+
805
+ self.img_res = [res]*self.dimension
806
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
807
+ [1, self.dimension]+list(self.img_res))
808
+
809
+ for i in range(1, self.hier_num + 1):
810
+ j=-i
811
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
812
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
813
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
814
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
815
+ self.block_down.append(nn.Sequential(
816
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
817
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
818
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
819
+ ))
820
+ if self.conditional_input:
821
+ # self.gate_img.append(nn.Sequential(
822
+ # nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0),
823
+ # nn.Sigmoid()
824
+ # ))
825
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
826
+ self.block_down_cond.append(nn.Sequential(
827
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
828
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
829
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
830
+ ))
831
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
832
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
833
+ if i==self.hier_num:
834
+ k=j
835
+ else:
836
+ k=j-1
837
+ self.block_up.append(nn.Sequential(
838
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
839
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
840
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
841
+ ))
842
+
843
+ # Bottleneck
844
+ self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
845
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
846
+ self.b_mid = nn.Sequential(
847
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
848
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
849
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
850
+ )
851
+
852
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
853
+
854
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
855
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
856
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
857
+ zip(sample_coords, max_sz)], 1)
858
+
859
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
860
+ ref = self.ref_grid if ref is None else ref
861
+ img_sz = self.max_sz if img_sz is None else img_sz
862
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
863
+
864
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
865
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
866
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
867
+ align_corners=True)
868
+
869
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
870
+ self.device = x.device
871
+ img_sz = x.size()[2:]
872
+ n = x.size()[0]
873
+ self.max_sz = [img_sz[0]] * self.dimension
874
+ ts_emb_shape=[n,-1]+[1]*self.dimension
875
+
876
+
877
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
878
+ if list(img_sz) != self.img_res:
879
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
880
+ # print(img_sz, self.img_res)
881
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
882
+ [1, self.dimension]+list(img_sz))
883
+ self.ref_grid = self.ref_grid.to(self.device)
884
+
885
+ img = x
886
+ t = self.time_embed(t)
887
+ if text is None:
888
+ text = self.text
889
+ # print(text.shape)
890
+ text = text.to(self.device)
891
+ txt_shape = [1,-1]+[1]*self.dimension
892
+ else:
893
+ txt_shape = [n,-1]+[1]*self.dimension
894
+
895
+ for rec_id in range(rec_num):
896
+ if self.conditional_input:
897
+ tgt = y
898
+ enc_list = []
899
+ out = img
900
+ for i in range(self.hier_num):
901
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
902
+ if self.conditional_input:
903
+ tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
904
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
905
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
906
+ enc_list.append(out)
907
+ out = self.down_layers[i](out)
908
+ if self.conditional_input:
909
+ tgt = self.down_layers[i](tgt)
910
+
911
+
912
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
913
+ if self.conditional_input:
914
+ # out += self.attn_layer(out, tgt, tgt)[0]
915
+ out_shape = out.shape
916
+ tgt_shape = tgt.shape
917
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
918
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
919
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
920
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
921
+ out = out + out_attn
922
+
923
+ if self.conditional_input:
924
+
925
+ # text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
926
+
927
+ # out_txt = self.img2txt(out) + text.reshape(txt_shape)
928
+ img_txt_feat = self.img2txt(out)
929
+ self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
930
+ out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
931
+ out_txt = self.txt_proc(out_txt)
932
+ out_txt = self.txt2img(out_txt)
933
+ out = out + out_txt
934
+
935
+ for i in range(self.hier_num):
936
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
937
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
938
+
939
+ out = self.conv_out(out)/128
940
+
941
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
942
+ if rec_id == 0:
943
+ ddf = ddf_one
944
+ else:
945
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
946
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
947
+
948
+ # print(torch.max(torch.abs(ddf)))
949
+
950
+ return ddf
951
+
952
+ def _make_te(self, dim_in, dim_out):
953
+ return nn.Sequential(
954
+ nn.Linear(dim_in, dim_out),
955
+ nn.ReLU(),
956
+ nn.Linear(dim_out, dim_out)
957
+ )
958
+
959
+ # class RecMutAttnNet(nn.Module):
960
+ # def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
961
+ # super(RecMutAttnNet, self).__init__()
962
+
963
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
964
+ # self.conditional_input = conditional_input
965
+
966
+ # self.dimension = ndims
967
+ # self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
968
+ # self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
969
+
970
+ # # Sinusoidal embedding
971
+ # self.time_embed = nn.Embedding(n_steps, time_emb_dim)
972
+ # self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
973
+ # self.time_embed.requires_grad_(False)
974
+ # self.hier_num = len(self.feat_channels) - 1
975
+ # self.down_layers = nn.ModuleList()
976
+ # self.up_layers = nn.ModuleList()
977
+ # self.ted_layers = nn.ModuleList()
978
+ # self.teu_layers = nn.ModuleList()
979
+ # self.block_down = nn.ModuleList()
980
+ # if self.conditional_input:
981
+ # self.block_down_cond = nn.ModuleList()
982
+ # self.fuse_conv0 = nn.ModuleList()
983
+ # self.fuse_conv1 = nn.ModuleList()
984
+ # self.block_up = nn.ModuleList()
985
+
986
+ # for i in range(1, self.hier_num + 1):
987
+ # j=-i
988
+ # self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
989
+ # self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
990
+ # self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
991
+ # self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
992
+ # self.block_down.append(nn.Sequential(
993
+ # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
994
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
995
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
996
+ # ))
997
+ # if self.conditional_input:
998
+ # self.block_down_cond.append(nn.Sequential(
999
+ # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
1000
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
1001
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
1002
+ # ))
1003
+ # self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
1004
+ # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
1005
+ # if i==self.hier_num:
1006
+ # k=j
1007
+ # else:
1008
+ # k=j-1
1009
+ # self.block_up.append(nn.Sequential(
1010
+ # AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
1011
+ # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
1012
+ # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
1013
+ # ))
1014
+
1015
+ # # Bottleneck
1016
+ # self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
1017
+ # self.b_mid = nn.Sequential(
1018
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
1019
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
1020
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
1021
+ # )
1022
+
1023
+ # self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
1024
+
1025
+ # def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
1026
+ # sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1027
+ # return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
1028
+ # zip(sample_coords, max_sz)], 1)
1029
+
1030
+ # def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
1031
+ # ref = self.ref_grid if ref is None else ref
1032
+ # img_sz = self.max_sz if img_sz is None else img_sz
1033
+ # resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
1034
+
1035
+ # return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
1036
+ # np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
1037
+ # [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
1038
+ # align_corners=True)
1039
+
1040
+ # def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
1041
+ # self.device = x.device
1042
+ # img_sz = x.size()[2:]
1043
+ # n = x.size()[0]
1044
+ # self.max_sz = [img_sz[0]] * self.dimension
1045
+ # ts_emb_shape=[n,-1]+[1]*self.dimension
1046
+ # self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
1047
+ # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
1048
+ # [1, self.dimension]+list(img_sz)).to(self.device)
1049
+ # img = x
1050
+ # t = self.time_embed(t)
1051
+
1052
+ # for rec_id in range(rec_num):
1053
+ # if self.conditional_input:
1054
+ # tgt = y
1055
+ # enc_list = []
1056
+ # out = img
1057
+ # for i in range(self.hier_num):
1058
+ # out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
1059
+ # if self.conditional_input:
1060
+ # tgt = self.block_down_cond[i](tgt)
1061
+ # out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
1062
+ # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
1063
+ # enc_list.append(out)
1064
+ # out = self.down_layers[i](out)
1065
+ # if self.conditional_input:
1066
+ # tgt = self.down_layers[i](tgt)
1067
+
1068
+ # out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
1069
+ # if self.conditional_input:
1070
+ # out = out + tgt
1071
+
1072
+ # for i in range(self.hier_num):
1073
+ # out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
1074
+ # out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
1075
+
1076
+ # out = self.conv_out(out)/128
1077
+
1078
+ # ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
1079
+ # if rec_id == 0:
1080
+ # ddf = ddf_one
1081
+ # else:
1082
+ # ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
1083
+ # img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
1084
+
1085
+ # return ddf
1086
+
1087
+ # def _make_te(self, dim_in, dim_out):
1088
+ # return nn.Sequential(
1089
+ # nn.Linear(dim_in, dim_out),
1090
+ # nn.ReLU(),
1091
+ # nn.Linear(dim_out, dim_out)
1092
+ # )
1093
+ # ==============================================
1094
+ # Layers
1095
+ # ==============================================
1096
+
1097
+
1098
+ def ddf_multiplier(dvf,mul_num=10,stn=None):
1099
+ ddf=dvf
1100
+ for i in range(mul_num):
1101
+ ddf = dvf + stn(ddf, dvf)
1102
+ return ddf
1103
+
1104
+
1105
+ def composite(ddfs,stn=None):
1106
+ if stn is None:
1107
+ stn = STN(device=ddfs[0].device,padding_mode="border")
1108
+ comp_ddf=ddfs[0]
1109
+ for i in range(1,len(ddfs)):
1110
+ comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
1111
+ return comp_ddf
1112
+
1113
+
1114
+
1115
+ class STN(nn.Module):
1116
+ def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
1117
+ super(STN, self).__init__()
1118
+ self.ndims=ndims
1119
+ self.img_sz=[img_sz]*ndims
1120
+ # self.img_sz=img_sz
1121
+ self.device = device
1122
+ self.padding_mode = padding_mode
1123
+ # max_sz=[128]*self.ndims
1124
+ max_sz=[img_sz]*self.ndims
1125
+ # max_sz=img_sz
1126
+ # max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz)
1127
+ # self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device)
1128
+ self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
1129
+ self.resample_mode=resample_mode
1130
+ if self.img_sz is not None:
1131
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
1132
+ [1, self.ndims] + self.img_sz).to(self.device)
1133
+ return
1134
+ def max_limit(self, sample_coords0, plus=0., minus=1.):
1135
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1136
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
1137
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
1138
+ zip(sample_coords, self.max_sz)], 1)
1139
+
1140
+ def boundary_limit(self, sample_coords0, plus=0., minus=1.):
1141
+
1142
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1143
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
1144
+ return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
1145
+ zip(sample_coords, self.max_sz, self.ref_grid)], 1)
1146
+
1147
+ def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
1148
+ # print(vol.device, ddf.device)
1149
+ # print(self.device)
1150
+ # print('===================')
1151
+ device = ddf.device
1152
+
1153
+ ref = self.ref_grid if ref is None else ref
1154
+ if img_sz is None:
1155
+ img_sz = self.max_sz
1156
+ else:
1157
+ img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
1158
+ # resample_mode = 'bicubic'
1159
+ if self.resample_mode is None:
1160
+ resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear'
1161
+ else:
1162
+ resample_mode=self.resample_mode
1163
+ # padding_mode = "border"
1164
+ # print(ddf.shape, ref.shape)
1165
+ return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
1166
+ [0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
1167
+ padding_mode=padding_mode,
1168
+ align_corners=True)
1169
+
1170
+ def forward(self,x,ddf):
1171
+ self.device = x.device if self.device is None else self.device
1172
+ if self.img_sz is None:
1173
+ self.img_sz = list(x.size()[2:]).to(self.device)
1174
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
1175
+ resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
1176
+ return resampled_x
1177
+
1178
+
1179
+ if __name__ == '__main__':
1180
+ ndims = 3
1181
+ res = 128
1182
+ x = torch.rand([1, 1] + [res]*ndims)
1183
+ t = torch.randint(0, 1000, (1,))
1184
+ text = torch.rand([1, 1024] + [1]*ndims)
1185
+ model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
1186
+ y = model(x, x, t, text=text)
1187
+ print("Ouput shape", y.shape)
1188
+
1189
+ # Total parameters
1190
+ total_params = sum(p.numel() for p in model.parameters())
1191
+ # Trainable parameters only
1192
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1193
+
1194
+ print(f"Total parameters: {total_params}")
1195
+ print(f"Trainable parameters: {trainable_params}")
Diffusion/networks_opt.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ networks_opt.py — Optimized network components.
3
+
4
+ Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:
5
+ 1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call
6
+ 2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —
7
+ eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step
8
+
9
+ All optimizations are mathematically equivalent to the originals.
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ from Diffusion.networks import RecMulModMutAttnNet, STN
18
+
19
+
20
+ # ======================================================================
21
+ # Optimized STN
22
+ # ======================================================================
23
+
24
+ class OptSTN(STN):
25
+ """STN with register_buffer for automatic device transfer.
26
+
27
+ Eliminates per-call .to(device) overhead in resample() and forward().
28
+ Buffers auto-transfer when module.to(device) is called.
29
+ """
30
+
31
+ def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,
32
+ padding_mode="border", resample_mode=None):
33
+ # Skip parent __init__ to avoid creating plain tensor attributes
34
+ nn.Module.__init__(self)
35
+ self.ndims = ndims
36
+ self.img_sz = [img_sz] * ndims
37
+ self.device = device
38
+ self.padding_mode = padding_mode
39
+ self.resample_mode = resample_mode
40
+
41
+ # OPT: register_buffer — auto device transfer, no per-call .to()
42
+ max_sz_val = [img_sz] * ndims
43
+ max_sz_tensor = torch.Tensor(
44
+ np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
45
+ )
46
+ self.register_buffer('max_sz', max_sz_tensor)
47
+
48
+ if self.img_sz is not None:
49
+ ref_grid = torch.reshape(
50
+ torch.stack(torch.meshgrid(
51
+ [torch.arange(end=s) for s in self.img_sz]
52
+ ), 0),
53
+ [1, self.ndims] + self.img_sz
54
+ )
55
+ self.register_buffer('ref_grid', ref_grid)
56
+
57
+ # OPT: pre-compute the img_sz tensor used when forward() calls resample()
58
+ img_sz_for_resample = torch.reshape(
59
+ torch.tensor([(s - 1) / 2. for s in self.img_sz]),
60
+ [1] + [1] * self.ndims + [self.ndims]
61
+ )
62
+ self.register_buffer('_img_sz_for_resample', img_sz_for_resample)
63
+
64
+ # OPT: pre-compute constant permutation order
65
+ self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]
66
+
67
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
68
+ # OPT: no .to(device) — buffers auto-transfer with module.to()
69
+ ref = self.ref_grid if ref is None else ref
70
+
71
+ if img_sz is None:
72
+ img_sz_t = self.max_sz
73
+ else:
74
+ # Use pre-computed tensor for the common case (called from forward)
75
+ img_sz_t = self._img_sz_for_resample
76
+
77
+ resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode
78
+
79
+ grid = torch.flip(
80
+ (ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
81
+ dims=[-1]
82
+ )
83
+ return F.grid_sample(vol, grid, mode=resample_mode,
84
+ padding_mode=padding_mode, align_corners=True)
85
+
86
+ def forward(self, x, ddf):
87
+ # OPT: no device check or ref_grid regeneration — buffers handle it
88
+ return self.resample(x, ddf=ddf, img_sz=self.img_sz,
89
+ padding_mode=self.padding_mode)
90
+
91
+
92
+ # ======================================================================
93
+ # Optimized RecMulModMutAttnNet
94
+ # ======================================================================
95
+
96
+ class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
97
+ """RecMulModMutAttnNet with cached tensors for resample/forward.
98
+
99
+ Eliminates per-call overhead:
100
+ - resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)
101
+ - forward(): cached img_sz tensor and ref_grid device placement
102
+ """
103
+
104
+ def __init__(self, *args, **kwargs):
105
+ super().__init__(*args, **kwargs)
106
+ # Cache slots — populated on first forward
107
+ self._cached_input_key = None
108
+ self._cached_max_sz_tensor = None
109
+ self._cached_img_sz_tensor = None
110
+ # OPT: pre-compute constant permutation order
111
+ self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]
112
+
113
+ def _ensure_cache(self, img_sz, device):
114
+ """Populate cached tensors if input size or device changed."""
115
+ key = (tuple(img_sz), device)
116
+ if key == self._cached_input_key:
117
+ return
118
+ self._cached_input_key = key
119
+ max_sz_list = [img_sz[0]] * self.dimension
120
+ self.max_sz = max_sz_list
121
+
122
+ # OPT: create max_sz tensor ONCE, reuse across all resample() calls
123
+ self._cached_max_sz_tensor = torch.Tensor(
124
+ np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
125
+ ).to(device)
126
+
127
+ # OPT: create img_sz tensor ONCE per size change
128
+ self._cached_img_sz_tensor = torch.reshape(
129
+ torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
130
+ [1] * (self.dimension + 1) + [self.dimension]
131
+ )
132
+
133
+ # OPT: ref_grid — only regenerate if size changed, only .to() if needed
134
+ if list(img_sz) != self.img_res:
135
+ self.ref_grid = torch.reshape(
136
+ torch.stack(torch.meshgrid(
137
+ [torch.arange(end=imsz) for imsz in img_sz]
138
+ ), 0),
139
+ [1, self.dimension] + list(img_sz)
140
+ ).to(device)
141
+ elif self.ref_grid.device != torch.device(device):
142
+ self.ref_grid = self.ref_grid.to(device)
143
+
144
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
145
+ # OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call
146
+ ref = self.ref_grid if ref is None else ref
147
+ img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor
148
+
149
+ grid = torch.flip(
150
+ (ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
151
+ dims=[-1]
152
+ )
153
+ return F.grid_sample(vol, grid, mode='bilinear',
154
+ padding_mode=padding_mode, align_corners=True)
155
+
156
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
157
+ self.device = x.device
158
+ img_sz = x.size()[2:]
159
+ n = x.size()[0]
160
+ ts_emb_shape = [n, -1] + [1] * self.dimension
161
+
162
+ # OPT: cache tensors — only recreate if input size/device changes
163
+ self._ensure_cache(img_sz, self.device)
164
+ self.img_sz = self._cached_img_sz_tensor
165
+
166
+ img = x
167
+ t = self.time_embed(t)
168
+ if text is None:
169
+ text = self.text
170
+ text = text.to(self.device)
171
+ txt_shape = [1, -1] + [1] * self.dimension
172
+ else:
173
+ txt_shape = [n, -1] + [1] * self.dimension
174
+
175
+ for rec_id in range(rec_num):
176
+ if self.conditional_input:
177
+ tgt = y
178
+ enc_list = []
179
+ out = img
180
+ for i in range(self.hier_num):
181
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
182
+ if self.conditional_input:
183
+ tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
184
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
185
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
186
+ enc_list.append(out)
187
+ out = self.down_layers[i](out)
188
+ if self.conditional_input:
189
+ tgt = self.down_layers[i](tgt)
190
+
191
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
192
+ if self.conditional_input:
193
+ out_shape = out.shape
194
+ tgt_shape = tgt.shape
195
+ out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
196
+ tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
197
+ out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
198
+ tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
199
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
200
+ tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
201
+ out = out + out_attn
202
+ tgt = tgt + tgt_attn
203
+ out = self.fuse(torch.cat([out, tgt], dim=1))
204
+
205
+ if self.conditional_input:
206
+ img_txt_feat = self.img2txt(out)
207
+ self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
208
+ out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
209
+ out_txt = self.txt_proc(out_txt)
210
+ out_txt = self.txt2img(out_txt)
211
+ out = out + out_txt
212
+
213
+ for i in range(self.hier_num):
214
+ out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
215
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
216
+
217
+ out = self.conv_out(out) / 128
218
+
219
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
220
+ if rec_id == 0:
221
+ ddf = ddf_one
222
+ else:
223
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
224
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
225
+
226
+ return ddf
227
+
228
+
229
+ # ======================================================================
230
+ # Factory function
231
+ # ======================================================================
232
+
233
+ def get_net_opt(name):
234
+ """Return optimized network class if available, else fall back to original."""
235
+ if name == "recmulmodmutattnnet":
236
+ return OptRecMulModMutAttnNet
237
+ # Fall back to original for other network types
238
+ from Diffusion.networks import get_net
239
+ return get_net(name)
Diffusion/safe_conv_transpose.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SafeConvTranspose3d: Drop-in replacement for nn.ConvTranspose3d that avoids
3
+ the XPU memory leak in the ConvTranspose3d backward pass (oneDNN autograd bug).
4
+
5
+ Mathematical Background
6
+ =======================
7
+
8
+ ConvTranspose3d (a.k.a. "transposed convolution" or "fractionally-strided
9
+ convolution") with parameters:
10
+ in_channels=C_in, out_channels=C_out, kernel_size=K, stride=S, padding=P
11
+
12
+ is the gradient (adjoint) of Conv3d with the same parameters. For an input x
13
+ of shape [B, C_in, D, H, W], the output has shape:
14
+ [B, C_out, S*(D-1) + K - 2*P, S*(H-1) + K - 2*P, S*(W-1) + K - 2*P]
15
+
16
+ For our specific case (K=4, S=2, P=1):
17
+ output_size = 2*(D-1) + 4 - 2 = 2*D (likewise for H, W)
18
+
19
+ The operation is mathematically equivalent to:
20
+ 1. Stride insertion: insert (S-1) zeros between each input element
21
+ 2. Padding: pad with (K - P - 1) zeros on each side
22
+ 3. Regular Conv3d with spatially-flipped, channel-transposed weight
23
+
24
+ Specifically:
25
+
26
+ Step 1 - Stride insertion:
27
+ Input [B, C_in, D, H, W] -> [B, C_in, S*(D-1)+1, S*(H-1)+1, S*(W-1)+1]
28
+ For S=2: [B, C_in, 2*D-1, 2*H-1, 2*W-1]
29
+ Original values placed at positions 0, S, 2S, ... ; zeros elsewhere.
30
+
31
+ Step 2 - Padding:
32
+ Pad each spatial dimension with (K - P - 1) zeros on each side.
33
+ For K=4, P=1: pad = 2 on each side.
34
+ Shape becomes: [B, C_in, 2*D+3, 2*H+3, 2*W+3]
35
+
36
+ Step 3 - Conv3d with transformed weight:
37
+ ConvTranspose3d weight shape: [C_in, C_out, K, K, K]
38
+ Equivalent Conv3d weight: weight.flip(2,3,4).transpose(0,1)
39
+ -> shape [C_out, C_in, K, K, K]
40
+
41
+ Conv3d(stride=1, padding=0) on the padded input gives:
42
+ [B, C_out, (2*D+3 - K + 1), ...] = [B, C_out, 2*D, 2*H, 2*W] (correct!)
43
+
44
+ Why this is safe on XPU:
45
+ The forward uses F.pad (ZERO leak) and F.conv3d (negligible leak).
46
+ The backward is computed automatically by PyTorch's autograd through these
47
+ same safe ops — no ConvTranspose3d backward kernel is ever invoked.
48
+ Specifically:
49
+ - F.conv3d backward -> uses Conv3d backward (safe, 0.004 GiB/step)
50
+ - F.pad backward -> tensor slicing (trivially safe)
51
+ - Stride insertion backward -> gather at stride positions (trivially safe)
52
+ - weight.flip().transpose() backward -> indexing (trivially safe)
53
+
54
+ Forward precision:
55
+ Not bit-for-bit identical to nn.ConvTranspose3d due to different summation
56
+ order (stride-insert + pad + conv3d vs native transposed conv), but the
57
+ difference is negligible: max absolute diff < 5e-7 in float32, no elements
58
+ exceeding 1e-6. This is well within float32 machine epsilon for typical
59
+ activation magnitudes.
60
+
61
+ Backward precision:
62
+ Gradients match nn.ConvTranspose3d within 1e-5 (input) and 1e-4 (weight)
63
+ for float32. Verified across all channel configurations used in the
64
+ codebase (16-256 channels).
65
+
66
+ Implementation choices:
67
+ We also provide SafeConvTranspose3d_v2 which uses a custom autograd function
68
+ to call F.conv_transpose3d in the forward (bit-for-bit identical) but
69
+ replaces the backward with safe Conv3d-based gradient computation.
70
+
71
+ RECOMMENDATION: Use SafeConvTranspose3d (V1, decomposed forward) because:
72
+ - Simpler implementation with no custom autograd
73
+ - Fully transparent to PyTorch's autograd
74
+ - Compatible with gradient checkpointing, torch.compile, etc.
75
+ - The ~5e-7 forward precision loss is negligible for training
76
+ - V2's custom autograd requires careful maintenance and is fragile
77
+ """
78
+
79
+ import torch
80
+ import torch.nn as nn
81
+ import torch.nn.functional as F
82
+ from torch.autograd import Function
83
+
84
+
85
+ # =============================================================================
86
+ # Approach 1 (RECOMMENDED): Decomposed forward pass
87
+ # =============================================================================
88
+
89
+ class SafeConvTranspose3d(nn.Module):
90
+ """Drop-in replacement for nn.ConvTranspose3d that decomposes the operation
91
+ into stride insertion + padding + regular Conv3d.
92
+
93
+ All operations in forward (and thus all backward ops via autograd) are
94
+ safe on XPU: no ConvTranspose3d backward kernel is invoked.
95
+
96
+ Supports: kernel_size, stride, padding (scalar or tuple), bias, groups=1.
97
+ Does NOT support: output_padding, dilation != 1, groups != 1.
98
+
99
+ The weight tensor has the SAME shape as nn.ConvTranspose3d:
100
+ [in_channels, out_channels, *kernel_size]
101
+ so checkpoints can be loaded directly with load_state_dict().
102
+ """
103
+
104
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
105
+ padding=0, output_padding=0, groups=1, bias=True,
106
+ dilation=1, padding_mode='zeros'):
107
+ super().__init__()
108
+
109
+ if groups != 1:
110
+ raise NotImplementedError("SafeConvTranspose3d only supports groups=1")
111
+ if output_padding != 0:
112
+ raise NotImplementedError("SafeConvTranspose3d does not support output_padding")
113
+
114
+ # Normalize to tuples
115
+ if isinstance(kernel_size, int):
116
+ kernel_size = (kernel_size, kernel_size, kernel_size)
117
+ if isinstance(stride, int):
118
+ stride = (stride, stride, stride)
119
+ if isinstance(padding, int):
120
+ padding = (padding, padding, padding)
121
+ if isinstance(dilation, int):
122
+ dilation = (dilation, dilation, dilation)
123
+ if dilation != (1, 1, 1):
124
+ raise NotImplementedError("SafeConvTranspose3d does not support dilation != 1")
125
+
126
+ self.in_channels = in_channels
127
+ self.out_channels = out_channels
128
+ self.kernel_size = kernel_size
129
+ self.stride = stride
130
+ self.padding = padding
131
+ self.groups = groups
132
+
133
+ # Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
134
+ self.weight = nn.Parameter(
135
+ torch.empty(in_channels, out_channels, *kernel_size)
136
+ )
137
+ if bias:
138
+ self.bias = nn.Parameter(torch.empty(out_channels))
139
+ else:
140
+ self.register_parameter('bias', None)
141
+
142
+ # Initialize weights same as nn.ConvTranspose3d
143
+ nn.init.kaiming_uniform_(self.weight, a=5**0.5)
144
+ if self.bias is not None:
145
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
146
+ if fan_in != 0:
147
+ bound = 1 / fan_in**0.5
148
+ nn.init.uniform_(self.bias, -bound, bound)
149
+
150
+ def forward(self, x):
151
+ B, C_in, D, H, W = x.shape
152
+ sd, sh, sw = self.stride
153
+ kd, kh, kw = self.kernel_size
154
+ pd, ph, pw = self.padding
155
+
156
+ # Step 1: Stride insertion — place input values at stride positions,
157
+ # zeros elsewhere. This is the "fractionally-strided" part.
158
+ if sd > 1 or sh > 1 or sw > 1:
159
+ D_ins = sd * (D - 1) + 1
160
+ H_ins = sh * (H - 1) + 1
161
+ W_ins = sw * (W - 1) + 1
162
+ x_inserted = x.new_zeros(B, C_in, D_ins, H_ins, W_ins)
163
+ x_inserted[:, :, ::sd, ::sh, ::sw] = x
164
+ else:
165
+ x_inserted = x
166
+
167
+ # Step 2: Pad with (kernel_size - padding - 1) zeros on each side.
168
+ # This converts ConvTranspose3d's "padding" (which removes output elements)
169
+ # into the equivalent zero-padding for a regular convolution.
170
+ pad_d = kd - pd - 1
171
+ pad_h = kh - ph - 1
172
+ pad_w = kw - pw - 1
173
+ # F.pad argument order: (W_left, W_right, H_left, H_right, D_left, D_right)
174
+ x_padded = F.pad(x_inserted, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d))
175
+
176
+ # Step 3: Transform weight from ConvTranspose3d layout to Conv3d layout.
177
+ # ConvTranspose3d weight: [C_in, C_out, kD, kH, kW]
178
+ # Equivalent Conv3d weight: [C_out, C_in, kD, kH, kW] with spatial dims flipped
179
+ w_conv = self.weight.flip(2, 3, 4).transpose(0, 1)
180
+
181
+ # Step 4: Standard Conv3d (stride=1, padding=0)
182
+ return F.conv3d(x_padded, w_conv, self.bias, stride=1, padding=0)
183
+
184
+ def extra_repr(self):
185
+ return (f'{self.in_channels}, {self.out_channels}, '
186
+ f'kernel_size={self.kernel_size}, stride={self.stride}, '
187
+ f'padding={self.padding}, bias={self.bias is not None}')
188
+
189
+
190
+ # =============================================================================
191
+ # Approach 2: Custom autograd — real forward, safe backward
192
+ # =============================================================================
193
+
194
+ class _SafeConvTranspose3dFunc(Function):
195
+ """Custom autograd function that uses F.conv_transpose3d in forward
196
+ (bit-for-bit identical) but computes gradients using Conv3d-based ops
197
+ in backward (avoiding the leaky oneDNN ConvTranspose3d backward kernel).
198
+
199
+ Gradient derivation:
200
+ For y = conv_transpose3d(x, w, stride=S, padding=P):
201
+
202
+ grad_x = conv3d(grad_y, w, stride=S, padding=P)
203
+ Confirmed bit-for-bit identical to PyTorch's own backward.
204
+
205
+ grad_w = conv3d(pad(stride_insert(x)).T, grad_y.T).flip(spatial)
206
+ where stride_insert inserts (S-1) zeros between elements,
207
+ pad adds (K-P-1) zeros on each side, and .T swaps batch/channel.
208
+ The spatial flip accounts for the flip in the forward decomposition.
209
+
210
+ grad_bias = grad_y.sum(dim=(0, 2, 3, 4))
211
+ """
212
+
213
+ @staticmethod
214
+ def forward(ctx, input, weight, bias, stride, padding, output_padding, groups, dilation):
215
+ # Use the real conv_transpose3d for bit-for-bit identical forward
216
+ output = F.conv_transpose3d(
217
+ input, weight, bias,
218
+ stride=stride, padding=padding,
219
+ output_padding=output_padding, groups=groups, dilation=dilation
220
+ )
221
+ ctx.save_for_backward(input, weight, bias)
222
+ ctx.stride = stride
223
+ ctx.padding = padding
224
+ ctx.output_padding = output_padding
225
+ ctx.groups = groups
226
+ ctx.dilation = dilation
227
+ return output
228
+
229
+ @staticmethod
230
+ def backward(ctx, grad_output):
231
+ input, weight, bias = ctx.saved_tensors
232
+ stride = ctx.stride
233
+ padding = ctx.padding
234
+ groups = ctx.groups
235
+ dilation = ctx.dilation
236
+
237
+ grad_input = grad_weight = grad_bias = None
238
+
239
+ if ctx.needs_input_grad[0]:
240
+ # grad_input of ConvTranspose3d = Conv3d(grad_output, weight)
241
+ # This is exact: ConvTranspose3d IS the adjoint of Conv3d.
242
+ grad_input = F.conv3d(
243
+ grad_output, weight,
244
+ bias=None, stride=stride, padding=padding,
245
+ dilation=dilation, groups=groups
246
+ )
247
+
248
+ if ctx.needs_input_grad[1]:
249
+ # grad_weight via the decomposed view.
250
+ # Forward decomposition: y = conv3d(x_padded, w.flip(spatial).T(0,1))
251
+ # The backward of this conv3d w.r.t. its weight can be expressed as:
252
+ # grad_w_conv = conv3d(x_padded.T(0,1), grad_y.T(0,1))
253
+ # where the batch-channel transpose turns the sum over batch
254
+ # into a channel dimension convolution.
255
+ #
256
+ # Then: grad_w = grad_w_conv.flip(spatial)
257
+ # because w_conv = w.flip(spatial).T(0,1), and the chain rule
258
+ # through the spatial flip gives an extra flip on the gradient.
259
+
260
+ B, C_in = input.shape[:2]
261
+ spatial = input.shape[2:]
262
+
263
+ # Stride-insert the input
264
+ if any(s > 1 for s in stride):
265
+ new_spatial = tuple(s * (d - 1) + 1 for s, d in zip(stride, spatial))
266
+ input_inserted = input.new_zeros(B, C_in, *new_spatial)
267
+ slices = (slice(None), slice(None)) + tuple(
268
+ slice(None, None, s) for s in stride
269
+ )
270
+ input_inserted[slices] = input
271
+ else:
272
+ input_inserted = input
273
+
274
+ # Pad: (K - P - 1) on each side per spatial dim
275
+ kernel_size = weight.shape[2:]
276
+ pad_sizes = []
277
+ for k, p in zip(reversed(kernel_size), reversed(padding)):
278
+ pad_val = k - p - 1
279
+ pad_sizes.extend([pad_val, pad_val])
280
+ x_padded = F.pad(input_inserted, pad_sizes)
281
+
282
+ # Compute grad_w_conv via conv3d with batch-channel transposition
283
+ x_padded_t = x_padded.transpose(0, 1) # [C_in, B, ...]
284
+ grad_output_t = grad_output.transpose(0, 1) # [C_out, B, ...]
285
+
286
+ # conv3d([C_in, B, D_pad...], [C_out, B, D_out...]) -> [C_in, C_out, K...]
287
+ grad_w_conv = F.conv3d(x_padded_t, grad_output_t)
288
+
289
+ # Undo the spatial flip from the forward decomposition
290
+ grad_weight = grad_w_conv.flip(2, 3, 4)
291
+
292
+ if bias is not None and ctx.needs_input_grad[2]:
293
+ grad_bias = grad_output.sum(dim=(0,) + tuple(range(2, grad_output.ndim)))
294
+
295
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
296
+
297
+
298
+ class SafeConvTranspose3d_v2(nn.Module):
299
+ """Drop-in replacement for nn.ConvTranspose3d using custom autograd.
300
+
301
+ Forward pass: Uses the real F.conv_transpose3d (bit-for-bit identical output).
302
+ Backward pass: Computes gradients using F.conv3d (avoids leaky oneDNN kernel).
303
+
304
+ Weight shape is identical to nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size]
305
+ """
306
+
307
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
308
+ padding=0, output_padding=0, groups=1, bias=True,
309
+ dilation=1, padding_mode='zeros'):
310
+ super().__init__()
311
+
312
+ if groups != 1:
313
+ raise NotImplementedError("SafeConvTranspose3d_v2 only supports groups=1")
314
+ if output_padding != 0:
315
+ raise NotImplementedError("SafeConvTranspose3d_v2 does not support output_padding")
316
+
317
+ # Normalize to tuples
318
+ if isinstance(kernel_size, int):
319
+ kernel_size = (kernel_size, kernel_size, kernel_size)
320
+ if isinstance(stride, int):
321
+ stride = (stride, stride, stride)
322
+ if isinstance(padding, int):
323
+ padding = (padding, padding, padding)
324
+ if isinstance(dilation, int):
325
+ dilation = (dilation, dilation, dilation)
326
+
327
+ self.in_channels = in_channels
328
+ self.out_channels = out_channels
329
+ self.kernel_size = kernel_size
330
+ self.stride = stride
331
+ self.padding = padding
332
+ self.output_padding = (0, 0, 0) if isinstance(output_padding, int) else output_padding
333
+ self.groups = groups
334
+ self.dilation = dilation
335
+
336
+ # Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
337
+ self.weight = nn.Parameter(
338
+ torch.empty(in_channels, out_channels, *kernel_size)
339
+ )
340
+ if bias:
341
+ self.bias = nn.Parameter(torch.empty(out_channels))
342
+ else:
343
+ self.register_parameter('bias', None)
344
+
345
+ # Initialize weights same as nn.ConvTranspose3d
346
+ nn.init.kaiming_uniform_(self.weight, a=5**0.5)
347
+ if self.bias is not None:
348
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
349
+ if fan_in != 0:
350
+ bound = 1 / fan_in**0.5
351
+ nn.init.uniform_(self.bias, -bound, bound)
352
+
353
+ def forward(self, x):
354
+ return _SafeConvTranspose3dFunc.apply(
355
+ x, self.weight, self.bias,
356
+ self.stride, self.padding, self.output_padding,
357
+ self.groups, self.dilation
358
+ )
359
+
360
+ def extra_repr(self):
361
+ return (f'{self.in_channels}, {self.out_channels}, '
362
+ f'kernel_size={self.kernel_size}, stride={self.stride}, '
363
+ f'padding={self.padding}, bias={self.bias is not None}')
364
+
365
+
366
+ # =============================================================================
367
+ # Utility: in-place replacement of ConvTranspose3d in existing models
368
+ # =============================================================================
369
+
370
+ def replace_conv_transpose3d(module, target_cls=SafeConvTranspose3d):
371
+ """Recursively replace all nn.ConvTranspose3d in a module with the given
372
+ replacement class, copying weights and biases.
373
+
374
+ Usage:
375
+ model = MyModel()
376
+ replace_conv_transpose3d(model) # in-place modification
377
+
378
+ Args:
379
+ module: The nn.Module to modify in-place.
380
+ target_cls: Replacement class (default: SafeConvTranspose3d).
381
+ """
382
+ for name, child in module.named_children():
383
+ if isinstance(child, nn.ConvTranspose3d):
384
+ ct = child
385
+ assert ct.groups == 1, f"groups={ct.groups} not supported"
386
+ assert ct.output_padding == (0,) * len(ct.output_padding), \
387
+ f"output_padding={ct.output_padding} not supported"
388
+
389
+ replacement = target_cls(
390
+ ct.in_channels, ct.out_channels, ct.kernel_size,
391
+ stride=ct.stride, padding=ct.padding,
392
+ bias=ct.bias is not None
393
+ )
394
+ # Copy weights — same tensor shape, no conversion needed
395
+ replacement.weight.data.copy_(ct.weight.data)
396
+ if ct.bias is not None:
397
+ replacement.bias.data.copy_(ct.bias.data)
398
+
399
+ setattr(module, name, replacement)
400
+ else:
401
+ replace_conv_transpose3d(child, target_cls)
Models/all_om_net/000110_all_om_net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9c2c90820aba95bfd89d870820574461963450ca50617ee44fb5af2b17385b3
3
+ size 3017380171
OM_reg.py CHANGED
@@ -72,7 +72,8 @@ min_crop_ratio = 0.9
72
  # label_keys = ['heart']
73
  label_keys = ['brain']
74
  # label_keys = ['pancreas']
75
- database = ['MSD']
 
76
 
77
  dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
78
  Infer_Loader = DataLoader(
@@ -112,6 +113,7 @@ Deformddpm = DeformDDPM(
112
  padding_mode = hyp_parameters["padding_mode"],
113
  v_scale = hyp_parameters["v_scale"],
114
  resample_mode = hyp_parameters["resample_mode"],
 
115
  )
116
  Deformddpm.to(hyp_parameters["device"])
117
 
@@ -125,7 +127,7 @@ ddf_stn.to(hyp_parameters["device"])
125
 
126
  print("Loading model from:", model_save_path)
127
  # Deformddpm.load_state_dict(torch.load(model_save_path))
128
- checkpoint = torch.load(model_save_path)
129
  Deformddpm.load_state_dict(checkpoint['model_state_dict'])
130
  Deformddpm.eval()
131
 
@@ -162,12 +164,8 @@ for e, d in tqdm(enumerate(Infer_Loader)):
162
  # print(pid, image_original.shape, mask_original.max())
163
 
164
 
165
- if hyp_parameters["ndims"] == 2:
166
- nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
167
- nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
168
- elif hyp_parameters["ndims"] == 3:
169
- nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
170
- nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
171
 
172
  # Saving original (undeformed image)
173
  # CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
@@ -198,16 +196,10 @@ for e, d in tqdm(enumerate(Infer_Loader)):
198
  noisy_imgs_np = img_diff.cpu().detach().numpy()
199
  noisy_msks_np = msk_diff.cpu().detach().numpy()
200
 
201
- if hyp_parameters["ndims"] == 2:
202
- nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
203
- nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
204
- nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
205
- nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
206
- elif hyp_parameters["ndims"] == 3:
207
- nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
208
- nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
209
- nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
210
- nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
211
 
212
  nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
213
  nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
 
72
  # label_keys = ['heart']
73
  label_keys = ['brain']
74
  # label_keys = ['pancreas']
75
+ # database = ['MSD']
76
+ database = ['Brats2019']
77
 
78
  dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
79
  Infer_Loader = DataLoader(
 
113
  padding_mode = hyp_parameters["padding_mode"],
114
  v_scale = hyp_parameters["v_scale"],
115
  resample_mode = hyp_parameters["resample_mode"],
116
+ inf_mode = True, # set to True for inference, which will use fixed slice num and slice idx for better evaluation
117
  )
118
  Deformddpm.to(hyp_parameters["device"])
119
 
 
127
 
128
  print("Loading model from:", model_save_path)
129
  # Deformddpm.load_state_dict(torch.load(model_save_path))
130
+ checkpoint = torch.load(model_save_path, map_location='cpu')
131
  Deformddpm.load_state_dict(checkpoint['model_state_dict'])
132
  Deformddpm.eval()
133
 
 
164
  # print(pid, image_original.shape, mask_original.max())
165
 
166
 
167
+ nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
168
+ nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
 
 
 
 
169
 
170
  # Saving original (undeformed image)
171
  # CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
 
196
  noisy_imgs_np = img_diff.cpu().detach().numpy()
197
  noisy_msks_np = msk_diff.cpu().detach().numpy()
198
 
199
+ nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
200
+ nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
201
+ nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
202
+ nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
 
 
 
 
 
 
203
 
204
  nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
205
  nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
OM_reg_flexres.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision.utils import save_image
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import Adam
7
+ from torchvision.utils import make_grid
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import random
12
+ import os
13
+ import utils
14
+ from Dataloader.dataloader0 import get_dataloader
15
+ from Dataloader.dataLoader import *
16
+
17
+ from torchvision.utils import save_image
18
+ from einops import rearrange, reduce, repeat
19
+ import numpy as np
20
+ import nibabel as nib
21
+ from tqdm import tqdm
22
+ import yaml
23
+ import argparse
24
+ import torch.nn.functional as F
25
+ import SimpleITK as sitk
26
+ from skimage.transform import resize
27
+
28
+ EPS = 10e-8
29
+
30
+ parser = argparse.ArgumentParser()
31
+
32
+ parser.add_argument(
33
+ "--config",
34
+ "-C",
35
+ help="Path for the config file",
36
+ type=str,
37
+ default="Config/config_om.yaml",
38
+ required=False,
39
+ )
40
+ args = parser.parse_args()
41
+ #=======================================================================================================================
42
+
43
+ # Load the YAML file into a dictionary
44
+ with open(args.config, 'r') as file:
45
+ hyp_parameters = yaml.safe_load(file)
46
+ print(hyp_parameters)
47
+
48
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
49
+ os.makedirs(hyp_parameters["aug_img_savepath"])
50
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
51
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
52
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
53
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
54
+ print(hyp_parameters["aug_img_savepath"])
55
+
56
+ hyp_parameters['batchsize'] = 1
57
+ model_img_sz = hyp_parameters['img_size'] # e.g. 128
58
+
59
+ # =======================================================================================================================
60
+ # Dataset is used only for its filtering logic (to get the right set of keys + metadata).
61
+ # We bypass the DataLoader and load volumes directly to ensure deterministic center-padding
62
+ # that is identical between the 128^3 model input and the full-res volume.
63
+ label_keys = ['brain']
64
+ database = ['Brats2019']
65
+
66
+ dataset = OminiDataset_inference_w_all(
67
+ transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database)
68
+ # =======================================================================================================================
69
+
70
+
71
+ epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
72
+ model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
73
+ model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
74
+
75
+
76
+ Net = get_net(hyp_parameters["net_name"])
77
+
78
+ Deformddpm = DeformDDPM(
79
+ network=Net(n_steps = hyp_parameters["timesteps"],
80
+ ndims = hyp_parameters["ndims"],
81
+ num_input_chn = hyp_parameters["num_input_chn"],
82
+ res = model_img_sz
83
+ ),
84
+ n_steps = hyp_parameters["timesteps"],
85
+ image_chw = [hyp_parameters["num_input_chn"]] + [model_img_sz]*hyp_parameters["ndims"],
86
+ device = hyp_parameters["device"],
87
+ batch_size = hyp_parameters["batchsize"],
88
+ img_pad_mode = hyp_parameters["img_pad_mode"],
89
+ ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
90
+ padding_mode = hyp_parameters["padding_mode"],
91
+ v_scale = hyp_parameters["v_scale"],
92
+ resample_mode = hyp_parameters["resample_mode"],
93
+ inf_mode = True,
94
+ )
95
+ Deformddpm.to(hyp_parameters["device"])
96
+
97
+ ddf_stn = STN(
98
+ img_sz = model_img_sz,
99
+ ndims = hyp_parameters["ndims"],
100
+ padding_mode = hyp_parameters['padding_mode'],
101
+ device = hyp_parameters["device"],
102
+ )
103
+ ddf_stn.to(hyp_parameters["device"])
104
+
105
+ print("Loading model from:", model_save_path)
106
+ checkpoint = torch.load(model_save_path, map_location='cpu')
107
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
108
+ Deformddpm.eval()
109
+
110
+ # Full-res output directories (append _fullres to the standard paths)
111
+ reg_img_savepath_fullres = hyp_parameters['reg_img_savepath'].rstrip('/') + '_fullres/'
112
+ reg_msk_savepath_fullres = hyp_parameters['reg_msk_savepath'].rstrip('/') + '_fullres/'
113
+ reg_ddf_savepath_fullres = hyp_parameters['reg_ddf_savepath'].rstrip('/') + '_fullres/'
114
+
115
+ os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
116
+ os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
117
+ os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
118
+ os.makedirs(reg_img_savepath_fullres, exist_ok=True)
119
+ os.makedirs(reg_msk_savepath_fullres, exist_ok=True)
120
+ os.makedirs(reg_ddf_savepath_fullres, exist_ok=True)
121
+
122
+
123
+ # ========== Helper functions ==========
124
+
125
+ def center_pad_to_cube(volume):
126
+ """Pad volume to a cube using the max dimension, with symmetric (center) padding."""
127
+ max_dim = max(volume.shape[:3])
128
+ pad_width = []
129
+ for s in volume.shape[:3]:
130
+ total_pad = max_dim - s
131
+ pad_before = total_pad // 2
132
+ pad_after = total_pad - pad_before
133
+ pad_width.append((pad_before, pad_after))
134
+ # Handle extra dims (e.g., multi-channel labels)
135
+ for _ in range(volume.ndim - 3):
136
+ pad_width.append((0, 0))
137
+ return np.pad(volume, pad_width, mode='constant', constant_values=0)
138
+
139
+
140
+ def load_fullres_volume(key, ds):
141
+ """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
142
+ volume = sitk.ReadImage(key)
143
+ volume = sitk.GetArrayFromImage(volume)
144
+ volume = reverse_axis_order(volume)
145
+ if volume.ndim == 4:
146
+ channel_ids = ds.get_channel_ids(key)
147
+ channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
148
+ volume = volume[:, :, :, channel_id]
149
+ # CT clamping
150
+ if ds.clamp_range is not None:
151
+ modality = ds.ALLdata_filtered[key].get("Modality", None)
152
+ if modality == "CT":
153
+ volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
154
+ volume = ds.normalize(volume)
155
+ volume = center_pad_to_cube(volume)
156
+ return volume # shape: [D, D, D] (cubic)
157
+
158
+
159
+ def load_fullres_label(key, ds, label_key):
160
+ """Load original-resolution label: axis reorder, center-pad to cube (no resize)."""
161
+ label_path_dict = ds.ALLdata_filtered[key].get('Label_path', {})
162
+ task_labels = label_path_dict.get('segmentation', {})
163
+ if label_key not in task_labels:
164
+ return None
165
+ label = sitk.ReadImage(task_labels[label_key])
166
+ label = sitk.GetArrayFromImage(label)
167
+ label = reverse_axis_order(label)
168
+ if label.ndim > 3:
169
+ channel_ids = ds.get_channel_ids(key)
170
+ if len(channel_ids) != 0:
171
+ label = label[..., channel_ids]
172
+ label = center_pad_to_cube(label)
173
+ return label
174
+
175
+
176
+ def apply_ddf(volume_tensor, ddf, padding_mode='border', resample_mode='bilinear'):
177
+ """Apply DDF to volume tensor at any resolution.
178
+
179
+ The DDF stores fractional displacements (value * max_sz = voxel displacement).
180
+ When the DDF is spatially upscaled via trilinear interpolation from model resolution
181
+ to full resolution, the fractional values remain correct — we use the new spatial
182
+ size as max_sz, which correctly scales the voxel displacement proportionally.
183
+ """
184
+ device = ddf.device
185
+ ndims = 3
186
+ img_sz = list(volume_tensor.shape[2:])
187
+ max_sz = torch.reshape(
188
+ torch.tensor(img_sz, dtype=torch.float32, device=device),
189
+ [1, ndims] + [1] * ndims)
190
+ ref_grid = torch.reshape(
191
+ torch.stack(torch.meshgrid(
192
+ [torch.arange(s, device=device) for s in img_sz], indexing='ij'), 0),
193
+ [1, ndims] + img_sz)
194
+ img_shape = torch.reshape(
195
+ torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=device),
196
+ [1] + [1] * ndims + [ndims])
197
+ grid = torch.flip(
198
+ (ddf * max_sz + ref_grid).permute(
199
+ [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
200
+ dims=[-1])
201
+ return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
202
+ padding_mode=padding_mode, align_corners=True)
203
+
204
+
205
+ # ========== Main inference loop ==========
206
+
207
+ keys = list(dataset.ALLdata_filtered.keys())
208
+ print("total num of images:", len(keys))
209
+
210
+ for e, key in enumerate(tqdm(keys)):
211
+ pid = e
212
+ print(f'Processing patient {pid}, image {e}, key: {key}')
213
+
214
+ # --- Load full-resolution volume (center-padded to cube) ---
215
+ fullres_vol = load_fullres_volume(key, dataset)
216
+ orig_sz = list(fullres_vol.shape) # e.g. [240, 240, 240]
217
+ print(f" Full-res padded shape: {orig_sz}")
218
+
219
+ # --- Resize to model resolution for inference ---
220
+ vol_model = resize(fullres_vol, [model_img_sz] * 3,
221
+ anti_aliasing=True, preserve_range=True)
222
+ img = torch.tensor(vol_model[None, None, :, :, :],
223
+ dtype=torch.float32, device=hyp_parameters["device"])
224
+
225
+ # --- Load full-res labels and resize to model resolution ---
226
+ fullres_labels = {}
227
+ for lk in label_keys:
228
+ lab = load_fullres_label(key, dataset, lk)
229
+ if lab is not None:
230
+ fullres_labels[lk] = lab
231
+
232
+ # Build mask at model resolution (128^3)
233
+ label_arrays_model = []
234
+ label_arrays_fullres = []
235
+ for lk in label_keys:
236
+ if lk in fullres_labels:
237
+ lab = fullres_labels[lk]
238
+ lab_model = resize(lab, [model_img_sz] * 3,
239
+ anti_aliasing=False, preserve_range=True, order=0)
240
+ if lab_model.ndim == 3:
241
+ lab_model = lab_model[None, :, :, :]
242
+ elif lab_model.ndim > 3:
243
+ lab_model = np.transpose(lab_model, (3, 0, 1, 2))
244
+ label_arrays_model.append(lab_model)
245
+
246
+ if lab.ndim == 3:
247
+ lab = lab[None, :, :, :]
248
+ elif lab.ndim > 3:
249
+ lab = np.transpose(lab, (3, 0, 1, 2))
250
+ label_arrays_fullres.append(lab)
251
+ else:
252
+ label_arrays_model.append(np.full([1] + [model_img_sz] * 3, -1))
253
+ label_arrays_fullres.append(np.full([1] + orig_sz, -1))
254
+
255
+ if len(label_arrays_model) > 0:
256
+ mask_model_np = np.concatenate(label_arrays_model, axis=0)
257
+ mask = torch.tensor(mask_model_np[None], dtype=torch.float32,
258
+ device=hyp_parameters["device"])
259
+ fullres_msk_np = np.concatenate(label_arrays_fullres, axis=0)
260
+ fullres_msk_tensor = torch.tensor(fullres_msk_np[None], dtype=torch.float32,
261
+ device=hyp_parameters["device"])
262
+ else:
263
+ mask = None
264
+ fullres_msk_np = None
265
+ fullres_msk_tensor = None
266
+
267
+ # Build full-res image tensor
268
+ fullres_img_tensor = torch.tensor(fullres_vol[None, None, :, :, :],
269
+ dtype=torch.float32,
270
+ device=hyp_parameters["device"])
271
+
272
+ # --- Save target conditioning image (first subject) ---
273
+ if e <= 0:
274
+ target_img = img.clone().detach()
275
+
276
+ # --- Save original images at 128^3 ---
277
+ image_original = img.cpu().numpy()
278
+ nib.save(utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
279
+ os.path.join(hyp_parameters['reg_img_savepath'],
280
+ utils.get_barcode([pid, e]) + '.nii.gz'))
281
+ if mask is not None:
282
+ mask_original = mask.cpu().numpy()
283
+ nib.save(utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
284
+ os.path.join(hyp_parameters['reg_msk_savepath'],
285
+ utils.get_barcode([pid, e]) + '_GT.nii.gz'))
286
+
287
+ # --- Save original at full-res ---
288
+ # fullres_vol is [D,D,D], wrap as [1,1,D,D,D] for converet_to_nibabel
289
+ nib.save(utils.converet_to_nibabel(fullres_vol[None, None], ndims=hyp_parameters["ndims"]),
290
+ os.path.join(reg_img_savepath_fullres,
291
+ utils.get_barcode([pid, e]) + '.nii.gz'))
292
+ if fullres_msk_np is not None:
293
+ # fullres_msk_np is [C,D,D,D], wrap as [1,C,D,D,D]
294
+ nib.save(utils.converet_to_nibabel(fullres_msk_np[None], ndims=hyp_parameters["ndims"]),
295
+ os.path.join(reg_msk_savepath_fullres,
296
+ utils.get_barcode([pid, e]) + '_GT.nii.gz'))
297
+
298
+ # --- Diffusion recovery at model resolution ---
299
+ noise_step = hyp_parameters["start_noise_step"]
300
+ with torch.no_grad():
301
+ for im in range(1):
302
+ print(f' Generating -> Subject-{pid}, Scan-{e} ({im}/{hyp_parameters["aug_coe"]})', end='\r')
303
+
304
+ [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save] = \
305
+ Deformddpm.diff_recover(
306
+ img_org=img,
307
+ cond_imgs=target_img.clone().detach(),
308
+ msk_org=mask,
309
+ T=[None, hyp_parameters["timesteps"]],
310
+ v_scale=hyp_parameters["v_scale"],
311
+ t_save=None,
312
+ proc_type=hyp_parameters["condition_type"])
313
+
314
+ # --- Save 128^3 results (same as OM_reg.py) ---
315
+ denoise_imgs = img_rec.cpu().numpy()
316
+ noisy_imgs_np = img_diff.cpu().numpy()
317
+
318
+ nib.save(utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
319
+ os.path.join(hyp_parameters['reg_img_savepath'],
320
+ utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
321
+ nib.save(utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]),
322
+ os.path.join(hyp_parameters['reg_img_savepath'],
323
+ utils.get_barcode([pid, e, im, noise_step],
324
+ header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
325
+
326
+ if msk_rec is not None:
327
+ denoise_msks = msk_rec.cpu().numpy()
328
+ noisy_msks_np = msk_diff.cpu().numpy()
329
+ nib.save(utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
330
+ os.path.join(hyp_parameters['reg_msk_savepath'],
331
+ utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
332
+ nib.save(utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]),
333
+ os.path.join(hyp_parameters['reg_msk_savepath'],
334
+ utils.get_barcode([pid, e, im, noise_step],
335
+ header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
336
+
337
+ # --- Upscale DDFs to original resolution ---
338
+ ddf_fullres = F.interpolate(ddf_comp, size=orig_sz,
339
+ mode='trilinear', align_corners=False)
340
+ ddf_rand_fullres = F.interpolate(ddf_rand, size=orig_sz,
341
+ mode='trilinear', align_corners=False)
342
+
343
+ # --- Apply DDFs at original resolution ---
344
+ img_rec_fullres = apply_ddf(fullres_img_tensor, ddf_fullres,
345
+ padding_mode='border')
346
+ img_noisy_fullres = apply_ddf(fullres_img_tensor, ddf_rand_fullres,
347
+ padding_mode='border')
348
+
349
+ if fullres_msk_tensor is not None:
350
+ msk_rec_fullres = apply_ddf(fullres_msk_tensor, ddf_fullres,
351
+ padding_mode='zeros', resample_mode='nearest')
352
+ msk_noisy_fullres = apply_ddf(fullres_msk_tensor, ddf_rand_fullres,
353
+ padding_mode='zeros', resample_mode='nearest')
354
+
355
+ # --- Save full-res results ---
356
+ nib.save(utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
357
+ os.path.join(reg_img_savepath_fullres,
358
+ utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
359
+ nib.save(utils.converet_to_nibabel(img_noisy_fullres, ndims=hyp_parameters["ndims"]),
360
+ os.path.join(reg_img_savepath_fullres,
361
+ utils.get_barcode([pid, e, im, noise_step],
362
+ header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
363
+
364
+ if fullres_msk_tensor is not None:
365
+ nib.save(utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
366
+ os.path.join(reg_msk_savepath_fullres,
367
+ utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
368
+ nib.save(utils.converet_to_nibabel(msk_noisy_fullres, ndims=hyp_parameters["ndims"]),
369
+ os.path.join(reg_msk_savepath_fullres,
370
+ utils.get_barcode([pid, e, im, noise_step],
371
+ header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
372
+
373
+ # Save full-res DDF (converet_to_nibabel handles multi-channel → channel-last)
374
+ nib.save(utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
375
+ os.path.join(reg_ddf_savepath_fullres,
376
+ utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
377
+
378
+ if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
379
+ noise_step = noise_step + hyp_parameters["noise_step"]
380
+
381
+ if e > 5:
382
+ break
OM_train_2modes-reg.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
+ import gc
7
+ import torch
8
+ import torchvision
9
+ from torch import nn
10
+ from torchvision.utils import save_image
11
+ from torch.utils.data import DataLoader
12
+
13
+ from torch.optim import Adam, SGD
14
+ from Diffusion.diffuser import DeformDDPM
15
+ from Diffusion.networks import get_net, STN
16
+ from torchvision.transforms import Lambda
17
+ import Diffusion.losses as losses
18
+ import random
19
+ import glob
20
+ import numpy as np
21
+ import utils
22
+ from tqdm import tqdm
23
+
24
+ from Dataloader.dataloader0 import get_dataloader
25
+ from Dataloader.dataLoader import *
26
+
27
+ from Dataloader.dataloader_utils import thresh_img
28
+ import yaml
29
+ import argparse
30
+
31
+ ####################
32
+ import torch.multiprocessing as mp
33
+ from torch.utils.data.distributed import DistributedSampler
34
+ from torch.nn.parallel import DistributedDataParallel as DDP
35
+ import torch.distributed as dist
36
+ # from torch.distributed import init_process_group
37
+ ###############
38
+ def ddp_setup(rank, world_size):
39
+ """
40
+ Args:
41
+ rank: Unique identifier of each process
42
+ world_size: Total number of processes
43
+ """
44
+ os.environ["MASTER_ADDR"] = "localhost"
45
+ os.environ["MASTER_PORT"] = "12355"
46
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
47
+ torch.cuda.set_device(rank)
48
+
49
+ use_distributed = True
50
+ # use_distributed = False
51
+
52
+ EPS = 1e-5
53
+ MSK_EPS = 0.01
54
+ TEXT_EMBED_PROB = 0.7
55
+ AUG_RESAMPLE_PROB = 0.6
56
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16] # [ang, dist, reg]
57
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
58
+ LOSS_WEIGHTS_REGIST = [1.0, 0.2, 1e2] # [imgsim, imgmse, ddf]
59
+ DIFF_REG_BATCH_RATIO = 2
60
+
61
+ # AUG_PERMUTE_PROB = 0.35
62
+
63
+ parser = argparse.ArgumentParser()
64
+
65
+ # config_file_path = 'Config/config_cmr.yaml'
66
+ parser.add_argument(
67
+ "--config",
68
+ "-C",
69
+ help="Path for the config file",
70
+ type=str,
71
+ # default="Config/config_cmr.yaml",
72
+ # default="Config/config_lct.yaml",
73
+ default="Config/config_all.yaml",
74
+ required=False,
75
+ )
76
+ args = parser.parse_args()
77
+ #=======================================================================================================================
78
+
79
+
80
+
81
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
82
+ if use_distributed:
83
+ ddp_setup(rank,world_size)
84
+
85
+ if torch.distributed.is_initialized():
86
+ print(f"World size: {torch.distributed.get_world_size()}")
87
+ print(f"Communication backend: {torch.distributed.get_backend()}")
88
+ gpu_id = rank
89
+
90
+ # Load the YAML file into a dictionary
91
+ with open(args.config, 'r') as file:
92
+ hyp_parameters = yaml.safe_load(file)
93
+ print(hyp_parameters)
94
+
95
+ # epoch_per_save=10
96
+ epoch_per_save=hyp_parameters['epoch_per_save']
97
+
98
+ data_name=hyp_parameters['data_name']
99
+ net_name = hyp_parameters['net_name']
100
+
101
+ Net=get_net(net_name)
102
+
103
+ suffix_pth=f'_{data_name}_{net_name}.pth'
104
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
105
+ model_dir=model_save_path
106
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
107
+
108
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
109
+
110
+ # tsfm = torchvision.transforms.Compose([
111
+ # torchvision.transforms.ToTensor(),
112
+ # ])
113
+
114
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
115
+ # train_loader = DataLoader(
116
+ # dataset,
117
+ # batch_size=hyp_parameters['batchsize'],
118
+ # # shuffle=False,
119
+ # shuffle=True,
120
+ # drop_last=True,
121
+ # )
122
+
123
+ # dataset = OminiDataset_v1(transform=None)
124
+ dataset = OMDataset_indiv(transform=None)
125
+ train_loader = DataLoader(
126
+ dataset,
127
+ batch_size=hyp_parameters['batchsize'],
128
+ shuffle=True,
129
+ drop_last=True,
130
+ )
131
+
132
+ # datasetp = OminiDataset_paired(transform=None)
133
+ datasetp = OMDataset_pair(transform=None)
134
+ train_loader_p = DataLoader(
135
+ datasetp,
136
+ batch_size=hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO,
137
+ shuffle=True,
138
+ drop_last=True,
139
+ )
140
+
141
+
142
+
143
+ Deformddpm = DeformDDPM(
144
+ network=Net(
145
+ n_steps=hyp_parameters["timesteps"],
146
+ ndims=hyp_parameters["ndims"],
147
+ num_input_chn = hyp_parameters["num_input_chn"],
148
+ res = hyp_parameters['img_size']
149
+ ),
150
+ n_steps=hyp_parameters["timesteps"],
151
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
152
+ device=hyp_parameters["device"],
153
+ batch_size=hyp_parameters["batchsize"],
154
+ img_pad_mode=hyp_parameters["img_pad_mode"],
155
+ v_scale=hyp_parameters["v_scale"],
156
+ )
157
+
158
+
159
+ ddf_stn = STN(
160
+ img_sz=hyp_parameters["img_size"],
161
+ ndims=hyp_parameters["ndims"],
162
+ # padding_mode="zeros",
163
+ padding_mode=hyp_parameters["padding_mode"],
164
+ device=hyp_parameters["device"],
165
+ )
166
+
167
+
168
+ if use_distributed:
169
+ Deformddpm.to(rank)
170
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
171
+ ddf_stn.to(rank)
172
+ else:
173
+ Deformddpm.to(hyp_parameters["device"])
174
+ ddf_stn.to(hyp_parameters["device"])
175
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
176
+
177
+
178
+ # mse = nn.MSELoss()
179
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
180
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
181
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
182
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
183
+
184
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
185
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
186
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
187
+ loss_imgsim = losses.LNCC()
188
+ loss_imgmse = losses.LMSE()
189
+
190
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
191
+ # hyp_parameters["lr"]=0.00000001
192
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
193
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
194
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
195
+
196
+ # # LR scheduler ----- YHM
197
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
198
+
199
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
200
+
201
+ # check for existing models
202
+ if not os.path.exists(model_dir):
203
+ os.makedirs(model_dir, exist_ok=True)
204
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
205
+ model_files.sort()
206
+ if model_files:
207
+ if gpu_id == 0:
208
+ print(model_files)
209
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
210
+ else:
211
+ initial_epoch = 0
212
+
213
+ if gpu_id == 0:
214
+ print('len_train_data: ',len(dataset))
215
+ # Training loop
216
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
217
+
218
+ epoch_loss_tot = 0.0
219
+ epoch_loss_gen_d = 0.0
220
+ epoch_loss_gen_a = 0.0
221
+ epoch_loss_reg = 0.0
222
+ epoch_loss_regist = 0.0
223
+ epoch_loss_imgsim = 0.0
224
+ epoch_loss_imgmse = 0.0
225
+ epoch_loss_ddfreg = 0.0
226
+ # Set model inside to train model
227
+ Deformddpm.train()
228
+
229
+ loss_nan_step = 0 # yu: count the number of nan loss steps
230
+
231
+ total = min(len(train_loader), len(train_loader_p))
232
+ # for step, batch in tqdm(enumerate(train_loader)):
233
+ # for step, batch in tqdm(enumerate(train_loader)):
234
+ # for step, batch in enumerate(train_loader_omni):
235
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
236
+
237
+ # x0, _ = batch
238
+
239
+
240
+ # ==========================================================================
241
+ # diffusion train on single image
242
+
243
+ # x0 = batch # for omni dataset
244
+ [x0,embd] = batch # for om dataset
245
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
246
+ # print('embd:', embd.shape)
247
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
248
+ embd = embd.to(hyp_parameters["device"]).type(torch.float32)
249
+ else:
250
+ embd = None
251
+
252
+
253
+
254
+ n = x0.size()[0] # batch_size -> n
255
+ x0 = x0.to(hyp_parameters["device"])
256
+
257
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
258
+
259
+ # random deformation + rotation
260
+ if hyp_parameters["ndims"]>2:
261
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
262
+ x0 = utils.random_resample(x0, deform_scale=0)
263
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
264
+ else:
265
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
266
+ # x0 = transformer(x0)
267
+ if hyp_parameters['noise_scale']>0:
268
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
269
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
270
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
271
+
272
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
273
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
274
+ hyp_parameters["device"]
275
+ ) # pick up a seq of rand number from 0 to 'timestep'
276
+
277
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
278
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
279
+ # print('proc_type:', proc_type)
280
+ cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
281
+
282
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
283
+
284
+ # print(torch.max(torch.abs(pre_dvf_I)))
285
+ # print(torch.max(torch.abs(dvf_I)))
286
+
287
+ loss_tot=0
288
+
289
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
290
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
291
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
292
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
293
+
294
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
295
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
296
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
297
+
298
+ # >> JZ: print nan in x0
299
+ if torch.isnan(x0).any():
300
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
301
+ # >> JZ: print loss of ddf
302
+ if loss_ddf>0.001:
303
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
304
+ # yu: check if loss_tot==nan or inf
305
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
306
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
307
+ loss_nan_step += 1
308
+ continue
309
+ if loss_nan_step > 5:
310
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
311
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
312
+
313
+ optimizer.zero_grad()
314
+ loss_tot.backward()
315
+ optimizer.step()
316
+
317
+ epoch_loss_tot += loss_tot.item() / total
318
+ epoch_loss_gen_d += loss_gen_d.item() / total
319
+ epoch_loss_gen_a += loss_gen_a.item() / total
320
+ epoch_loss_reg += loss_ddf.item() / total
321
+
322
+ # print(loss_gen_a.item())
323
+ # if 0:
324
+ # if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
325
+ if step%train_mode_ratio == 0:
326
+ # ==========================================================================
327
+ # registration train on paired images
328
+ # x1, y1 = next(iter(train_loader_p))
329
+ # [x1, y1, _, embd_y] = next(iter(train_loader_p))
330
+ [x1, y1, _, embd_y] = batch_p
331
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
332
+ # embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
333
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
334
+ else:
335
+ # embd_x = None
336
+ embd_y = None
337
+
338
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
339
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
340
+ n = x1.size()[0] # batch_size -> n
341
+ # random deformation + rotation
342
+ # if hyp_parameters["ndims"]>2:
343
+ # if np.random.uniform(0,1)<0.6:
344
+ # x1 = utils.random_resample(x1, deform_scale=0)
345
+ # y1 = utils.random_resample(y1, deform_scale=0)
346
+ # x1 = transformer(x1)
347
+ # y1 = transformer(y1)
348
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
349
+ if hyp_parameters['noise_scale']>0:
350
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
351
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
352
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
353
+ x1 = x1 * random_scale + random_shift
354
+ y1 = y1 * random_scale + random_shift
355
+ # x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
356
+ # x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
357
+ # y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
358
+ # y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
359
+ # # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
360
+ # t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
361
+ # hyp_parameters["device"]
362
+ # ) # pick up a seq of rand number from 0 to 'timestep'
363
+
364
+
365
+ # scale_regist = np.random.uniform(0.6,1.)
366
+ # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
367
+ # print('T_regist (0.6,1) sampling range:', T_regist)
368
+ scale_regist = np.random.uniform(0.0,0.7)
369
+ select_timestep = np.random.randint(8, 17) # select a random number of timesteps to sample, between 8 and 16
370
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
371
+ # print('T_regist (0.1,0.7) sampling range:', T_regist)
372
+ # scale_regist = np.random.uniform(0.4,1.)
373
+ # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
374
+ # scale_regist = np.random.uniform(0.6,1.)
375
+ # init_T = int(hyp_parameters["timesteps"] * scale_regist)
376
+ # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
377
+
378
+ T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
379
+
380
+ # print('T_regist:', T_regist)
381
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
382
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
383
+ # proc_type = random.choice(['project'])
384
+ y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
385
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
386
+ # loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
387
+ # loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
388
+ # loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=(msk_tgt+MSK_EPS)) # calculate loss for the registration process
389
+ loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
390
+ loss_mse = loss_imgmse(img_rec, y1, label=(y1>=0.0)) # calculate loss for the registration process
391
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
392
+
393
+ loss_regist = 0
394
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
395
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
396
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
397
+ # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
398
+ # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
399
+
400
+ # >> JZ: print nan in x0
401
+ if torch.isnan(x0).any():
402
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
403
+ # >> JZ: print loss of ddf
404
+ if loss_ddf1>0.002:
405
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
406
+ # # Print gradients for each parameter
407
+ # for name, param in Deformddpm.named_parameters():
408
+ # if param.grad is not None:
409
+ # print(f"Gradient for {name}: {param.grad.norm()}")
410
+ # else:
411
+ # print(f"Gradient for {name}: None")
412
+
413
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
414
+ optimizer.zero_grad()
415
+ loss_regist.backward()
416
+
417
+
418
+
419
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.4)
420
+ optimizer.step()
421
+
422
+ epoch_loss_regist += loss_regist.item() / total
423
+ epoch_loss_imgsim += loss_sim.item() / total
424
+ epoch_loss_imgmse += loss_mse.item() / total
425
+ epoch_loss_ddfreg += loss_ddf1.item() / total
426
+
427
+
428
+ if step % 10 == 0:
429
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
430
+ print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
431
+ # break # FOR TESTING
432
+ # else:
433
+ # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
434
+ # pass
435
+
436
+ if 1:
437
+ # if gpu_id == 0:
438
+ print('==================')
439
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
440
+ print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
441
+ print('==================')
442
+ # # LR schedular step ----- YHM
443
+ # scheduler.step()
444
+
445
+ if 0 == epoch % epoch_per_save:
446
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
447
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
448
+ # break # FOR TESTING
449
+ if not use_distributed:
450
+ print(f"saved in {save_dir}")
451
+ # torch.save(Deformddpm.state_dict(), save_dir)
452
+ torch.save({
453
+ 'model_state_dict': Deformddpm.state_dict(),
454
+ 'optimizer_state_dict': optimizer.state_dict(),
455
+ 'epoch': epoch
456
+ }, save_dir)
457
+ elif gpu_id == 0:
458
+ print(f"saved in {save_dir}")
459
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
460
+ torch.save({
461
+ 'model_state_dict': Deformddpm.module.state_dict(),
462
+ 'optimizer_state_dict': optimizer.state_dict(),
463
+ 'epoch': epoch
464
+ }, save_dir)
465
+
466
+ # Resource cleanup at the end of training
467
+ torch.cuda.empty_cache()
468
+ gc.collect()
469
+ if use_distributed and dist.is_initialized():
470
+ dist.destroy_process_group()
471
+
472
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
473
+
474
+ if gpu_id == 0:
475
+ # if 0:
476
+ utils.print_memory_usage("Before Loading Model")
477
+ if 1:
478
+ gc.collect()
479
+ torch.cuda.empty_cache()
480
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
481
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
482
+ checkpoint = torch.load(model_file)
483
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
484
+ if use_distributed:
485
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
486
+ else:
487
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
488
+ if load_strict:
489
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
490
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
491
+
492
+ if use_distributed:
493
+ # Broadcast model weights from rank 0 to all other GPUs
494
+ dist.barrier()
495
+ for param in Deformddpm.parameters():
496
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
497
+ dist.barrier()
498
+ for param_group in optimizer.param_groups:
499
+ for param in param_group['params']:
500
+ if param.grad is not None:
501
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
502
+
503
+ # initial_epoch = checkpoint['epoch'] + 1
504
+ # get the epoch number from the filename and add 1 to set as initial_epoch
505
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
506
+
507
+ return initial_epoch, Deformddpm, optimizer
508
+
509
+
510
+
511
+ if __name__ == "__main__":
512
+ if use_distributed:
513
+ world_size = torch.cuda.device_count()
514
+ print(f"Distributed GPU number = {world_size}")
515
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
516
+ else:
517
+ main_train(0,1)
OM_train_2modes.py CHANGED
@@ -1,4 +1,8 @@
1
- import os
 
 
 
 
2
  import gc
3
  import torch
4
  import torchvision
@@ -48,12 +52,11 @@ use_distributed = True
48
  EPS = 1e-5
49
  MSK_EPS = 0.01
50
  TEXT_EMBED_PROB = 0.7
51
- AUG_RESAMPLE_PROB = 0.6
52
- LOSS_WEIGHTS_DIFF = [2.0, 1.0, 30] # [ang, dist, reg]
53
  # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
54
- # LOSS_WEIGHTS_REGIST = [10.0, 1.0, 1.0] # [imgsim, imgmse, ddf]
55
- # LOSS_WEIGHTS_REGIST = [2.0, 0.1, 1e3] # [imgsim, imgmse, ddf]
56
- LOSS_WEIGHTS_REGIST = [2.0, 0.1, 256] # [imgsim, imgmse, ddf]
57
 
58
  # AUG_PERMUTE_PROB = 0.35
59
 
@@ -130,7 +133,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
130
  datasetp = OMDataset_pair(transform=None)
131
  train_loader_p = DataLoader(
132
  datasetp,
133
- batch_size=hyp_parameters['batchsize']//2,
134
  shuffle=True,
135
  drop_last=True,
136
  )
@@ -174,12 +177,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
174
 
175
  # mse = nn.MSELoss()
176
  # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
 
177
  loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
178
  loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
 
179
  loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
180
  # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
181
  loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
182
- loss_imgsim = losses.LNCC()
 
183
  loss_imgmse = losses.LMSE()
184
 
185
  optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
@@ -220,15 +226,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
220
  epoch_loss_ddfreg = 0.0
221
  # Set model inside to train model
222
  Deformddpm.train()
223
-
224
  loss_nan_step = 0 # yu: count the number of nan loss steps
225
 
226
  total = min(len(train_loader), len(train_loader_p))
227
- for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
228
  # for step, batch in tqdm(enumerate(train_loader)):
229
  # for step, batch in tqdm(enumerate(train_loader)):
230
-
231
  # for step, batch in enumerate(train_loader_omni):
 
 
232
  # x0, _ = batch
233
 
234
 
@@ -258,10 +264,10 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
258
  # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
259
  else:
260
  [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
261
- x0 = transformer(x0)
262
  if hyp_parameters['noise_scale']>0:
263
  if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
264
- x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
265
  x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
266
 
267
  # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
@@ -270,12 +276,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
270
  ) # pick up a seq of rand number from 0 to 'timestep'
271
 
272
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
273
- proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
274
  # print('proc_type:', proc_type)
275
  cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
276
 
277
  pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
278
 
 
 
 
279
  loss_tot=0
280
 
281
  loss_ddf = loss_reg(pre_dvf_I,img=x0)
@@ -302,15 +311,14 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
302
  print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
303
  raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
304
 
305
-
306
  optimizer.zero_grad()
307
  loss_tot.backward()
308
  optimizer.step()
309
 
310
- epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
311
- epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
312
- epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
313
- epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
314
 
315
  # print(loss_gen_a.item())
316
  # if 0:
@@ -336,8 +344,8 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
336
  # if np.random.uniform(0,1)<0.6:
337
  # x1 = utils.random_resample(x1, deform_scale=0)
338
  # y1 = utils.random_resample(y1, deform_scale=0)
339
- x1 = transformer(x1)
340
- y1 = transformer(y1)
341
  [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
342
  if hyp_parameters['noise_scale']>0:
343
  [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
@@ -355,10 +363,13 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
355
  # ) # pick up a seq of rand number from 0 to 'timestep'
356
 
357
 
358
- # scale_regist = np.random.uniform(0.2,0.25)
359
  # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
360
- scale_regist = np.random.uniform(0.05,0.7)
361
- T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), 16), reverse=True)
 
 
 
362
  # scale_regist = np.random.uniform(0.4,1.)
363
  # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
364
  # scale_regist = np.random.uniform(0.6,1.)
@@ -369,33 +380,30 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
369
 
370
  # print('T_regist:', T_regist)
371
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
372
- proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'none'])
373
  # proc_type = random.choice(['project'])
374
  y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
375
- # msk_tgt = msk_tgt + MSK_EPS
376
  [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
377
- # loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
378
- # loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
379
- # loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
380
- loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
381
- loss_mse = loss_imgmse(img_rec, y1, label=(y1>0.0)) # calculate loss for the registration process
382
- loss_ddf1 = loss_reg1(ddf_comp,img=y1) # calculate loss for the registration process
383
-
384
  loss_regist = 0
385
  loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
386
  loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
387
  loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
388
  # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
389
  # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
390
-
391
  # >> JZ: print nan in x0
392
  if torch.isnan(x0).any():
393
  print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
394
-
395
-
396
-
397
  # >> JZ: print loss of ddf
398
- if loss_ddf1>0.001:
399
  print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
400
  # # Print gradients for each parameter
401
  # for name, param in Deformddpm.named_parameters():
@@ -403,43 +411,25 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
403
  # print(f"Gradient for {name}: {param.grad.norm()}")
404
  # else:
405
  # print(f"Gradient for {name}: None")
406
-
407
  loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
408
  optimizer.zero_grad()
409
  loss_regist.backward()
410
 
411
 
412
 
413
- torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
414
  optimizer.step()
415
 
416
- epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
417
- epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
418
- epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
419
- epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
420
 
421
 
422
- print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
423
- print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
424
- # >> JZ: if loss_imgsim is zero
425
- if loss_sim.item()>-0.001:
426
- print(f"*** Zero image similarity loss at epoch {epoch}, step {step}.")
427
- def save_niftiimage(tensor, filename):
428
- import nibabel as nib
429
- import numpy as np
430
- array = tensor.squeeze().cpu().detach().numpy()
431
- nifti_img = nib.Nifti1Image(array, affine=np.eye(4))
432
- nib.save(nifti_img, filename)
433
- # save the x1 and y1 images for debugging
434
- save_path = os.path.join('/home/data/Github/OmniMorph/Log/error_files',f"debug_images_epoch{epoch}_step{step}/")
435
- os.makedirs(save_path, exist_ok=True)
436
- save_niftiimage(img_rec, os.path.join(save_path, 'img_rec.nii.gz'))
437
- save_niftiimage(x1, os.path.join(save_path, 'x1.nii.gz'))
438
- save_niftiimage(y1, os.path.join(save_path, 'y1.nii.gz'))
439
- save_niftiimage(y1_proc, os.path.join(save_path, 'y1_proc.nii.gz'))
440
- exit()
441
- # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
442
-
443
  # break # FOR TESTING
444
  # else:
445
  # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
@@ -481,7 +471,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
481
  if use_distributed and dist.is_initialized():
482
  dist.destroy_process_group()
483
 
484
- def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
485
 
486
  if gpu_id == 0:
487
  # if 0:
@@ -494,10 +484,11 @@ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True
494
  checkpoint = torch.load(model_file)
495
  # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
496
  if use_distributed:
497
- Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
498
  else:
499
- Deformddpm.load_state_dict(checkpoint['model_state_dict'])
500
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
501
  utils.print_memory_usage("After Loading Checkpoint on GPU")
502
 
503
  if use_distributed:
 
1
+ import os, sys
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
  import gc
7
  import torch
8
  import torchvision
 
52
  EPS = 1e-5
53
  MSK_EPS = 0.01
54
  TEXT_EMBED_PROB = 0.7
55
+ AUG_RESAMPLE_PROB = 0.5
56
+ LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0] # [ang, dist, reg]
57
  # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
58
+ LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
59
+ DIFF_REG_BATCH_RATIO = 2
 
60
 
61
  # AUG_PERMUTE_PROB = 0.35
62
 
 
133
  datasetp = OMDataset_pair(transform=None)
134
  train_loader_p = DataLoader(
135
  datasetp,
136
+ batch_size=hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO,
137
  shuffle=True,
138
  drop_last=True,
139
  )
 
177
 
178
  # mse = nn.MSELoss()
179
  # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
180
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
181
  loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
182
  loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
183
+
184
  loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
185
  # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
186
  loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
187
+ # loss_imgsim = losses.LNCC()
188
+ loss_imgsim = losses.MSLNCC()
189
  loss_imgmse = losses.LMSE()
190
 
191
  optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
 
226
  epoch_loss_ddfreg = 0.0
227
  # Set model inside to train model
228
  Deformddpm.train()
229
+
230
  loss_nan_step = 0 # yu: count the number of nan loss steps
231
 
232
  total = min(len(train_loader), len(train_loader_p))
 
233
  # for step, batch in tqdm(enumerate(train_loader)):
234
  # for step, batch in tqdm(enumerate(train_loader)):
 
235
  # for step, batch in enumerate(train_loader_omni):
236
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
237
+
238
  # x0, _ = batch
239
 
240
 
 
264
  # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
265
  else:
266
  [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
267
+ # x0 = transformer(x0)
268
  if hyp_parameters['noise_scale']>0:
269
  if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
270
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
271
  x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
272
 
273
  # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
 
276
  ) # pick up a seq of rand number from 0 to 'timestep'
277
 
278
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
279
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
280
  # print('proc_type:', proc_type)
281
  cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
282
 
283
  pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
284
 
285
+ # print(torch.max(torch.abs(pre_dvf_I)))
286
+ # print(torch.max(torch.abs(dvf_I)))
287
+
288
  loss_tot=0
289
 
290
  loss_ddf = loss_reg(pre_dvf_I,img=x0)
 
311
  print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
312
  raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
313
 
 
314
  optimizer.zero_grad()
315
  loss_tot.backward()
316
  optimizer.step()
317
 
318
+ epoch_loss_tot += loss_tot.item() / total
319
+ epoch_loss_gen_d += loss_gen_d.item() / total
320
+ epoch_loss_gen_a += loss_gen_a.item() / total
321
+ epoch_loss_reg += loss_ddf.item() / total
322
 
323
  # print(loss_gen_a.item())
324
  # if 0:
 
344
  # if np.random.uniform(0,1)<0.6:
345
  # x1 = utils.random_resample(x1, deform_scale=0)
346
  # y1 = utils.random_resample(y1, deform_scale=0)
347
+ # x1 = transformer(x1)
348
+ # y1 = transformer(y1)
349
  [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
350
  if hyp_parameters['noise_scale']>0:
351
  [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
 
363
  # ) # pick up a seq of rand number from 0 to 'timestep'
364
 
365
 
366
+ # scale_regist = np.random.uniform(0.6,1.)
367
  # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
368
+ # print('T_regist (0.6,1) sampling range:', T_regist)
369
+ scale_regist = np.random.uniform(0.0,0.7)
370
+ select_timestep = np.random.randint(8, 17) # select a random number of timesteps to sample, between 8 and 16
371
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
372
+ # print('T_regist (0.1,0.7) sampling range:', T_regist)
373
  # scale_regist = np.random.uniform(0.4,1.)
374
  # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
375
  # scale_regist = np.random.uniform(0.6,1.)
 
380
 
381
  # print('T_regist:', T_regist)
382
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
383
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
384
  # proc_type = random.choice(['project'])
385
  y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
386
+ msk_tgt = msk_tgt+MSK_EPS
387
  [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
388
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
389
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
390
+ # loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=(msk_tgt+MSK_EPS)) # calculate loss for the registration process
391
+ # loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
392
+ # loss_mse = loss_imgmse(img_rec, y1, label=(y1>=0.0)) # calculate loss for the registration process
393
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
394
+
395
  loss_regist = 0
396
  loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
397
  loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
398
  loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
399
  # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
400
  # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
401
+
402
  # >> JZ: print nan in x0
403
  if torch.isnan(x0).any():
404
  print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
 
 
 
405
  # >> JZ: print loss of ddf
406
+ if loss_ddf1>0.002:
407
  print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
408
  # # Print gradients for each parameter
409
  # for name, param in Deformddpm.named_parameters():
 
411
  # print(f"Gradient for {name}: {param.grad.norm()}")
412
  # else:
413
  # print(f"Gradient for {name}: None")
414
+
415
  loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
416
  optimizer.zero_grad()
417
  loss_regist.backward()
418
 
419
 
420
 
421
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.2)
422
  optimizer.step()
423
 
424
+ epoch_loss_regist += loss_regist.item() / total
425
+ epoch_loss_imgsim += loss_sim.item() / total
426
+ epoch_loss_imgmse += loss_mse.item() / total
427
+ epoch_loss_ddfreg += loss_ddf1.item() / total
428
 
429
 
430
+ if step % 10 == 0:
431
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
432
+ print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  # break # FOR TESTING
434
  # else:
435
  # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
 
471
  if use_distributed and dist.is_initialized():
472
  dist.destroy_process_group()
473
 
474
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
475
 
476
  if gpu_id == 0:
477
  # if 0:
 
484
  checkpoint = torch.load(model_file)
485
  # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
486
  if use_distributed:
487
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
488
  else:
489
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
490
+ if load_strict:
491
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
492
  utils.print_memory_usage("After Loading Checkpoint on GPU")
493
 
494
  if use_distributed:
OM_train_3modes-XPU.py ADDED
@@ -0,0 +1,957 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, contextlib
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
+ import gc
7
+ import torch
8
+ import torchvision
9
+ from torch import nn
10
+ from torchvision.utils import save_image
11
+ from torch.utils.data import DataLoader
12
+
13
+ from torch.optim import Adam, SGD
14
+ from Diffusion.diffuser import DeformDDPM
15
+ from Diffusion.networks import get_net, STN
16
+ from torchvision.transforms import Lambda
17
+ import torch.nn.functional as F
18
+ import Diffusion.losses as losses
19
+ import random
20
+ import glob
21
+ import numpy as np
22
+ import utils
23
+ from tqdm import tqdm
24
+
25
+ from Dataloader.dataloader0 import get_dataloader
26
+ from Dataloader.dataLoader import *
27
+
28
+ from Dataloader.dataloader_utils import thresh_img
29
+ import yaml
30
+ import argparse
31
+
32
+ # XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
33
+ try:
34
+ import intel_extension_for_pytorch as ipex
35
+ except ImportError:
36
+ ipex = None
37
+ try:
38
+ import oneccl_bindings_for_pytorch
39
+ except (ImportError, Exception) as e:
40
+ print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
41
+
42
+ ####################
43
+ import torch.multiprocessing as mp
44
+ from torch.utils.data.distributed import DistributedSampler
45
+ from torch.nn.parallel import DistributedDataParallel as DDP
46
+ import torch.distributed as dist
47
+ # from torch.distributed import init_process_group
48
+ ###############
49
+ def _device_available(device_type):
50
+ if device_type == 'xpu':
51
+ return hasattr(torch, 'xpu') and torch.xpu.is_available()
52
+ return torch.cuda.is_available()
53
+
54
+ def _device_count(device_type):
55
+ if device_type == 'xpu':
56
+ return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
57
+ return torch.cuda.device_count()
58
+
59
+ def _set_device(rank, device_type):
60
+ if device_type == 'xpu':
61
+ torch.xpu.set_device(rank)
62
+ else:
63
+ torch.cuda.set_device(rank)
64
+
65
+ def _empty_cache(device_type):
66
+ if device_type == 'xpu' and hasattr(torch, 'xpu'):
67
+ torch.xpu.empty_cache()
68
+ elif torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+
71
+ def ddp_setup(rank, world_size):
72
+ """
73
+ Args:
74
+ rank: Unique identifier of each process (local_rank when launched by torchrun)
75
+ world_size: Total number of processes
76
+ """
77
+ backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
78
+ if "LOCAL_RANK" in os.environ:
79
+ # Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
80
+ dist.init_process_group(backend=backend)
81
+ _set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
82
+ else:
83
+ # Single-node mp.spawn
84
+ os.environ["MASTER_ADDR"] = "localhost"
85
+ os.environ["MASTER_PORT"] = "12355"
86
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
87
+ _set_device(rank, DEVICE_TYPE)
88
+
89
+ EPS = 1e-5
90
+ MSK_EPS = 0.01
91
+ TEXT_EMBED_PROB = 0.5
92
+ AUG_RESAMPLE_PROB = 0.5
93
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
94
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
95
+ LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
96
+ DIFF_REG_BATCH_RATIO = 2
97
+ LOSS_WEIGHT_CONTRASTIVE = 1e-4
98
+ REGISTRATION_STEP_RATIO = 1
99
+ CONTRASTIVE_STEP_RATIO = 1
100
+ MID_EPOCH_SAVE_STEPS = 10 # Save mid-epoch checkpoint every N steps for crash recovery.
101
+ # XPU autograd leaks ~1.0 GiB/step of device memory (Intel bug).
102
+ # With gradient checkpointing, training survives ~26 steps from fresh start,
103
+ # but fewer when carrying leaked memory from previous epoch.
104
+ # Save every 10 steps to minimize lost work on OOM crash.
105
+ EXIT_CODE_RESTART = 42 # Exit code signaling proactive restart (not a crash).
106
+
107
+ # AUG_PERMUTE_PROB = 0.35
108
+
109
+ parser = argparse.ArgumentParser()
110
+
111
+ # config_file_path = 'Config/config_cmr.yaml'
112
+ parser.add_argument(
113
+ "--config",
114
+ "-C",
115
+ help="Path for the config file",
116
+ type=str,
117
+ # default="Config/config_cmr.yaml",
118
+ # default="Config/config_lct.yaml",
119
+ default="Config/config_all.yaml",
120
+ required=False,
121
+ )
122
+ parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
123
+ parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
124
+ parser.add_argument("--max-steps-before-restart", type=int, default=0,
125
+ help="Proactive restart: exit after N training steps to reset XPU memory leak. "
126
+ "0=disabled (rely on OOM crash + auto-resubmit). "
127
+ "Recommended: 20 for XPU (survives ~26 steps max).")
128
+ parser.add_argument("--no-save", action="store_true",
129
+ help="Disable all checkpoint saving (for diagnostic/validation runs)")
130
+ parser.add_argument("--reset-optimizer", action="store_true",
131
+ help="Skip optimizer state loading from checkpoint (use when architecture changed)")
132
+ parser.add_argument("--eval-only", action="store_true",
133
+ help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)")
134
+ args = parser.parse_args()
135
+
136
+ # Read config early to determine device type for DDP setup
137
+ with open(args.config, 'r') as _f:
138
+ _cfg = yaml.safe_load(_f)
139
+ DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
140
+
141
+ # Auto-detect: use DDP only when multiple devices are available
142
+ use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
143
+ # use_distributed = True
144
+ # use_distributed = False
145
+ #=======================================================================================================================
146
+
147
+ class _DummyIndiv(torch.utils.data.Dataset):
148
+ def __init__(self, n, sz, embd_dim=1024):
149
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
150
+ def __len__(self): return self.n
151
+ def __getitem__(self, i):
152
+ return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
153
+
154
+ class _DummyPair(torch.utils.data.Dataset):
155
+ def __init__(self, n, sz, embd_dim=1024):
156
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
157
+ def __len__(self): return self.n
158
+ def __getitem__(self, i):
159
+ return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
160
+ np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
161
+ np.random.randn(self.embd_dim).astype(np.float32),
162
+ np.random.randn(self.embd_dim).astype(np.float32))
163
+
164
+
165
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
166
+ if use_distributed:
167
+ ddp_setup(rank,world_size)
168
+
169
+ if torch.distributed.is_initialized() and rank == 0:
170
+ print(f"World size: {torch.distributed.get_world_size()}")
171
+ print(f"Communication backend: {torch.distributed.get_backend()}")
172
+ print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}")
173
+ if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
174
+ props = torch.xpu.get_device_properties(0)
175
+ print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB")
176
+ # gpu_id = global rank (for save/print guards); rank = local device index
177
+ if "RANK" in os.environ:
178
+ gpu_id = int(os.environ["RANK"])
179
+ rank = int(os.environ["LOCAL_RANK"])
180
+ else:
181
+ gpu_id = rank
182
+
183
+ # Load the YAML file into a dictionary
184
+ with open(args.config, 'r') as file:
185
+ hyp_parameters = yaml.safe_load(file)
186
+ if args.batchsize > 0:
187
+ hyp_parameters['batchsize'] = args.batchsize
188
+ if gpu_id == 0:
189
+ print(hyp_parameters)
190
+
191
+ # epoch_per_save=10
192
+ epoch_per_save=hyp_parameters['epoch_per_save']
193
+
194
+ data_name=hyp_parameters['data_name']
195
+ net_name = hyp_parameters['net_name']
196
+
197
+ Net=get_net(net_name)
198
+
199
+ suffix_pth=f'_{data_name}_{net_name}.pth'
200
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
201
+ model_dir=model_save_path
202
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
203
+
204
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
205
+
206
+ # tsfm = torchvision.transforms.Compose([
207
+ # torchvision.transforms.ToTensor(),
208
+ # ])
209
+
210
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
211
+ # train_loader = DataLoader(
212
+ # dataset,
213
+ # batch_size=hyp_parameters['batchsize'],
214
+ # # shuffle=False,
215
+ # shuffle=True,
216
+ # drop_last=True,
217
+ # )
218
+
219
+ if args.dummy_samples > 0:
220
+ dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
221
+ datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
222
+ else:
223
+ # dataset = OminiDataset_v1(transform=None)
224
+ dataset = OMDataset_indiv(transform=None)
225
+ # datasetp = OminiDataset_paired(transform=None)
226
+ datasetp = OMDataset_pair(transform=None)
227
+
228
+ if use_distributed:
229
+ sampler = DistributedSampler(dataset, shuffle=True)
230
+ sampler_p = DistributedSampler(datasetp, shuffle=True)
231
+ else:
232
+ sampler = None
233
+ sampler_p = None
234
+
235
+ train_loader = DataLoader(
236
+ dataset,
237
+ batch_size=hyp_parameters['batchsize'],
238
+ shuffle=(sampler is None),
239
+ drop_last=True,
240
+ sampler=sampler,
241
+ )
242
+ train_loader_p = DataLoader(
243
+ datasetp,
244
+ batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
245
+ shuffle=(sampler_p is None),
246
+ drop_last=True,
247
+ sampler=sampler_p,
248
+ )
249
+
250
+
251
+
252
+ network = Net(
253
+ n_steps=hyp_parameters["timesteps"],
254
+ ndims=hyp_parameters["ndims"],
255
+ num_input_chn = hyp_parameters["num_input_chn"],
256
+ res = hyp_parameters['img_size']
257
+ )
258
+ # Enable gradient checkpointing on XPU to reduce peak activation memory.
259
+ # XPU autograd leaks ~1.0 GiB/step; lower peak buys more steps before OOM.
260
+ if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'):
261
+ network.use_checkpoint = True
262
+ if gpu_id == 0:
263
+ print(" [init] Gradient checkpointing enabled for XPU", flush=True)
264
+
265
+ Deformddpm = DeformDDPM(
266
+ network=network,
267
+ n_steps=hyp_parameters["timesteps"],
268
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
269
+ device=hyp_parameters["device"],
270
+ batch_size=hyp_parameters["batchsize"],
271
+ img_pad_mode=hyp_parameters["img_pad_mode"],
272
+ v_scale=hyp_parameters["v_scale"],
273
+ )
274
+
275
+
276
+ ddf_stn = STN(
277
+ img_sz=hyp_parameters["img_size"],
278
+ ndims=hyp_parameters["ndims"],
279
+ # padding_mode="zeros",
280
+ padding_mode=hyp_parameters["padding_mode"],
281
+ device=hyp_parameters["device"],
282
+ )
283
+
284
+
285
+ if use_distributed:
286
+ device = f"{DEVICE_TYPE}:{rank}"
287
+ # NO pre-allocation. CCL/oneDNN accumulate ~1.4 GiB/step of device memory outside
288
+ # PyTorch's caching allocator. Pre-allocating steals from that budget:
289
+ # 92% pre-alloc → crash at step 3, 78% → step 10, none (70% cap) → step 14.
290
+ # Instead, use empty_cache() between training phases to release unused cached memory
291
+ # back to the device for CCL/oneDNN.
292
+ if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
293
+ total_mem = torch.xpu.get_device_properties(rank).total_memory
294
+ print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True)
295
+ Deformddpm.to(device)
296
+ Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
297
+ ddf_stn.to(device)
298
+ else:
299
+ Deformddpm.to(hyp_parameters["device"])
300
+ ddf_stn.to(hyp_parameters["device"])
301
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
302
+
303
+
304
+ # mse = nn.MSELoss()
305
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
306
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
307
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
308
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
309
+
310
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
311
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
312
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
313
+ loss_imgsim = losses.MSLNCC()
314
+ loss_imgmse = losses.LMSE()
315
+
316
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
317
+ # hyp_parameters["lr"]=0.00000001
318
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
319
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
320
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
321
+
322
+ # # LR scheduler ----- YHM
323
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
324
+
325
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
326
+
327
+ # check for existing models
328
+ if not os.path.exists(model_dir):
329
+ os.makedirs(model_dir, exist_ok=True)
330
+ # Check for checkpoints: first check tmp/ for mid-epoch, then main dir for epoch-level
331
+ tmp_dir = os.path.join(model_dir, "tmp")
332
+ tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth")))
333
+ model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
334
+ initial_step = 0
335
+
336
+ # Epoch stats and RNG states to restore when resuming from mid-epoch checkpoint
337
+ _resume_epoch_stats = None
338
+ _resume_rng = None
339
+
340
+ if tmp_files and not args.eval_only and args.max_steps_before_restart > 0:
341
+ # Mid-epoch checkpoint: only use when proactive restart is enabled
342
+ latest = tmp_files[-1]
343
+ if gpu_id == 0:
344
+ print(f" [resume] Found mid-epoch checkpoint: {latest}")
345
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
346
+ basename = os.path.basename(latest)
347
+ initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0])
348
+ _ckpt = torch.load(latest, map_location='cpu', weights_only=False)
349
+ _resume_epoch_stats = _ckpt.get('epoch_stats', None)
350
+ del _ckpt
351
+ if gpu_id == 0:
352
+ print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}"
353
+ f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True)
354
+ elif model_files:
355
+ if gpu_id == 0:
356
+ print(model_files)
357
+ latest = model_files[-1]
358
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
359
+ else:
360
+ initial_epoch = 0
361
+
362
+ if gpu_id == 0:
363
+ print('len_train_data: ',len(dataset))
364
+
365
+ # Proactive restart: track steps since process start to exit before OOM.
366
+ max_steps_restart = args.max_steps_before_restart
367
+ steps_since_start = 0
368
+
369
+ # Training loop
370
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
371
+ if use_distributed and sampler is not None:
372
+ sampler.set_epoch(epoch)
373
+ sampler_p.set_epoch(epoch)
374
+
375
+ epoch_loss_tot = 0.0
376
+ epoch_loss_gen_d = 0.0
377
+ epoch_loss_gen_a = 0.0
378
+ epoch_loss_reg = 0.0
379
+ epoch_loss_regist = 0.0
380
+ epoch_loss_imgsim = 0.0
381
+ epoch_loss_imgmse = 0.0
382
+ epoch_loss_ddfreg = 0.0
383
+ epoch_loss_contrastive = 0.0
384
+ total_contra = 0
385
+ total_reg_restored = None
386
+ total_contra_restored = None
387
+
388
+ # Restore epoch accumulators from mid-epoch checkpoint (only for the resumed epoch)
389
+ if _resume_epoch_stats is not None and epoch == initial_epoch:
390
+ epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0)
391
+ epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0)
392
+ epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0)
393
+ epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0)
394
+ epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0)
395
+ epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0)
396
+ epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0)
397
+ epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0)
398
+ epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0)
399
+ total_reg_restored = _resume_epoch_stats.get('total_reg', None)
400
+ total_contra_restored = _resume_epoch_stats.get('total_contra', None)
401
+ loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0)
402
+ # RNG states are restored INSIDE the skip loop (at the last skipped step)
403
+ # to avoid DataLoader __getitem__ calls corrupting the restored state.
404
+ _resume_rng = {k: _resume_epoch_stats[k] for k in
405
+ ('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda')
406
+ if k in _resume_epoch_stats}
407
+ if gpu_id == 0:
408
+ print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True)
409
+ _resume_epoch_stats = None # Only restore once
410
+ else:
411
+ loss_nan_step = 0 # only reset when NOT resuming mid-epoch
412
+
413
+ # Set model inside to train model
414
+ Deformddpm.train()
415
+
416
+ total = min(len(train_loader), len(train_loader_p))
417
+ total_reg = total // REGISTRATION_STEP_RATIO
418
+ # Restore total_reg and total_contra from checkpoint if available (mid-epoch resume)
419
+ if total_reg_restored is not None:
420
+ total_reg = total_reg_restored
421
+ total_reg_restored = None
422
+ if total_contra_restored is not None:
423
+ total_contra = total_contra_restored
424
+ total_contra_restored = None
425
+ # for step, batch in tqdm(enumerate(train_loader)):
426
+ # for step, batch in tqdm(enumerate(train_loader)):
427
+ # for step, batch in enumerate(train_loader_omni):
428
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
429
+
430
+ # Skip steps already completed (mid-epoch resume).
431
+ # Checkpoint at step N is saved AFTER step N's training completes,
432
+ # so step N itself must also be skipped (use <=, not <).
433
+ if epoch == initial_epoch and initial_step > 0 and step <= initial_step:
434
+ # Restore RNG at the last skipped step, AFTER DataLoader __getitem__
435
+ # has consumed RNG for all skipped batches. This way the first
436
+ # non-skipped step starts with exactly the saved RNG state.
437
+ if step == initial_step and _resume_rng is not None:
438
+ # Restore rank 0's RNG as base state, then re-seed per-rank
439
+ # so each rank has independent RNG (matching continuous run's
440
+ # divergent-per-rank behavior). Without this, all ranks would
441
+ # share rank 0's RNG → correlated augmentation/dropout decisions.
442
+ if 'rng_torch' in _resume_rng:
443
+ torch.set_rng_state(_resume_rng['rng_torch'])
444
+ if 'rng_numpy' in _resume_rng:
445
+ np.random.set_state(_resume_rng['rng_numpy'])
446
+ if 'rng_python' in _resume_rng:
447
+ random.setstate(_resume_rng['rng_python'])
448
+ if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu':
449
+ torch.xpu.set_rng_state(_resume_rng['rng_xpu'])
450
+ elif 'rng_cuda' in _resume_rng and torch.cuda.is_available():
451
+ torch.cuda.set_rng_state(_resume_rng['rng_cuda'])
452
+ # Per-rank re-seed: checkpoint only has rank 0's RNG state.
453
+ # Advance each rank's RNG by a deterministic offset so they
454
+ # diverge (as they would in a continuous run).
455
+ if gpu_id > 0:
456
+ rank_seed = gpu_id * 100003 + initial_step * 31
457
+ torch.manual_seed(torch.initial_seed() + rank_seed)
458
+ np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31))
459
+ random.seed(random.getrandbits(32) + rank_seed)
460
+ if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
461
+ torch.xpu.manual_seed(torch.initial_seed() + rank_seed)
462
+ elif torch.cuda.is_available():
463
+ torch.cuda.manual_seed(torch.initial_seed() + rank_seed)
464
+ _resume_rng = None
465
+ if gpu_id == 0:
466
+ print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True)
467
+ continue
468
+
469
+ # Free registration tensors from previous step
470
+ x1 = y1 = ddf_comp = img_rec = img_diff = None
471
+ ddf_rand = y1_proc = msk_tgt = img_save = None
472
+ loss_regist = loss_sim = loss_mse = loss_ddf1 = None
473
+
474
+ # Memory diagnostic (one per node via local rank 0) — only warn when abnormal
475
+ # Normal at step start: ~16 GiB reserved, ~48 GiB free (of 64 GiB total)
476
+ if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
477
+ torch.xpu.reset_peak_memory_stats(rank)
478
+ free_mem, total_mem_dev = torch.xpu.mem_get_info(rank)
479
+ used_gib = (total_mem_dev - free_mem) / 1024**3
480
+ if used_gib > 24: # Normal is ~16 GiB at step start; warn if accumulating
481
+ alloc = torch.xpu.memory_allocated() / 1024**3
482
+ reserved = torch.xpu.memory_reserved() / 1024**3
483
+ free_gib = free_mem / 1024**3
484
+ print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: "
485
+ f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), "
486
+ f"{free_gib:.1f} GiB free", flush=True)
487
+
488
+ # ==========================================================================
489
+ # diffusion train on single image
490
+
491
+ # x0 = batch # for omni dataset
492
+ [x0,embd] = batch # for om dataset
493
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
494
+ # print('embd:', embd.shape)
495
+ embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
496
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
497
+ embd_in = embd_dev
498
+ else:
499
+ embd_in = None
500
+
501
+
502
+
503
+ n = x0.size()[0] # batch_size -> n
504
+ x0 = x0.to(hyp_parameters["device"])
505
+
506
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
507
+
508
+ # random deformation + rotation
509
+ if hyp_parameters["ndims"]>2:
510
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
511
+ x0 = utils.random_resample(x0, deform_scale=0)
512
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
513
+ else:
514
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
515
+ # x0 = transformer(x0)
516
+ if hyp_parameters['noise_scale']>0:
517
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
518
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
519
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
520
+
521
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
522
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
523
+ hyp_parameters["device"]
524
+ ) # pick up a seq of rand number from 0 to 'timestep'
525
+
526
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
527
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
528
+ # print('proc_type:', proc_type)
529
+ ddpm = Deformddpm.module if use_distributed else Deformddpm
530
+ cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
531
+
532
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
533
+
534
+ loss_tot=0
535
+
536
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
537
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
538
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
539
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
540
+
541
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
542
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
543
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
544
+
545
+ # >> JZ: print nan in x0
546
+ if torch.isnan(x0).any():
547
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
548
+ # >> JZ: print loss of ddf
549
+ if loss_ddf>0.001:
550
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
551
+ # yu: check if loss_tot==nan or inf
552
+ # Synchronize NaN skip across all DDP ranks to avoid collective desync
553
+ # Use broadcast from rank 0 instead of all_reduce to avoid CCL hang on single-node XPU
554
+ is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
555
+ if use_distributed:
556
+ nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
557
+ dist.broadcast(nan_flag, src=0)
558
+ is_nan = nan_flag.item() > 0
559
+ if is_nan:
560
+ if gpu_id == 0:
561
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
562
+ loss_nan_step += 1
563
+ continue
564
+ if loss_nan_step > 5:
565
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
566
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
567
+
568
+ # ==========================================================================
569
+ # Diffusion backward (no gradient clipping — diffusion dominates training)
570
+ if not args.eval_only:
571
+ optimizer.zero_grad()
572
+ loss_tot.backward()
573
+ optimizer.step()
574
+
575
+ epoch_loss_tot += loss_tot.item() / total
576
+ epoch_loss_gen_d += loss_gen_d.item() / total
577
+ epoch_loss_gen_a += loss_gen_a.item() / total
578
+ epoch_loss_reg += loss_ddf.item() / total
579
+
580
+ # Print running average every 20 steps in eval-only mode
581
+ if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0:
582
+ n = step + 1
583
+ print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} "
584
+ f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True)
585
+
586
+ # Free diffusion intermediates and aggressively release all memory to device.
587
+ # XPU runtime leaks ~1.3 GiB/step outside the caching allocator.
588
+ # gc.collect() + synchronize() + empty_cache() attempts to reclaim deferred/lazy allocations.
589
+ loss_gen_a_val = loss_gen_a.item()
590
+ del pre_dvf_I, dvf_I, trm_pred, loss_tot, loss_gen_a, loss_gen_d, loss_ddf
591
+ gc.collect()
592
+ if DEVICE_TYPE == 'xpu':
593
+ torch.xpu.synchronize()
594
+ _empty_cache(DEVICE_TYPE)
595
+
596
+ # Sync loss_gen_a across DDP ranks for contrastive and registration gating
597
+ if use_distributed:
598
+ loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}")
599
+ dist.broadcast(loss_gen_a_sync, src=0)
600
+ loss_gen_a_gate = loss_gen_a_sync.item()
601
+ else:
602
+ loss_gen_a_gate = loss_gen_a_val
603
+
604
+ # ==========================================================================
605
+ # Contrastive train on single image (text-image alignment)
606
+ # Separate backward with gradient clipping to prevent destabilizing diffusion.
607
+ loss_contra_val = None
608
+ if step % CONTRASTIVE_STEP_RATIO == 0:
609
+ n_contra = x0.size()[0]
610
+ t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
611
+ # Route through DDP wrapper and return img_embd directly so DDP
612
+ # traces the correct subgraph (encoder + mid + attn + img2txt).
613
+ img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None) # [B, 1024]
614
+ loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.25)
615
+
616
+ if not args.eval_only:
617
+ optimizer.zero_grad()
618
+ loss_contra.backward()
619
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=1e-3)
620
+ optimizer.step()
621
+ loss_contra_val = loss_contra.item()
622
+ epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
623
+
624
+ # Free remaining intermediates and aggressively release memory before registration
625
+ if cond_img is not None:
626
+ del cond_img
627
+ if blind_mask is not None:
628
+ del blind_mask
629
+ gc.collect()
630
+ if DEVICE_TYPE == 'xpu':
631
+ torch.xpu.synchronize()
632
+ _empty_cache(DEVICE_TYPE)
633
+
634
+ # ==========================================================================
635
+ # registration train on paired images
636
+ # loss_gen_a_gate already synced across DDP ranks above
637
+ do_regist = step % REGISTRATION_STEP_RATIO == 0 and loss_gen_a_gate < -0.8
638
+ if do_regist:
639
+ [x1, y1, _, embd_y] = batch_p
640
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
641
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
642
+ else:
643
+ embd_y = None
644
+
645
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
646
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
647
+ n = x1.size()[0] # batch_size -> n
648
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
649
+ if hyp_parameters['noise_scale']>0:
650
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
651
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
652
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
653
+ x1 = x1 * random_scale + random_shift
654
+ y1 = y1 * random_scale + random_shift
655
+
656
+ scale_regist = np.random.uniform(0.0,0.5)
657
+ select_timestep = np.random.randint(12, 32) # select a random number of timesteps to sample, between 8 and 16
658
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
659
+
660
+ T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
661
+
662
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
663
+ ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
664
+ y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
665
+ msk_tgt = msk_tgt+MSK_EPS
666
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
667
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
668
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
669
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
670
+
671
+ loss_regist = 0
672
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
673
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
674
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
675
+
676
+ # >> JZ: print nan in x0
677
+ if torch.isnan(x0).any():
678
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
679
+ # >> JZ: print loss of ddf
680
+ if loss_ddf1>0.002:
681
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
682
+
683
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
684
+ if not args.eval_only:
685
+ optimizer.zero_grad()
686
+ loss_regist.backward()
687
+
688
+ # torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
689
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
690
+ optimizer.step()
691
+
692
+ epoch_loss_regist += loss_regist.item()
693
+ epoch_loss_imgsim += loss_sim.item()
694
+ epoch_loss_imgmse += loss_mse.item()
695
+ epoch_loss_ddfreg += loss_ddf1.item()
696
+ else:
697
+ loss_sim = torch.tensor(0.0)
698
+ loss_mse = torch.tensor(0.0)
699
+ loss_ddf1 = torch.tensor(0.0)
700
+ loss_regist = torch.tensor(0.0)
701
+ if step % REGISTRATION_STEP_RATIO==0:
702
+ total_reg = total_reg-1
703
+
704
+ # Mid-epoch checkpoint and proactive restart (only when --max-steps-before-restart > 0)
705
+ if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save:
706
+ _epoch_stats = {
707
+ 'epoch_loss_tot': epoch_loss_tot,
708
+ 'epoch_loss_gen_d': epoch_loss_gen_d,
709
+ 'epoch_loss_gen_a': epoch_loss_gen_a,
710
+ 'epoch_loss_reg': epoch_loss_reg,
711
+ 'epoch_loss_regist': epoch_loss_regist,
712
+ 'epoch_loss_imgsim': epoch_loss_imgsim,
713
+ 'epoch_loss_imgmse': epoch_loss_imgmse,
714
+ 'epoch_loss_ddfreg': epoch_loss_ddfreg,
715
+ 'epoch_loss_contrastive': epoch_loss_contrastive,
716
+ 'total_reg': total_reg,
717
+ 'total_contra': total_contra,
718
+ 'loss_nan_step': loss_nan_step,
719
+ 'rng_torch': torch.get_rng_state(),
720
+ 'rng_numpy': np.random.get_state(),
721
+ 'rng_python': random.getstate(),
722
+ **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
723
+ {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
724
+ }
725
+ tmp_dir = os.path.join(model_save_path, "tmp")
726
+ os.makedirs(tmp_dir, exist_ok=True)
727
+ for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
728
+ os.remove(old_f)
729
+ mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
730
+ state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
731
+ torch.save({
732
+ 'model_state_dict': state,
733
+ 'optimizer_state_dict': optimizer.state_dict(),
734
+ 'epoch': epoch,
735
+ 'step': step,
736
+ 'epoch_stats': _epoch_stats,
737
+ }, mid_save)
738
+ print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
739
+
740
+ # Proactive restart: exit cleanly after N steps to reset XPU memory leak.
741
+ # The bash wrapper will re-launch srun within the same SLURM allocation.
742
+ steps_since_start += 1
743
+ if max_steps_restart > 0 and steps_since_start >= max_steps_restart:
744
+ # Save checkpoint at current position (if not just saved above)
745
+ if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save:
746
+ _epoch_stats = {
747
+ 'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d,
748
+ 'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg,
749
+ 'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim,
750
+ 'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg,
751
+ 'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra,
752
+ 'loss_nan_step': loss_nan_step,
753
+ 'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(),
754
+ 'rng_python': random.getstate(),
755
+ **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
756
+ {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
757
+ }
758
+ tmp_dir = os.path.join(model_save_path, "tmp")
759
+ os.makedirs(tmp_dir, exist_ok=True)
760
+ for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
761
+ os.remove(old_f)
762
+ mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
763
+ state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
764
+ torch.save({
765
+ 'model_state_dict': state,
766
+ 'optimizer_state_dict': optimizer.state_dict(),
767
+ 'epoch': epoch,
768
+ 'step': step,
769
+ 'epoch_stats': _epoch_stats,
770
+ }, mid_save)
771
+ print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
772
+ if gpu_id == 0:
773
+ print(f" [restart] Proactive restart after {steps_since_start} steps "
774
+ f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True)
775
+ # Clean shutdown
776
+ _empty_cache(DEVICE_TYPE)
777
+ gc.collect()
778
+ if use_distributed and dist.is_initialized():
779
+ dist.barrier()
780
+ dist.destroy_process_group()
781
+ sys.exit(EXIT_CODE_RESTART)
782
+
783
+ if gpu_id == 0:
784
+ print('==================')
785
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
786
+ print(f' loss_contrastive: {epoch_loss_contrastive}')
787
+ total_reg_safe = max(total_reg, 1)
788
+ print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)')
789
+ print('==================')
790
+
791
+
792
+ if 0 == epoch % epoch_per_save and not args.no_save:
793
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
794
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
795
+ # break # FOR TESTING
796
+ if not use_distributed:
797
+ print(f"saved in {save_dir}")
798
+ # torch.save(Deformddpm.state_dict(), save_dir)
799
+ torch.save({
800
+ 'model_state_dict': Deformddpm.state_dict(),
801
+ 'optimizer_state_dict': optimizer.state_dict(),
802
+ 'epoch': epoch
803
+ }, save_dir)
804
+ elif gpu_id == 0:
805
+ print(f"saved in {save_dir}")
806
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
807
+ torch.save({
808
+ 'model_state_dict': Deformddpm.module.state_dict(),
809
+ 'optimizer_state_dict': optimizer.state_dict(),
810
+ 'epoch': epoch
811
+ }, save_dir)
812
+ # Clean up tmp/ mid-epoch checkpoints after completed epoch
813
+ if gpu_id == 0 and not args.no_save:
814
+ tmp_dir = os.path.join(model_dir, "tmp")
815
+ tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth"))
816
+ if tmp_pths:
817
+ for f in tmp_pths:
818
+ os.remove(f)
819
+ print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True)
820
+ # Reset initial_step after first epoch completes (no more skipping)
821
+ initial_step = 0
822
+
823
+ # XPU CCL workaround: restart after each epoch to avoid CCL hang on 2nd epoch.
824
+ # CCL's Level Zero IPC handles accumulate and cause deadlock after ~200+ collectives.
825
+ # A fresh process resets the L0 context. The bash loop catches exit code 42 and restarts.
826
+ if DEVICE_TYPE == 'xpu' and use_distributed:
827
+ if gpu_id == 0:
828
+ print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True)
829
+ _empty_cache(DEVICE_TYPE)
830
+ gc.collect()
831
+ if dist.is_initialized():
832
+ dist.barrier()
833
+ dist.destroy_process_group()
834
+ sys.exit(EXIT_CODE_RESTART)
835
+
836
+ # Resource cleanup at the end of training
837
+ _empty_cache(DEVICE_TYPE)
838
+ gc.collect()
839
+ if use_distributed and dist.is_initialized():
840
+ dist.destroy_process_group()
841
+
842
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
843
+
844
+ # All ranks load checkpoint so optimizer state is consistent across DDP processes.
845
+ # (Optimizer state includes per-parameter Adam momentum/variance which are NOT
846
+ # broadcast — only model weights are broadcast. Without this, non-rank-0 processes
847
+ # would have fresh Adam state after restart.)
848
+ gc.collect()
849
+ _empty_cache(DEVICE_TYPE)
850
+ if gpu_id == 0:
851
+ utils.print_memory_usage("Before Loading Model")
852
+ checkpoint = torch.load(model_file, map_location='cpu', weights_only=False)
853
+ if use_distributed:
854
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
855
+ else:
856
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
857
+ # Restore optimizer state when available (needed for mid-epoch resume).
858
+ # Selective loading: load states for parameters with matching shapes, skip mismatched ones
859
+ # (e.g., UpsampleConv replaced ConvTranspose3d — different kernel shapes).
860
+ # After one epoch, the saved checkpoint will have correct state for ALL parameters.
861
+ if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer:
862
+ saved_opt = checkpoint['optimizer_state_dict']
863
+ saved_state = saved_opt.get('state', {})
864
+ param_list = [p for group in optimizer.param_groups for p in group['params']]
865
+
866
+ # Check if all shapes match (fast path: full load)
867
+ all_match = True
868
+ skipped = 0
869
+ for idx, s in saved_state.items():
870
+ if int(idx) < len(param_list):
871
+ p = param_list[int(idx)]
872
+ for k, v in s.items():
873
+ if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape:
874
+ all_match = False
875
+ break
876
+ if not all_match:
877
+ break
878
+
879
+ if all_match:
880
+ optimizer.load_state_dict(saved_opt)
881
+ else:
882
+ # Selective load: restore param_groups settings (lr, betas, etc.)
883
+ for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups):
884
+ for k, v in saved_g.items():
885
+ if k != 'params':
886
+ group[k] = v
887
+ # Restore per-parameter state only where shapes match
888
+ for idx, s in saved_state.items():
889
+ idx_int = int(idx)
890
+ if idx_int < len(param_list):
891
+ p = param_list[idx_int]
892
+ shapes_ok = all(
893
+ v.shape == p.shape for k, v in s.items()
894
+ if isinstance(v, torch.Tensor) and v.dim() > 0
895
+ )
896
+ if shapes_ok:
897
+ # Cast state tensors to match parameter dtype/device
898
+ new_state = {}
899
+ for k, v in s.items():
900
+ if isinstance(v, torch.Tensor):
901
+ new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v
902
+ else:
903
+ new_state[k] = v
904
+ optimizer.state[p] = new_state
905
+ else:
906
+ skipped += 1
907
+ if gpu_id == 0:
908
+ loaded = len(saved_state) - skipped
909
+ print(f" [checkpoint] Selective optimizer load: {loaded} params restored, "
910
+ f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True)
911
+ elif args.reset_optimizer and gpu_id == 0:
912
+ print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True)
913
+ del checkpoint
914
+ if gpu_id == 0:
915
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
916
+
917
+ if use_distributed:
918
+ # Broadcast model weights from rank 0 to ensure exact consistency
919
+ dist.barrier()
920
+ for param in Deformddpm.parameters():
921
+ dist.broadcast(param.data, src=0)
922
+
923
+ # get the epoch number from the filename
924
+ basename = os.path.basename(model_file)
925
+ epoch_from_file = int(basename[:6])
926
+ if '_step' in basename:
927
+ # Mid-epoch checkpoint: resume at same epoch (don't +1)
928
+ initial_epoch = epoch_from_file
929
+ else:
930
+ # End-of-epoch checkpoint: start next epoch
931
+ initial_epoch = epoch_from_file + 1
932
+
933
+ return initial_epoch, Deformddpm, optimizer
934
+
935
+
936
+
937
+ if __name__ == "__main__":
938
+ if "LOCAL_RANK" in os.environ:
939
+ # Multi-node: launched by torchrun / srun
940
+ use_distributed = True
941
+ local_rank = int(os.environ["LOCAL_RANK"])
942
+ world_size = int(os.environ["WORLD_SIZE"])
943
+ print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
944
+ try:
945
+ main_train(local_rank, world_size)
946
+ except Exception as e:
947
+ import traceback
948
+ print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
949
+ traceback.print_exc()
950
+ raise
951
+ elif use_distributed:
952
+ # Single-node multi-GPU: use mp.spawn
953
+ world_size = _device_count(DEVICE_TYPE)
954
+ print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
955
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
956
+ else:
957
+ main_train(0,1)
OM_train_3modes.py CHANGED
@@ -1,4 +1,8 @@
1
- import os
 
 
 
 
2
  import gc
3
  import torch
4
  import torchvision
@@ -9,21 +13,32 @@ from torch.utils.data import DataLoader
9
  from torch.optim import Adam, SGD
10
  from Diffusion.diffuser import DeformDDPM
11
  from Diffusion.networks import get_net, STN
12
- from torchvision.transforms import Lambda
 
13
  import Diffusion.losses as losses
14
  import random
15
  import glob
16
  import numpy as np
17
  import utils
18
- from tqdm import tqdm
19
 
20
- from Dataloader.dataloader0 import get_dataloader
21
  from Dataloader.dataLoader import *
22
 
23
  from Dataloader.dataloader_utils import thresh_img
24
  import yaml
25
  import argparse
26
 
 
 
 
 
 
 
 
 
 
 
27
  ####################
28
  import torch.multiprocessing as mp
29
  from torch.utils.data.distributed import DistributedSampler
@@ -31,27 +46,66 @@ from torch.nn.parallel import DistributedDataParallel as DDP
31
  import torch.distributed as dist
32
  # from torch.distributed import init_process_group
33
  ###############
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def ddp_setup(rank, world_size):
35
  """
36
  Args:
37
- rank: Unique identifier of each process
38
  world_size: Total number of processes
39
  """
40
- os.environ["MASTER_ADDR"] = "localhost"
41
- os.environ["MASTER_PORT"] = "12355"
42
- dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
43
- torch.cuda.set_device(rank)
44
-
45
- use_distributed = True
46
- # use_distributed = False
 
 
 
 
47
 
48
  EPS = 1e-5
49
  MSK_EPS = 0.01
50
- TEXT_EMBED_PROB = 0.7
51
- AUG_RESAMPLE_PROB = 0.6
52
- LOSS_WEIGHTS_DIFF = [2.0, 1.0, 3.0] # [ang, dist, reg]
53
  # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
54
- LOSS_WEIGHTS_REGIST = [1.0, 0.2, 1e3] # [imgsim, imgmse, ddf]
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # AUG_PERMUTE_PROB = 0.35
57
 
@@ -68,23 +122,73 @@ parser.add_argument(
68
  default="Config/config_all.yaml",
69
  required=False,
70
  )
 
 
 
 
 
 
 
 
 
 
 
 
71
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
72
  #=======================================================================================================================
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
 
76
  def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
77
  if use_distributed:
78
  ddp_setup(rank,world_size)
79
 
80
- if torch.distributed.is_initialized():
81
  print(f"World size: {torch.distributed.get_world_size()}")
82
  print(f"Communication backend: {torch.distributed.get_backend()}")
83
- gpu_id = rank
 
 
 
 
 
 
 
 
 
84
 
85
  # Load the YAML file into a dictionary
86
  with open(args.config, 'r') as file:
87
  hyp_parameters = yaml.safe_load(file)
 
 
 
88
  print(hyp_parameters)
89
 
90
  # epoch_per_save=10
@@ -98,7 +202,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
98
  suffix_pth=f'_{data_name}_{net_name}.pth'
99
  model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
100
  model_dir=model_save_path
101
- transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
102
 
103
  # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
104
 
@@ -115,33 +219,54 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
115
  # drop_last=True,
116
  # )
117
 
118
- # dataset = OminiDataset_v1(transform=None)
119
- dataset = OMDataset_indiv(transform=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  train_loader = DataLoader(
121
  dataset,
122
  batch_size=hyp_parameters['batchsize'],
123
- shuffle=True,
124
  drop_last=True,
 
125
  )
126
-
127
- # datasetp = OminiDataset_paired(transform=None)
128
- datasetp = OMDataset_pair(transform=None)
129
  train_loader_p = DataLoader(
130
  datasetp,
131
- batch_size=hyp_parameters['batchsize']//2,
132
- shuffle=True,
133
  drop_last=True,
 
134
  )
135
 
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  Deformddpm = DeformDDPM(
139
- network=Net(
140
- n_steps=hyp_parameters["timesteps"],
141
- ndims=hyp_parameters["ndims"],
142
- num_input_chn = hyp_parameters["num_input_chn"],
143
- res = hyp_parameters['img_size']
144
- ),
145
  n_steps=hyp_parameters["timesteps"],
146
  image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
147
  device=hyp_parameters["device"],
@@ -161,9 +286,18 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
161
 
162
 
163
  if use_distributed:
164
- Deformddpm.to(rank)
165
- Deformddpm = DDP(Deformddpm, device_ids=[rank])
166
- ddf_stn.to(rank)
 
 
 
 
 
 
 
 
 
167
  else:
168
  Deformddpm.to(hyp_parameters["device"])
169
  ddf_stn.to(hyp_parameters["device"])
@@ -172,12 +306,14 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
172
 
173
  # mse = nn.MSELoss()
174
  # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
175
- loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e2)
176
- loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e2)
 
 
177
  loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
178
  # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
179
  loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
180
- loss_imgsim = losses.LNCC()
181
  loss_imgmse = losses.LMSE()
182
 
183
  optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
@@ -194,19 +330,51 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
194
  # check for existing models
195
  if not os.path.exists(model_dir):
196
  os.makedirs(model_dir, exist_ok=True)
197
- model_files = glob.glob(os.path.join(model_dir, "*.pth"))
198
- model_files.sort()
199
- if model_files:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if gpu_id == 0:
201
  print(model_files)
202
- initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
 
203
  else:
204
  initial_epoch = 0
205
 
206
  if gpu_id == 0:
207
  print('len_train_data: ',len(dataset))
 
 
 
 
 
 
208
  # Training loop
209
  for epoch in range(initial_epoch,hyp_parameters["epoch"]):
 
 
 
210
 
211
  epoch_loss_tot = 0.0
212
  epoch_loss_gen_d = 0.0
@@ -216,17 +384,110 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
216
  epoch_loss_imgsim = 0.0
217
  epoch_loss_imgmse = 0.0
218
  epoch_loss_ddfreg = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Set model inside to train model
220
  Deformddpm.train()
221
-
222
- loss_nan_step = 0 # yu: count the number of nan loss steps
223
 
224
- for step, batch in tqdm(enumerate(train_loader)):
 
 
 
 
 
 
 
 
 
225
  # for step, batch in tqdm(enumerate(train_loader)):
226
-
227
  # for step, batch in enumerate(train_loader_omni):
228
- # x0, _ = batch
229
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  # ==========================================================================
232
  # diffusion train on single image
@@ -235,12 +496,11 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
235
  [x0,embd] = batch # for om dataset
236
  x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
237
  # print('embd:', embd.shape)
 
238
  if np.random.uniform(0,1)<TEXT_EMBED_PROB:
239
- embd = embd.to(hyp_parameters["device"]).type(torch.float32)
240
  else:
241
- embd = None
242
-
243
-
244
 
245
  n = x0.size()[0] # batch_size -> n
246
  x0 = x0.to(hyp_parameters["device"])
@@ -254,10 +514,10 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
254
  # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
255
  else:
256
  [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
257
- x0 = transformer(x0)
258
  if hyp_parameters['noise_scale']>0:
259
  if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
260
- x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
261
  x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
262
 
263
  # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
@@ -266,157 +526,301 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
266
  ) # pick up a seq of rand number from 0 to 'timestep'
267
 
268
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
269
- proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
270
  # print('proc_type:', proc_type)
271
- cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
272
-
273
- pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
274
-
275
- loss_tot=0
276
-
277
- loss_ddf = loss_reg(pre_dvf_I,img=x0)
278
- trm_pred = ddf_stn(pre_dvf_I, dvf_I)
279
- loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
280
- loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
281
-
282
- loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
283
- loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
284
- loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
285
-
286
- # >> JZ: print nan in x0
287
- if torch.isnan(x0).any():
288
- print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
289
- # >> JZ: print loss of ddf
290
- if loss_ddf>0.001:
291
- print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
292
- # yu: check if loss_tot==nan or inf
293
- if torch.isnan(loss_tot) or torch.isinf(loss_tot):
294
- print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
295
- loss_nan_step += 1
296
- continue
297
- if loss_nan_step > 5:
298
- print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
299
- raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
300
 
 
 
 
 
 
301
 
302
- optimizer.zero_grad()
303
- loss_tot.backward()
304
- optimizer.step()
 
305
 
306
- epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
307
- epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
308
- epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
309
- epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- # print(loss_gen_a.item())
312
- # if 0:
313
- # if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
314
- if step%train_mode_ratio == 0:
315
  # ==========================================================================
316
- # registration train on paired images
317
- # x1, y1 = next(iter(train_loader_p))
318
- [x1, y1, _, embd_y] = next(iter(train_loader_p))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  if np.random.uniform(0,1)<TEXT_EMBED_PROB:
320
- # embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
321
  embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
322
  else:
323
- # embd_x = None
324
  embd_y = None
325
 
326
  x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
327
  y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
328
  n = x1.size()[0] # batch_size -> n
329
- # random deformation + rotation
330
- # if hyp_parameters["ndims"]>2:
331
- # if np.random.uniform(0,1)<0.6:
332
- # x1 = utils.random_resample(x1, deform_scale=0)
333
- # y1 = utils.random_resample(y1, deform_scale=0)
334
- x1 = transformer(x1)
335
- y1 = transformer(y1)
336
  [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
337
  if hyp_parameters['noise_scale']>0:
338
- x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
339
- x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
340
- y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
341
- y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
342
- # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
343
- t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
344
- hyp_parameters["device"]
345
- ) # pick up a seq of rand number from 0 to 'timestep'
346
-
347
-
348
- scale_regist = np.random.uniform(0.6,1.)
349
- T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
350
- # scale_regist = np.random.uniform(0.4,1.)
351
- # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
352
- # scale_regist = np.random.uniform(0.6,1.)
353
- # init_T = int(hyp_parameters["timesteps"] * scale_regist)
354
- # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
355
-
356
- T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
357
-
358
- # print('T_regist:', T_regist)
359
- # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
360
- proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'none'])
361
- # proc_type = random.choice(['project'])
362
- y1, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
363
- msk_tgt = msk_tgt + MSK_EPS
364
- [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
365
- loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
366
  loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
367
- loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
 
368
 
369
  loss_regist = 0
370
  loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
371
  loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
372
  loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
373
- # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
374
- # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
375
 
376
  # >> JZ: print nan in x0
377
  if torch.isnan(x0).any():
378
  print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
379
  # >> JZ: print loss of ddf
380
- if loss_ddf1>0.001:
381
  print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
382
-
383
- loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
384
- optimizer.zero_grad()
385
- loss_regist.backward()
386
 
387
- # # Print gradients for each parameter
388
- # for name, param in Deformddpm.named_parameters():
389
- # if param.grad is not None:
390
- # print(f"Gradient for {name}: {param.grad.norm()}")
391
- # else:
392
- # print(f"Gradient for {name}: None")
393
-
394
- torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
395
- optimizer.step()
396
 
397
- epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
398
- epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
399
- epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
400
- epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
-
404
- # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
405
-
406
- # break # FOR TESTING
407
- # else:
408
- # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
409
- # pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
- if 1:
412
- # if gpu_id == 0:
413
  print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
414
- print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
 
 
 
415
 
416
- # # LR schedular step ----- YHM
417
- # scheduler.step()
418
 
419
- if 0 == epoch % epoch_per_save:
420
  save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
421
  os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
422
  # break # FOR TESTING
@@ -436,55 +840,150 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
436
  'optimizer_state_dict': optimizer.state_dict(),
437
  'epoch': epoch
438
  }, save_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
 
440
  # Resource cleanup at the end of training
441
- torch.cuda.empty_cache()
442
  gc.collect()
443
  if use_distributed and dist.is_initialized():
444
  dist.destroy_process_group()
445
 
446
- def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
447
-
 
 
 
 
 
 
448
  if gpu_id == 0:
449
- # if 0:
450
  utils.print_memory_usage("Before Loading Model")
451
- if 1:
452
- gc.collect()
453
- torch.cuda.empty_cache()
454
- # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
455
- # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
456
- checkpoint = torch.load(model_file)
457
- # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
458
- if use_distributed:
459
- Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  else:
461
- Deformddpm.load_state_dict(checkpoint['model_state_dict'])
462
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  utils.print_memory_usage("After Loading Checkpoint on GPU")
464
 
465
  if use_distributed:
466
- # Broadcast model weights from rank 0 to all other GPUs
467
  dist.barrier()
468
  for param in Deformddpm.parameters():
469
- dist.broadcast(param.data, src=0) # Synchronize model across ranks
470
- dist.barrier()
471
- for param_group in optimizer.param_groups:
472
- for param in param_group['params']:
473
- if param.grad is not None:
474
- dist.broadcast(param.grad, src=0) # Sync optimizer gradients
475
 
476
- # initial_epoch = checkpoint['epoch'] + 1
477
- # get the epoch number from the filename and add 1 to set as initial_epoch
478
- initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
 
 
 
 
 
 
479
 
480
  return initial_epoch, Deformddpm, optimizer
481
 
482
 
483
 
484
  if __name__ == "__main__":
485
- if use_distributed:
486
- world_size = torch.cuda.device_count()
487
- print(f"Distributed GPU number = {world_size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  mp.spawn(main_train,args = (world_size,),nprocs = world_size)
489
  else:
490
  main_train(0,1)
 
1
+ import os, sys, contextlib
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
  import gc
7
  import torch
8
  import torchvision
 
13
  from torch.optim import Adam, SGD
14
  from Diffusion.diffuser import DeformDDPM
15
  from Diffusion.networks import get_net, STN
16
+ # from torchvision.transforms import Lambda
17
+ import torch.nn.functional as F
18
  import Diffusion.losses as losses
19
  import random
20
  import glob
21
  import numpy as np
22
  import utils
23
+ from tqdm import tqdm
24
 
25
+ # from Dataloader.dataloader0 import get_dataloader
26
  from Dataloader.dataLoader import *
27
 
28
  from Dataloader.dataloader_utils import thresh_img
29
  import yaml
30
  import argparse
31
 
32
+ # XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
33
+ try:
34
+ import intel_extension_for_pytorch as ipex
35
+ except ImportError:
36
+ ipex = None
37
+ try:
38
+ import oneccl_bindings_for_pytorch
39
+ except (ImportError, Exception) as e:
40
+ print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
41
+
42
  ####################
43
  import torch.multiprocessing as mp
44
  from torch.utils.data.distributed import DistributedSampler
 
46
  import torch.distributed as dist
47
  # from torch.distributed import init_process_group
48
  ###############
49
+ def _device_available(device_type):
50
+ if device_type == 'xpu':
51
+ return hasattr(torch, 'xpu') and torch.xpu.is_available()
52
+ return torch.cuda.is_available()
53
+
54
+ def _device_count(device_type):
55
+ if device_type == 'xpu':
56
+ return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
57
+ return torch.cuda.device_count()
58
+
59
+ def _set_device(rank, device_type):
60
+ if device_type == 'xpu':
61
+ torch.xpu.set_device(rank)
62
+ else:
63
+ torch.cuda.set_device(rank)
64
+
65
+ def _empty_cache(device_type):
66
+ if device_type == 'xpu' and hasattr(torch, 'xpu'):
67
+ torch.xpu.empty_cache()
68
+ elif torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+
71
  def ddp_setup(rank, world_size):
72
  """
73
  Args:
74
+ rank: Unique identifier of each process (local_rank when launched by torchrun)
75
  world_size: Total number of processes
76
  """
77
+ backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
78
+ if "LOCAL_RANK" in os.environ:
79
+ # Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
80
+ dist.init_process_group(backend=backend)
81
+ _set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
82
+ else:
83
+ # Single-node mp.spawn
84
+ os.environ["MASTER_ADDR"] = "localhost"
85
+ os.environ["MASTER_PORT"] = "12355"
86
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
87
+ _set_device(rank, DEVICE_TYPE)
88
 
89
  EPS = 1e-5
90
  MSK_EPS = 0.01
91
+ TEXT_EMBED_PROB = 0.5
92
+ AUG_RESAMPLE_PROB = 0.5
93
+ LOSS_WEIGHTS_DIFF = [4.0, 2.0, 8.0] # [ang, dist, reg]
94
  # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
95
+ LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
96
+ DIFF_REG_BATCH_RATIO = 2
97
+ # LOSS_WEIGHT_CONTRASTIVE = 1e-4
98
+ LOSS_WEIGHT_CONTRASTIVE = 1e-1
99
+ REGISTRATION_STEP_RATIO = 1
100
+ CONTRASTIVE_STEP_RATIO = 1
101
+ ACCEPT_THRESH_CONTRASTIVE = 0.1
102
+ ACCEPT_THRESH_ANGLE = -0.8
103
+ MID_EPOCH_SAVE_STEPS = 1e4 # Save mid-epoch checkpoint every N steps for crash recovery.
104
+ # XPU autograd leaks ~1.0 GiB/step of device memory (Intel bug).
105
+ # With gradient checkpointing, training survives ~26 steps from fresh start,
106
+ # but fewer when carrying leaked memory from previous epoch.
107
+ # Save every 10 steps to minimize lost work on OOM crash.
108
+ EXIT_CODE_RESTART = 42 # Exit code signaling proactive restart (not a crash).
109
 
110
  # AUG_PERMUTE_PROB = 0.35
111
 
 
122
  default="Config/config_all.yaml",
123
  required=False,
124
  )
125
+ parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
126
+ parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
127
+ parser.add_argument("--max-steps-before-restart", type=int, default=0,
128
+ help="Proactive restart: exit after N training steps to reset XPU memory leak. "
129
+ "0=disabled (rely on OOM crash + auto-resubmit). "
130
+ "Recommended: 20 for XPU (survives ~26 steps max).")
131
+ parser.add_argument("--no-save", action="store_true", default=False,
132
+ help="Disable all checkpoint saving (for diagnostic/validation runs)")
133
+ parser.add_argument("--reset-optimizer", action="store_true",
134
+ help="Skip optimizer state loading from checkpoint (use when architecture changed)")
135
+ parser.add_argument("--eval-only", action="store_true",
136
+ help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)")
137
  args = parser.parse_args()
138
+
139
+ # Read config early to determine device type for DDP setup
140
+ with open(args.config, 'r') as _f:
141
+ _cfg = yaml.safe_load(_f)
142
+ DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
143
+
144
+ # Auto-detect: use DDP only when multiple devices are available
145
+ use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
146
+ # use_distributed = True
147
+ # use_distributed = False
148
  #=======================================================================================================================
149
 
150
+ class _DummyIndiv(torch.utils.data.Dataset):
151
+ def __init__(self, n, sz, embd_dim=1024):
152
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
153
+ def __len__(self): return self.n
154
+ def __getitem__(self, i):
155
+ return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
156
+
157
+ class _DummyPair(torch.utils.data.Dataset):
158
+ def __init__(self, n, sz, embd_dim=1024):
159
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
160
+ def __len__(self): return self.n
161
+ def __getitem__(self, i):
162
+ return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
163
+ np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
164
+ np.random.randn(self.embd_dim).astype(np.float32),
165
+ np.random.randn(self.embd_dim).astype(np.float32))
166
 
167
 
168
  def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
169
  if use_distributed:
170
  ddp_setup(rank,world_size)
171
 
172
+ if torch.distributed.is_initialized() and rank == 0:
173
  print(f"World size: {torch.distributed.get_world_size()}")
174
  print(f"Communication backend: {torch.distributed.get_backend()}")
175
+ print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}")
176
+ if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
177
+ props = torch.xpu.get_device_properties(0)
178
+ print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB")
179
+ # gpu_id = global rank (for save/print guards); rank = local device index
180
+ if "RANK" in os.environ:
181
+ gpu_id = int(os.environ["RANK"])
182
+ rank = int(os.environ["LOCAL_RANK"])
183
+ else:
184
+ gpu_id = rank
185
 
186
  # Load the YAML file into a dictionary
187
  with open(args.config, 'r') as file:
188
  hyp_parameters = yaml.safe_load(file)
189
+ if args.batchsize > 0:
190
+ hyp_parameters['batchsize'] = args.batchsize
191
+ if gpu_id == 0:
192
  print(hyp_parameters)
193
 
194
  # epoch_per_save=10
 
202
  suffix_pth=f'_{data_name}_{net_name}.pth'
203
  model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
204
  model_dir=model_save_path
205
+ # transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
206
 
207
  # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
208
 
 
219
  # drop_last=True,
220
  # )
221
 
222
+ if args.dummy_samples > 0:
223
+ dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
224
+ datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
225
+ else:
226
+ # dataset = OminiDataset_v1(transform=None)
227
+ dataset = OMDataset_indiv(transform=None)
228
+ # datasetp = OminiDataset_paired(transform=None)
229
+ datasetp = OMDataset_pair(transform=None)
230
+
231
+ if use_distributed:
232
+ sampler = DistributedSampler(dataset, shuffle=True)
233
+ sampler_p = DistributedSampler(datasetp, shuffle=True)
234
+ else:
235
+ sampler = None
236
+ sampler_p = None
237
+
238
  train_loader = DataLoader(
239
  dataset,
240
  batch_size=hyp_parameters['batchsize'],
241
+ shuffle=(sampler is None),
242
  drop_last=True,
243
+ sampler=sampler,
244
  )
 
 
 
245
  train_loader_p = DataLoader(
246
  datasetp,
247
+ batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
248
+ shuffle=(sampler_p is None),
249
  drop_last=True,
250
+ sampler=sampler_p,
251
  )
252
 
253
 
254
 
255
+ network = Net(
256
+ n_steps=hyp_parameters["timesteps"],
257
+ ndims=hyp_parameters["ndims"],
258
+ num_input_chn = hyp_parameters["num_input_chn"],
259
+ res = hyp_parameters['img_size']
260
+ )
261
+ # Enable gradient checkpointing on XPU to reduce peak activation memory.
262
+ # XPU autograd leaks ~1.0 GiB/step; lower peak buys more steps before OOM.
263
+ if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'):
264
+ network.use_checkpoint = True
265
+ if gpu_id == 0:
266
+ print(" [init] Gradient checkpointing enabled for XPU", flush=True)
267
+
268
  Deformddpm = DeformDDPM(
269
+ network=network,
 
 
 
 
 
270
  n_steps=hyp_parameters["timesteps"],
271
  image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
272
  device=hyp_parameters["device"],
 
286
 
287
 
288
  if use_distributed:
289
+ device = f"{DEVICE_TYPE}:{rank}"
290
+ # NO pre-allocation. CCL/oneDNN accumulate ~1.4 GiB/step of device memory outside
291
+ # PyTorch's caching allocator. Pre-allocating steals from that budget:
292
+ # 92% pre-alloc → crash at step 3, 78% → step 10, none (70% cap) → step 14.
293
+ # Instead, use empty_cache() between training phases to release unused cached memory
294
+ # back to the device for CCL/oneDNN.
295
+ if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
296
+ total_mem = torch.xpu.get_device_properties(rank).total_memory
297
+ print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True)
298
+ Deformddpm.to(device)
299
+ Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
300
+ ddf_stn.to(device)
301
  else:
302
  Deformddpm.to(hyp_parameters["device"])
303
  ddf_stn.to(hyp_parameters["device"])
 
306
 
307
  # mse = nn.MSELoss()
308
  # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
309
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
310
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
311
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
312
+
313
  loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
314
  # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
315
  loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
316
+ loss_imgsim = losses.MSLNCC()
317
  loss_imgmse = losses.LMSE()
318
 
319
  optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
 
330
  # check for existing models
331
  if not os.path.exists(model_dir):
332
  os.makedirs(model_dir, exist_ok=True)
333
+ # Check for checkpoints: first check tmp/ for mid-epoch, then main dir for epoch-level
334
+ tmp_dir = os.path.join(model_dir, "tmp")
335
+ tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth")))
336
+ model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
337
+ initial_step = 0
338
+
339
+ # Epoch stats and RNG states to restore when resuming from mid-epoch checkpoint
340
+ _resume_epoch_stats = None
341
+ _resume_rng = None
342
+
343
+ if tmp_files and not args.eval_only and args.max_steps_before_restart > 0:
344
+ # Mid-epoch checkpoint: only use when proactive restart is enabled
345
+ latest = tmp_files[-1]
346
+ if gpu_id == 0:
347
+ print(f" [resume] Found mid-epoch checkpoint: {latest}")
348
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
349
+ basename = os.path.basename(latest)
350
+ initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0])
351
+ _ckpt = torch.load(latest, map_location='cpu', weights_only=False)
352
+ _resume_epoch_stats = _ckpt.get('epoch_stats', None)
353
+ del _ckpt
354
+ if gpu_id == 0:
355
+ print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}"
356
+ f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True)
357
+ elif model_files:
358
  if gpu_id == 0:
359
  print(model_files)
360
+ latest = model_files[-1]
361
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
362
  else:
363
  initial_epoch = 0
364
 
365
  if gpu_id == 0:
366
  print('len_train_data: ',len(dataset))
367
+
368
+ # Proactive restart: track steps since process start to exit before OOM.
369
+ max_steps_restart = args.max_steps_before_restart
370
+ steps_since_start = 0
371
+ loss_contra_gate = 0.0
372
+
373
  # Training loop
374
  for epoch in range(initial_epoch,hyp_parameters["epoch"]):
375
+ if use_distributed and sampler is not None:
376
+ sampler.set_epoch(epoch)
377
+ sampler_p.set_epoch(epoch)
378
 
379
  epoch_loss_tot = 0.0
380
  epoch_loss_gen_d = 0.0
 
384
  epoch_loss_imgsim = 0.0
385
  epoch_loss_imgmse = 0.0
386
  epoch_loss_ddfreg = 0.0
387
+ epoch_loss_contrastive = 0.0
388
+ total_contra = 0
389
+ total_reg_restored = None
390
+ total_contra_restored = None
391
+
392
+ # Restore epoch accumulators from mid-epoch checkpoint (only for the resumed epoch)
393
+ if _resume_epoch_stats is not None and epoch == initial_epoch:
394
+ epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0)
395
+ epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0)
396
+ epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0)
397
+ epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0)
398
+ epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0)
399
+ epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0)
400
+ epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0)
401
+ epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0)
402
+ epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0)
403
+ total_reg_restored = _resume_epoch_stats.get('total_reg', None)
404
+ total_contra_restored = _resume_epoch_stats.get('total_contra', None)
405
+ loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0)
406
+ # RNG states are restored INSIDE the skip loop (at the last skipped step)
407
+ # to avoid DataLoader __getitem__ calls corrupting the restored state.
408
+ _resume_rng = {k: _resume_epoch_stats[k] for k in
409
+ ('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda')
410
+ if k in _resume_epoch_stats}
411
+ if gpu_id == 0:
412
+ print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True)
413
+ _resume_epoch_stats = None # Only restore once
414
+ else:
415
+ loss_nan_step = 0 # only reset when NOT resuming mid-epoch
416
+
417
  # Set model inside to train model
418
  Deformddpm.train()
 
 
419
 
420
+ total = min(len(train_loader), len(train_loader_p))
421
+ total_reg = total // REGISTRATION_STEP_RATIO
422
+ # Restore total_reg and total_contra from checkpoint if available (mid-epoch resume)
423
+ if total_reg_restored is not None:
424
+ total_reg = total_reg_restored
425
+ total_reg_restored = None
426
+ if total_contra_restored is not None:
427
+ total_contra = total_contra_restored
428
+ total_contra_restored = None
429
+ # for step, batch in tqdm(enumerate(train_loader)):
430
  # for step, batch in tqdm(enumerate(train_loader)):
 
431
  # for step, batch in enumerate(train_loader_omni):
432
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
433
+
434
+ # Skip steps already completed (mid-epoch resume).
435
+ # Checkpoint at step N is saved AFTER step N's training completes,
436
+ # so step N itself must also be skipped (use <=, not <).
437
+ if epoch == initial_epoch and initial_step > 0 and step <= initial_step:
438
+ # Restore RNG at the last skipped step, AFTER DataLoader __getitem__
439
+ # has consumed RNG for all skipped batches. This way the first
440
+ # non-skipped step starts with exactly the saved RNG state.
441
+ if step == initial_step and _resume_rng is not None:
442
+ # Restore rank 0's RNG as base state, then re-seed per-rank
443
+ # so each rank has independent RNG (matching continuous run's
444
+ # divergent-per-rank behavior). Without this, all ranks would
445
+ # share rank 0's RNG → correlated augmentation/dropout decisions.
446
+ if 'rng_torch' in _resume_rng:
447
+ torch.set_rng_state(_resume_rng['rng_torch'])
448
+ if 'rng_numpy' in _resume_rng:
449
+ np.random.set_state(_resume_rng['rng_numpy'])
450
+ if 'rng_python' in _resume_rng:
451
+ random.setstate(_resume_rng['rng_python'])
452
+ if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu':
453
+ torch.xpu.set_rng_state(_resume_rng['rng_xpu'])
454
+ elif 'rng_cuda' in _resume_rng and torch.cuda.is_available():
455
+ torch.cuda.set_rng_state(_resume_rng['rng_cuda'])
456
+ # Per-rank re-seed: checkpoint only has rank 0's RNG state.
457
+ # Advance each rank's RNG by a deterministic offset so they
458
+ # diverge (as they would in a continuous run).
459
+ if gpu_id > 0:
460
+ rank_seed = gpu_id * 100003 + initial_step * 31
461
+ torch.manual_seed(torch.initial_seed() + rank_seed)
462
+ np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31))
463
+ random.seed(random.getrandbits(32) + rank_seed)
464
+ if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
465
+ torch.xpu.manual_seed(torch.initial_seed() + rank_seed)
466
+ elif torch.cuda.is_available():
467
+ torch.cuda.manual_seed(torch.initial_seed() + rank_seed)
468
+ _resume_rng = None
469
+ if gpu_id == 0:
470
+ print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True)
471
+ continue
472
+
473
+ # Free registration tensors from previous step
474
+ x1 = y1 = ddf_comp = img_rec = img_diff = None
475
+ ddf_rand = y1_proc = msk_tgt = img_save = None
476
+ loss_regist = loss_sim = loss_mse = loss_ddf1 = None
477
+
478
+ # Memory diagnostic (one per node via local rank 0) — only warn when abnormal
479
+ # Normal at step start: ~16 GiB reserved, ~48 GiB free (of 64 GiB total)
480
+ if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
481
+ torch.xpu.reset_peak_memory_stats(rank)
482
+ free_mem, total_mem_dev = torch.xpu.mem_get_info(rank)
483
+ used_gib = (total_mem_dev - free_mem) / 1024**3
484
+ if used_gib > 24: # Normal is ~16 GiB at step start; warn if accumulating
485
+ alloc = torch.xpu.memory_allocated() / 1024**3
486
+ reserved = torch.xpu.memory_reserved() / 1024**3
487
+ free_gib = free_mem / 1024**3
488
+ print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: "
489
+ f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), "
490
+ f"{free_gib:.1f} GiB free", flush=True)
491
 
492
  # ==========================================================================
493
  # diffusion train on single image
 
496
  [x0,embd] = batch # for om dataset
497
  x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
498
  # print('embd:', embd.shape)
499
+ embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
500
  if np.random.uniform(0,1)<TEXT_EMBED_PROB:
501
+ embd_in = embd_dev
502
  else:
503
+ embd_in = None
 
 
504
 
505
  n = x0.size()[0] # batch_size -> n
506
  x0 = x0.to(hyp_parameters["device"])
 
514
  # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
515
  else:
516
  [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
517
+ # x0 = transformer(x0)
518
  if hyp_parameters['noise_scale']>0:
519
  if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
520
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
521
  x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
522
 
523
  # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
 
526
  ) # pick up a seq of rand number from 0 to 'timestep'
527
 
528
  # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
529
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
530
  # print('proc_type:', proc_type)
531
+ ddpm = Deformddpm.module if use_distributed else Deformddpm
532
+ cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
 
534
+ if loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE:
535
+
536
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
537
+
538
+ loss_tot=0
539
 
540
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
541
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
542
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
543
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
544
 
545
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
546
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
547
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
548
+
549
+ # >> JZ: print nan in x0
550
+ if torch.isnan(x0).any():
551
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
552
+ # >> JZ: print loss of ddf
553
+ if loss_ddf>0.001:
554
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
555
+ # yu: check if loss_tot==nan or inf
556
+ # Synchronize NaN skip across all DDP ranks to avoid collective desync
557
+ # Use broadcast from rank 0 instead of all_reduce to avoid CCL hang on single-node XPU
558
+ is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
559
+ if use_distributed:
560
+ nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
561
+ dist.broadcast(nan_flag, src=0)
562
+ is_nan = nan_flag.item() > 0
563
+ if is_nan:
564
+ if gpu_id == 0:
565
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
566
+ loss_nan_step += 1
567
+ continue
568
+ if loss_nan_step > 5:
569
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
570
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
571
 
 
 
 
 
572
  # ==========================================================================
573
+ # Diffusion backward (no gradient clipping — diffusion dominates training)
574
+ # print(loss_contra_gate)
575
+ if (not args.eval_only): # Skip backward when contrastive loss is high to avoid destabilizing diffusion training (especially early on)
576
+ optimizer.zero_grad()
577
+ loss_tot.backward()
578
+ optimizer.step()
579
+
580
+ epoch_loss_tot += loss_tot.item() / total
581
+ epoch_loss_gen_d += loss_gen_d.item() / total
582
+ epoch_loss_gen_a += loss_gen_a.item() / total
583
+ epoch_loss_reg += loss_ddf.item() / total
584
+
585
+ # Print running average every 20 steps in eval-only mode
586
+ if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0:
587
+ n = step + 1
588
+ print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} "
589
+ f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True)
590
+
591
+ # Free diffusion intermediates and aggressively release all memory to device.
592
+ # XPU runtime leaks ~1.3 GiB/step outside the caching allocator.
593
+ # gc.collect() + synchronize() + empty_cache() attempts to reclaim deferred/lazy allocations.
594
+ loss_gen_a_val = loss_gen_a.item()
595
+
596
+ # del pre_dvf_I, dvf_I, trm_pred, loss_tot, loss_gen_a, loss_gen_d, loss_ddf
597
+ gc.collect()
598
+ if DEVICE_TYPE == 'xpu':
599
+ torch.xpu.synchronize()
600
+ _empty_cache(DEVICE_TYPE)
601
+
602
+ # Sync loss_gen_a across DDP ranks for contrastive and registration gating
603
+ if use_distributed:
604
+ loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}")
605
+ dist.broadcast(loss_gen_a_sync, src=0)
606
+ loss_gen_a_gate = loss_gen_a_sync.item()
607
+ else:
608
+ loss_gen_a_gate = loss_gen_a_val
609
+
610
+ LOSS_WEIGHT_CONTRASTIVE=1e-4
611
+ else:
612
+ LOSS_WEIGHT_CONTRASTIVE=1e-1
613
+ if gpu_id == 0:
614
+ print(f" [train] step {step}: Skipping backward (contra_gate={loss_contra_gate:.4f})", flush=True)
615
+
616
+
617
+ # ==========================================================================
618
+ # Contrastive train on single image (text-image alignment)
619
+ # Separate backward with gradient clipping to prevent destabilizing diffusion.
620
+ loss_contra_val = None
621
+ if step % CONTRASTIVE_STEP_RATIO == 0:
622
+ n_contra = x0.size()[0]
623
+ t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
624
+ # Route through DDP wrapper and return img_embd directly so DDP
625
+ # traces the correct subgraph (encoder + mid + attn + img2txt).
626
+ img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None) # [B, 1024]
627
+ loss_contra_preweight = F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1)-0.25).mean()
628
+ loss_contra = LOSS_WEIGHT_CONTRASTIVE * loss_contra_preweight
629
+
630
+ if not args.eval_only:
631
+ optimizer.zero_grad()
632
+ loss_contra.backward()
633
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=LOSS_WEIGHT_CONTRASTIVE*1)
634
+ optimizer.step()
635
+ loss_contra_val = loss_contra.item()
636
+ epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
637
+
638
+ # else:
639
+ # if gpu_id == 0:
640
+ # print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
641
+
642
+ # Free remaining intermediates and aggressively release memory before registration
643
+ if cond_img is not None:
644
+ del cond_img
645
+ if blind_mask is not None:
646
+ del blind_mask
647
+ gc.collect()
648
+ if DEVICE_TYPE == 'xpu':
649
+ torch.xpu.synchronize()
650
+ _empty_cache(DEVICE_TYPE)
651
+
652
+ # Sync loss_gen_a across DDP ranks for contrastive and registration gating
653
+ if use_distributed:
654
+ loss_contra_sync = torch.tensor([loss_contra_preweight], device=f"{DEVICE_TYPE}:{rank}")
655
+ dist.broadcast(loss_contra_sync, src=0)
656
+ loss_contra_gate = loss_contra_sync.item()
657
+ else:
658
+ loss_contra_gate = loss_contra_preweight
659
+
660
+ # ==========================================================================
661
+ # registration train on paired images
662
+ # loss_gen_a_gate already synced across DDP ranks above
663
+ do_regist = step % REGISTRATION_STEP_RATIO == 0 and (loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE) and loss_gen_a_gate < ACCEPT_THRESH_ANGLE
664
+ if do_regist:
665
+ [x1, y1, _, embd_y] = batch_p
666
  if np.random.uniform(0,1)<TEXT_EMBED_PROB:
 
667
  embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
668
  else:
 
669
  embd_y = None
670
 
671
  x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
672
  y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
673
  n = x1.size()[0] # batch_size -> n
 
 
 
 
 
 
 
674
  [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
675
  if hyp_parameters['noise_scale']>0:
676
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
677
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
678
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
679
+ x1 = x1 * random_scale + random_shift
680
+ y1 = y1 * random_scale + random_shift
681
+
682
+ scale_regist = np.random.uniform(0.0,0.5)
683
+ select_timestep = np.random.randint(12, 32) # select a random number of timesteps to sample, between 8 and 16
684
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
685
+
686
+ T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
687
+
688
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
689
+ ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
690
+ y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
691
+ msk_tgt = msk_tgt+MSK_EPS
692
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
 
 
 
 
 
 
 
 
 
 
 
693
  loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
694
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
695
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
696
 
697
  loss_regist = 0
698
  loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
699
  loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
700
  loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
 
 
701
 
702
  # >> JZ: print nan in x0
703
  if torch.isnan(x0).any():
704
  print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
705
  # >> JZ: print loss of ddf
706
+ if loss_ddf1>0.002:
707
  print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
 
 
 
 
708
 
709
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
710
+ if not args.eval_only:
711
+ optimizer.zero_grad()
712
+ loss_regist.backward()
 
 
 
 
 
713
 
714
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
715
+ optimizer.step()
 
 
716
 
717
+ epoch_loss_regist += loss_regist.item()
718
+ epoch_loss_imgsim += loss_sim.item()
719
+ epoch_loss_imgmse += loss_mse.item()
720
+ epoch_loss_ddfreg += loss_ddf1.item()
721
+ else:
722
+ loss_sim = torch.tensor(0.0)
723
+ loss_mse = torch.tensor(0.0)
724
+ loss_ddf1 = torch.tensor(0.0)
725
+ loss_regist = torch.tensor(0.0)
726
+ if step % REGISTRATION_STEP_RATIO==0:
727
+ total_reg = total_reg-1
728
+
729
+ # print for checking
730
+ if step % 10 == 0:
731
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
732
+ print(f'- loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
733
+ print(f'- loss_contra: {loss_contra}')
734
 
735
+ # Mid-epoch checkpoint and proactive restart (only when --max-steps-before-restart > 0)
736
+ if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save:
737
+ _epoch_stats = {
738
+ 'epoch_loss_tot': epoch_loss_tot,
739
+ 'epoch_loss_gen_d': epoch_loss_gen_d,
740
+ 'epoch_loss_gen_a': epoch_loss_gen_a,
741
+ 'epoch_loss_reg': epoch_loss_reg,
742
+ 'epoch_loss_regist': epoch_loss_regist,
743
+ 'epoch_loss_imgsim': epoch_loss_imgsim,
744
+ 'epoch_loss_imgmse': epoch_loss_imgmse,
745
+ 'epoch_loss_ddfreg': epoch_loss_ddfreg,
746
+ 'epoch_loss_contrastive': epoch_loss_contrastive,
747
+ 'total_reg': total_reg,
748
+ 'total_contra': total_contra,
749
+ 'loss_nan_step': loss_nan_step,
750
+ 'rng_torch': torch.get_rng_state(),
751
+ 'rng_numpy': np.random.get_state(),
752
+ 'rng_python': random.getstate(),
753
+ **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
754
+ {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
755
+ }
756
+ tmp_dir = os.path.join(model_save_path, "tmp")
757
+ os.makedirs(tmp_dir, exist_ok=True)
758
+ for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
759
+ os.remove(old_f)
760
+ mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
761
+ state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
762
+ torch.save({
763
+ 'model_state_dict': state,
764
+ 'optimizer_state_dict': optimizer.state_dict(),
765
+ 'epoch': epoch,
766
+ 'step': step,
767
+ 'epoch_stats': _epoch_stats,
768
+ }, mid_save)
769
+ print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
770
+
771
+ # Proactive restart: exit cleanly after N steps to reset XPU memory leak.
772
+ # The bash wrapper will re-launch srun within the same SLURM allocation.
773
+ steps_since_start += 1
774
+ if max_steps_restart > 0 and steps_since_start >= max_steps_restart:
775
+ # Save checkpoint at current position (if not just saved above)
776
+ if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save:
777
+ _epoch_stats = {
778
+ 'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d,
779
+ 'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg,
780
+ 'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim,
781
+ 'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg,
782
+ 'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra,
783
+ 'loss_nan_step': loss_nan_step,
784
+ 'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(),
785
+ 'rng_python': random.getstate(),
786
+ **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
787
+ {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
788
+ }
789
+ tmp_dir = os.path.join(model_save_path, "tmp")
790
+ os.makedirs(tmp_dir, exist_ok=True)
791
+ for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
792
+ os.remove(old_f)
793
+ mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
794
+ state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
795
+ torch.save({
796
+ 'model_state_dict': state,
797
+ 'optimizer_state_dict': optimizer.state_dict(),
798
+ 'epoch': epoch,
799
+ 'step': step,
800
+ 'epoch_stats': _epoch_stats,
801
+ }, mid_save)
802
+ print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
803
+ if gpu_id == 0:
804
+ print(f" [restart] Proactive restart after {steps_since_start} steps "
805
+ f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True)
806
+ # Clean shutdown
807
+ _empty_cache(DEVICE_TYPE)
808
+ gc.collect()
809
+ if use_distributed and dist.is_initialized():
810
+ dist.barrier()
811
+ dist.destroy_process_group()
812
+ sys.exit(EXIT_CODE_RESTART)
813
 
814
+ if gpu_id == 0:
815
+ print('==================')
816
  print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
817
+ print(f' loss_contrastive: {epoch_loss_contrastive}')
818
+ total_reg_safe = max(total_reg, 1)
819
+ print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)')
820
+ print('==================')
821
 
 
 
822
 
823
+ if 0 == epoch % epoch_per_save and not args.no_save:
824
  save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
825
  os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
826
  # break # FOR TESTING
 
840
  'optimizer_state_dict': optimizer.state_dict(),
841
  'epoch': epoch
842
  }, save_dir)
843
+ # Clean up tmp/ mid-epoch checkpoints after completed epoch
844
+ if gpu_id == 0 and not args.no_save:
845
+ tmp_dir = os.path.join(model_dir, "tmp")
846
+ tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth"))
847
+ if tmp_pths:
848
+ for f in tmp_pths:
849
+ os.remove(f)
850
+ print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True)
851
+ # Reset initial_step after first epoch completes (no more skipping)
852
+ initial_step = 0
853
+
854
+ # XPU CCL workaround: restart after each epoch to avoid CCL hang on 2nd epoch.
855
+ # CCL's Level Zero IPC handles accumulate and cause deadlock after ~200+ collectives.
856
+ # A fresh process resets the L0 context. The bash loop catches exit code 42 and restarts.
857
+ if DEVICE_TYPE == 'xpu' and use_distributed:
858
+ if gpu_id == 0:
859
+ print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True)
860
+ _empty_cache(DEVICE_TYPE)
861
+ gc.collect()
862
+ if dist.is_initialized():
863
+ dist.barrier()
864
+ dist.destroy_process_group()
865
+ sys.exit(EXIT_CODE_RESTART)
866
 
867
  # Resource cleanup at the end of training
868
+ _empty_cache(DEVICE_TYPE)
869
  gc.collect()
870
  if use_distributed and dist.is_initialized():
871
  dist.destroy_process_group()
872
 
873
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
874
+
875
+ # All ranks load checkpoint so optimizer state is consistent across DDP processes.
876
+ # (Optimizer state includes per-parameter Adam momentum/variance which are NOT
877
+ # broadcast — only model weights are broadcast. Without this, non-rank-0 processes
878
+ # would have fresh Adam state after restart.)
879
+ gc.collect()
880
+ _empty_cache(DEVICE_TYPE)
881
  if gpu_id == 0:
 
882
  utils.print_memory_usage("Before Loading Model")
883
+ # checkpoint = torch.load(model_file, map_location='cpu', weights_only=False)
884
+ checkpoint = torch.load(model_file, map_location='cpu')
885
+ if use_distributed:
886
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
887
+ else:
888
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
889
+ # Restore optimizer state when available (needed for mid-epoch resume).
890
+ # Selective loading: load states for parameters with matching shapes, skip mismatched ones
891
+ # (e.g., UpsampleConv replaced ConvTranspose3d — different kernel shapes).
892
+ # After one epoch, the saved checkpoint will have correct state for ALL parameters.
893
+ if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer:
894
+ saved_opt = checkpoint['optimizer_state_dict']
895
+ saved_state = saved_opt.get('state', {})
896
+ param_list = [p for group in optimizer.param_groups for p in group['params']]
897
+
898
+ # Check if all shapes match (fast path: full load)
899
+ all_match = True
900
+ skipped = 0
901
+ for idx, s in saved_state.items():
902
+ if int(idx) < len(param_list):
903
+ p = param_list[int(idx)]
904
+ for k, v in s.items():
905
+ if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape:
906
+ all_match = False
907
+ break
908
+ if not all_match:
909
+ break
910
+
911
+ if all_match:
912
+ optimizer.load_state_dict(saved_opt)
913
  else:
914
+ # Selective load: restore param_groups settings (lr, betas, etc.)
915
+ for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups):
916
+ for k, v in saved_g.items():
917
+ if k != 'params':
918
+ group[k] = v
919
+ # Restore per-parameter state only where shapes match
920
+ for idx, s in saved_state.items():
921
+ idx_int = int(idx)
922
+ if idx_int < len(param_list):
923
+ p = param_list[idx_int]
924
+ shapes_ok = all(
925
+ v.shape == p.shape for k, v in s.items()
926
+ if isinstance(v, torch.Tensor) and v.dim() > 0
927
+ )
928
+ if shapes_ok:
929
+ # Cast state tensors to match parameter dtype/device
930
+ new_state = {}
931
+ for k, v in s.items():
932
+ if isinstance(v, torch.Tensor):
933
+ new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v
934
+ else:
935
+ new_state[k] = v
936
+ optimizer.state[p] = new_state
937
+ else:
938
+ skipped += 1
939
+ if gpu_id == 0:
940
+ loaded = len(saved_state) - skipped
941
+ print(f" [checkpoint] Selective optimizer load: {loaded} params restored, "
942
+ f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True)
943
+ elif args.reset_optimizer and gpu_id == 0:
944
+ print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True)
945
+ del checkpoint
946
+ if gpu_id == 0:
947
  utils.print_memory_usage("After Loading Checkpoint on GPU")
948
 
949
  if use_distributed:
950
+ # Broadcast model weights from rank 0 to ensure exact consistency
951
  dist.barrier()
952
  for param in Deformddpm.parameters():
953
+ dist.broadcast(param.data, src=0)
 
 
 
 
 
954
 
955
+ # get the epoch number from the filename
956
+ basename = os.path.basename(model_file)
957
+ epoch_from_file = int(basename[:6])
958
+ if '_step' in basename:
959
+ # Mid-epoch checkpoint: resume at same epoch (don't +1)
960
+ initial_epoch = epoch_from_file
961
+ else:
962
+ # End-of-epoch checkpoint: start next epoch
963
+ initial_epoch = epoch_from_file + 1
964
 
965
  return initial_epoch, Deformddpm, optimizer
966
 
967
 
968
 
969
  if __name__ == "__main__":
970
+ if "LOCAL_RANK" in os.environ:
971
+ # Multi-node: launched by torchrun / srun
972
+ use_distributed = True
973
+ local_rank = int(os.environ["LOCAL_RANK"])
974
+ world_size = int(os.environ["WORLD_SIZE"])
975
+ print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
976
+ try:
977
+ main_train(local_rank, world_size)
978
+ except Exception as e:
979
+ import traceback
980
+ print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
981
+ traceback.print_exc()
982
+ raise
983
+ elif use_distributed:
984
+ # Single-node multi-GPU: use mp.spawn
985
+ world_size = _device_count(DEVICE_TYPE)
986
+ print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
987
  mp.spawn(main_train,args = (world_size,),nprocs = world_size)
988
  else:
989
  main_train(0,1)
OM_train_3modes_cudaonly.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
+ import gc
7
+ import torch
8
+ import torchvision
9
+ from torch import nn
10
+ from torchvision.utils import save_image
11
+ from torch.utils.data import DataLoader
12
+
13
+ from torch.optim import Adam, SGD
14
+ from Diffusion.diffuser import DeformDDPM
15
+ from Diffusion.networks import get_net, STN
16
+ from torchvision.transforms import Lambda
17
+ import torch.nn.functional as F
18
+ import Diffusion.losses as losses
19
+ import random
20
+ import glob
21
+ import numpy as np
22
+ import utils
23
+ from tqdm import tqdm
24
+
25
+ from Dataloader.dataloader0 import get_dataloader
26
+ from Dataloader.dataLoader import *
27
+
28
+ from Dataloader.dataloader_utils import thresh_img
29
+ import yaml
30
+ import argparse
31
+
32
+ ####################
33
+ import torch.multiprocessing as mp
34
+ from torch.utils.data.distributed import DistributedSampler
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ import torch.distributed as dist
37
+ # from torch.distributed import init_process_group
38
+ ###############
39
+ def ddp_setup(rank, world_size):
40
+ """
41
+ Args:
42
+ rank: Unique identifier of each process
43
+ world_size: Total number of processes
44
+ """
45
+ os.environ["MASTER_ADDR"] = "localhost"
46
+ os.environ["MASTER_PORT"] = "12355"
47
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
48
+ torch.cuda.set_device(rank)
49
+
50
+ # Auto-detect: use DDP only when multiple CUDA GPUs are available
51
+ use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
52
+ # use_distributed = True
53
+ # use_distributed = False
54
+
55
+ EPS = 1e-5
56
+ MSK_EPS = 0.01
57
+ TEXT_EMBED_PROB = 0.5
58
+ AUG_RESAMPLE_PROB = 0.5
59
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
60
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
61
+ LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
62
+ DIFF_REG_BATCH_RATIO = 2
63
+ LOSS_WEIGHT_CONTRASTIVE = 0.001
64
+ REGISTRATION_STEP_RATIO = 1
65
+ CONTRASTIVE_STEP_RATIO = 1
66
+
67
+ # AUG_PERMUTE_PROB = 0.35
68
+
69
+ parser = argparse.ArgumentParser()
70
+
71
+ # config_file_path = 'Config/config_cmr.yaml'
72
+ parser.add_argument(
73
+ "--config",
74
+ "-C",
75
+ help="Path for the config file",
76
+ type=str,
77
+ # default="Config/config_cmr.yaml",
78
+ # default="Config/config_lct.yaml",
79
+ default="Config/config_all.yaml",
80
+ required=False,
81
+ )
82
+ # parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
83
+ parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
84
+ args = parser.parse_args()
85
+ #=======================================================================================================================
86
+
87
+
88
+
89
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
90
+ if use_distributed:
91
+ ddp_setup(rank,world_size)
92
+
93
+ if torch.distributed.is_initialized():
94
+ print(f"World size: {torch.distributed.get_world_size()}")
95
+ print(f"Communication backend: {torch.distributed.get_backend()}")
96
+ gpu_id = rank
97
+
98
+ # Load the YAML file into a dictionary
99
+ with open(args.config, 'r') as file:
100
+ hyp_parameters = yaml.safe_load(file)
101
+ if args.batchsize > 0:
102
+ hyp_parameters['batchsize'] = args.batchsize
103
+ print(hyp_parameters)
104
+
105
+ # epoch_per_save=10
106
+ epoch_per_save=hyp_parameters['epoch_per_save']
107
+
108
+ data_name=hyp_parameters['data_name']
109
+ net_name = hyp_parameters['net_name']
110
+
111
+ Net=get_net(net_name)
112
+
113
+ suffix_pth=f'_{data_name}_{net_name}.pth'
114
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
115
+ model_dir=model_save_path
116
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
117
+
118
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
119
+
120
+ # tsfm = torchvision.transforms.Compose([
121
+ # torchvision.transforms.ToTensor(),
122
+ # ])
123
+
124
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
125
+ # train_loader = DataLoader(
126
+ # dataset,
127
+ # batch_size=hyp_parameters['batchsize'],
128
+ # # shuffle=False,
129
+ # shuffle=True,
130
+ # drop_last=True,
131
+ # )
132
+
133
+
134
+ # dataset = OminiDataset_v1(transform=None)
135
+ dataset = OMDataset_indiv(transform=None)
136
+ # datasetp = OminiDataset_paired(transform=None)
137
+ datasetp = OMDataset_pair(transform=None)
138
+
139
+ train_loader = DataLoader(
140
+ dataset,
141
+ batch_size=hyp_parameters['batchsize'],
142
+ shuffle=True,
143
+ drop_last=True,
144
+ )
145
+ train_loader_p = DataLoader(
146
+ datasetp,
147
+ batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
148
+ shuffle=True,
149
+ drop_last=True,
150
+ )
151
+
152
+
153
+
154
+ Deformddpm = DeformDDPM(
155
+ network=Net(
156
+ n_steps=hyp_parameters["timesteps"],
157
+ ndims=hyp_parameters["ndims"],
158
+ num_input_chn = hyp_parameters["num_input_chn"],
159
+ res = hyp_parameters['img_size']
160
+ ),
161
+ n_steps=hyp_parameters["timesteps"],
162
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
163
+ device=hyp_parameters["device"],
164
+ batch_size=hyp_parameters["batchsize"],
165
+ img_pad_mode=hyp_parameters["img_pad_mode"],
166
+ v_scale=hyp_parameters["v_scale"],
167
+ )
168
+
169
+
170
+ ddf_stn = STN(
171
+ img_sz=hyp_parameters["img_size"],
172
+ ndims=hyp_parameters["ndims"],
173
+ # padding_mode="zeros",
174
+ padding_mode=hyp_parameters["padding_mode"],
175
+ device=hyp_parameters["device"],
176
+ )
177
+
178
+
179
+ if use_distributed:
180
+ Deformddpm.to(rank)
181
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
182
+ ddf_stn.to(rank)
183
+ else:
184
+ Deformddpm.to(hyp_parameters["device"])
185
+ ddf_stn.to(hyp_parameters["device"])
186
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
187
+
188
+
189
+ # mse = nn.MSELoss()
190
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
191
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
192
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
193
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
194
+
195
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
196
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
197
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
198
+ loss_imgsim = losses.MSLNCC()
199
+ loss_imgmse = losses.LMSE()
200
+
201
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
202
+ # hyp_parameters["lr"]=0.00000001
203
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
204
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
205
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
206
+
207
+ # # LR scheduler ----- YHM
208
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
209
+
210
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
211
+
212
+ # check for existing models
213
+ if not os.path.exists(model_dir):
214
+ os.makedirs(model_dir, exist_ok=True)
215
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
216
+ model_files.sort()
217
+ if model_files:
218
+ if gpu_id == 0:
219
+ print(model_files)
220
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
221
+ else:
222
+ initial_epoch = 0
223
+
224
+ if gpu_id == 0:
225
+ print('len_train_data: ',len(dataset))
226
+ # Training loop
227
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
228
+
229
+ epoch_loss_tot = 0.0
230
+ epoch_loss_gen_d = 0.0
231
+ epoch_loss_gen_a = 0.0
232
+ epoch_loss_reg = 0.0
233
+ epoch_loss_regist = 0.0
234
+ epoch_loss_imgsim = 0.0
235
+ epoch_loss_imgmse = 0.0
236
+ epoch_loss_ddfreg = 0.0
237
+ epoch_loss_contrastive = 0.0
238
+ # Set model inside to train model
239
+ Deformddpm.train()
240
+
241
+ loss_nan_step = 0 # yu: count the number of nan loss steps
242
+
243
+ total = min(len(train_loader), len(train_loader_p))
244
+ total_reg = total // REGISTRATION_STEP_RATIO
245
+ # for step, batch in tqdm(enumerate(train_loader)):
246
+ # for step, batch in tqdm(enumerate(train_loader)):
247
+ # for step, batch in enumerate(train_loader_omni):
248
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
249
+
250
+ # x0, _ = batch
251
+
252
+
253
+ # ==========================================================================
254
+ # diffusion train on single image
255
+
256
+ # x0 = batch # for omni dataset
257
+ [x0,embd] = batch # for om dataset
258
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
259
+ # print('embd:', embd.shape)
260
+ embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
261
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
262
+ embd_in = embd_dev
263
+ else:
264
+ embd_in = None
265
+
266
+
267
+
268
+ n = x0.size()[0] # batch_size -> n
269
+ x0 = x0.to(hyp_parameters["device"])
270
+
271
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
272
+
273
+ # random deformation + rotation
274
+ if hyp_parameters["ndims"]>2:
275
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
276
+ x0 = utils.random_resample(x0, deform_scale=0)
277
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
278
+ else:
279
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
280
+ # x0 = transformer(x0)
281
+ if hyp_parameters['noise_scale']>0:
282
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
283
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
284
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
285
+
286
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
287
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
288
+ hyp_parameters["device"]
289
+ ) # pick up a seq of rand number from 0 to 'timestep'
290
+
291
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
292
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
293
+ # print('proc_type:', proc_type)
294
+ ddpm = Deformddpm.module if use_distributed else Deformddpm
295
+ cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
296
+
297
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
298
+
299
+ # print(torch.max(torch.abs(pre_dvf_I)))
300
+ # print(torch.max(torch.abs(dvf_I)))
301
+
302
+ loss_tot=0
303
+
304
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
305
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
306
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
307
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
308
+
309
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
310
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
311
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
312
+
313
+ # >> JZ: print nan in x0
314
+ if torch.isnan(x0).any():
315
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
316
+ # >> JZ: print loss of ddf
317
+ if loss_ddf>0.001:
318
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
319
+ # yu: check if loss_tot==nan or inf
320
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
321
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
322
+ loss_nan_step += 1
323
+ continue
324
+ if loss_nan_step > 5:
325
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
326
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
327
+
328
+ optimizer.zero_grad()
329
+ loss_tot.backward()
330
+ optimizer.step()
331
+
332
+ epoch_loss_tot += loss_tot.item() / total
333
+ epoch_loss_gen_d += loss_gen_d.item() / total
334
+ epoch_loss_gen_a += loss_gen_a.item() / total
335
+ epoch_loss_reg += loss_ddf.item() / total
336
+
337
+ # ==========================================================================
338
+ # contrastive train on single image (text-image alignment)
339
+ loss_contra_val = None
340
+ if step % CONTRASTIVE_STEP_RATIO == 0:
341
+ raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
342
+ n_contra = x0.size()[0]
343
+ t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
344
+ _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
345
+ if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
346
+ img_embd = raw_network.img_embd # [B, 1024]
347
+ loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.05) # contrastive loss to align image embedding with text embedding, with a margin of 0.02
348
+
349
+ optimizer.zero_grad()
350
+ loss_contra.backward()
351
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
352
+ optimizer.step()
353
+ loss_contra_val = loss_contra.item()
354
+ epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
355
+ else:
356
+ if gpu_id == 0:
357
+ print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
358
+
359
+ # ==========================================================================
360
+ # registration train on paired images
361
+ if step%REGISTRATION_STEP_RATIO == 0 and loss_gen_a.item()<-0.6: # only train registration on relatively well-deformed images, to avoid too large registration loss and unstable training in the early stage
362
+ [x1, y1, _, embd_y] = batch_p
363
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
364
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
365
+ else:
366
+ embd_y = None
367
+
368
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
369
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
370
+ n = x1.size()[0] # batch_size -> n
371
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
372
+ if hyp_parameters['noise_scale']>0:
373
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
374
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
375
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
376
+ x1 = x1 * random_scale + random_shift
377
+ y1 = y1 * random_scale + random_shift
378
+
379
+ scale_regist = np.random.uniform(0.0,0.7)
380
+ select_timestep = np.random.randint(12, 25) # select a random number of timesteps to sample, between 8 and 16
381
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
382
+
383
+ T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
384
+
385
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
386
+ ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
387
+ y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
388
+ msk_tgt = msk_tgt+MSK_EPS
389
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
390
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
391
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
392
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
393
+
394
+ loss_regist = 0
395
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
396
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
397
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
398
+
399
+ # >> JZ: print nan in x0
400
+ if torch.isnan(x0).any():
401
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
402
+ # >> JZ: print loss of ddf
403
+ if loss_ddf1>0.002:
404
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
405
+
406
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
407
+ optimizer.zero_grad()
408
+ loss_regist.backward()
409
+
410
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
411
+ optimizer.step()
412
+
413
+ epoch_loss_regist += loss_regist.item()
414
+ epoch_loss_imgsim += loss_sim.item()
415
+ epoch_loss_imgmse += loss_mse.item()
416
+ epoch_loss_ddfreg += loss_ddf1.item()
417
+ else:
418
+ loss_sim = torch.tensor(0.0)
419
+ loss_mse = torch.tensor(0.0)
420
+ loss_ddf1 = torch.tensor(0.0)
421
+ loss_regist = torch.tensor(0.0)
422
+ if step % REGISTRATION_STEP_RATIO==0:
423
+ total_reg = total_reg-1
424
+
425
+ if step % 10 == 0:
426
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
427
+ if loss_contra_val is not None:
428
+ print(f' loss_contrastive: {loss_contra_val:.6f}')
429
+ print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
430
+
431
+ if 1:
432
+ print('==================')
433
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
434
+ print(f' loss_contrastive: {epoch_loss_contrastive}')
435
+ print(f' loss_regist: {epoch_loss_regist/total_reg} = {epoch_loss_imgsim/total_reg} (imgsim) + {epoch_loss_imgmse/total_reg} (imgmse) + {epoch_loss_ddfreg/total_reg} (ddf)')
436
+ print('==================')
437
+
438
+
439
+ if 0 == epoch % epoch_per_save:
440
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
441
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
442
+ # break # FOR TESTING
443
+ if not use_distributed:
444
+ print(f"saved in {save_dir}")
445
+ # torch.save(Deformddpm.state_dict(), save_dir)
446
+ torch.save({
447
+ 'model_state_dict': Deformddpm.state_dict(),
448
+ 'optimizer_state_dict': optimizer.state_dict(),
449
+ 'epoch': epoch
450
+ }, save_dir)
451
+ elif gpu_id == 0:
452
+ print(f"saved in {save_dir}")
453
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
454
+ torch.save({
455
+ 'model_state_dict': Deformddpm.module.state_dict(),
456
+ 'optimizer_state_dict': optimizer.state_dict(),
457
+ 'epoch': epoch
458
+ }, save_dir)
459
+
460
+ # Resource cleanup at the end of training
461
+ if torch.cuda.is_available():
462
+ torch.cuda.empty_cache()
463
+ gc.collect()
464
+ if use_distributed and dist.is_initialized():
465
+ dist.destroy_process_group()
466
+
467
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
468
+
469
+ if gpu_id == 0:
470
+ # if 0:
471
+ utils.print_memory_usage("Before Loading Model")
472
+ if torch.cuda.is_available():
473
+ gc.collect()
474
+ torch.cuda.empty_cache()
475
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
476
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
477
+ checkpoint = torch.load(model_file, map_location='cpu')
478
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
479
+ if use_distributed:
480
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
481
+ else:
482
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
483
+ if load_strict:
484
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
485
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
486
+
487
+ if use_distributed:
488
+ # Broadcast model weights from rank 0 to all other GPUs
489
+ dist.barrier()
490
+ for param in Deformddpm.parameters():
491
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
492
+ dist.barrier()
493
+ for param_group in optimizer.param_groups:
494
+ for param in param_group['params']:
495
+ if param.grad is not None:
496
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
497
+
498
+ # initial_epoch = checkpoint['epoch'] + 1
499
+ # get the epoch number from the filename and add 1 to set as initial_epoch
500
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
501
+
502
+ return initial_epoch, Deformddpm, optimizer
503
+
504
+
505
+
506
+ if __name__ == "__main__":
507
+ if use_distributed:
508
+ world_size = torch.cuda.device_count()
509
+ print(f"Distributed GPU number = {world_size}")
510
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
511
+ else:
512
+ main_train(0,1)
OM_train_3modes_opt.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OM_train_3modes_opt.py — Optimized 3-mode training (diffusion + contrastive + registration).
3
+
4
+ Speed optimizations over OM_train_3modes.py (all mathematically equivalent):
5
+ 1. DataLoader: num_workers, pin_memory, persistent_workers for I/O overlap
6
+ 2. optimizer.zero_grad(set_to_none=True) — avoids zero-fill overhead
7
+ 3. Fixed-length T_regist (16 steps) — avoids XPU dynamic shape recompilation
8
+ 4. Removed redundant x0.to(device) call
9
+ 5. Uses diffuser_opt.DeformDDPM (hoisted clone, no *0 redundancy, OptSTN, inference_mode)
10
+ 6. Uses losses_opt.MSLNCC/LNCC (register_buffer for kernels)
11
+ 7. Pre-compute proc_type lists to reduce Python overhead in hot loop
12
+ 8. Uses OptRecMulModMutAttnNet (cached resample tensors, ~300 fewer CPU→GPU transfers)
13
+ 9. Uses OptSTN for ddf_stn (register_buffer, no per-call .to())
14
+ """
15
+
16
+ import os, sys
17
+
18
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
19
+ sys.path.append(ROOT_DIR)
20
+
21
+ import gc
22
+ import torch
23
+ import torchvision
24
+ from torch import nn
25
+ from torchvision.utils import save_image
26
+ from torch.utils.data import DataLoader
27
+
28
+ from torch.optim import Adam, SGD
29
+ from Diffusion.diffuser_opt import DeformDDPM
30
+ from Diffusion.networks_opt import get_net_opt, OptSTN
31
+ from torchvision.transforms import Lambda
32
+ import torch.nn.functional as F
33
+ import Diffusion.losses_opt as losses
34
+ import random
35
+ import glob
36
+ import numpy as np
37
+ import utils
38
+ from tqdm import tqdm
39
+
40
+ from Dataloader.dataloader0 import get_dataloader
41
+ from Dataloader.dataLoader import *
42
+
43
+ from Dataloader.dataloader_utils import thresh_img
44
+ import yaml
45
+ import argparse
46
+
47
+ ####################
48
+ import torch.multiprocessing as mp
49
+ from torch.utils.data.distributed import DistributedSampler
50
+ from torch.nn.parallel import DistributedDataParallel as DDP
51
+ import torch.distributed as dist
52
+ ###############
53
+ def ddp_setup(rank, world_size):
54
+ """
55
+ Args:
56
+ rank: Unique identifier of each process
57
+ world_size: Total number of processes
58
+ """
59
+ os.environ["MASTER_ADDR"] = "localhost"
60
+ os.environ["MASTER_PORT"] = "12355"
61
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
62
+ torch.cuda.set_device(rank)
63
+
64
+ # Auto-detect: use DDP only when multiple CUDA GPUs are available
65
+ use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
66
+ # use_distributed = True
67
+ # use_distributed = False
68
+
69
+ EPS = 1e-5
70
+ MSK_EPS = 0.01
71
+ TEXT_EMBED_PROB = 0.7
72
+ AUG_RESAMPLE_PROB = 0.5
73
+ LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0] # [ang, dist, reg]
74
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
75
+ LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
76
+ DIFF_REG_BATCH_RATIO = 2
77
+ LOSS_WEIGHT_CONTRASTIVE = 1.0
78
+ CONTRASTIVE_STEP_RATIO = 2
79
+
80
+ # OPT: Fixed registration timestep count to avoid XPU dynamic shape recompilation
81
+ FIXED_T_REGIST_LEN = 16
82
+
83
+ # OPT: DataLoader workers (set to 0 to disable multiprocessing if needed)
84
+ NUM_WORKERS = 4
85
+ PIN_MEMORY = True
86
+
87
+ # AUG_PERMUTE_PROB = 0.35
88
+
89
+ parser = argparse.ArgumentParser()
90
+
91
+ # config_file_path = 'Config/config_cmr.yaml'
92
+ parser.add_argument(
93
+ "--config",
94
+ "-C",
95
+ help="Path for the config file",
96
+ type=str,
97
+ # default="Config/config_cmr.yaml",
98
+ # default="Config/config_lct.yaml",
99
+ default="Config/config_all.yaml",
100
+ required=False,
101
+ )
102
+ parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
103
+ parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
104
+ parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, help="DataLoader num_workers (default: 4)")
105
+ args = parser.parse_args()
106
+ #=======================================================================================================================
107
+
108
+ class _DummyIndiv(torch.utils.data.Dataset):
109
+ def __init__(self, n, sz, embd_dim=1024):
110
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
111
+ def __len__(self): return self.n
112
+ def __getitem__(self, i):
113
+ return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
114
+
115
+ class _DummyPair(torch.utils.data.Dataset):
116
+ def __init__(self, n, sz, embd_dim=1024):
117
+ self.n, self.sz, self.embd_dim = n, sz, embd_dim
118
+ def __len__(self): return self.n
119
+ def __getitem__(self, i):
120
+ return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
121
+ np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
122
+ np.random.randn(self.embd_dim).astype(np.float32),
123
+ np.random.randn(self.embd_dim).astype(np.float32))
124
+
125
+
126
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
127
+ if use_distributed:
128
+ ddp_setup(rank,world_size)
129
+
130
+ if torch.distributed.is_initialized():
131
+ print(f"World size: {torch.distributed.get_world_size()}")
132
+ print(f"Communication backend: {torch.distributed.get_backend()}")
133
+ gpu_id = rank
134
+
135
+ # Load the YAML file into a dictionary
136
+ with open(args.config, 'r') as file:
137
+ hyp_parameters = yaml.safe_load(file)
138
+ if args.batchsize > 0:
139
+ hyp_parameters['batchsize'] = args.batchsize
140
+ print(hyp_parameters)
141
+
142
+ # epoch_per_save=10
143
+ epoch_per_save=hyp_parameters['epoch_per_save']
144
+
145
+ data_name=hyp_parameters['data_name']
146
+ net_name = hyp_parameters['net_name']
147
+
148
+ Net=get_net_opt(net_name)
149
+
150
+ suffix_pth=f'_{data_name}_{net_name}.pth'
151
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
152
+ model_dir=model_save_path
153
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
154
+
155
+ # OPT: DataLoader with num_workers, pin_memory, persistent_workers
156
+ num_workers = args.num_workers
157
+ use_pin_memory = PIN_MEMORY and hyp_parameters["device"] != "cpu"
158
+
159
+ if args.dummy_samples > 0:
160
+ dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
161
+ datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
162
+ else:
163
+ dataset = OMDataset_indiv(transform=None)
164
+ datasetp = OMDataset_pair(transform=None)
165
+
166
+ train_loader = DataLoader(
167
+ dataset,
168
+ batch_size=hyp_parameters['batchsize'],
169
+ shuffle=True,
170
+ drop_last=True,
171
+ num_workers=num_workers, # OPT
172
+ pin_memory=use_pin_memory, # OPT
173
+ persistent_workers=num_workers > 0, # OPT
174
+ )
175
+ train_loader_p = DataLoader(
176
+ datasetp,
177
+ batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
178
+ shuffle=True,
179
+ drop_last=True,
180
+ num_workers=num_workers, # OPT
181
+ pin_memory=use_pin_memory, # OPT
182
+ persistent_workers=num_workers > 0, # OPT
183
+ )
184
+
185
+
186
+
187
+ Deformddpm = DeformDDPM(
188
+ network=Net(
189
+ n_steps=hyp_parameters["timesteps"],
190
+ ndims=hyp_parameters["ndims"],
191
+ num_input_chn = hyp_parameters["num_input_chn"],
192
+ res = hyp_parameters['img_size']
193
+ ),
194
+ n_steps=hyp_parameters["timesteps"],
195
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
196
+ device=hyp_parameters["device"],
197
+ batch_size=hyp_parameters["batchsize"],
198
+ img_pad_mode=hyp_parameters["img_pad_mode"],
199
+ v_scale=hyp_parameters["v_scale"],
200
+ )
201
+
202
+
203
+ ddf_stn = OptSTN(
204
+ img_sz=hyp_parameters["img_size"],
205
+ ndims=hyp_parameters["ndims"],
206
+ # padding_mode="zeros",
207
+ padding_mode=hyp_parameters["padding_mode"],
208
+ device=hyp_parameters["device"],
209
+ )
210
+
211
+
212
+ if use_distributed:
213
+ Deformddpm.to(rank)
214
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
215
+ ddf_stn.to(rank)
216
+ else:
217
+ Deformddpm.to(hyp_parameters["device"])
218
+ ddf_stn.to(hyp_parameters["device"])
219
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
220
+
221
+
222
+ # mse = nn.MSELoss()
223
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
224
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
225
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
226
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
227
+
228
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
229
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
230
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
231
+ loss_imgsim = losses.MSLNCC()
232
+ loss_imgmse = losses.LMSE()
233
+
234
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
235
+
236
+ # check for existing models
237
+ if not os.path.exists(model_dir):
238
+ os.makedirs(model_dir, exist_ok=True)
239
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
240
+ model_files.sort()
241
+ if model_files:
242
+ if gpu_id == 0:
243
+ print(model_files)
244
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
245
+ else:
246
+ initial_epoch = 0
247
+
248
+ if gpu_id == 0:
249
+ print('len_train_data: ',len(dataset))
250
+ # Training loop
251
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
252
+
253
+ epoch_loss_tot = 0.0
254
+ epoch_loss_gen_d = 0.0
255
+ epoch_loss_gen_a = 0.0
256
+ epoch_loss_reg = 0.0
257
+ epoch_loss_regist = 0.0
258
+ epoch_loss_imgsim = 0.0
259
+ epoch_loss_imgmse = 0.0
260
+ epoch_loss_ddfreg = 0.0
261
+ epoch_loss_contrastive = 0.0
262
+ # Set model inside to train model
263
+ Deformddpm.train()
264
+
265
+ loss_nan_step = 0 # yu: count the number of nan loss steps
266
+
267
+ total = min(len(train_loader), len(train_loader_p))
268
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
269
+
270
+ # ==========================================================================
271
+ # diffusion train on single image
272
+
273
+ [x0,embd] = batch # for om dataset
274
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
275
+ embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
276
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
277
+ embd_in = embd_dev
278
+ else:
279
+ embd_in = None
280
+
281
+ n = x0.size()[0] # batch_size -> n
282
+ # OPT: removed redundant x0.to(device) — already done above
283
+
284
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
285
+
286
+ # random deformation + rotation
287
+ if hyp_parameters["ndims"]>2:
288
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
289
+ x0 = utils.random_resample(x0, deform_scale=0)
290
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
291
+ else:
292
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
293
+ # x0 = transformer(x0)
294
+ if hyp_parameters['noise_scale']>0:
295
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
296
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
297
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
298
+
299
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
300
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
301
+ hyp_parameters["device"]
302
+ ) # pick up a seq of rand number from 0 to 'timestep'
303
+
304
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
305
+ ddpm = Deformddpm.module if use_distributed else Deformddpm
306
+ cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
307
+
308
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
309
+
310
+ loss_tot=0
311
+
312
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
313
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
314
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
315
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
316
+
317
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
318
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
319
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
320
+
321
+ # >> JZ: print nan in x0
322
+ if torch.isnan(x0).any():
323
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
324
+ # >> JZ: print loss of ddf
325
+ if loss_ddf>0.001:
326
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
327
+ # yu: check if loss_tot==nan or inf
328
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
329
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
330
+ loss_nan_step += 1
331
+ continue
332
+ if loss_nan_step > 5:
333
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
334
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
335
+
336
+ optimizer.zero_grad(set_to_none=True) # OPT: set_to_none faster than zero-fill
337
+ loss_tot.backward()
338
+ optimizer.step()
339
+
340
+ epoch_loss_tot += loss_tot.item() / total
341
+ epoch_loss_gen_d += loss_gen_d.item() / total
342
+ epoch_loss_gen_a += loss_gen_a.item() / total
343
+ epoch_loss_reg += loss_ddf.item() / total
344
+
345
+ # ==========================================================================
346
+ # contrastive train on single image (text-image alignment)
347
+ loss_contra_val = None
348
+ if step % CONTRASTIVE_STEP_RATIO == 0:
349
+ raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
350
+ n_contra = x0.size()[0]
351
+ t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
352
+ _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
353
+ if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
354
+ img_embd = raw_network.img_embd # [B, 1024]
355
+ loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
356
+
357
+ optimizer.zero_grad(set_to_none=True) # OPT
358
+ loss_contra.backward()
359
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.05)
360
+ optimizer.step()
361
+ loss_contra_val = loss_contra.item()
362
+ epoch_loss_contrastive += loss_contra_val / total
363
+ else:
364
+ if gpu_id == 0:
365
+ print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
366
+
367
+ # ==========================================================================
368
+ # registration train on paired images
369
+ if step%train_mode_ratio == 0:
370
+ [x1, y1, _, embd_y] = batch_p
371
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
372
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
373
+ else:
374
+ embd_y = None
375
+
376
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
377
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
378
+ n = x1.size()[0] # batch_size -> n
379
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
380
+ if hyp_parameters['noise_scale']>0:
381
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
382
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
383
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
384
+ x1 = x1 * random_scale + random_shift
385
+ y1 = y1 * random_scale + random_shift
386
+
387
+ scale_regist = np.random.uniform(0.0,0.7)
388
+ # OPT: fixed-length T_regist to avoid XPU dynamic shape recompilation
389
+ # Sample FIXED_T_REGIST_LEN timesteps (was: random 8-16), always same loop length
390
+ t_pool = list(range(int(hyp_parameters["timesteps"] * scale_regist), hyp_parameters["timesteps"]))
391
+ select_timestep = min(FIXED_T_REGIST_LEN, len(t_pool))
392
+ T_regist = sorted(random.sample(t_pool, select_timestep), reverse=True)
393
+
394
+ T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
395
+
396
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
397
+ ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
398
+ y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
399
+ msk_tgt = msk_tgt+MSK_EPS
400
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
401
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
402
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
403
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
404
+
405
+ loss_regist = 0
406
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
407
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
408
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
409
+
410
+ # >> JZ: print nan in x0
411
+ if torch.isnan(x0).any():
412
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
413
+ # >> JZ: print loss of ddf
414
+ if loss_ddf1>0.002:
415
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
416
+
417
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
418
+ optimizer.zero_grad(set_to_none=True) # OPT
419
+ loss_regist.backward()
420
+
421
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.2)
422
+ optimizer.step()
423
+
424
+ epoch_loss_regist += loss_regist.item() / total
425
+ epoch_loss_imgsim += loss_sim.item() / total
426
+ epoch_loss_imgmse += loss_mse.item() / total
427
+ epoch_loss_ddfreg += loss_ddf1.item() / total
428
+
429
+ if step % 10 == 0:
430
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
431
+ if loss_contra_val is not None:
432
+ print(f' loss_contrastive: {loss_contra_val:.6f}')
433
+ print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
434
+
435
+ if 1:
436
+ print('==================')
437
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
438
+ print(f' loss_contrastive: {epoch_loss_contrastive}')
439
+ print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
440
+ print('==================')
441
+
442
+
443
+ if 0 == epoch % epoch_per_save:
444
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
445
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
446
+ # break # FOR TESTING
447
+ if not use_distributed:
448
+ print(f"saved in {save_dir}")
449
+ # torch.save(Deformddpm.state_dict(), save_dir)
450
+ torch.save({
451
+ 'model_state_dict': Deformddpm.state_dict(),
452
+ 'optimizer_state_dict': optimizer.state_dict(),
453
+ 'epoch': epoch
454
+ }, save_dir)
455
+ elif gpu_id == 0:
456
+ print(f"saved in {save_dir}")
457
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
458
+ torch.save({
459
+ 'model_state_dict': Deformddpm.module.state_dict(),
460
+ 'optimizer_state_dict': optimizer.state_dict(),
461
+ 'epoch': epoch
462
+ }, save_dir)
463
+
464
+ # Resource cleanup at the end of training
465
+ if torch.cuda.is_available():
466
+ torch.cuda.empty_cache()
467
+ gc.collect()
468
+ if use_distributed and dist.is_initialized():
469
+ dist.destroy_process_group()
470
+
471
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
472
+
473
+ if gpu_id == 0:
474
+ # if 0:
475
+ utils.print_memory_usage("Before Loading Model")
476
+ if torch.cuda.is_available():
477
+ gc.collect()
478
+ torch.cuda.empty_cache()
479
+ checkpoint = torch.load(model_file, map_location='cpu')
480
+ if use_distributed:
481
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
482
+ else:
483
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
484
+ if load_strict:
485
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
486
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
487
+
488
+ if use_distributed:
489
+ # Broadcast model weights from rank 0 to all other GPUs
490
+ dist.barrier()
491
+ for param in Deformddpm.parameters():
492
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
493
+ dist.barrier()
494
+ for param_group in optimizer.param_groups:
495
+ for param in param_group['params']:
496
+ if param.grad is not None:
497
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
498
+
499
+ # initial_epoch = checkpoint['epoch'] + 1
500
+ # get the epoch number from the filename and add 1 to set as initial_epoch
501
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
502
+
503
+ return initial_epoch, Deformddpm, optimizer
504
+
505
+
506
+
507
+ if __name__ == "__main__":
508
+ if use_distributed:
509
+ world_size = torch.cuda.device_count()
510
+ print(f"Distributed GPU number = {world_size}")
511
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
512
+ else:
513
+ main_train(0,1)
OM_train_3modes_original.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ sys.path.append(ROOT_DIR)
5
+
6
+ import gc
7
+ import torch
8
+ import torchvision
9
+ from torch import nn
10
+ from torchvision.utils import save_image
11
+ from torch.utils.data import DataLoader
12
+
13
+ from torch.optim import Adam, SGD
14
+ from Diffusion.diffuser import DeformDDPM
15
+ from Diffusion.networks import get_net, STN
16
+ from torchvision.transforms import Lambda
17
+ import torch.nn.functional as F
18
+ import Diffusion.losses as losses
19
+ import random
20
+ import glob
21
+ import numpy as np
22
+ import utils
23
+ from tqdm import tqdm
24
+
25
+ from Dataloader.dataloader0 import get_dataloader
26
+ from Dataloader.dataLoader import *
27
+
28
+ from Dataloader.dataloader_utils import thresh_img
29
+ import yaml
30
+ import argparse
31
+
32
+ # XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
33
+ try:
34
+ import intel_extension_for_pytorch as ipex
35
+ except ImportError:
36
+ ipex = None
37
+ try:
38
+ import oneccl_bindings_for_pytorch
39
+ except (ImportError, Exception) as e:
40
+ print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
41
+
42
+ ####################
43
+ import torch.multiprocessing as mp
44
+ from torch.utils.data.distributed import DistributedSampler
45
+ from torch.nn.parallel import DistributedDataParallel as DDP
46
+ import torch.distributed as dist
47
+ # from torch.distributed import init_process_group
48
+ ###############
49
+ def _device_available(device_type):
50
+ if device_type == 'xpu':
51
+ return hasattr(torch, 'xpu') and torch.xpu.is_available()
52
+ return torch.cuda.is_available()
53
+
54
+ def _device_count(device_type):
55
+ if device_type == 'xpu':
56
+ return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
57
+ return torch.cuda.device_count()
58
+
59
+ def _set_device(rank, device_type):
60
+ if device_type == 'xpu':
61
+ torch.xpu.set_device(rank)
62
+ else:
63
+ torch.cuda.set_device(rank)
64
+
65
+ def _empty_cache(device_type):
66
+ if device_type == 'xpu' and hasattr(torch, 'xpu'):
67
+ torch.xpu.empty_cache()
68
+ elif torch.cuda.is_available():
69
+ torch.cuda.empty_cache()
70
+
71
+ def ddp_setup(rank, world_size):
72
+ """
73
+ Args:
74
+ rank: Unique identifier of each process (local_rank when launched by torchrun)
75
+ world_size: Total number of processes
76
+ """
77
+ backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
78
+ if "LOCAL_RANK" in os.environ:
79
+ # Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
80
+ dist.init_process_group(backend=backend)
81
+ _set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
82
+ else:
83
+ # Single-node mp.spawn
84
+ os.environ["MASTER_ADDR"] = "localhost"
85
+ os.environ["MASTER_PORT"] = "12355"
86
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
87
+ _set_device(rank, DEVICE_TYPE)
88
+
89
+ EPS = 1e-5
90
+ MSK_EPS = 0.01
91
+ TEXT_EMBED_PROB = 0.5
92
+ AUG_RESAMPLE_PROB = 0.5
93
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
94
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
95
+ LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
96
+ DIFF_REG_BATCH_RATIO = 2
97
+ LOSS_WEIGHT_CONTRASTIVE = 0.001
98
+ REGISTRATION_STEP_RATIO = 1
99
+ CONTRASTIVE_STEP_RATIO = 1
100
+
101
+ # AUG_PERMUTE_PROB = 0.35
102
+
103
+ parser = argparse.ArgumentParser()
104
+
105
+ # config_file_path = 'Config/config_cmr.yaml'
106
+ parser.add_argument(
107
+ "--config",
108
+ "-C",
109
+ help="Path for the config file",
110
+ type=str,
111
+ # default="Config/config_cmr.yaml",
112
+ # default="Config/config_lct.yaml",
113
+ default="Config/config_all.yaml",
114
+ required=False,
115
+ )
116
+ parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
117
+ args = parser.parse_args()
118
+
119
+ # Read config early to determine device type for DDP setup
120
+ with open(args.config, 'r') as _f:
121
+ _cfg = yaml.safe_load(_f)
122
+ DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
123
+
124
+ # Auto-detect: use DDP only when multiple devices are available
125
+ use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
126
+ # use_distributed = True
127
+ # use_distributed = False
128
+ #=======================================================================================================================
129
+
130
+
131
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
132
+ if use_distributed:
133
+ ddp_setup(rank,world_size)
134
+
135
+ if torch.distributed.is_initialized() and rank == 0:
136
+ print(f"World size: {torch.distributed.get_world_size()}")
137
+ print(f"Communication backend: {torch.distributed.get_backend()}")
138
+ # gpu_id = global rank (for save/print guards); rank = local device index
139
+ if "RANK" in os.environ:
140
+ gpu_id = int(os.environ["RANK"])
141
+ rank = int(os.environ["LOCAL_RANK"])
142
+ else:
143
+ gpu_id = rank
144
+
145
+ # Load the YAML file into a dictionary
146
+ with open(args.config, 'r') as file:
147
+ hyp_parameters = yaml.safe_load(file)
148
+ if args.batchsize > 0:
149
+ hyp_parameters['batchsize'] = args.batchsize
150
+ if gpu_id == 0:
151
+ print(hyp_parameters)
152
+
153
+ # epoch_per_save=10
154
+ epoch_per_save=hyp_parameters['epoch_per_save']
155
+
156
+ data_name=hyp_parameters['data_name']
157
+ net_name = hyp_parameters['net_name']
158
+
159
+ Net=get_net(net_name)
160
+
161
+ suffix_pth=f'_{data_name}_{net_name}.pth'
162
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
163
+ model_dir=model_save_path
164
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
165
+
166
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
167
+
168
+ # tsfm = torchvision.transforms.Compose([
169
+ # torchvision.transforms.ToTensor(),
170
+ # ])
171
+
172
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
173
+ # train_loader = DataLoader(
174
+ # dataset,
175
+ # batch_size=hyp_parameters['batchsize'],
176
+ # # shuffle=False,
177
+ # shuffle=True,
178
+ # drop_last=True,
179
+ # )
180
+
181
+
182
+ # dataset = OminiDataset_v1(transform=None)
183
+ dataset = OMDataset_indiv(transform=None)
184
+ # datasetp = OminiDataset_paired(transform=None)
185
+ datasetp = OMDataset_pair(transform=None)
186
+
187
+ if use_distributed:
188
+ sampler = DistributedSampler(dataset, shuffle=True)
189
+ sampler_p = DistributedSampler(datasetp, shuffle=True)
190
+ else:
191
+ sampler = None
192
+ sampler_p = None
193
+
194
+ train_loader = DataLoader(
195
+ dataset,
196
+ batch_size=hyp_parameters['batchsize'],
197
+ shuffle=(sampler is None),
198
+ drop_last=True,
199
+ sampler=sampler,
200
+ )
201
+ train_loader_p = DataLoader(
202
+ datasetp,
203
+ batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
204
+ shuffle=(sampler_p is None),
205
+ drop_last=True,
206
+ sampler=sampler_p,
207
+ )
208
+
209
+
210
+
211
+ Deformddpm = DeformDDPM(
212
+ network=Net(
213
+ n_steps=hyp_parameters["timesteps"],
214
+ ndims=hyp_parameters["ndims"],
215
+ num_input_chn = hyp_parameters["num_input_chn"],
216
+ res = hyp_parameters['img_size']
217
+ ),
218
+ n_steps=hyp_parameters["timesteps"],
219
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
220
+ device=hyp_parameters["device"],
221
+ batch_size=hyp_parameters["batchsize"],
222
+ img_pad_mode=hyp_parameters["img_pad_mode"],
223
+ v_scale=hyp_parameters["v_scale"],
224
+ )
225
+
226
+
227
+ ddf_stn = STN(
228
+ img_sz=hyp_parameters["img_size"],
229
+ ndims=hyp_parameters["ndims"],
230
+ # padding_mode="zeros",
231
+ padding_mode=hyp_parameters["padding_mode"],
232
+ device=hyp_parameters["device"],
233
+ )
234
+
235
+
236
+ if use_distributed:
237
+ device = f"{DEVICE_TYPE}:{rank}"
238
+ Deformddpm.to(device)
239
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
240
+ ddf_stn.to(device)
241
+ else:
242
+ Deformddpm.to(hyp_parameters["device"])
243
+ ddf_stn.to(hyp_parameters["device"])
244
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
245
+
246
+
247
+ # mse = nn.MSELoss()
248
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
249
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
250
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
251
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
252
+
253
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
254
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
255
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
256
+ loss_imgsim = losses.MSLNCC()
257
+ loss_imgmse = losses.LMSE()
258
+
259
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
260
+ # hyp_parameters["lr"]=0.00000001
261
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
262
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
263
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
264
+
265
+ # # LR scheduler ----- YHM
266
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
267
+
268
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
269
+
270
+ # check for existing models
271
+ if not os.path.exists(model_dir):
272
+ os.makedirs(model_dir, exist_ok=True)
273
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
274
+ model_files.sort()
275
+ if model_files:
276
+ if gpu_id == 0:
277
+ print(model_files)
278
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
279
+ else:
280
+ initial_epoch = 0
281
+
282
+ if gpu_id == 0:
283
+ print('len_train_data: ',len(dataset))
284
+ # Training loop
285
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
286
+ if use_distributed and sampler is not None:
287
+ sampler.set_epoch(epoch)
288
+ sampler_p.set_epoch(epoch)
289
+
290
+ epoch_loss_tot = 0.0
291
+ epoch_loss_gen_d = 0.0
292
+ epoch_loss_gen_a = 0.0
293
+ epoch_loss_reg = 0.0
294
+ epoch_loss_regist = 0.0
295
+ epoch_loss_imgsim = 0.0
296
+ epoch_loss_imgmse = 0.0
297
+ epoch_loss_ddfreg = 0.0
298
+ epoch_loss_contrastive = 0.0
299
+ # Set model inside to train model
300
+ Deformddpm.train()
301
+
302
+ loss_nan_step = 0 # yu: count the number of nan loss steps
303
+
304
+ total = min(len(train_loader), len(train_loader_p))
305
+ total_reg = total // REGISTRATION_STEP_RATIO
306
+ # for step, batch in tqdm(enumerate(train_loader)):
307
+ # for step, batch in tqdm(enumerate(train_loader)):
308
+ # for step, batch in enumerate(train_loader_omni):
309
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
310
+
311
+ # x0, _ = batch
312
+
313
+
314
+ # ==========================================================================
315
+ # diffusion train on single image
316
+
317
+ # x0 = batch # for omni dataset
318
+ [x0,embd] = batch # for om dataset
319
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
320
+ # print('embd:', embd.shape)
321
+ embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
322
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
323
+ embd_in = embd_dev
324
+ else:
325
+ embd_in = None
326
+
327
+
328
+
329
+ n = x0.size()[0] # batch_size -> n
330
+ x0 = x0.to(hyp_parameters["device"])
331
+
332
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
333
+
334
+ # random deformation + rotation
335
+ if hyp_parameters["ndims"]>2:
336
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
337
+ x0 = utils.random_resample(x0, deform_scale=0)
338
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
339
+ else:
340
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
341
+ # x0 = transformer(x0)
342
+ if hyp_parameters['noise_scale']>0:
343
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
344
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
345
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
346
+
347
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
348
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
349
+ hyp_parameters["device"]
350
+ ) # pick up a seq of rand number from 0 to 'timestep'
351
+
352
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
353
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
354
+ # print('proc_type:', proc_type)
355
+ ddpm = Deformddpm.module if use_distributed else Deformddpm
356
+ cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
357
+
358
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
359
+
360
+ # print(torch.max(torch.abs(pre_dvf_I)))
361
+ # print(torch.max(torch.abs(dvf_I)))
362
+
363
+ loss_tot=0
364
+
365
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
366
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
367
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
368
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
369
+
370
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
371
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
372
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
373
+
374
+ # >> JZ: print nan in x0
375
+ if torch.isnan(x0).any():
376
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
377
+ # >> JZ: print loss of ddf
378
+ if loss_ddf>0.001:
379
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
380
+ # yu: check if loss_tot==nan or inf
381
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
382
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
383
+ loss_nan_step += 1
384
+ continue
385
+ if loss_nan_step > 5:
386
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
387
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
388
+
389
+ optimizer.zero_grad()
390
+ loss_tot.backward()
391
+ optimizer.step()
392
+
393
+ epoch_loss_tot += loss_tot.item() / total
394
+ epoch_loss_gen_d += loss_gen_d.item() / total
395
+ epoch_loss_gen_a += loss_gen_a.item() / total
396
+ epoch_loss_reg += loss_ddf.item() / total
397
+
398
+ # ==========================================================================
399
+ # contrastive train on single image (text-image alignment)
400
+ loss_contra_val = None
401
+ if step % CONTRASTIVE_STEP_RATIO == 0:
402
+ raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
403
+ n_contra = x0.size()[0]
404
+ t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
405
+ _ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
406
+ if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
407
+ img_embd = raw_network.img_embd # [B, 1024]
408
+ loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.05) # contrastive loss to align image embedding with text embedding, with a margin of 0.02
409
+
410
+ optimizer.zero_grad()
411
+ loss_contra.backward()
412
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
413
+ optimizer.step()
414
+ loss_contra_val = loss_contra.item()
415
+ epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
416
+ else:
417
+ if gpu_id == 0:
418
+ print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
419
+
420
+ # ==========================================================================
421
+ # registration train on paired images
422
+ if step%REGISTRATION_STEP_RATIO == 0 and loss_gen_a.item()<-0.6: # only train registration on relatively well-deformed images, to avoid too large registration loss and unstable training in the early stage
423
+ [x1, y1, _, embd_y] = batch_p
424
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
425
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
426
+ else:
427
+ embd_y = None
428
+
429
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
430
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
431
+ n = x1.size()[0] # batch_size -> n
432
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
433
+ if hyp_parameters['noise_scale']>0:
434
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
435
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
436
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
437
+ x1 = x1 * random_scale + random_shift
438
+ y1 = y1 * random_scale + random_shift
439
+
440
+ scale_regist = np.random.uniform(0.0,0.7)
441
+ select_timestep = np.random.randint(12, 25) # select a random number of timesteps to sample, between 8 and 16
442
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
443
+
444
+ T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
445
+
446
+ proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
447
+ ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
448
+ y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
449
+ msk_tgt = msk_tgt+MSK_EPS
450
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
451
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
452
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
453
+ loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
454
+
455
+ loss_regist = 0
456
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
457
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
458
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
459
+
460
+ # >> JZ: print nan in x0
461
+ if torch.isnan(x0).any():
462
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
463
+ # >> JZ: print loss of ddf
464
+ if loss_ddf1>0.002:
465
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
466
+
467
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
468
+ optimizer.zero_grad()
469
+ loss_regist.backward()
470
+
471
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
472
+ optimizer.step()
473
+
474
+ epoch_loss_regist += loss_regist.item()
475
+ epoch_loss_imgsim += loss_sim.item()
476
+ epoch_loss_imgmse += loss_mse.item()
477
+ epoch_loss_ddfreg += loss_ddf1.item()
478
+ else:
479
+ loss_sim = torch.tensor(0.0)
480
+ loss_mse = torch.tensor(0.0)
481
+ loss_ddf1 = torch.tensor(0.0)
482
+ loss_regist = torch.tensor(0.0)
483
+ if step % REGISTRATION_STEP_RATIO==0:
484
+ total_reg = total_reg-1
485
+
486
+ # if step % 50 == 0:
487
+ # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
488
+ # if loss_contra_val is not None:
489
+ # print(f' loss_contrastive: {loss_contra_val:.6f}')
490
+ # print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
491
+
492
+ if gpu_id == 0:
493
+ print('==================')
494
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
495
+ print(f' loss_contrastive: {epoch_loss_contrastive}')
496
+ print(f' loss_regist: {epoch_loss_regist/total_reg} = {epoch_loss_imgsim/total_reg} (imgsim) + {epoch_loss_imgmse/total_reg} (imgmse) + {epoch_loss_ddfreg/total_reg} (ddf)')
497
+ print('==================')
498
+
499
+
500
+ if 0 == epoch % epoch_per_save:
501
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
502
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
503
+ # break # FOR TESTING
504
+ if not use_distributed:
505
+ print(f"saved in {save_dir}")
506
+ # torch.save(Deformddpm.state_dict(), save_dir)
507
+ torch.save({
508
+ 'model_state_dict': Deformddpm.state_dict(),
509
+ 'optimizer_state_dict': optimizer.state_dict(),
510
+ 'epoch': epoch
511
+ }, save_dir)
512
+ elif gpu_id == 0:
513
+ print(f"saved in {save_dir}")
514
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
515
+ torch.save({
516
+ 'model_state_dict': Deformddpm.module.state_dict(),
517
+ 'optimizer_state_dict': optimizer.state_dict(),
518
+ 'epoch': epoch
519
+ }, save_dir)
520
+
521
+ # Resource cleanup at the end of training
522
+ _empty_cache(DEVICE_TYPE)
523
+ gc.collect()
524
+ if use_distributed and dist.is_initialized():
525
+ dist.destroy_process_group()
526
+
527
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
528
+
529
+ if gpu_id == 0:
530
+ # if 0:
531
+ utils.print_memory_usage("Before Loading Model")
532
+ gc.collect()
533
+ _empty_cache(DEVICE_TYPE)
534
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
535
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
536
+ checkpoint = torch.load(model_file, map_location='cpu')
537
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
538
+ if use_distributed:
539
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
540
+ else:
541
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
542
+ if load_strict:
543
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
544
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
545
+
546
+ if use_distributed:
547
+ # Broadcast model weights from rank 0 to all other GPUs
548
+ dist.barrier()
549
+ for param in Deformddpm.parameters():
550
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
551
+ dist.barrier()
552
+ for param_group in optimizer.param_groups:
553
+ for param in param_group['params']:
554
+ if param.grad is not None:
555
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
556
+
557
+ # initial_epoch = checkpoint['epoch'] + 1
558
+ # get the epoch number from the filename and add 1 to set as initial_epoch
559
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
560
+
561
+ return initial_epoch, Deformddpm, optimizer
562
+
563
+
564
+
565
+ if __name__ == "__main__":
566
+ if "LOCAL_RANK" in os.environ:
567
+ # Multi-node: launched by torchrun / srun
568
+ use_distributed = True
569
+ local_rank = int(os.environ["LOCAL_RANK"])
570
+ world_size = int(os.environ["WORLD_SIZE"])
571
+ print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
572
+ try:
573
+ main_train(local_rank, world_size)
574
+ except Exception as e:
575
+ import traceback
576
+ print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
577
+ traceback.print_exc()
578
+ raise
579
+ elif use_distributed:
580
+ # Single-node multi-GPU: use mp.spawn
581
+ world_size = _device_count(DEVICE_TYPE)
582
+ print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
583
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
584
+ else:
585
+ main_train(0,1)
OMorpher/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .omorpher import OMorpher
2
+
3
+ __all__ = ['OMorpher']
OMorpher/omorpher.py ADDED
@@ -0,0 +1,1058 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OMorpher — Object-oriented wrapper for OmniMorph diffusion-based deformation.
3
+
4
+ Stores original high-res images and composes all intermediate deformations as
5
+ deformation fields (DDFs), resampling only once at the end to avoid blurring.
6
+ Independent of DeformDDPM at runtime; reimplements the diffusion logic using
7
+ the network / STN / loss building blocks from Diffusion.*.
8
+ """
9
+
10
+ import os
11
+ import glob
12
+ import math
13
+ import random
14
+ from typing import Optional, Union, List, Tuple, Dict
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ import yaml
22
+ import SimpleITK as sitk
23
+ from skimage.transform import resize as sk_resize
24
+
25
+ from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
26
+ from Diffusion.losses import Grad, MRSE, NCC
27
+
28
+ EPS = 1e-8
29
+
30
+
31
+ class OMorpher:
32
+ """High-level interface for OmniMorph deformation diffusion.
33
+
34
+ All images are kept at their original resolution internally. Deformation
35
+ fields are composed at model resolution and up-scaled on demand so that the
36
+ original image is resampled at most *once*.
37
+ """
38
+
39
+ # ------------------------------------------------------------------
40
+ # Construction
41
+ # ------------------------------------------------------------------
42
+
43
+ def __init__(
44
+ self,
45
+ config: Union[str, dict],
46
+ checkpoint_path: Optional[str] = None,
47
+ device: Optional[str] = None,
48
+ bert_model_path: Optional[str] = None,
49
+ ):
50
+ # ---- Config ----
51
+ if isinstance(config, str):
52
+ with open(config, "r") as f:
53
+ config = yaml.safe_load(f)
54
+ self.config: dict = config
55
+
56
+ self.net_name: str = config.get("net_name", "recmutattnnet")
57
+ self.ndims: int = config.get("ndims", 3)
58
+ self.img_size: int = config.get("img_size", 128)
59
+ self.timesteps: int = config.get("timesteps", 80)
60
+ self.v_scale: float = config.get("v_scale", 5e-5)
61
+ self.noise_scale: float = config.get("noise_scale", 0.1)
62
+ self.condition_type: str = config.get("condition_type", "none")
63
+ self.num_input_chn: int = config.get("num_input_chn", 1)
64
+ self.img_pad_mode: str = config.get("img_pad_mode", "zeros")
65
+ self.ddf_pad_mode: str = config.get("ddf_pad_mode", "border")
66
+ self.padding_mode: str = config.get("padding_mode", "border")
67
+ self.resample_mode: str = config.get("resample_mode", "bilinear")
68
+ self.batch_size: int = config.get("batchsize", 1)
69
+ self.data_name: str = config.get("data_name", "all")
70
+ self.clamp_range: list = config.get("clamp_range", [-400, 400])
71
+ self.inf_mode: bool = config.get("inf_mode", True)
72
+
73
+ # ---- Device ----
74
+ if device is not None:
75
+ self.device = torch.device(device)
76
+ else:
77
+ self.device = self._resolve_device(config.get("device", None))
78
+
79
+ # ---- BERT (lazy) ----
80
+ self.bert_model_path = bert_model_path or os.path.join(
81
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
82
+ "External", "Models", "bert_large_uncased",
83
+ )
84
+ self._bert_model = None
85
+ self._bert_tokenizer = None
86
+
87
+ # ---- Network ----
88
+ Net = get_net(self.net_name)
89
+ self.network = Net(
90
+ n_steps=self.timesteps,
91
+ ndims=self.ndims,
92
+ num_input_chn=self.num_input_chn,
93
+ res=self.img_size,
94
+ )
95
+ self.network.to(self.device)
96
+
97
+ # ---- STN instances ----
98
+ self.ctl_ratio = 4
99
+ self.ctl_sz = self.img_size // self.ctl_ratio
100
+
101
+ self.stn_full = STN(
102
+ img_sz=self.img_size,
103
+ ndims=self.ndims,
104
+ padding_mode=self.padding_mode,
105
+ device=self.device,
106
+ )
107
+ self.stn_ctl = STN(
108
+ img_sz=self.ctl_sz,
109
+ ndims=self.ndims,
110
+ padding_mode=self.ddf_pad_mode,
111
+ device=self.device,
112
+ )
113
+ self.img_stn = STN(
114
+ img_sz=self.img_size,
115
+ ndims=self.ndims,
116
+ padding_mode=self.img_pad_mode,
117
+ device=self.device,
118
+ resample_mode=self.resample_mode if self.resample_mode != "bilinear" else None,
119
+ )
120
+ self.msk_stn = STN(
121
+ img_sz=self.img_size,
122
+ ndims=self.ndims,
123
+ padding_mode=self.img_pad_mode,
124
+ device=self.device,
125
+ resample_mode="nearest",
126
+ )
127
+
128
+ # ---- Loss functions (for fine-tuning) ----
129
+ self._loss_grad = Grad(penalty=["l1"], ndims=self.ndims)
130
+ self._loss_dist = MRSE(img_sz=self.img_size)
131
+ self._loss_ang = NCC(img_sz=self.img_size)
132
+
133
+ # ---- Load checkpoint ----
134
+ if checkpoint_path is not None:
135
+ self._load_checkpoint(checkpoint_path)
136
+ else:
137
+ auto_path = self._auto_find_checkpoint()
138
+ if auto_path is not None:
139
+ self._load_checkpoint(auto_path)
140
+
141
+ self.network.eval()
142
+
143
+ # ---- State ----
144
+ self._init_img: Optional[torch.Tensor] = None # [B,1,S,S,S] model-res
145
+ self._init_img_raw: Optional[torch.Tensor] = None # [B,1,D,H,W] full-res
146
+ self._init_img_original_shape: Optional[tuple] = None
147
+ self._init_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
148
+ self._cond_img: Optional[torch.Tensor] = None # [B,1,S,S,S]
149
+ self._cond_txt: Optional[torch.Tensor] = None # [B,1024]
150
+ self._predicted_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
151
+ self._intermediate_ddfs: List[Tuple[int, torch.Tensor]] = []
152
+
153
+ # ---- Fine-tuning state ----
154
+ self._optimizer: Optional[torch.optim.Optimizer] = None
155
+
156
+ # ------------------------------------------------------------------
157
+ # Device resolution
158
+ # ------------------------------------------------------------------
159
+
160
+ @staticmethod
161
+ def _resolve_device(hint: Optional[str] = None) -> torch.device:
162
+ if hint is not None:
163
+ s = str(hint).lower()
164
+ if s not in ("auto", ""):
165
+ return torch.device(s)
166
+ # XPU → CUDA → CPU
167
+ try:
168
+ import intel_extension_for_pytorch # noqa: F401
169
+ if torch.xpu.is_available():
170
+ return torch.device("xpu")
171
+ except (ImportError, AttributeError):
172
+ pass
173
+ if torch.cuda.is_available():
174
+ return torch.device("cuda")
175
+ return torch.device("cpu")
176
+
177
+ # ------------------------------------------------------------------
178
+ # Checkpoint helpers
179
+ # ------------------------------------------------------------------
180
+
181
+ def _auto_find_checkpoint(self) -> Optional[str]:
182
+ pattern = os.path.join(
183
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
184
+ "Models",
185
+ f"{self.data_name}_{self.net_name}",
186
+ "*.pth",
187
+ )
188
+ files = sorted(glob.glob(pattern))
189
+ return files[-1] if files else None
190
+
191
+ def _load_checkpoint(self, path: str):
192
+ ckpt = torch.load(path, map_location="cpu")
193
+ state_dict = ckpt.get("model_state_dict", ckpt)
194
+ # Strip DDP 'module.' prefix and DeformDDPM wrapper keys
195
+ cleaned = {}
196
+ for k, v in state_dict.items():
197
+ k = k.replace("module.", "")
198
+ if k.startswith("network."):
199
+ k = k[len("network."):]
200
+ cleaned[k] = v
201
+ # Only load keys that exist in the network
202
+ net_keys = set(self.network.state_dict().keys())
203
+ filtered = {k: v for k, v in cleaned.items() if k in net_keys}
204
+ if filtered:
205
+ self.network.load_state_dict(filtered, strict=False)
206
+
207
+ # ------------------------------------------------------------------
208
+ # Public — Input setters
209
+ # ------------------------------------------------------------------
210
+
211
+ def set_init_img(
212
+ self,
213
+ img,
214
+ modality: Optional[str] = None,
215
+ ) -> "OMorpher":
216
+ """Set the initial image. Accepts numpy, torch, path, or (img, ddf) tuple."""
217
+ init_ddf = None
218
+ if isinstance(img, (tuple, list)):
219
+ img, init_ddf = img[0], img[1]
220
+
221
+ model_tensor, fullres_tensor, orig_shape = self._standardize_img(
222
+ img, modality=modality, keep_raw=True,
223
+ )
224
+ self._init_img = model_tensor
225
+ self._init_img_raw = fullres_tensor
226
+ self._init_img_original_shape = orig_shape
227
+
228
+ if init_ddf is not None:
229
+ self._init_ddf = self._to_ddf_tensor(init_ddf)
230
+ else:
231
+ B = self._init_img.shape[0]
232
+ S = self.img_size
233
+ self._init_ddf = torch.zeros(
234
+ [B, self.ndims] + [S] * self.ndims,
235
+ dtype=torch.float32, device=self.device,
236
+ )
237
+ return self
238
+
239
+ def set_cond_img(
240
+ self,
241
+ img=None,
242
+ modality: Optional[str] = None,
243
+ ) -> "OMorpher":
244
+ """Set the conditioning image. Default: Gaussian noise sigma=0.1."""
245
+ if img is None:
246
+ B = self._init_img.shape[0] if self._init_img is not None else self.batch_size
247
+ S = self.img_size
248
+ self._cond_img = torch.randn(
249
+ [B, 1] + [S] * self.ndims,
250
+ dtype=torch.float32, device=self.device,
251
+ ) * 0.1
252
+ else:
253
+ tensor, _, _ = self._standardize_img(img, modality=modality, keep_raw=False)
254
+ self._cond_img = tensor
255
+ return self
256
+
257
+ def set_cond_txt(self, txt=None) -> "OMorpher":
258
+ """Set the text conditioning. Accepts string, numpy [1024], torch [1024], or None."""
259
+ self._cond_txt = self._standardize_txt(txt)
260
+ return self
261
+
262
+ def set_init_def(self, ddf=None) -> "OMorpher":
263
+ """Set or regenerate the initial deformation field.
264
+
265
+ If *ddf* is ``None``, a random DDF is generated using the forward
266
+ diffusion parameters (useful for data augmentation).
267
+ """
268
+ if ddf is None:
269
+ if self._init_img is None:
270
+ raise RuntimeError("set_init_img() must be called before set_init_def()")
271
+ t_val = self.config.get("start_noise_step", self.timesteps // 2)
272
+ t = torch.tensor([t_val], dtype=torch.long, device=self.device)
273
+ _, _, random_ddf = self._get_random_ddf(self._init_img, t)
274
+ self._init_ddf = random_ddf
275
+ else:
276
+ self._init_ddf = self._to_ddf_tensor(ddf)
277
+ return self
278
+
279
+ # ------------------------------------------------------------------
280
+ # Public — Core operations (inference)
281
+ # ------------------------------------------------------------------
282
+
283
+ def predict(
284
+ self,
285
+ T: Optional[list] = None,
286
+ proc_type: Optional[str] = None,
287
+ t_save: Optional[list] = None,
288
+ ) -> "OMorpher":
289
+ """Run reverse diffusion and store predicted DDF. Returns ``self`` for chaining."""
290
+ if self._init_img is None:
291
+ raise RuntimeError("set_init_img() must be called before predict()")
292
+
293
+ # Defaults
294
+ start_noise = self.config.get("start_noise_step", 0)
295
+ if T is None:
296
+ T = [start_noise, self.timesteps]
297
+ if proc_type is None:
298
+ proc_type = self.condition_type
299
+
300
+ B = self._init_img.shape[0]
301
+ S = self.img_size
302
+
303
+ # Conditioning
304
+ cond_img_src = self._cond_img if self._cond_img is not None else self._init_img.clone().detach()
305
+ cond_img, mask, cond_ratio = self._proc_cond_img(cond_img_src, proc_type=proc_type)
306
+
307
+ # Text embedding
308
+ txt = self._cond_txt
309
+ if txt is None:
310
+ txt = torch.zeros([B, 1024], dtype=torch.float32, device=self.device)
311
+
312
+ # Reshape text for network consumption
313
+ if isinstance(self.network, DefRec_MutAttnNet):
314
+ txt = txt.view(B, -1, *([1] * self.ndims))
315
+
316
+ # Initial state
317
+ init_ddf_is_zero = (self._init_ddf is None) or torch.all(self._init_ddf == 0)
318
+
319
+ if not init_ddf_is_zero:
320
+ ddf_comp = self._init_ddf.clone()
321
+ img_rec = self.img_stn(self._init_img, ddf_comp)
322
+ elif T[0] is not None and T[0] > 0:
323
+ t_start = torch.tensor(np.array([T[0]]), device=self.device)
324
+ img_rec, _, ddf_comp = self._get_random_ddf(self._init_img, t_start)
325
+ else:
326
+ img_rec = self._init_img.clone()
327
+ ddf_comp = torch.zeros(
328
+ [B, self.ndims] + [S] * self.ndims,
329
+ dtype=torch.float32, device=self.device,
330
+ )
331
+
332
+ # Reverse diffusion loop
333
+ self._intermediate_ddfs = []
334
+
335
+ rec_num = 2 # matches DeformDDPM.rec_num default
336
+
337
+ if isinstance(self.network, DefRec_MutAttnNet):
338
+ # DefRec network: pass full time list at once
339
+ t_list = list(range(T[1] - 1, -1, -1))
340
+ with torch.no_grad():
341
+ pre_dvf = self.network(
342
+ x=img_rec, y=cond_img, t=t_list, rec_num=rec_num, text=txt,
343
+ )
344
+ ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
345
+ img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
346
+ if t_save:
347
+ self._intermediate_ddfs.append((0, ddf_comp.clone()))
348
+ else:
349
+ # Standard iterative recovery
350
+ time_steps = range(T[1] - 1, -1, -1)
351
+ for i in time_steps:
352
+ t = torch.tensor(np.array([i]), device=self.device)
353
+ with torch.no_grad():
354
+ pre_dvf = self.network(
355
+ x=img_rec, y=cond_img, t=t, rec_num=rec_num, text=txt,
356
+ )
357
+ ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
358
+ img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
359
+ if t_save is not None and i in t_save:
360
+ self._intermediate_ddfs.append((i, ddf_comp.clone()))
361
+
362
+ self._predicted_ddf = ddf_comp
363
+ return self
364
+
365
+ def get_def(
366
+ self,
367
+ t_list: Optional[list] = None,
368
+ ) -> Union[torch.Tensor, Dict[int, torch.Tensor]]:
369
+ """Return the final predicted DDF, or intermediate DDFs for given timesteps."""
370
+ if t_list is None:
371
+ if self._predicted_ddf is None:
372
+ raise RuntimeError("predict() must be called before get_def()")
373
+ return self._predicted_ddf
374
+ out = {}
375
+ for t, ddf in self._intermediate_ddfs:
376
+ if t in t_list:
377
+ out[t] = ddf
378
+ return out
379
+
380
+ def apply_def(
381
+ self,
382
+ img=None,
383
+ ddf: Optional[torch.Tensor] = None,
384
+ padding_mode: Optional[str] = None,
385
+ resample_mode: Optional[str] = None,
386
+ ) -> torch.Tensor:
387
+ """Apply a DDF to an image. Auto-upscales DDF when sizes differ.
388
+
389
+ Defaults: init image at full resolution, predicted DDF.
390
+ """
391
+ if padding_mode is None:
392
+ padding_mode = self.padding_mode
393
+ if resample_mode is None:
394
+ resample_mode = "bilinear"
395
+
396
+ # Default DDF
397
+ if ddf is None:
398
+ if self._predicted_ddf is None:
399
+ raise RuntimeError("predict() must be called before apply_def()")
400
+ ddf = self._predicted_ddf
401
+
402
+ # Default image: full-res init image tensor
403
+ if img is None:
404
+ if self._init_img_raw is not None:
405
+ vol_tensor = self._init_img_raw
406
+ else:
407
+ vol_tensor = self._init_img
408
+ else:
409
+ vol_tensor = self._ensure_tensor(img)
410
+
411
+ # Upscale DDF if sizes differ
412
+ target_sz = list(vol_tensor.shape[2:])
413
+ ddf_sz = list(ddf.shape[2:])
414
+ if target_sz != ddf_sz:
415
+ ddf = F.interpolate(
416
+ ddf, size=target_sz,
417
+ mode="bilinear" if self.ndims == 2 else "trilinear",
418
+ align_corners=False,
419
+ )
420
+
421
+ return self._apply_ddf(vol_tensor, ddf, padding_mode=padding_mode, resample_mode=resample_mode)
422
+
423
+ # ------------------------------------------------------------------
424
+ # Public — Fine-tuning
425
+ # ------------------------------------------------------------------
426
+
427
+ def finetune_setup(
428
+ self,
429
+ lr: float = 1e-4,
430
+ optimizer_cls=None,
431
+ ) -> "OMorpher":
432
+ """Switch to training mode and create an optimizer."""
433
+ self.network.train()
434
+ self.inf_mode = False
435
+ if optimizer_cls is None:
436
+ optimizer_cls = torch.optim.Adam
437
+ self._optimizer = optimizer_cls(self.network.parameters(), lr=lr)
438
+ return self
439
+
440
+ def finetune_step(
441
+ self,
442
+ img_batch,
443
+ cond_batch=None,
444
+ text_batch=None,
445
+ t=None,
446
+ proc_type=None,
447
+ ) -> dict:
448
+ """Single training step. Returns loss dict."""
449
+ if self._optimizer is None:
450
+ raise RuntimeError("finetune_setup() must be called first")
451
+
452
+ img, _, _ = self._standardize_img(img_batch, keep_raw=False)
453
+ cond = self._standardize_img(cond_batch, keep_raw=False)[0] if cond_batch is not None else img.clone()
454
+ text = self._standardize_txt(text_batch)
455
+
456
+ B = img.shape[0]
457
+ if t is None:
458
+ t = torch.randint(0, self.timesteps, (B,), device=self.device)
459
+ else:
460
+ t = torch.tensor(t, device=self.device) if not isinstance(t, torch.Tensor) else t.to(self.device)
461
+
462
+ proc_type = proc_type or self.condition_type
463
+ cond_img, mask, cond_ratio = self._proc_cond_img(cond, proc_type=proc_type)
464
+ noisy_img, dvf_gt, _ = self._get_random_ddf(img, t)
465
+
466
+ # Reshape text for network
467
+ if isinstance(self.network, DefRec_MutAttnNet):
468
+ if text is not None:
469
+ text = text.view(B, -1, *([1] * self.ndims))
470
+ t_input = [t]
471
+ else:
472
+ t_input = t
473
+
474
+ pre_dvf = self.network(x=noisy_img * mask, y=cond_img, t=t_input, rec_num=2, text=text)
475
+
476
+ loss_grad = self._loss_grad(y_pred=pre_dvf, img=img)
477
+ trm_pred = self.stn_full(pre_dvf, dvf_gt)
478
+ loss_dist = self._loss_dist(pred=trm_pred, inv_lab=dvf_gt)
479
+ loss_ang = self._loss_ang(pred=trm_pred, inv_lab=dvf_gt)
480
+ loss_total = 2.0 * loss_ang + 1.0 * loss_dist + 16.0 * loss_grad
481
+
482
+ self._optimizer.zero_grad()
483
+ loss_total.backward()
484
+ self._optimizer.step()
485
+
486
+ return {
487
+ "loss_total": loss_total.item(),
488
+ "loss_grad": loss_grad.item(),
489
+ "loss_dist": loss_dist.item(),
490
+ "loss_ang": loss_ang.item(),
491
+ }
492
+
493
+ def finetune_save(self, path: str, epoch: int = 0):
494
+ """Save checkpoint in the standard OmniMorph format."""
495
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
496
+ torch.save(
497
+ {
498
+ "model_state_dict": self.network.state_dict(),
499
+ "optimizer_state_dict": self._optimizer.state_dict() if self._optimizer else None,
500
+ "epoch": epoch,
501
+ },
502
+ path,
503
+ )
504
+
505
+ def finetune_teardown(self) -> "OMorpher":
506
+ """Switch back to eval mode."""
507
+ self.network.eval()
508
+ self.inf_mode = True
509
+ self._optimizer = None
510
+ return self
511
+
512
+ # ------------------------------------------------------------------
513
+ # Private — Diffusion logic
514
+ # ------------------------------------------------------------------
515
+
516
+ def _get_ddf_scale(
517
+ self, t: torch.Tensor, divide_num: int = 1, max_ddf_num: int = 200,
518
+ ) -> Tuple[int, torch.Tensor, torch.Tensor]:
519
+ """Timestep-dependent deformation magnitude. Mirrors DeformDDPM._get_ddf_scale()."""
520
+ rec_num = 1
521
+ mul_num_ddf = torch.floor_divide(2 * torch.pow(t.float(), 1.3), 3 * divide_num).int()
522
+ mul_num_dvf = torch.floor_divide(torch.pow(t.float(), 0.6), divide_num).int()
523
+ mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
524
+ mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
525
+ return rec_num, mul_num_ddf, mul_num_dvf
526
+
527
+ def _sample_random_uniform_multi_order(
528
+ self, high=None, low=0.0, order_num=3,
529
+ ) -> float:
530
+ sample_value = low
531
+ for _ in range(order_num):
532
+ sample_value = np.random.uniform(low=sample_value, high=high)
533
+ return sample_value
534
+
535
+ def _multiscale_dvf_generate(
536
+ self, v_scale: float, ctl_szs: list = None, rand_v_scale: bool = True,
537
+ ) -> torch.Tensor:
538
+ """Multi-scale Gaussian DVF at control-point sizes."""
539
+ if ctl_szs is None:
540
+ ctl_szs = [4, 8, 16, 32, 64]
541
+ dvf = 0
542
+ for ctl_sz in ctl_szs:
543
+ _v = (
544
+ self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2)
545
+ if rand_v_scale
546
+ else v_scale
547
+ )
548
+ if ctl_sz <= 2:
549
+ _v = _v / 2
550
+ dvf_comp = torch.randn(
551
+ [self.batch_size, self.ndims] + [ctl_sz] * self.ndims
552
+ ) * _v
553
+ dvf_comp = F.interpolate(
554
+ dvf_comp * self.ctl_sz / ctl_sz,
555
+ [self.ctl_sz] * self.ndims,
556
+ align_corners=False,
557
+ mode="bilinear" if self.ndims == 2 else "trilinear",
558
+ )
559
+ dvf = dvf + dvf_comp
560
+ return dvf
561
+
562
+ def _random_ddf_generate(
563
+ self,
564
+ rec_num: int = 3,
565
+ mul_num: list = None,
566
+ noise_ratio: float = 0.08,
567
+ select_num: int = 4,
568
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
569
+ """Compose DVFs to build a DDF. Mirrors DeformDDPM._random_ddf_generate()."""
570
+ if mul_num is None:
571
+ mul_num = [torch.tensor([5]), torch.tensor([5])]
572
+
573
+ crop_rate = 2
574
+ # unsqueeze mul_num for broadcasting
575
+ for _ in range(self.ndims + 1):
576
+ mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
577
+
578
+ ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
579
+ ddf = torch.zeros(ctl_ddf_sz)
580
+ dddf = torch.zeros(ctl_ddf_sz)
581
+ scale_num = min(8, int(math.log2(self.ctl_sz)))
582
+ ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
583
+
584
+ for _i in range(rec_num):
585
+ if len(ctl_szs_all) > select_num:
586
+ ctl_szs = random.sample(ctl_szs_all, select_num)
587
+ else:
588
+ ctl_szs = ctl_szs_all
589
+ dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
590
+ if noise_ratio == 0:
591
+ dvf0 = dvf
592
+ else:
593
+ dvf0 = dvf + self.stn_ctl(
594
+ self._multiscale_dvf_generate(
595
+ self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False,
596
+ ).to(self.device),
597
+ dvf,
598
+ )
599
+ for j in range(torch.max(mul_num[0]).item()):
600
+ flag = [(n > j).int().to(self.device) for n in mul_num]
601
+ ddf = dvf0 * flag[0] + self.stn_ctl(ddf, dvf0 * flag[0])
602
+ dddf = dvf * flag[1] + self.stn_ctl(dddf, dvf * flag[1])
603
+
604
+ # Upscale and center-crop
605
+ interp_mode = "bilinear" if self.ndims == 2 else "trilinear"
606
+ ddf = F.interpolate(
607
+ ddf * self.img_size / self.ctl_sz,
608
+ self.img_size * crop_rate,
609
+ mode=interp_mode,
610
+ )
611
+ dddf = F.interpolate(
612
+ dddf * self.img_size / self.ctl_sz,
613
+ self.img_size * crop_rate,
614
+ mode=interp_mode,
615
+ )
616
+ half = self.img_size // 2
617
+ three_half = self.img_size * 3 // 2
618
+ if self.ndims == 2:
619
+ ddf = ddf[..., half:three_half, half:three_half]
620
+ dddf = dddf[..., half:three_half, half:three_half]
621
+ else:
622
+ ddf = ddf[..., half:three_half, half:three_half, half:three_half]
623
+ dddf = dddf[..., half:three_half, half:three_half, half:three_half]
624
+ return ddf, dddf
625
+
626
+ def _get_random_ddf(
627
+ self, img: torch.Tensor, t: torch.Tensor,
628
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
629
+ """Forward-diffuse: generate random DDF and warp image."""
630
+ rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
631
+ ddf_forward, dvf_forward = self._random_ddf_generate(
632
+ rec_num=rec_num, mul_num=[mul_num_ddf, mul_num_dvf],
633
+ )
634
+ warped_img = self.img_stn(img, ddf_forward)
635
+ return warped_img, dvf_forward, ddf_forward
636
+
637
+ # ------------------------------------------------------------------
638
+ # Private — Conditioning processing
639
+ # ------------------------------------------------------------------
640
+
641
+ def _proc_cond_img(
642
+ self,
643
+ img: torch.Tensor,
644
+ proc_type: Optional[str] = None,
645
+ noise_scale: float = 0.1,
646
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
647
+ """Conditioning strategies. Mirrors DeformDDPM.proc_cond_img()."""
648
+ proc_img = img.clone().detach()
649
+ if proc_type is None:
650
+ proc_type = random.choices(
651
+ ["adding", "independ", "downsample", "slice", "none", "uncon"],
652
+ weights=[1, 1, 1, 1, 1, 3],
653
+ k=1,
654
+ )[0]
655
+
656
+ mask = torch.tensor(1, device=img.device)
657
+ cond_ratio = torch.tensor(1.0, device=img.device)
658
+
659
+ if proc_type in ["none", None, "", "None"]:
660
+ return proc_img, mask, cond_ratio
661
+
662
+ noise_type = random.choice(["gaussian", "uniform", "none"])
663
+
664
+ if proc_type == "uncon":
665
+ noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
666
+ return noise_map, torch.tensor(0, device=img.device), torch.tensor(0, device=img.device)
667
+
668
+ noise_map = None
669
+ if proc_type in ["adding", "independ", "slice"]:
670
+ noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
671
+
672
+ if proc_type == "adding":
673
+ noise_ratio = np.random.uniform(0.0, 1.0)
674
+ proc_img = proc_img * (1 - noise_ratio) + noise_map * noise_ratio
675
+ cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
676
+ elif proc_type == "independ":
677
+ mask = self._create_noise_map(img, noise_type="binary")
678
+ proc_img = img * mask
679
+ cond_ratio = mask.float().mean()
680
+ elif proc_type == "downsample":
681
+ down_ratio = list(np.random.uniform(1.0 / 64, 1, [self.ndims]))
682
+ down_img = F.interpolate(
683
+ proc_img, scale_factor=down_ratio,
684
+ mode="bilinear" if self.ndims == 2 else "trilinear",
685
+ )
686
+ proc_img = F.interpolate(
687
+ down_img, size=[self.img_size] * self.ndims,
688
+ mode="bilinear" if self.ndims == 2 else "trilinear",
689
+ align_corners=False,
690
+ )
691
+ cond_ratio = torch.tensor(np.sqrt(np.prod(down_ratio)), device=img.device)
692
+ elif proc_type == "slice":
693
+ slice_num_max = random.randint(1, 64)
694
+ slice_num_max = random.randint(1, slice_num_max)
695
+ mask, sample_ratio = self._get_slice_mask(img, slice_num_range=[0, slice_num_max])
696
+ proc_img = img * mask
697
+ cond_ratio = torch.tensor(sample_ratio, device=img.device)
698
+ elif proc_type == "project":
699
+ proj_img = torch.zeros_like(img)
700
+ rand_bourn = np.random.randint(0, 2, size=[self.ndims])
701
+ proj_dim_num = np.sum(rand_bourn)
702
+ for i, pflag in zip(range(2, 2 + self.ndims), rand_bourn):
703
+ if pflag:
704
+ proj_img += torch.mean(img, dim=i, keepdim=True)
705
+ proc_img = proj_img / (proj_dim_num + EPS)
706
+ cond_ratio = torch.tensor(proj_dim_num / (128 * self.ndims), device=img.device)
707
+
708
+ return proc_img, mask, cond_ratio
709
+
710
+ def _create_noise_map(
711
+ self,
712
+ img: torch.Tensor,
713
+ noise_type: str = "gaussian",
714
+ noise_scale: float = 0.1,
715
+ ) -> torch.Tensor:
716
+ if noise_type == "gaussian":
717
+ return (torch.randn_like(img) * noise_scale).to(img.device)
718
+ elif noise_type == "uniform":
719
+ return (torch.rand_like(img) * noise_scale * 2 - noise_scale).to(img.device)
720
+ elif noise_type == "binary":
721
+ return torch.bernoulli(torch.rand_like(img)).to(img.device)
722
+ return torch.zeros_like(img).to(img.device)
723
+
724
+ def _get_slice_mask(
725
+ self,
726
+ img: torch.Tensor,
727
+ slice_num_range: list = None,
728
+ ) -> Tuple[torch.Tensor, float]:
729
+ if slice_num_range is None:
730
+ slice_num_range = [0, 32]
731
+ slice_num_range[1] = min(slice_num_range[1], self.img_size)
732
+ mask = torch.zeros_like(img)
733
+ sample_ratio = 0.0
734
+ for i in range(self.ndims):
735
+ if self.inf_mode:
736
+ slice_num = 1
737
+ slice_idx = [self.img_size // 2]
738
+ else:
739
+ slice_num = random.randint(slice_num_range[0], slice_num_range[1])
740
+ slice_idx = random.sample(range(self.img_size), slice_num)
741
+ transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
742
+ for idx in slice_idx:
743
+ mask[..., idx] = 1
744
+ mask = mask.permute(*transpose_list)
745
+ sample_ratio += np.sqrt(slice_num / self.img_size) / self.ndims
746
+ return mask, sample_ratio
747
+
748
+ # ------------------------------------------------------------------
749
+ # Private — Standardization
750
+ # ------------------------------------------------------------------
751
+
752
+ def _standardize_img(
753
+ self,
754
+ img,
755
+ modality: Optional[str] = None,
756
+ keep_raw: bool = False,
757
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple]]:
758
+ """Deterministic inference variant of the dataloader pipeline.
759
+
760
+ Returns ``(model_tensor, fullres_tensor_or_None, orig_shape_or_None)``.
761
+
762
+ * *model_tensor*: ``[B, C, S, S, S]`` at model resolution.
763
+ * *fullres_tensor*: ``[B, C, D, H, W]`` at original padded resolution
764
+ (only when *keep_raw=True*).
765
+ * *orig_shape*: spatial dims of padded volume before resize.
766
+
767
+ Accepts numpy arrays, torch tensors (any dimensionality), or a
768
+ file path (loaded via SimpleITK). Torch tensors with >= 4 dims
769
+ are treated as already-batched and are passed through with
770
+ appropriate device/dtype conversion.
771
+ """
772
+ fullres_tensor = None
773
+ orig_shape = None
774
+
775
+ # 1. Load from path
776
+ if isinstance(img, str):
777
+ sitk_img = sitk.ReadImage(img)
778
+ vol = sitk.GetArrayFromImage(sitk_img)
779
+ vol = self._reverse_axis_order(vol)
780
+ elif isinstance(img, np.ndarray):
781
+ vol = img.copy()
782
+ elif isinstance(img, torch.Tensor):
783
+ # If already a batched tensor [B,C,...], pass through
784
+ if img.ndim >= 4:
785
+ t = img.float().to(self.device)
786
+ if keep_raw:
787
+ fullres_tensor = t.clone()
788
+ return t, fullres_tensor, None
789
+ # 1-3D tensor — treat as spatial-only numpy
790
+ vol = img.numpy()
791
+ else:
792
+ raise TypeError(f"Unsupported image type: {type(img)}")
793
+
794
+ # 2. Extract 3D from 4D
795
+ if vol.ndim == 4:
796
+ vol = vol[:, :, :, 0]
797
+
798
+ # 3. CT clamping
799
+ if modality is not None and modality.upper() == "CT" and self.clamp_range is not None:
800
+ vol = np.clip(vol, self.clamp_range[0], self.clamp_range[1])
801
+
802
+ # 4. Normalize [0, 1]
803
+ vol = vol.astype(np.float64)
804
+ vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-7)
805
+
806
+ # 5. Center-pad to cube
807
+ vol = self._center_pad_to_cube(vol)
808
+ orig_shape = vol.shape[:3]
809
+
810
+ # 6. Full-res tensor (before resize)
811
+ if keep_raw:
812
+ fullres_tensor = torch.tensor(
813
+ vol[None, None, ...], dtype=torch.float32, device=self.device,
814
+ )
815
+
816
+ # 7. Resize to model resolution
817
+ target_sz = [self.img_size] * self.ndims
818
+ vol_resized = sk_resize(
819
+ vol, target_sz, anti_aliasing=True, preserve_range=True,
820
+ )
821
+
822
+ # 8. Add batch + channel dims
823
+ model_tensor = torch.tensor(
824
+ vol_resized[None, None, ...], dtype=torch.float32, device=self.device,
825
+ )
826
+ return model_tensor, fullres_tensor, orig_shape
827
+
828
+ def _standardize_label(
829
+ self,
830
+ label,
831
+ fill_value: float = -1,
832
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
833
+ """Standardize a label volume for inference.
834
+
835
+ Returns ``(model_tensor, fullres_tensor)``.
836
+
837
+ * *model_tensor*: ``[1, C, S, S, S]`` at model resolution
838
+ (nearest-neighbor resize, no anti-aliasing).
839
+ * *fullres_tensor*: ``[1, C, D, H, W]`` at original padded resolution.
840
+
841
+ If *label* is ``None``, returns *fill_value*-filled placeholders
842
+ shaped to match the current init image (model-res and full-res).
843
+
844
+ Accepts numpy arrays or torch tensors. Does NOT apply
845
+ normalization or clamping (labels are discrete indices).
846
+ """
847
+ # --- Placeholder for missing labels ---
848
+ if label is None:
849
+ model_sz = [self.img_size] * self.ndims
850
+ model_t = torch.full(
851
+ [1, 1] + model_sz, fill_value,
852
+ dtype=torch.float32, device=self.device,
853
+ )
854
+ if self._init_img_raw is not None:
855
+ fullres_sz = list(self._init_img_raw.shape[2:])
856
+ else:
857
+ fullres_sz = model_sz
858
+ fullres_t = torch.full(
859
+ [1, 1] + fullres_sz, fill_value,
860
+ dtype=torch.float32, device=self.device,
861
+ )
862
+ return model_t, fullres_t
863
+
864
+ # --- Convert to numpy if needed ---
865
+ if isinstance(label, torch.Tensor):
866
+ if label.ndim >= 4:
867
+ # Already batched tensor — pass through
868
+ fullres_t = label.float().to(self.device)
869
+ target_sz = [self.img_size] * self.ndims
870
+ model_t = F.interpolate(
871
+ fullres_t, size=target_sz, mode="nearest",
872
+ )
873
+ return model_t, fullres_t
874
+ lab = label.numpy()
875
+ elif isinstance(label, np.ndarray):
876
+ lab = label.copy()
877
+ else:
878
+ raise TypeError(f"Unsupported label type: {type(label)}")
879
+
880
+ # --- Center-pad to cube ---
881
+ lab = self._center_pad_to_cube(lab)
882
+
883
+ # --- Channel dim: 3D→[C=1,...], 4D→channels-first [C,...] ---
884
+ if lab.ndim == 3:
885
+ lab = lab[None, :, :, :] # [1, D, H, W]
886
+ elif lab.ndim > 3:
887
+ lab = np.transpose(lab, (3, 0, 1, 2)) # [C, D, H, W]
888
+
889
+ # --- Full-res tensor ---
890
+ fullres_t = torch.tensor(
891
+ lab[None, ...], dtype=torch.float32, device=self.device,
892
+ ) # [1, C, D, H, W]
893
+
894
+ # --- Resize to model resolution (nearest-neighbor) ---
895
+ target_sz = [self.img_size] * self.ndims
896
+ # Resize each channel separately to avoid resizing the channel dim
897
+ channels = []
898
+ for c in range(lab.shape[0]):
899
+ ch = sk_resize(
900
+ lab[c], target_sz,
901
+ anti_aliasing=False, preserve_range=True, order=0,
902
+ )
903
+ channels.append(ch)
904
+ lab_model = np.stack(channels, axis=0) # [C, S, S, S]
905
+ model_t = torch.tensor(
906
+ lab_model[None, ...], dtype=torch.float32, device=self.device,
907
+ ) # [1, C, S, S, S]
908
+
909
+ return model_t, fullres_t
910
+
911
+ def _standardize_txt(self, txt) -> Optional[torch.Tensor]:
912
+ """Convert text input to [B, 1024] tensor."""
913
+ if txt is None:
914
+ return None
915
+ if isinstance(txt, str):
916
+ self._ensure_bert()
917
+ from Dataloader.bert_helper import str2emb
918
+ emb = str2emb(
919
+ txt, max_words_num=100,
920
+ embeder=self._bert_model, tokenizer=self._bert_tokenizer,
921
+ reduce_method="mean",
922
+ )
923
+ return emb.to(self.device) # [1, 1024]
924
+ if isinstance(txt, np.ndarray):
925
+ t = torch.tensor(txt, dtype=torch.float32, device=self.device)
926
+ if t.ndim == 1:
927
+ t = t.unsqueeze(0)
928
+ return t
929
+ if isinstance(txt, torch.Tensor):
930
+ t = txt.float().to(self.device)
931
+ if t.ndim == 1:
932
+ t = t.unsqueeze(0)
933
+ return t
934
+ raise TypeError(f"Unsupported text type: {type(txt)}")
935
+
936
+ def _ensure_bert(self):
937
+ if self._bert_model is None:
938
+ from Dataloader.bert_helper import get_frozen_embeder
939
+ self._bert_model, self._bert_tokenizer = get_frozen_embeder(self.bert_model_path)
940
+
941
+ # ------------------------------------------------------------------
942
+ # Private — Spatial utilities
943
+ # ------------------------------------------------------------------
944
+
945
+ @staticmethod
946
+ def _reverse_axis_order(arr: np.ndarray) -> np.ndarray:
947
+ """SimpleITK → NumPy axis order."""
948
+ return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
949
+
950
+ @staticmethod
951
+ def _center_pad_to_cube(volume: np.ndarray) -> np.ndarray:
952
+ """Pad volume to a cube using the max dimension, with symmetric padding."""
953
+ max_dim = max(volume.shape[:3])
954
+ pad_width = []
955
+ for s in volume.shape[:3]:
956
+ total_pad = max_dim - s
957
+ pad_before = total_pad // 2
958
+ pad_after = total_pad - pad_before
959
+ pad_width.append((pad_before, pad_after))
960
+ for _ in range(volume.ndim - 3):
961
+ pad_width.append((0, 0))
962
+ return np.pad(volume, pad_width, mode="constant", constant_values=0)
963
+
964
+ def _apply_ddf(
965
+ self,
966
+ volume_tensor: torch.Tensor,
967
+ ddf: torch.Tensor,
968
+ padding_mode: str = "border",
969
+ resample_mode: str = "bilinear",
970
+ ) -> torch.Tensor:
971
+ """Apply DDF to volume tensor at any resolution via grid_sample."""
972
+ device = ddf.device
973
+ ndims = self.ndims
974
+ img_sz = list(volume_tensor.shape[2:])
975
+ max_sz = torch.reshape(
976
+ torch.tensor(img_sz, dtype=torch.float32, device=device),
977
+ [1, ndims] + [1] * ndims,
978
+ )
979
+ ref_grid = torch.reshape(
980
+ torch.stack(
981
+ torch.meshgrid(
982
+ [torch.arange(s, device=device, dtype=torch.float32) for s in img_sz],
983
+ indexing="ij",
984
+ ),
985
+ 0,
986
+ ),
987
+ [1, ndims] + img_sz,
988
+ )
989
+ img_shape = torch.reshape(
990
+ torch.tensor(
991
+ [(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device,
992
+ ),
993
+ [1] + [1] * ndims + [ndims],
994
+ )
995
+ grid = torch.flip(
996
+ (ddf * max_sz + ref_grid).permute(
997
+ [0] + list(range(2, 2 + ndims)) + [1]
998
+ )
999
+ / img_shape
1000
+ - 1,
1001
+ dims=[-1],
1002
+ )
1003
+ return F.grid_sample(
1004
+ volume_tensor.to(device),
1005
+ grid.float(),
1006
+ mode=resample_mode,
1007
+ padding_mode=padding_mode,
1008
+ align_corners=True,
1009
+ )
1010
+
1011
+ def _ensure_tensor(self, img) -> torch.Tensor:
1012
+ """Convert numpy/torch input to a [B, C, ...] float tensor on device."""
1013
+ if isinstance(img, np.ndarray):
1014
+ t = torch.tensor(img, dtype=torch.float32, device=self.device)
1015
+ elif isinstance(img, torch.Tensor):
1016
+ t = img.float().to(self.device)
1017
+ else:
1018
+ raise TypeError(f"Unsupported image type: {type(img)}")
1019
+ if t.ndim == self.ndims: # spatial only → [B=1, C=1, ...]
1020
+ t = t[None, None, ...]
1021
+ elif t.ndim == self.ndims + 1: # [C, ...] → [B=1, C, ...]
1022
+ t = t[None, ...]
1023
+ return t
1024
+
1025
+ def _to_ddf_tensor(self, ddf) -> torch.Tensor:
1026
+ """Convert ddf input to proper tensor on device."""
1027
+ if isinstance(ddf, np.ndarray):
1028
+ ddf = torch.tensor(ddf, dtype=torch.float32)
1029
+ ddf = ddf.float().to(self.device)
1030
+ if ddf.ndim == self.ndims + 1:
1031
+ ddf = ddf.unsqueeze(0)
1032
+ # Resize to model resolution if needed
1033
+ model_sz = [self.img_size] * self.ndims
1034
+ if list(ddf.shape[2:]) != model_sz:
1035
+ ddf = F.interpolate(
1036
+ ddf, size=model_sz,
1037
+ mode="bilinear" if self.ndims == 2 else "trilinear",
1038
+ align_corners=False,
1039
+ )
1040
+ return ddf
1041
+
1042
+ # ------------------------------------------------------------------
1043
+ # Convenience / repr
1044
+ # ------------------------------------------------------------------
1045
+
1046
+ def __repr__(self) -> str:
1047
+ status_parts = []
1048
+ if self._init_img is not None:
1049
+ status_parts.append(f"init_img={list(self._init_img.shape)}")
1050
+ if self._cond_img is not None:
1051
+ status_parts.append(f"cond_img={list(self._cond_img.shape)}")
1052
+ if self._predicted_ddf is not None:
1053
+ status_parts.append(f"predicted_ddf={list(self._predicted_ddf.shape)}")
1054
+ status = ", ".join(status_parts) if status_parts else "empty"
1055
+ return (
1056
+ f"OMorpher(net={self.net_name}, ndims={self.ndims}, "
1057
+ f"img_size={self.img_size}, device={self.device}, {status})"
1058
+ )
README.md CHANGED
@@ -1,80 +1,129 @@
1
- # OmniMorph: Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on conditional Deformation-Recovery Diffusion Model
2
-
3
- ## Links
4
-
5
- - **Google Drive**: [Dataset & Resources](https://drive.google.com/drive/folders/1N72SeYKwnaMmFq9_NqqEXxZ1jUcw2SwG?usp=drive_link)
6
- - **Notion**: [Dataset Documentation](https://www.notion.so/Dataset-2bc2300266fe48dfafef580dacf16d50?pvs=4)
7
- - **Overleaf**: [Paper Draft](https://www.overleaf.com/4489753418kstfhwsxgtkw#a0dbad)
8
- - **Discord**: [Channel Invite](https://discord.gg/6HrD29T2)
9
- - **GitHub Repository**: `/home/data/Github/OmniMorph`
10
-
11
- ## Environments
12
-
13
- ### Data Processing
14
- - Library: **SimpleITK**
15
- - Environment:
16
- ```bash
17
- conda activate torch
18
- conda deactivate
19
- ```
20
-
21
- ### Diffusion Model / DataEngineer (with BERT)
22
- > Note: 暂不更新,等 MIA 审稿
23
-
24
- ```bash
25
- source /home/data/jzheng/Adaptive_Motion_Generator-master/pipenv/bin/activate
26
- deactivate
27
- ```
28
-
29
- Or:
30
- ```bash
31
- source /home/data/Github/OmniMorph/ominenv/bin/activate
32
- ```
33
-
34
- ### nnUNet
35
- ```bash
36
- source ~/PycharmProjects/pythonProject/venv/bin/activate
37
- ```
38
-
39
- ### Masking CUDA
40
- ```bash
41
- CUDA_VISIBLE_DEVICES=0,1,3 python ...
42
- ```
43
-
44
- ## Rental Server (租赁服务器)
45
-
46
- ```bash
47
- ssh -p 49419 root@i-2.gpushare.com
48
- # Password: aFwd98tamsHPtDDhWzUqvXfTagUqfNg8
49
- ```
50
-
51
- SSH Config:
52
- ```
53
- Host gpushare
54
- HostName i-2.gpushare.com
55
- User root
56
- Port 49419
57
- ```
58
-
59
- Conda environments on server:
60
- ```bash
61
- conda activate OM
62
- conda activate unigrad
63
- ```
64
-
65
- Data path: `/hy-tmp`
66
-
67
- ## Data Paths
68
-
69
- | Item | Path |
70
- |------|------|
71
- | Dataset | `/home/data/Github/data/data_gen_def/DATASETS` |
72
- | Processed Data | `/home/data/Github/data/data_gen_def/DATASETS_processed` |
73
- | Data Processing Template | `/home/data/jzheng/Data_Engineering/dataclean_TotSeg.py` |
74
-
75
- ## Related Documentation
76
-
77
- 1. **DataEngineer**:
78
- - `/home/data/jzheng/Data_Engineering/README.md`
79
- - `/home/data/jzheng/data_process`
80
- 2. **OmniMorph**: `/home/data/Github/OmniMorph/README.md`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - medical-imaging
5
+ - registration
6
+ - diffusion
7
+ - 3d
8
+ - image-generation
9
+ - image-restoration
10
+ - pytorch
11
+ library_name: pytorch
12
+ ---
13
+
14
+ # OmniMorph
15
+
16
+ **Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).**
17
+
18
+ OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:
19
+
20
+ - **Generation** — text-conditioned image synthesis via BERT embeddings.
21
+ - **Restoration** recover anatomically plausible images from degraded inputs.
22
+ - **Registration** paired / unpaired / flexible-resolution registration via diffused deformation vector fields.
23
+
24
+ ## Repository Contents
25
+
26
+ | Path | Description |
27
+ |---|---|
28
+ | `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) |
29
+ | `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts |
30
+ | `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils |
31
+ | `OMorpher/` | Higher-level model wrapper |
32
+ | `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) |
33
+ | `Config/` | YAML training/inference configs |
34
+ | `Scripts/` | Auxiliary scripts (registration, evaluation) |
35
+ | `tests/` | Pytest suite for `OMorpher` and loss functions |
36
+ | `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
37
+ | `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint (epoch 110, multi-modal `recmulmodmutattnnet`) |
38
+
39
+ > **Note** Only the final checkpoint (epoch 110) is shipped here. Earlier epochs and the `bert_large_uncased` weights are not bundled — download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
40
+
41
+ ## Setup
42
+
43
+ ```bash
44
+ git clone https://huggingface.co/DRDMsig/Omini3D
45
+ cd Omini3D
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements.
50
+
51
+ ## Quick Start
52
+
53
+ ### Training
54
+
55
+ ```bash
56
+ # Single-mode diffusion
57
+ CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
58
+
59
+ # Dual mode (diffusion + registration)
60
+ CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
61
+
62
+ # Triple mode (diffusion + contrastive + registration)
63
+ CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml
64
+
65
+ # Intel XPU (single node)
66
+ sbatch bash_train_single_node.sh
67
+ ```
68
+
69
+ ### Inference
70
+
71
+ ```bash
72
+ # Augmentation / restoration with a trained model
73
+ python OM_aug.py -C Config/config_om.yaml
74
+
75
+ # Paired registration
76
+ python OM_reg.py -C Config/config_om.yaml
77
+
78
+ # Flexible-resolution registration
79
+ python OM_reg_flexres.py -C Config/config_om.yaml
80
+ ```
81
+
82
+ ### Loading the checkpoint
83
+
84
+ ```python
85
+ import torch
86
+ from Diffusion.networks import get_net
87
+
88
+ # Production network (multi-modal recmutattnnet)
89
+ net = get_net("recmulmodmutattnnet")
90
+ state = torch.load("Models/all_om_net/000110_all_om_net.pth", map_location="cpu")
91
+ net.load_state_dict(state["model"] if "model" in state else state)
92
+ net.eval()
93
+ ```
94
+
95
+ ## Architecture
96
+
97
+ ```
98
+ Config YAML → DataLoader(s) → DeformDDPM(Network, STN) → Loss → Checkpoint
99
+ ```
100
+
101
+ - **`DeformDDPM`** (`Diffusion/diffuser.py`) — forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`.
102
+ - **Networks** (`Diffusion/networks.py`) — selectable via `get_net(name)`:
103
+ - `recmulmodmutattnnet` — current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`)
104
+ - `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet`
105
+ - **`STN`** — Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`.
106
+ - **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) — `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`.
107
+
108
+ ## Datasets Supported
109
+
110
+ `Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including:
111
+ AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.
112
+
113
+ The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths.
114
+
115
+ ## Citation
116
+
117
+ ```bibtex
118
+ @article{omnimorph,
119
+ title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
120
+ Restoration and Registration via Conditional Deformation-Recovery
121
+ Diffusion Models},
122
+ author = {Zheng, J. and Mo, M. and others},
123
+ year = {2025}
124
+ }
125
+ ```
126
+
127
+ ## License
128
+
129
+ MIT — see `LICENSE`.
Scripts/OM_aug_om.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OM_aug_om.py — Augmentation using OMorpher.
3
+
4
+ Drop-in replacement for OM_aug.py. Produces identical outputs but uses
5
+ OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
6
+
7
+ Usage:
8
+ python Scripts/OM_aug_om.py -C Config/config_om.yaml
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import argparse
14
+
15
+ # Add project root to path so imports work from Scripts/
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import nibabel as nib
21
+ import yaml
22
+ from tqdm import tqdm
23
+
24
+ import utils
25
+ from Dataloader.dataLoader import OminiDataset_inference_w_all
26
+ from torch.utils.data import DataLoader
27
+ from OMorpher import OMorpher
28
+
29
+ # ========== CLI ==========
30
+
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument(
33
+ "--config", "-C",
34
+ help="Path for the config file",
35
+ type=str,
36
+ default="Config/config_cmr.yaml",
37
+ required=False,
38
+ )
39
+ args = parser.parse_args()
40
+
41
+ # ========== Config ==========
42
+
43
+ with open(args.config, "r") as file:
44
+ hyp_parameters = yaml.safe_load(file)
45
+ print(hyp_parameters)
46
+
47
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
48
+ os.makedirs(hyp_parameters["aug_img_savepath"])
49
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
50
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
51
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
52
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
53
+ print(hyp_parameters["aug_img_savepath"])
54
+
55
+ hyp_parameters["batchsize"] = 1
56
+
57
+ # ========== Dataset (identical to OM_aug.py) ==========
58
+
59
+ select_channels_dict = {}
60
+ min_crop_ratio = 0.9
61
+
62
+ label_keys = ["heart"]
63
+ database = ["MnMs"]
64
+ subtype = "es"
65
+ hyp_parameters["aug_img_savepath"] = f"Data/Aug_data/mnms_{subtype}/img/"
66
+ hyp_parameters["aug_msk_savepath"] = f"Data/Aug_data/mnms_{subtype}/msk/"
67
+ hyp_parameters["aug_ddf_savepath"] = f"Data/Aug_data/mnms_{subtype}/ddf/"
68
+ select_channels_dict = {"ImgDict": [subtype]}
69
+
70
+ dataset = OminiDataset_inference_w_all(
71
+ transform=None,
72
+ min_crop_ratio=min_crop_ratio,
73
+ label_key=label_keys,
74
+ database=database,
75
+ select_channels_dict=select_channels_dict,
76
+ )
77
+ Infer_Loader = DataLoader(
78
+ dataset,
79
+ batch_size=hyp_parameters["batchsize"],
80
+ shuffle=False,
81
+ )
82
+
83
+ # ========== OMorpher setup ==========
84
+
85
+ epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
86
+ model_save_path = os.path.join(
87
+ f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
88
+ str(epoch) + ".pth",
89
+ )
90
+ print("Loading model from:", model_save_path)
91
+
92
+ om = OMorpher(
93
+ config=hyp_parameters,
94
+ checkpoint_path=model_save_path,
95
+ device=str(hyp_parameters.get("device", "cpu")),
96
+ )
97
+ print(om)
98
+
99
+ # ========== Output directories ==========
100
+
101
+ os.makedirs(hyp_parameters["aug_img_savepath"], exist_ok=True)
102
+ os.makedirs(hyp_parameters["aug_msk_savepath"], exist_ok=True)
103
+ os.makedirs(hyp_parameters["aug_ddf_savepath"], exist_ok=True)
104
+
105
+ # ========== Main inference loop ==========
106
+
107
+ device = om.device
108
+ print("total num of image:", len(Infer_Loader))
109
+
110
+ for e, d in tqdm(enumerate(Infer_Loader)):
111
+ img = d["img"]
112
+ mask = d["labels"]
113
+ label_str = str(d["label_channels"])
114
+ pid = e
115
+
116
+ print("Processing to patient:", pid, " image:", e)
117
+
118
+ img = img.type(torch.float32).to(device)
119
+ image_original = img.cpu().detach().numpy()
120
+
121
+ mask = mask.type(torch.float32).to(device)
122
+ mask_original = mask.cpu().detach().numpy()
123
+
124
+ # Save original image and mask
125
+ nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
126
+ nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
127
+
128
+ nib.save(
129
+ nifti_img,
130
+ os.path.join(
131
+ hyp_parameters["aug_img_savepath"],
132
+ utils.get_barcode([pid, e]) + ".nii.gz",
133
+ ),
134
+ )
135
+ nib.save(
136
+ nifti_mask,
137
+ os.path.join(
138
+ hyp_parameters["aug_msk_savepath"],
139
+ utils.get_barcode([pid, e]) + "_GT.nii.gz",
140
+ ),
141
+ )
142
+
143
+ # Augmentation loop
144
+ noise_step = hyp_parameters["start_noise_step"]
145
+ with torch.no_grad():
146
+ for im in range(hyp_parameters["aug_coe"]):
147
+ print(
148
+ f"Generating -> Subject-{pid}, Scan-{e} "
149
+ f'({im}/{hyp_parameters["aug_coe"]})',
150
+ end="\r",
151
+ )
152
+
153
+ # 1. Set init image (DataLoader tensor passes through)
154
+ om.set_init_img(img)
155
+
156
+ # 2. Self-conditioning (matches: cond_imgs = img_org.clone().detach())
157
+ om.set_cond_img(img)
158
+
159
+ # 3. Forward diffuse to get noisy image + random DDF
160
+ t_start = torch.tensor(np.array([noise_step]), device=device)
161
+ img_diff, _, ddf_rand = om._get_random_ddf(om._init_img, t_start)
162
+
163
+ # 4. Get noisy mask
164
+ msk_diff = om.apply_def(
165
+ img=mask, ddf=ddf_rand,
166
+ padding_mode="zeros", resample_mode="nearest",
167
+ )
168
+
169
+ # 5. Set random DDF as initial DDF
170
+ om.set_init_def(ddf=ddf_rand.clone().detach())
171
+
172
+ # 6. Run reverse diffusion
173
+ om.predict(
174
+ T=[noise_step, hyp_parameters["timesteps"]],
175
+ proc_type=hyp_parameters["condition_type"],
176
+ )
177
+
178
+ # 7. Get recovered outputs
179
+ ddf_comp = om.get_def()
180
+ img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
181
+ msk_rec = om.apply_def(
182
+ img=mask, ddf=ddf_comp,
183
+ padding_mode="zeros", resample_mode="nearest",
184
+ )
185
+
186
+ # Convert to numpy for saving
187
+ denoise_imgs = img_rec.cpu().detach().numpy()
188
+ denoise_msks = msk_rec.cpu().detach().numpy()
189
+ noisy_imgs_np = img_diff.cpu().detach().numpy()
190
+ noisy_msks_np = msk_diff.cpu().detach().numpy()
191
+
192
+ # Save augmented (recovered) outputs
193
+ nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
194
+ nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
195
+ nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
196
+ nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
197
+
198
+ nib.save(
199
+ nifti_img_aug,
200
+ os.path.join(
201
+ hyp_parameters["aug_img_savepath"],
202
+ utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
203
+ ),
204
+ )
205
+ nib.save(
206
+ nifti_mask_aug,
207
+ os.path.join(
208
+ hyp_parameters["aug_msk_savepath"],
209
+ utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
210
+ ),
211
+ )
212
+
213
+ # Save noisy image/mask
214
+ nib.save(
215
+ nifti_img,
216
+ os.path.join(
217
+ hyp_parameters["aug_img_savepath"],
218
+ utils.get_barcode(
219
+ [pid, e, im, noise_step],
220
+ header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
221
+ ) + ".nii.gz",
222
+ ),
223
+ )
224
+ nib.save(
225
+ nifti_mask,
226
+ os.path.join(
227
+ hyp_parameters["aug_msk_savepath"],
228
+ utils.get_barcode(
229
+ [pid, e, im, noise_step],
230
+ header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
231
+ ) + "_GT.nii.gz",
232
+ ),
233
+ )
234
+
235
+ if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
236
+ noise_step = noise_step + hyp_parameters["noise_step"]
237
+
238
+ if e >= 0:
239
+ exit()
Scripts/OM_reg_flexres_om.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OM_reg_flexres_om.py — Full-resolution registration using OMorpher.
3
+
4
+ Drop-in replacement for OM_reg_flexres.py. Produces identical outputs but
5
+ uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
6
+
7
+ Usage:
8
+ python Scripts/OM_reg_flexres_om.py -C Config/config_om.yaml
9
+ """
10
+
11
+ import os
12
+ import sys
13
+ import argparse
14
+
15
+ # Add project root to path so imports work from Scripts/
16
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ import nibabel as nib
22
+ import yaml
23
+ import SimpleITK as sitk
24
+ from tqdm import tqdm
25
+
26
+ import utils
27
+ from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
28
+ from OMorpher import OMorpher
29
+
30
+ # ========== CLI ==========
31
+
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument(
34
+ "--config", "-C",
35
+ help="Path for the config file",
36
+ type=str,
37
+ default="Config/config_om.yaml",
38
+ required=False,
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ # ========== Config ==========
43
+
44
+ with open(args.config, "r") as file:
45
+ hyp_parameters = yaml.safe_load(file)
46
+ print(hyp_parameters)
47
+
48
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
49
+ os.makedirs(hyp_parameters["aug_img_savepath"])
50
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
51
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
52
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
53
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
54
+ print(hyp_parameters["aug_img_savepath"])
55
+
56
+ hyp_parameters["batchsize"] = 1
57
+ model_img_sz = hyp_parameters["img_size"]
58
+
59
+ # ========== Dataset (unchanged — used only for filtering/metadata) ==========
60
+
61
+ label_keys = ["brain"]
62
+ database = ["Brats2019"]
63
+
64
+ dataset = OminiDataset_inference_w_all(
65
+ transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
66
+ )
67
+
68
+ # ========== OMorpher setup ==========
69
+
70
+ epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
71
+ model_save_path = os.path.join(
72
+ f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
73
+ str(epoch) + ".pth",
74
+ )
75
+ print("Loading model from:", model_save_path)
76
+
77
+ om = OMorpher(
78
+ config=hyp_parameters,
79
+ checkpoint_path=model_save_path,
80
+ device=str(hyp_parameters.get("device", "cpu")),
81
+ )
82
+ print(om)
83
+
84
+ # ========== Output directories ==========
85
+
86
+ reg_img_savepath_fullres = hyp_parameters["reg_img_savepath"].rstrip("/") + "_fullres/"
87
+ reg_msk_savepath_fullres = hyp_parameters["reg_msk_savepath"].rstrip("/") + "_fullres/"
88
+ reg_ddf_savepath_fullres = hyp_parameters["reg_ddf_savepath"].rstrip("/") + "_fullres/"
89
+
90
+ for p in [
91
+ hyp_parameters["reg_img_savepath"],
92
+ hyp_parameters["reg_msk_savepath"],
93
+ hyp_parameters["reg_ddf_savepath"],
94
+ reg_img_savepath_fullres,
95
+ reg_msk_savepath_fullres,
96
+ reg_ddf_savepath_fullres,
97
+ ]:
98
+ os.makedirs(p, exist_ok=True)
99
+
100
+
101
+ # ========== Helper: load full-res data (same as original) ==========
102
+
103
+ def center_pad_to_cube(volume):
104
+ """Pad volume to a cube using the max dimension, with symmetric (center) padding."""
105
+ max_dim = max(volume.shape[:3])
106
+ pad_width = []
107
+ for s in volume.shape[:3]:
108
+ total_pad = max_dim - s
109
+ pad_before = total_pad // 2
110
+ pad_after = total_pad - pad_before
111
+ pad_width.append((pad_before, pad_after))
112
+ for _ in range(volume.ndim - 3):
113
+ pad_width.append((0, 0))
114
+ return np.pad(volume, pad_width, mode="constant", constant_values=0)
115
+
116
+
117
+ def load_fullres_volume(key, ds):
118
+ """Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
119
+ volume = sitk.ReadImage(key)
120
+ volume = sitk.GetArrayFromImage(volume)
121
+ volume = reverse_axis_order(volume)
122
+ if volume.ndim == 4:
123
+ channel_ids = ds.get_channel_ids(key)
124
+ channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
125
+ volume = volume[:, :, :, channel_id]
126
+ if ds.clamp_range is not None:
127
+ modality = ds.ALLdata_filtered[key].get("Modality", None)
128
+ if modality == "CT":
129
+ volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
130
+ volume = ds.normalize(volume)
131
+ volume = center_pad_to_cube(volume)
132
+ return volume
133
+
134
+
135
+ def load_fullres_label(key, ds, label_key):
136
+ """Load original-resolution label: axis reorder, center-pad to cube."""
137
+ label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
138
+ task_labels = label_path_dict.get("segmentation", {})
139
+ if label_key not in task_labels:
140
+ return None
141
+ label = sitk.ReadImage(task_labels[label_key])
142
+ label = sitk.GetArrayFromImage(label)
143
+ label = reverse_axis_order(label)
144
+ if label.ndim > 3:
145
+ channel_ids = ds.get_channel_ids(key)
146
+ if len(channel_ids) != 0:
147
+ label = label[..., channel_ids]
148
+ label = center_pad_to_cube(label)
149
+ return label
150
+
151
+
152
+ # ========== Main inference loop ==========
153
+
154
+ keys = list(dataset.ALLdata_filtered.keys())
155
+ print("total num of images:", len(keys))
156
+ device = om.device
157
+
158
+ for e, key in enumerate(tqdm(keys)):
159
+ pid = e
160
+ print(f"Processing patient {pid}, image {e}, key: {key}")
161
+
162
+ # --- Load & standardize volume via OMorpher ---
163
+ fullres_vol = load_fullres_volume(key, dataset)
164
+ om.set_init_img(fullres_vol)
165
+ img = om._init_img # [1, 1, model_sz, model_sz, model_sz]
166
+ fullres_img_tensor = om._init_img_raw # [1, 1, D, H, W] full-res tensor
167
+ orig_sz = list(fullres_img_tensor.shape[2:])
168
+ print(f" Full-res padded shape: {orig_sz}")
169
+
170
+ # --- Load & standardize labels via OMorpher ---
171
+ masks_model = []
172
+ masks_fullres = []
173
+ for lk in label_keys:
174
+ lab = load_fullres_label(key, dataset, lk)
175
+ model_t, fullres_t = om._standardize_label(lab) # None → -1 placeholder
176
+ masks_model.append(model_t)
177
+ masks_fullres.append(fullres_t)
178
+
179
+ if masks_model:
180
+ mask = torch.cat(masks_model, dim=1) # [1, C_total, S, S, S]
181
+ fullres_msk_tensor = torch.cat(masks_fullres, dim=1) # [1, C_total, D, H, W]
182
+ else:
183
+ mask = None
184
+ fullres_msk_tensor = None
185
+
186
+ # --- Save target conditioning image (first subject) ---
187
+ if e <= 0:
188
+ target_img = img.clone().detach()
189
+
190
+ # --- Save original images at model resolution ---
191
+ image_original = img.cpu().numpy()
192
+ nib.save(
193
+ utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
194
+ os.path.join(hyp_parameters["reg_img_savepath"],
195
+ utils.get_barcode([pid, e]) + ".nii.gz"),
196
+ )
197
+ if mask is not None:
198
+ mask_original = mask.cpu().numpy()
199
+ nib.save(
200
+ utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
201
+ os.path.join(hyp_parameters["reg_msk_savepath"],
202
+ utils.get_barcode([pid, e]) + "_GT.nii.gz"),
203
+ )
204
+
205
+ # --- Save original at full-res ---
206
+ nib.save(
207
+ utils.converet_to_nibabel(fullres_img_tensor, ndims=hyp_parameters["ndims"]),
208
+ os.path.join(reg_img_savepath_fullres,
209
+ utils.get_barcode([pid, e]) + ".nii.gz"),
210
+ )
211
+ if fullres_msk_tensor is not None:
212
+ nib.save(
213
+ utils.converet_to_nibabel(fullres_msk_tensor, ndims=hyp_parameters["ndims"]),
214
+ os.path.join(reg_msk_savepath_fullres,
215
+ utils.get_barcode([pid, e]) + "_GT.nii.gz"),
216
+ )
217
+
218
+ # --- Diffusion recovery via OMorpher ---
219
+ noise_step = hyp_parameters["start_noise_step"]
220
+ with torch.no_grad():
221
+ for im in range(1):
222
+ print(
223
+ f" Generating -> Subject-{pid}, Scan-{e} "
224
+ f'({im}/{hyp_parameters["aug_coe"]})',
225
+ end="\r",
226
+ )
227
+
228
+ # Set up OMorpher inputs
229
+ om.set_init_img(img)
230
+ om.set_cond_img(target_img.clone().detach())
231
+
232
+ # Run diffusion recovery
233
+ # T=[None, timesteps] in original means: no initial noise, full reverse diffusion
234
+ om.predict(
235
+ T=[None, hyp_parameters["timesteps"]],
236
+ proc_type=hyp_parameters["condition_type"],
237
+ )
238
+
239
+ ddf_comp = om.get_def()
240
+
241
+ # Reconstruct images at model resolution using OMorpher
242
+ img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
243
+
244
+ # --- Save model-resolution results ---
245
+ denoise_imgs = img_rec.cpu().numpy()
246
+
247
+ nib.save(
248
+ utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
249
+ os.path.join(
250
+ hyp_parameters["reg_img_savepath"],
251
+ utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
252
+ ),
253
+ )
254
+
255
+ if mask is not None:
256
+ msk_rec = om.apply_def(
257
+ img=mask, ddf=ddf_comp,
258
+ padding_mode="zeros", resample_mode="nearest",
259
+ )
260
+ denoise_msks = msk_rec.cpu().numpy()
261
+ nib.save(
262
+ utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
263
+ os.path.join(
264
+ hyp_parameters["reg_msk_savepath"],
265
+ utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
266
+ ),
267
+ )
268
+
269
+ # --- Upscale DDF and apply at full resolution via OMorpher ---
270
+ img_rec_fullres = om.apply_def(
271
+ img=fullres_img_tensor, ddf=ddf_comp, padding_mode="border",
272
+ )
273
+
274
+ if fullres_msk_tensor is not None:
275
+ msk_rec_fullres = om.apply_def(
276
+ img=fullres_msk_tensor, ddf=ddf_comp,
277
+ padding_mode="zeros", resample_mode="nearest",
278
+ )
279
+
280
+ # Upscale DDF for saving
281
+ ddf_fullres = F.interpolate(
282
+ ddf_comp, size=orig_sz, mode="trilinear", align_corners=False,
283
+ )
284
+
285
+ # --- Save full-res results ---
286
+ nib.save(
287
+ utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
288
+ os.path.join(
289
+ reg_img_savepath_fullres,
290
+ utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
291
+ ),
292
+ )
293
+
294
+ if fullres_msk_tensor is not None:
295
+ nib.save(
296
+ utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
297
+ os.path.join(
298
+ reg_msk_savepath_fullres,
299
+ utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
300
+ ),
301
+ )
302
+
303
+ nib.save(
304
+ utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
305
+ os.path.join(
306
+ reg_ddf_savepath_fullres,
307
+ utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
308
+ ),
309
+ )
310
+
311
+ if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
312
+ noise_step = noise_step + hyp_parameters["noise_step"]
313
+
314
+ if e > 5:
315
+ break
Scripts/OM_reg_pair_ext.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OM_reg_pair.py — Paired registration using OMorpher with external dataset.
3
+
4
+ Loads fixed/moving pairs from a Learn2Reg-style JSON dataset file
5
+ (e.g. HippocampusMR_dataset.json) and registers each moving image to its
6
+ paired fixed image. Saves registered images, masks, DDFs, source originals,
7
+ and evaluation metrics (DSC, ASD, HD) per organ label.
8
+
9
+ Usage:
10
+ python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
11
+ --dataset-json /path/to/HippocampusMR_dataset.json \
12
+ --split val
13
+
14
+ python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
15
+ --dataset-json /path/to/HippocampusMR_dataset.json \
16
+ --split test -N 10
17
+ """
18
+
19
+ import os
20
+ import sys
21
+
22
+ # Add project root to path so imports work from Scripts/
23
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+
25
+ import csv
26
+ import json
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import nibabel as nib
31
+ import yaml
32
+ import SimpleITK as sitk
33
+ from scipy.ndimage import distance_transform_edt, binary_erosion
34
+ from tqdm import tqdm
35
+
36
+ import utils
37
+ from Dataloader.dataLoader import reverse_axis_order
38
+ from OMorpher import OMorpher
39
+
40
+ # ========== CLI ==========
41
+
42
+ import argparse
43
+
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--config", "-C",
47
+ help="Path for the config file",
48
+ type=str,
49
+ default="Config/config_om.yaml",
50
+ required=False,
51
+ )
52
+ parser.add_argument(
53
+ "--dataset-json",
54
+ help="Path to the Learn2Reg-style dataset JSON",
55
+ type=str,
56
+ default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/HippocampusMR/HippocampusMR_dataset.json",
57
+ )
58
+ parser.add_argument(
59
+ "--split",
60
+ help="Which registration split to use: 'val' or 'test'",
61
+ type=str,
62
+ choices=["val", "test"],
63
+ default="val",
64
+ )
65
+ parser.add_argument(
66
+ "--max-samples", "-N",
67
+ help="Max number of pairs to register (0 = all)",
68
+ type=int,
69
+ default=0,
70
+ )
71
+ args = parser.parse_args()
72
+
73
+ # ========== Config ==========
74
+
75
+ with open(args.config, "r") as file:
76
+ hyp_parameters = yaml.safe_load(file)
77
+ print(hyp_parameters)
78
+
79
+ hyp_parameters["batchsize"] = 1
80
+ model_img_sz = hyp_parameters["img_size"]
81
+ timesteps = hyp_parameters["timesteps"]
82
+ condition_type = hyp_parameters["condition_type"]
83
+ ndims = hyp_parameters["ndims"]
84
+
85
+ # ========== Load external dataset JSON ==========
86
+
87
+ dataset_json_path = os.path.expanduser(args.dataset_json)
88
+ dataset_root = os.path.dirname(dataset_json_path)
89
+
90
+ with open(dataset_json_path, "r") as f:
91
+ dataset_meta = json.load(f)
92
+
93
+ dataset_name = dataset_meta.get("name", "UnknownDataset")
94
+ print(f"Dataset: {dataset_name}")
95
+
96
+ # Select registration split
97
+ if args.split == "val":
98
+ pairs = dataset_meta.get("registration_val", [])
99
+ elif args.split == "test":
100
+ pairs = dataset_meta.get("registration_test", [])
101
+ else:
102
+ raise ValueError(f"Unknown split: {args.split}")
103
+
104
+ if args.max_samples > 0:
105
+ pairs = pairs[: args.max_samples]
106
+
107
+ print(f"Split: {args.split}, Pairs: {len(pairs)}")
108
+
109
+ # Build label lookup: image basename -> label relative path
110
+ # from the "training" entries in the JSON
111
+ _label_lookup = {}
112
+ for entry in dataset_meta.get("training", []):
113
+ img_base = os.path.basename(entry["image"])
114
+ _label_lookup[img_base] = entry.get("label")
115
+
116
+ # Label class names (from JSON: "0": "background", "1": "head", "2": "tail")
117
+ _label_names = dataset_meta.get("labels", {}).get("0", {})
118
+ # Organ labels are all non-background classes
119
+ organ_label_ids = {int(k): v for k, v in _label_names.items() if int(k) > 0}
120
+ print(f"Organ labels for evaluation: {organ_label_ids}")
121
+
122
+ # ========== OMorpher setup ==========
123
+
124
+ epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
125
+ model_save_path = os.path.join(
126
+ f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
127
+ str(epoch) + ".pth",
128
+ )
129
+ print("Loading model from:", model_save_path)
130
+
131
+ om = OMorpher(
132
+ config=hyp_parameters,
133
+ checkpoint_path=model_save_path,
134
+ device=str(hyp_parameters.get("device", "cpu")),
135
+ )
136
+ print(om)
137
+
138
+ # ========== Output directories ==========
139
+
140
+ reg_img_savepath = hyp_parameters["reg_img_savepath"]
141
+ reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
142
+ reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
143
+
144
+ reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
145
+ reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
146
+ reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
147
+
148
+ eval_dir = os.path.join(reg_img_savepath, "..", "eval")
149
+
150
+ for p in [
151
+ reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
152
+ reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
153
+ eval_dir,
154
+ ]:
155
+ os.makedirs(p, exist_ok=True)
156
+
157
+
158
+ # ========== Helper functions ==========
159
+
160
+
161
+ def resolve_path(rel_path):
162
+ """Resolve a relative path from the dataset JSON to an absolute path."""
163
+ if os.path.isabs(rel_path):
164
+ return rel_path
165
+ return os.path.normpath(os.path.join(dataset_root, rel_path))
166
+
167
+
168
+ def load_volume(nifti_path):
169
+ """Load a NIfTI volume: axis reorder only.
170
+
171
+ OMorpher._standardize_img handles: normalize → pad-to-cube → resize to model res.
172
+ """
173
+ volume = sitk.ReadImage(nifti_path)
174
+ volume = sitk.GetArrayFromImage(volume)
175
+ volume = reverse_axis_order(volume)
176
+ if volume.ndim == 4:
177
+ volume = volume[:, :, :, 0]
178
+ return volume
179
+
180
+
181
+ def load_label(nifti_path):
182
+ """Load a NIfTI label map: axis reorder only.
183
+
184
+ OMorpher._standardize_label handles: pad-to-cube → resize to model res (nearest).
185
+ """
186
+ label = sitk.ReadImage(nifti_path)
187
+ label = sitk.GetArrayFromImage(label)
188
+ label = reverse_axis_order(label)
189
+ if label.ndim > 3:
190
+ label = label[:, :, :, 0]
191
+ return label
192
+
193
+
194
+ def get_label_path_for_image(image_rel_path):
195
+ """Find the label path for an image by looking up the training entries."""
196
+ img_base = os.path.basename(image_rel_path)
197
+ label_rel = _label_lookup.get(img_base)
198
+ if label_rel is None:
199
+ return None
200
+ return resolve_path(label_rel)
201
+
202
+
203
+ def split_label_classes(label_map, class_ids):
204
+ """Split a multi-class label map into per-class binary masks.
205
+
206
+ Returns a dict {class_id: binary_numpy_array}.
207
+ """
208
+ masks = {}
209
+ for cid in class_ids:
210
+ masks[cid] = (label_map == cid).astype(np.float32)
211
+ return masks
212
+
213
+
214
+ def get_volume_name(path):
215
+ """Extract a short name from a NIfTI file path."""
216
+ name = os.path.basename(path)
217
+ for ext in [".nii.gz", ".nii"]:
218
+ if name.endswith(ext):
219
+ name = name[: -len(ext)]
220
+ break
221
+ return name
222
+
223
+
224
+ # ---------- Evaluation metrics ----------
225
+
226
+
227
+ def _surface_distances(pred, gt):
228
+ """Compute directed surface distances between two binary masks."""
229
+ pred_bool = pred > 0.5
230
+ gt_bool = gt > 0.5
231
+
232
+ if not np.any(pred_bool) or not np.any(gt_bool):
233
+ return None, None
234
+
235
+ struct = None
236
+ pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
237
+ gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
238
+
239
+ if not np.any(pred_surface):
240
+ pred_surface = pred_bool
241
+ if not np.any(gt_surface):
242
+ gt_surface = gt_bool
243
+
244
+ dt_gt = distance_transform_edt(~gt_surface)
245
+ dt_pred = distance_transform_edt(~pred_surface)
246
+
247
+ return dt_gt[pred_surface], dt_pred[gt_surface]
248
+
249
+
250
+ def compute_dsc(pred, gt):
251
+ """Dice Similarity Coefficient."""
252
+ pred_bool = pred > 0.5
253
+ gt_bool = gt > 0.5
254
+ intersection = np.sum(pred_bool & gt_bool)
255
+ denom = np.sum(pred_bool) + np.sum(gt_bool)
256
+ if denom == 0:
257
+ return 1.0
258
+ return 2.0 * float(intersection) / float(denom)
259
+
260
+
261
+ def compute_asd(pred, gt):
262
+ """Average (symmetric) Surface Distance."""
263
+ d1, d2 = _surface_distances(pred, gt)
264
+ if d1 is None:
265
+ return float("nan")
266
+ return (np.mean(d1) + np.mean(d2)) / 2.0
267
+
268
+
269
+ def compute_hd(pred, gt):
270
+ """Hausdorff Distance (maximum of directed HDs)."""
271
+ d1, d2 = _surface_distances(pred, gt)
272
+ if d1 is None:
273
+ return float("nan")
274
+ return float(max(np.max(d1), np.max(d2)))
275
+
276
+
277
+ def compute_negdetj_pct(ddf, ndims=3):
278
+ """Percent of voxels with negative Jacobian determinant.
279
+
280
+ Args:
281
+ ddf: displacement field tensor [1, ndims, ...] or numpy array.
282
+ ndims: 2 or 3.
283
+ Returns:
284
+ Percentage of voxels where det(Jacobian) < 0.
285
+ """
286
+ if isinstance(ddf, torch.Tensor):
287
+ ddf = ddf.detach().cpu().numpy()
288
+ # ddf shape: [1, C, ...] or [C, ...]
289
+ if ddf.ndim == ndims + 2:
290
+ ddf = ddf[0] # remove batch dim -> [C, ...]
291
+
292
+ # Compute spatial gradients via finite differences (forward diff, clipped)
293
+ if ndims == 3:
294
+ # ddf: [3, D, H, W]
295
+ # Derivatives along each spatial axis
296
+ dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
297
+ duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
298
+ duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
299
+
300
+ dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
301
+ duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
302
+ duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
303
+
304
+ dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
305
+ duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
306
+ duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
307
+
308
+ # Jacobian = I + du/dx
309
+ j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
310
+ j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
311
+ j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
312
+
313
+ detj = (
314
+ j11 * (j22 * j33 - j23 * j32)
315
+ - j12 * (j21 * j33 - j23 * j31)
316
+ + j13 * (j21 * j32 - j22 * j31)
317
+ )
318
+ elif ndims == 2:
319
+ dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
320
+ duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
321
+
322
+ dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
323
+ duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
324
+
325
+ detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
326
+ else:
327
+ raise ValueError(f"Unsupported ndims={ndims}")
328
+
329
+ n_neg = np.sum(detj < 0)
330
+ n_total = detj.size
331
+ return 100.0 * float(n_neg) / float(n_total)
332
+
333
+
334
+ # ========== Prepare evaluation structures ==========
335
+
336
+ # metrics[class_id][metric_name][pair_idx] = value (post-registration)
337
+ metrics = {
338
+ cid: {"dsc": {}, "asd": {}, "hd": {}}
339
+ for cid in organ_label_ids
340
+ }
341
+ # metrics_pre: same structure but for pre-registration (source vs target, no deformation)
342
+ metrics_pre = {
343
+ cid: {"dsc": {}, "asd": {}, "hd": {}}
344
+ for cid in organ_label_ids
345
+ }
346
+
347
+ # Per-pair DDF quality metric (not per-class)
348
+ negdetj_pct = {} # pair_idx -> percentage of negative Jacobian determinant
349
+
350
+ # Also collect per-pair info for the CSV
351
+ pair_info = [] # list of (pair_idx, fixed_name, moving_name)
352
+
353
+ # ========== Paired registration ==========
354
+
355
+ with torch.no_grad():
356
+ for pair_idx, pair in enumerate(tqdm(pairs, desc="Pairs")):
357
+ fixed_rel = pair["fixed"]
358
+ moving_rel = pair["moving"]
359
+
360
+ fixed_path = resolve_path(fixed_rel)
361
+ moving_path = resolve_path(moving_rel)
362
+
363
+ fixed_name = get_volume_name(fixed_rel)
364
+ moving_name = get_volume_name(moving_rel)
365
+ pair_tag = f"Tgt{pair_idx:04d}_Src{pair_idx:04d}"
366
+
367
+ pair_info.append((pair_idx, fixed_name, moving_name))
368
+ print(f"\n [{pair_idx}] Fixed: {fixed_name}, Moving: {moving_name}")
369
+
370
+ # --- Load volumes ---
371
+ fixed_vol = load_volume(fixed_path)
372
+ moving_vol = load_volume(moving_path)
373
+
374
+ # --- Load labels (if available) ---
375
+ fixed_label_path = get_label_path_for_image(fixed_rel)
376
+ moving_label_path = get_label_path_for_image(moving_rel)
377
+
378
+ fixed_label_map = None
379
+ moving_label_map = None
380
+ if fixed_label_path is not None and os.path.exists(fixed_label_path):
381
+ fixed_label_map = load_label(fixed_label_path)
382
+ if moving_label_path is not None and os.path.exists(moving_label_path):
383
+ moving_label_map = load_label(moving_label_path)
384
+
385
+ # --- Prepare tensors via OMorpher ---
386
+ # Set moving image as init (source to be deformed)
387
+ om.set_init_img(moving_vol)
388
+ src_img_model = om._init_img.clone()
389
+ src_img_fullres = om._init_img_raw.clone()
390
+ src_orig_sz = list(src_img_fullres.shape[2:])
391
+
392
+ # Set fixed image as conditioning (target)
393
+ om.set_init_img(fixed_vol)
394
+ tgt_img_model = om._init_img.clone()
395
+ tgt_img_fullres = om._init_img_raw.clone()
396
+
397
+ # Standardize labels through OMorpher
398
+ src_mask_model, src_mask_fullres = None, None
399
+ tgt_mask_model, tgt_mask_fullres = None, None
400
+
401
+ if moving_label_map is not None:
402
+ # Split into per-class binary masks, stack as channels
403
+ src_class_masks = split_label_classes(moving_label_map, organ_label_ids.keys())
404
+ src_masks_model = []
405
+ src_masks_fullres = []
406
+ om.set_init_img(moving_vol) # reset so _standardize_label uses correct shape
407
+ for cid in sorted(organ_label_ids.keys()):
408
+ m_model, m_fullres = om._standardize_label(src_class_masks[cid])
409
+ src_masks_model.append(m_model)
410
+ src_masks_fullres.append(m_fullres)
411
+ src_mask_model = torch.cat(src_masks_model, dim=1)
412
+ src_mask_fullres = torch.cat(src_masks_fullres, dim=1)
413
+
414
+ if fixed_label_map is not None:
415
+ tgt_class_masks = split_label_classes(fixed_label_map, organ_label_ids.keys())
416
+ tgt_masks_model = []
417
+ tgt_masks_fullres = []
418
+ om.set_init_img(fixed_vol) # reset so _standardize_label uses correct shape
419
+ for cid in sorted(organ_label_ids.keys()):
420
+ m_model, m_fullres = om._standardize_label(tgt_class_masks[cid])
421
+ tgt_masks_model.append(m_model)
422
+ tgt_masks_fullres.append(m_fullres)
423
+ tgt_mask_model = torch.cat(tgt_masks_model, dim=1)
424
+ tgt_mask_fullres = torch.cat(tgt_masks_fullres, dim=1)
425
+
426
+ # --- Save target (fixed) original at model resolution ---
427
+ nib.save(
428
+ utils.converet_to_nibabel(tgt_img_model, ndims=ndims),
429
+ os.path.join(reg_img_savepath, f"{pair_tag}_TGT_ORG.nii.gz"),
430
+ )
431
+ if tgt_mask_model is not None:
432
+ nib.save(
433
+ utils.converet_to_nibabel(tgt_mask_model, ndims=ndims),
434
+ os.path.join(reg_msk_savepath, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
435
+ )
436
+
437
+ # --- Save source (moving) original at model resolution ---
438
+ nib.save(
439
+ utils.converet_to_nibabel(src_img_model, ndims=ndims),
440
+ os.path.join(reg_img_savepath, f"Src{pair_idx:04d}_ORG.nii.gz"),
441
+ )
442
+ if src_mask_model is not None:
443
+ nib.save(
444
+ utils.converet_to_nibabel(src_mask_model, ndims=ndims),
445
+ os.path.join(reg_msk_savepath, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
446
+ )
447
+
448
+ # --- Save target original at full resolution ---
449
+ nib.save(
450
+ utils.converet_to_nibabel(tgt_img_fullres, ndims=ndims),
451
+ os.path.join(reg_img_savepath_fullres, f"{pair_tag}_TGT_ORG.nii.gz"),
452
+ )
453
+ if tgt_mask_fullres is not None:
454
+ nib.save(
455
+ utils.converet_to_nibabel(tgt_mask_fullres, ndims=ndims),
456
+ os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
457
+ )
458
+
459
+ # --- Save source original at full resolution ---
460
+ nib.save(
461
+ utils.converet_to_nibabel(src_img_fullres, ndims=ndims),
462
+ os.path.join(reg_img_savepath_fullres, f"Src{pair_idx:04d}_ORG.nii.gz"),
463
+ )
464
+ if src_mask_fullres is not None:
465
+ nib.save(
466
+ utils.converet_to_nibabel(src_mask_fullres, ndims=ndims),
467
+ os.path.join(reg_msk_savepath_fullres, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
468
+ )
469
+
470
+ # --- Register moving to fixed ---
471
+ om.set_init_img(src_img_model)
472
+ om.set_cond_img(tgt_img_model.clone().detach())
473
+
474
+ om.predict(
475
+ T=[None, timesteps],
476
+ proc_type=condition_type,
477
+ )
478
+
479
+ ddf_comp = om.get_def()
480
+
481
+ # --- DDF quality: percent negative Jacobian determinant ---
482
+ neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
483
+ negdetj_pct[pair_idx] = neg_pct
484
+ print(f" %|J|<0 = {neg_pct:.4f}%")
485
+
486
+ # --- Model-resolution registered image ---
487
+ img_rec = om.apply_def(
488
+ img=src_img_model, ddf=ddf_comp, padding_mode="zeros",
489
+ )
490
+ nib.save(
491
+ utils.converet_to_nibabel(img_rec, ndims=ndims),
492
+ os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
493
+ )
494
+
495
+ # --- Model-resolution registered mask ---
496
+ msk_rec = None
497
+ if src_mask_model is not None:
498
+ msk_rec = om.apply_def(
499
+ img=src_mask_model, ddf=ddf_comp,
500
+ padding_mode="zeros", resample_mode="nearest",
501
+ )
502
+ nib.save(
503
+ utils.converet_to_nibabel(msk_rec, ndims=ndims),
504
+ os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
505
+ )
506
+
507
+ # --- Model-resolution DDF ---
508
+ nib.save(
509
+ utils.converet_to_nibabel(ddf_comp, ndims=ndims),
510
+ os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
511
+ )
512
+
513
+ # --- Full-resolution registered image ---
514
+ img_rec_fullres = om.apply_def(
515
+ img=src_img_fullres, ddf=ddf_comp, padding_mode="border",
516
+ )
517
+ nib.save(
518
+ utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
519
+ os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
520
+ )
521
+
522
+ # --- Full-resolution registered mask ---
523
+ msk_rec_fullres = None
524
+ if src_mask_fullres is not None:
525
+ msk_rec_fullres = om.apply_def(
526
+ img=src_mask_fullres, ddf=ddf_comp,
527
+ padding_mode="zeros", resample_mode="nearest",
528
+ )
529
+ nib.save(
530
+ utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
531
+ os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
532
+ )
533
+
534
+ # --- Full-resolution DDF ---
535
+ ddf_fullres = F.interpolate(
536
+ ddf_comp, size=src_orig_sz, mode="trilinear", align_corners=False,
537
+ )
538
+ nib.save(
539
+ utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
540
+ os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
541
+ )
542
+
543
+ # --- Evaluation metrics (full-res organ labels) ---
544
+ if (
545
+ organ_label_ids
546
+ and src_mask_fullres is not None
547
+ and tgt_mask_fullres is not None
548
+ ):
549
+ for ch_idx, cid in enumerate(sorted(organ_label_ids.keys())):
550
+ lk = organ_label_ids[cid]
551
+ tgt_mask_np = tgt_mask_fullres[0, ch_idx].cpu().numpy()
552
+ src_mask_np = src_mask_fullres[0, ch_idx].cpu().numpy()
553
+
554
+ if np.all(tgt_mask_np < 0) or np.all(src_mask_np < 0):
555
+ continue
556
+
557
+ # Pre-registration: source vs target (no deformation)
558
+ pre_dsc = compute_dsc(src_mask_np, tgt_mask_np)
559
+ pre_asd = compute_asd(src_mask_np, tgt_mask_np)
560
+ pre_hd = compute_hd(src_mask_np, tgt_mask_np)
561
+
562
+ metrics_pre[cid]["dsc"][pair_idx] = pre_dsc
563
+ metrics_pre[cid]["asd"][pair_idx] = pre_asd
564
+ metrics_pre[cid]["hd"][pair_idx] = pre_hd
565
+
566
+ # Post-registration: registered mask vs target
567
+ if msk_rec_fullres is not None:
568
+ reg_mask_np = msk_rec_fullres[0, ch_idx].cpu().numpy()
569
+ post_dsc = compute_dsc(reg_mask_np, tgt_mask_np)
570
+ post_asd = compute_asd(reg_mask_np, tgt_mask_np)
571
+ post_hd = compute_hd(reg_mask_np, tgt_mask_np)
572
+ else:
573
+ post_dsc = float("nan")
574
+ post_asd = float("nan")
575
+ post_hd = float("nan")
576
+
577
+ metrics[cid]["dsc"][pair_idx] = post_dsc
578
+ metrics[cid]["asd"][pair_idx] = post_asd
579
+ metrics[cid]["hd"][pair_idx] = post_hd
580
+
581
+ print(
582
+ f" [{lk}] PRE DSC={pre_dsc:.4f} ASD={pre_asd:.2f} HD={pre_hd:.2f}"
583
+ )
584
+ print(
585
+ f" [{lk}] POST DSC={post_dsc:.4f} ASD={post_asd:.2f} HD={post_hd:.2f}"
586
+ )
587
+
588
+ print("\nPaired registration complete.")
589
+
590
+ # ========== Write evaluation CSVs ==========
591
+
592
+ n_pairs = len(pairs)
593
+
594
+ def _fmt(val):
595
+ if val is None:
596
+ return ""
597
+ if np.isnan(val):
598
+ return "NaN"
599
+ return f"{val:.6f}"
600
+
601
+
602
+ # --- Per-pair %|J|<0 CSV ---
603
+ negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
604
+ with open(negdetj_csv_path, "w", newline="") as f:
605
+ writer = csv.writer(f)
606
+ writer.writerow(["pair_idx", "fixed", "moving", "negdetj_pct"])
607
+ for pi, fixed_name, moving_name in pair_info:
608
+ writer.writerow([pi, fixed_name, moving_name, _fmt(negdetj_pct.get(pi))])
609
+ print(f"Saved {negdetj_csv_path}")
610
+
611
+ for cid in sorted(organ_label_ids.keys()):
612
+ lk = organ_label_ids[cid]
613
+ prefix = f"{lk}_" if len(organ_label_ids) > 1 else ""
614
+
615
+ for metric_name in ["dsc", "asd", "hd"]:
616
+ mn_upper = metric_name.upper()
617
+ csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
618
+ with open(csv_path, "w", newline="") as f:
619
+ writer = csv.writer(f)
620
+ writer.writerow([
621
+ "pair_idx", "fixed", "moving",
622
+ f"pre_{mn_upper}", f"post_{mn_upper}",
623
+ ])
624
+ for pi, fixed_name, moving_name in pair_info:
625
+ pre_val = metrics_pre[cid][metric_name].get(pi)
626
+ post_val = metrics[cid][metric_name].get(pi)
627
+ writer.writerow([
628
+ pi, fixed_name, moving_name,
629
+ _fmt(pre_val), _fmt(post_val),
630
+ ])
631
+ print(f"Saved {csv_path}")
632
+
633
+ # --- Overall summary ---
634
+ overall_path = os.path.join(eval_dir, "overall.csv")
635
+ with open(overall_path, "w", newline="") as f:
636
+ writer = csv.writer(f)
637
+ writer.writerow([
638
+ "label", "metric",
639
+ "pre_mean", "pre_std",
640
+ "post_mean", "post_std",
641
+ "n_pairs",
642
+ ])
643
+ # %|J|<0 summary (not per-label)
644
+ negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
645
+ writer.writerow([
646
+ "ALL",
647
+ "%|J|<0",
648
+ "", "",
649
+ _fmt(np.mean(negdetj_vals) if negdetj_vals else float("nan")),
650
+ _fmt(np.std(negdetj_vals) if negdetj_vals else float("nan")),
651
+ len(negdetj_vals),
652
+ ])
653
+ for cid in sorted(organ_label_ids.keys()):
654
+ lk = organ_label_ids[cid]
655
+ for metric_name in ["dsc", "asd", "hd"]:
656
+ pre_vals = [
657
+ v for v in metrics_pre[cid][metric_name].values()
658
+ if not np.isnan(v)
659
+ ]
660
+ post_vals = [
661
+ v for v in metrics[cid][metric_name].values()
662
+ if not np.isnan(v)
663
+ ]
664
+ pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
665
+ pre_std = np.std(pre_vals) if pre_vals else float("nan")
666
+ post_mean = np.mean(post_vals) if post_vals else float("nan")
667
+ post_std = np.std(post_vals) if post_vals else float("nan")
668
+ n = max(len(pre_vals), len(post_vals))
669
+ writer.writerow([
670
+ lk,
671
+ metric_name.upper(),
672
+ _fmt(pre_mean), _fmt(pre_std),
673
+ _fmt(post_mean), _fmt(post_std),
674
+ n,
675
+ ])
676
+ print(f"Saved {overall_path}")