maxmo2009 commited on
Commit
75854b3
·
verified ·
1 Parent(s): ac27c09

Initial upload: OmniMorph codebase

Browse files
Files changed (48) hide show
  1. .gitattributes +13 -0
  2. .gitignore +29 -0
  3. Config/config_cmr.yaml +29 -0
  4. Config/config_lct.yaml +31 -0
  5. Config/config_om.yaml +53 -0
  6. Config/config_om_contrastive.yaml +51 -0
  7. Dataloader/PSMA-CT_mappings.json +3 -0
  8. Dataloader/bert_helper.py +258 -0
  9. Dataloader/dataLoader.py +1473 -0
  10. Dataloader/dataloader0.py +421 -0
  11. Dataloader/dataloader_tester.py +39 -0
  12. Dataloader/dataloader_utils.py +193 -0
  13. Dataloader/embding_gen.py +149 -0
  14. Dataloader/nifty_mappings/AbdomenAtlas_mappings.json +3 -0
  15. Dataloader/nifty_mappings/AbdomenCT1k_mappings.json +3 -0
  16. Dataloader/nifty_mappings/Brats2019_mappings.json +3 -0
  17. Dataloader/nifty_mappings/Brats2020_mappings.json +3 -0
  18. Dataloader/nifty_mappings/Brats2021_mappings.json +3 -0
  19. Dataloader/nifty_mappings/CIA_mappings.json +3 -0
  20. Dataloader/nifty_mappings/Kaggle_osic_mappings.json +0 -0
  21. Dataloader/nifty_mappings/MSD_mappings.json +3 -0
  22. Dataloader/nifty_mappings/MnMs_mappings.json +0 -0
  23. Dataloader/nifty_mappings/OASIS_1_mappings.json +3 -0
  24. Dataloader/nifty_mappings/OASIS_2_mappings.json +3 -0
  25. Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json +3 -0
  26. Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json +3 -0
  27. Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json +3 -0
  28. Diffusion/__init__.py +8 -0
  29. Diffusion/diffuser.py +531 -0
  30. Diffusion/losses.py +534 -0
  31. Diffusion/losses_ncc0.py +496 -0
  32. Diffusion/networks.py +1167 -0
  33. Diffusion/utils_diff.py +477 -0
  34. LICENSE +201 -0
  35. OM_aug.py +254 -0
  36. OM_aug_highres.py +233 -0
  37. OM_contrastive.py +72 -0
  38. OM_reg.py +240 -0
  39. OM_train.py +309 -0
  40. OM_train_2modes.py +528 -0
  41. OM_train_3modes.py +490 -0
  42. OM_train_uncon.py +258 -0
  43. README.md +11 -0
  44. bash_infer.sh +9 -0
  45. bash_train.sh +12 -0
  46. dataloader_tester.py +65 -0
  47. requirements.txt +57 -0
  48. utils.py +498 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Dataloader/PSMA-CT_mappings.json filter=lfs diff=lfs merge=lfs -text
37
+ Dataloader/nifty_mappings/AbdomenAtlas_mappings.json filter=lfs diff=lfs merge=lfs -text
38
+ Dataloader/nifty_mappings/AbdomenCT1k_mappings.json filter=lfs diff=lfs merge=lfs -text
39
+ Dataloader/nifty_mappings/Brats2019_mappings.json filter=lfs diff=lfs merge=lfs -text
40
+ Dataloader/nifty_mappings/Brats2020_mappings.json filter=lfs diff=lfs merge=lfs -text
41
+ Dataloader/nifty_mappings/Brats2021_mappings.json filter=lfs diff=lfs merge=lfs -text
42
+ Dataloader/nifty_mappings/CIA_mappings.json filter=lfs diff=lfs merge=lfs -text
43
+ Dataloader/nifty_mappings/MSD_mappings.json filter=lfs diff=lfs merge=lfs -text
44
+ Dataloader/nifty_mappings/OASIS_1_mappings.json filter=lfs diff=lfs merge=lfs -text
45
+ Dataloader/nifty_mappings/OASIS_2_mappings.json filter=lfs diff=lfs merge=lfs -text
46
+ Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
47
+ Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
48
+ Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model checkpoints
2
+ Models/
3
+
4
+ # Data files
5
+ Data/
6
+
7
+ # Python cache
8
+ __pycache__/
9
+
10
+ # Virtual environment
11
+ ominenv/
12
+
13
+ # External libraries
14
+ External/
15
+
16
+ # Logs
17
+ Log/
18
+ swanlog/
19
+ train_log.txt
20
+ aug_log.txt
21
+
22
+ # Reference implementation
23
+ def_diff_rec/
24
+
25
+ # IDE
26
+ .vscode/
27
+
28
+ # Misc
29
+ CLAUDE.md
Config/config_cmr.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: cmr
2
+ net_name: recresacnet
3
+ ndims: 2
4
+ img_size: 256
5
+ batchsize: 1
6
+ ddf_pad_mode: border
7
+ device: cuda
8
+ img_pad_mode: zeros
9
+ num_input_chn: 1
10
+ padding_mode: zeros
11
+ resample_mode: bicubic
12
+ timesteps: 80
13
+ v_scale: 4.0e-05
14
+ # =========================
15
+ # TRAINING SETTING
16
+ epoch: 10000
17
+ epoch_per_save: 1
18
+ lr: 0.0001
19
+ noise_scale: 0.1
20
+ # =========================
21
+ # AUGMENTATION SETTING
22
+ patients_list: []
23
+ model_id_str: '000000'
24
+ start_noise_step: 48
25
+ noise_step: 2
26
+ aug_coe: 32 # how many times each sample will be augmented
27
+ aug_img_savepath: Data/Aug_data/cmr/img/
28
+ aug_msk_savepath: Data/Aug_data/cmr/msk/
29
+ aug_ddf_savepath: Data/Aug_data/cmr/ddf/
Config/config_lct.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: lct
2
+ net_name: recmutattnnet
3
+ # net_name: recresacnet
4
+ ndims: 3
5
+ img_size: 128 #was 128
6
+ batchsize: 2
7
+ ddf_pad_mode: border
8
+ device: cuda
9
+ img_pad_mode: zeros
10
+ num_input_chn: 1
11
+ padding_mode: border
12
+ resample_mode: bilinear
13
+ timesteps: 80
14
+ v_scale: 4.0e-05
15
+ # =========================
16
+ # TRAINING SETTING
17
+ epoch: 10000
18
+ epoch_per_save: 1
19
+ lr: 0.00001
20
+ noise_scale: 0.1
21
+ # =========================
22
+ # AUGMENTATION SETTING
23
+ patients_list: []
24
+ model_id_str: '001157'
25
+ start_noise_step: 64
26
+ noise_step: 1
27
+ aug_coe: 32 # how many times each sample will be augmented
28
+ condition_type: 'project' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
29
+ aug_img_savepath: Data/Aug_data/lct/img/
30
+ aug_msk_savepath: Data/Aug_data/lct/msk/
31
+ aug_ddf_savepath: Data/Aug_data/lct/ddf/
Config/config_om.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: all
2
+ # net_name: recresacnet
3
+ net_name: recmutattnnet
4
+ # net_name: recmutattnnet1
5
+ # net_name: defrecmutattnnet
6
+ ndims: 3
7
+ img_size: 128
8
+ batchsize: 2
9
+ ddf_pad_mode: border
10
+ device: cuda
11
+ img_pad_mode: zeros
12
+ num_input_chn: 1
13
+ padding_mode: border
14
+ resample_mode: bilinear
15
+ timesteps: 80
16
+ v_scale: 5.0e-05
17
+ # =========================
18
+ # TRAINING SETTING
19
+ epoch: 10000
20
+ epoch_per_save: 1
21
+ lr: 0.00001
22
+ noise_scale: 0.1
23
+ # =========================
24
+ # AUGMENTATION SETTING
25
+ patients_list: []
26
+ # model_id_str: '000000'
27
+ # model_id_str: '000180' # before registration training
28
+ # model_id_str: '000353' # good augmentation results on msd
29
+ model_id_str: '000354' #
30
+ # model_id_str: '000157'
31
+ # model_id_str: '000171'
32
+ start_noise_step: 48 # starting from which noise step to add noise
33
+ noise_step: 1
34
+ aug_coe: 64 # how many times each sample will be augmented
35
+ # start_noise_step: 56 # starting from which noise step to add noise
36
+ # noise_step: 4
37
+ # aug_coe: 4 # how many times each sample will be augmented
38
+ condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
39
+ # aug_img_savepath: Data/Aug_data/totseg/img/
40
+ # aug_msk_savepath: Data/Aug_data/totseg/msk/
41
+ # aug_ddf_savepath: Data/Aug_data/totseg/ddf/
42
+ # aug_img_savepath: Data/Aug_data/om/img/
43
+ # aug_msk_savepath: Data/Aug_data/om/msk/
44
+ # aug_ddf_savepath: Data/Aug_data/om/ddf/
45
+ reg_img_savepath: Data/Reg_data/om/img/
46
+ reg_msk_savepath: Data/Reg_data/om/msk/
47
+ reg_ddf_savepath: Data/Reg_data/om/ddf/
48
+ # aug_img_savepath: Data/Aug_data/msd/img/
49
+ # aug_msk_savepath: Data/Aug_data/msd/msk/
50
+ # aug_ddf_savepath: Data/Aug_data/msd/ddf/
51
+ aug_img_savepath: Data/Aug_data/mnms/img/
52
+ aug_msk_savepath: Data/Aug_data/mnms/msk/
53
+ aug_ddf_savepath: Data/Aug_data/mnms/ddf/
Config/config_om_contrastive.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_name: all
2
+ # net_name: recresacnet
3
+ # net_name: recmutattnnet
4
+ net_name: recmutattnnet_contrastive
5
+ # net_name: recmutattnnet1
6
+ # net_name: defrecmutattnnet
7
+ ndims: 3
8
+ img_size: 128
9
+ batchsize: 1 #1 for testing
10
+ ddf_pad_mode: border
11
+ device: cuda
12
+ img_pad_mode: zeros
13
+ num_input_chn: 1
14
+ padding_mode: border
15
+ resample_mode: bilinear
16
+ timesteps: 80
17
+ v_scale: 5.0e-05
18
+ # =========================
19
+ # TRAINING SETTING
20
+ epoch: 10000
21
+ epoch_per_save: 1
22
+ lr: 0.00001
23
+ noise_scale: 0.1
24
+ # =========================
25
+ # AUGMENTATION SETTING
26
+ patients_list: []
27
+ # model_id_str: '000000'
28
+ # model_id_str: '000180' # before registration training
29
+ # model_id_str: '000353' # good augmentation results on msd
30
+ model_id_str: '000354' #
31
+ # model_id_str: '000157'
32
+ # model_id_str: '000171'
33
+ start_noise_step: 48 # starting from which noise step to add noise
34
+ noise_step: 1
35
+ aug_coe: 64 # how many times each sample will be augmented
36
+ # start_noise_step: 56 # starting from which noise step to add noise
37
+ # noise_step: 4
38
+ # aug_coe: 4 # how many times each sample will be augmented
39
+ condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
40
+ # aug_img_savepath: Data/Aug_data/totseg/img/
41
+ # aug_msk_savepath: Data/Aug_data/totseg/msk/
42
+ # aug_ddf_savepath: Data/Aug_data/totseg/ddf/
43
+ # aug_img_savepath: Data/Aug_data/om/img/
44
+ # aug_msk_savepath: Data/Aug_data/om/msk/
45
+ # aug_ddf_savepath: Data/Aug_data/om/ddf/
46
+ reg_img_savepath: Data/Reg_data/om/img/
47
+ reg_msk_savepath: Data/Reg_data/om/msk/
48
+ reg_ddf_savepath: Data/Reg_data/om/ddf/
49
+ aug_img_savepath: Data/Aug_data/msd/img/
50
+ aug_msk_savepath: Data/Aug_data/msd/msk/
51
+ aug_ddf_savepath: Data/Aug_data/msd/ddf/
Dataloader/PSMA-CT_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fbbdc9b4b48688a37c4f828eea2823820a1ee27f954d5987d8cbf3b67d6d9bf
3
+ size 179285490
Dataloader/bert_helper.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import math
4
+ from torch.nn import Tanh, BatchNorm1d
5
+ from typing import Optional
6
+ import torch.nn as nn
7
+ import torch
8
+ from transformers import BertModel, BertForSequenceClassification
9
+ from transformers import BertTokenizer
10
+ from transformers import AutoTokenizer, AutoModel
11
+
12
+ from torch.utils.data import Dataset as Dataset_n
13
+ from torch.utils.data import DataLoader as DataLoader_n
14
+ from torch.utils.data import WeightedRandomSampler
15
+
16
+ def _freeze_bert(
17
+ bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
18
+ ):
19
+ """Freeze parameters in BertModel (in place)
20
+ Args:
21
+ bert_model: HuggingFace bert model
22
+ freeze_bert: Bool whether to freeze the bert model
23
+ freeze_layer_count: If freeze_bert, up to what layer to freeze.
24
+ Returns:
25
+ bert_model
26
+ """
27
+ if freeze_bert:
28
+ # freeze the entire bert model
29
+ for param in bert_model.parameters():
30
+ param.requires_grad = False
31
+ else:
32
+ # freeze the embeddings
33
+ for param in bert_model.embeddings.parameters():
34
+ param.requires_grad = False
35
+ if freeze_layer_count != -1:
36
+ if freeze_layer_count > 0 :
37
+ # freeze layers in bert_model.encoder
38
+ for layer in bert_model.encoder.layer[:freeze_layer_count]:
39
+ for param in layer.parameters():
40
+ param.requires_grad = False
41
+
42
+ if freeze_layer_count < 0 :
43
+ # freeze layers in bert_model.encoder
44
+ for layer in bert_model.encoder.layer[freeze_layer_count:]:
45
+ for param in layer.parameters():
46
+ param.requires_grad = False
47
+ return None
48
+
49
+ def get_frozen_embeder(key_word="bert-large-uncased"):
50
+ tokenizer = AutoTokenizer.from_pretrained(key_word, do_lower_case=False)
51
+ model = AutoModel.from_pretrained(key_word)
52
+
53
+ _freeze_bert(model, freeze_bert=True, freeze_layer_count=None)
54
+ return model, tokenizer
55
+
56
+
57
+ def str2emb(string, max_words_num=100, embeder=None, tokenizer=None, reduce_method='mean'):
58
+ string = string.lower()
59
+ str_token = tokenizer(string, return_tensors='pt', max_length=max_words_num,
60
+ padding='max_length', truncation=True)
61
+ embeder_output = embeder(**str_token)
62
+ if reduce_method == 'mean':
63
+ embeder_output = torch.mean(embeder_output.last_hidden_state, dim=1)
64
+ elif reduce_method == 'max':
65
+ embeder_output = torch.max(embeder_output.last_hidden_state, dim=1)[0]
66
+ else:
67
+ embeder_output = embeder_output.last_hidden_state
68
+ return embeder_output
69
+
70
+ def get_synonyms_dict(dict_type=None):
71
+ '''
72
+ Get the dictionary of synonyms for the specified dictionary type
73
+ '''
74
+ if dict_type == 'ROI':
75
+ dict_synonyms = {
76
+ 'whole-body': ['whole-body', 'whole body', 'wholebody', 'whole body', 'whole-body', 'whole body', 'wholebody','polytrauma','head-neck-thorax-abdomen-pelvis-leg','head-neck-thorax-abdomen-pelvis'],
77
+ 'neck-thorax-abdomen-pelvis-leg': ['neck-thorax-abdomen-pelvis-leg','neck-thx-abd-pelvis-leg', 'angiography neck-thx-abd-pelvis-leg', 'neck thorax abdomen pelvis leg', 'neck and thorax and abdomen and pelvis and leg', 'neck, thorax, abdomen, pelvis & leg', 'neck/thorax/abdomen/pelvis/leg', 'neck, thorax, abdomen, pelvis and leg', 'neck thorax abdomen pelvis leg'],
78
+ 'neck-thorax-abdomen-pelvis': ['neck-thorax-abdomen-pelvis', 'neck-thx-abd-pelvis', 'neck thorax abdomen pelvis', 'neck and thorax and abdomen and pelvis', 'neck, thorax, abdomen & pelvis', 'neck/thorax/abdomen/pelvis', 'neck, thorax, abdomen and pelvis', 'neck thorax abdomen & pelvis'],
79
+ 'thorax-abdomen-pelvis-leg': ['thorax-abdomen-pelvis-leg','thx-abd-pelvis-leg', 'angiography thx-abd-pelvis-leg', 'thorax abdomen pelvis leg', 'thorax and abdomen and pelvis and leg', 'thorax, abdomen, pelvis & leg', 'thorax/abdomen/pelvis/leg', 'thorax, abdomen, pelvis and leg', 'thorax abdomen pelvis leg'],
80
+ 'neck-thorax-abdomen': ['neck-thorax-abdomen', 'neck-thorax-abdomen', 'neck thorax abdomen', 'neck and thorax and abdomen', 'neck, thorax, abdomen', 'neck/thorax/abdomen', 'neck, thorax, abdomen', 'neck thorax abdomen'],
81
+ 'head-neck-thorax-abdomen': ['head-neck-thorax-abdomen', 'head-neck-thorax-abdomen', 'head neck thorax abdomen', 'head and neck and thorax and abdomen', 'head, neck, thorax, abdomen', 'head/thorax/abdomen', 'head, thorax, abdomen', 'head thorax abdomen'],
82
+ 'head-neck-thorax': ['head-neck-thorax', 'head neck thorax', 'head and neck and thorax', 'head, neck, thorax', 'head/thorax', 'head, thorax', 'head thorax'],
83
+ 'thorax-abdomen-pelvis': ['thorax-abdomen-pelvis', 'thx-abd-pelvis', 'polytrauma', 'thorax abdomen pelvis', 'thorax and abdomen and pelvis', 'thorax, abdomen & pelvis', 'thorax/abdomen/pelvis', 'thorax, abdomen and pelvis', 'thorax abdomen & pelvis'],
84
+ 'abdomen-pelvis-leg': ['abdomen-pelvis-leg', 'angiography abdomen-pelvis-leg', 'abd-pelvis-leg', 'abdomen pelvis leg', 'abdomen and pelvis and leg', 'abdomen, pelvis & leg', 'abdomen/pelvis/leg', 'abdomen, pelvis, leg', 'abdomen pelvis leg'],
85
+ 'neck-thorax': ['neck-thorax', 'neck thorax', 'neck and thorax', 'neck, thorax', 'thorax-neck', 'thorax neck', 'thorax and neck', 'thorax, neck','thorax/neck'],
86
+ 'thorax-abdomen': ['thorax-abdomen', 'thorax abdomen', 'thorax and abdomen', 'thorax, abdomen', 'aortic valve'],
87
+ 'abdomen-pelvis': ['abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis', 'abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis'],
88
+ 'pelvis-leg': ['pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg', 'pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg'],
89
+ 'head-neck': ['head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck', 'head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck'],
90
+ 'abdomen': ['abdomen', 'abdominal', 'belly', 'stomach', 'tummy', 'gut', 'guts', 'viscera', 'bowels', 'intestines', 'gastrointestinal', 'digestive', 'peritoneum','gastric', 'liver', 'spleen', 'pancreas','kidney','lumbar','renal','hepatic','splenic','pancreatic','intervention'],
91
+ 'thorax': ['chest', 'thorax', 'breast', 'lung', 'heart','heart-thorakale aorta', 'heart-thorakale', 'mediastinum', 'pleura', 'bronchus', 'bronchi', 'trachea', 'esophagus', 'diaphragm', 'rib', 'sternum', 'clavicle', 'scapula', 'axilla', 'armpit','breast biopsy','thoracic','mammary','caeiothoracic','mediastinal','pleural','bronchial','bronchial tree','tracheal','esophageal','diaphragmatic','costal','sternal','clavicular','scapular','axillary','axillar','cardiac','pericardial','pericardiac','pericardium'],
92
+ 'head': ['head', 'headbasis', 'brain', 'skull', 'face','nose','ear','eye','mouth','jaw','cheek','chin','forehead','temporal','parietal','occipital','frontal','mandible','maxilla','mandibular','maxillary','nasal','orbital','orbita','ocular','auricular','otic','oral','buccal','labial','lingual','palatal'],
93
+ 'neck': ['neck', 'throat', 'cervical', 'thyroid', 'trachea', 'larynx', 'pharynx', 'esophagus','pharyngeal','laryngeal','cervical','thyroid','trachea','esophagus','carotid','jugular'],
94
+ 'hand': ['hand', 'finger', 'thumb', 'palm', 'wrist', 'knuckle', 'fingernail', 'phalanx', 'metacarpal', 'carpal', 'radius'],
95
+ 'arm': ['arm', 'forearm', 'upper arm', 'bicep', 'tricep', 'brachium', 'brachial', 'humerus', 'radius', 'ulna', 'elbow', 'shoulder', 'armpit''clavicle', 'scapula', 'acromion', 'acromioclavicular'],
96
+ 'leg': ['leg', 'felsenleg','thigh', 'calf', 'shin', 'knee', 'foot', 'ankle', 'toe', 'heel', 'sole', 'arch', 'instep', 'metatarsal', 'phalanx', 'tibia', 'fibula', 'femur', 'patella', 'kneecap','achilles tendon','achilles'],
97
+ 'pelvis': ['pelvis', 'hip', 'groin', 'buttock', 'gluteus', 'gluteal', 'ischium', 'pubis', 'sacrum', 'coccyx', 'acetabulum', 'iliac', 'iliac crest', 'iliac spine', 'iliac wing', 'sacroiliac', 'sacroiliac joint', 'sacroiliac ligament', 'sacroiliac spine', 'ureter', 'bladder', 'urethra', 'prostate', 'testicle', 'ovary', 'uterus',],
98
+ 'skeleton': ['skeleton','bone','spine', 'back', 'vertebra', 'sacrum', 'coccyx'],
99
+ }
100
+ elif dict_type == 'Label_tissue':
101
+ dict_synonyms = {
102
+ 'liver': ['liver','hepatic'],
103
+ 'spleen': ['spleen','splenic'],
104
+ 'kidney': ['kidney','renal'],
105
+ 'pancreas': ['pancreas','pancreatic'],
106
+ 'stomach': ['stomach','gastric'],
107
+ 'intestine': ['large intestine', 'small intestine','large bowel','small bowel'],
108
+ 'gallbladder': ['gallbladder'],
109
+ 'adrenal_gland': ['adrenal_gland','adrenal gland'],
110
+ 'bladder': ['bladder'],
111
+ 'prostate': ['prostate'],
112
+ 'uterus': ['uterus'],
113
+ 'ovary': ['ovary'],
114
+ 'testicle': ['testicle'],
115
+ 'lymph_node': ['lymph_node','lymph node'],
116
+ 'bone': ['bone'],
117
+ 'lung': ['lung'],
118
+ 'heart': ['heart'],
119
+ 'esophagus': ['esophagus'],
120
+ 'muscle': ['muscle'],
121
+ 'fat': ['fat'],
122
+ 'skin': ['skin'],
123
+ 'vessel': ['vessel'],
124
+ 'tumor': ['tumor'],
125
+ 'other': ['other']
126
+ }
127
+ elif dict_type == 'Task':
128
+ dict_synonyms = {
129
+ 'segmentation': ['segmentation', 'seg', 'mask'],
130
+ 'classification': ['classification', 'class', 'diagnosis','identify','identification'],
131
+ 'localization': ['localization', 'locate', 'location', 'position'],
132
+ 'registration': ['registration', 'register', 'align', 'alignment'],
133
+ 'detection': ['detection', 'detect', 'find', 'locate'],
134
+ 'quantification': ['quantification', 'quantify', 'measure', 'measurement'],
135
+ }
136
+ elif dict_type == 'Modality':
137
+ dict_synonyms = {
138
+ 'CT': ['CT', 'computed tomography'],
139
+ 'MRI': ['MRI', 'MR', 'magnetic resonance imaging'],
140
+ 'PET': ['PET', 'positron emission tomography'],
141
+ 'US': ['US', 'ultrasound'],
142
+ 'X-ray': ['X-ray', 'radiography'],
143
+ 'SPECT': ['SPECT', 'single-photon emission computed tomlogy'],
144
+ }
145
+ else:
146
+ dict_synonyms = {
147
+ '\'gender\'': ['\'gender\'', '\'sex\'', '\'M/F\'', '\'m/f\''],
148
+ '\'modality\'': ['\'modality\'', '\'modal\''],
149
+ '\'male\'': ['\'male\'', '\'m\''],
150
+ '\'female\'': ['\'female\'', '\'f\'','\'woman\''],
151
+ '\'high-grade glioma\'': ['\'high-grade glioma\'', '\'high grade glioma\'', '\'HGG\''],
152
+ '\'low-grade glioma\'': ['\'low-grade glioma\'', '\'low grade glioma\'', '\'LGG\''],
153
+ '\'atlas scaling factor\'': ['\'atlas scaling factor\'', '\'asf\''],
154
+ '\'age\'': ['\'age\'', '\'years\'', '\'year\'', '\'y/o\'', '\'y.o.\''],
155
+ '\'education\'': ['\'educ\'', '\'educat\'', '\'education\''],
156
+ '\'roi\'': ['\'roi\'', '\'region of interest\'', '\'region\''],
157
+ '\'mini-mental state examination\'': ['\'mini-mental state examination\'', '\'mmse\''],
158
+ '\'clinical dementia rating\'': ['\'clinical dementia rating\'', '\'cdr\''],
159
+ '\'socio-economic status\'': ['\'socio-economic status\'', '\'ses\''],
160
+ '\'unknown\'': ['\'unknown\'', '\'unkn\'', '\'not available\'', '\'nan\'', '\'n/a\'', '\'none\'', '\'n.a.\'', '\'not applicable\'','\'not specified\'', '\'unspecified\'', '\'not given\'', '\'null\''],
161
+ '': [' segmentation', '\'seg\'', '\'registration\''],
162
+ }
163
+ return dict_synonyms
164
+
165
+ def replace_text(text, dict_synonyms):
166
+ '''
167
+ Replace the text in the text with the standard term
168
+ '''
169
+ if isinstance(text, str):
170
+ for key, value in dict_synonyms.items():
171
+ for v in value:
172
+ if v.lower() in text.lower():
173
+ text = text.replace(v, key)
174
+ return text
175
+ elif isinstance(text, list):
176
+ text = [replace_text(t, dict_synonyms) for t in text]
177
+ elif isinstance(text, dict):
178
+ for key in text.keys():
179
+ # replace values in dict
180
+ text[key] = replace_text(text[key], dict_synonyms)
181
+ # replace keys in dict
182
+ for k in dict_synonyms.keys():
183
+ if k.lower() in key.lower():
184
+ text[dict_synonyms[k]] = text.pop(key)
185
+ return text
186
+
187
+
188
+ def replace_synonyms(text, dict_synonyms):
189
+ '''
190
+ Replace the synonyms in the text with the standard term
191
+ '''
192
+ if isinstance(text,str):
193
+ for key, value in dict_synonyms.items():
194
+ for v in value:
195
+ if v.lower() in text.lower():
196
+ return key
197
+ Warning(f"Value {text} is not in the correct format")
198
+ elif isinstance(text,list):
199
+ text = [replace_synonyms(t, dict_synonyms) for t in text]
200
+ elif isinstance(text,dict):
201
+ for key in text.keys():
202
+ # replace values in dict
203
+ text[key] = replace_synonyms(text[key], dict_synonyms)
204
+ # replace keys in dict
205
+ for k in dict_synonyms.keys():
206
+ text[dict_synonyms[k]] = text.pop(key)
207
+ return text
208
+
209
+ if __name__ == "__main__":
210
+ # model_name = "bert-base-uncased"
211
+ # model_name = "bert-large-uncased"
212
+ model_name = "/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased"
213
+ # model_name = "Rostlab/prot_bert"
214
+ # model_name = "fspanda/Medical-Bio-BERT2"
215
+ # model_name = "GerMedBERT/medbert-512"
216
+
217
+ reduce_method = 'mean'
218
+ max_words_num = 32 # max number of words in the caption > 2
219
+
220
+ embeder, tokenizer = get_frozen_embeder(model_name)
221
+
222
+ # string1 = ["mri", "female"]
223
+ string1 = "modality: ct, gender: female, age: 51, roi: abdomen"
224
+ # string1 = "modality: Magnetic Resonance, gender: female"
225
+ embeder_output1 = str2emb(string1, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
226
+
227
+ # string2 = "Hello world!"
228
+ # string2 = ["ct", "male"]
229
+ # string2 = "modality: mri, gender: female, roi: head"
230
+ string2 = "modality: ct, gender: female, age: 50, roi: head"
231
+ # string2 = "modality: ct, gender: male, roi: head"
232
+ embeder_output2 = str2emb(string2, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
233
+
234
+ input_size = embeder.config.vocab_size
235
+ in_size = embeder.config.hidden_size
236
+
237
+ print(embeder, input_size, in_size)
238
+ print(tokenizer)
239
+
240
+ # embeder_output1 shape: [batch_size, max_words_num, hidden_size]
241
+ print(embeder_output1)
242
+ print(embeder_output1.shape) # torch.Size([1, 8, 768])
243
+
244
+ # embeder_output2 shape: [batch_size, max_words_num, hidden_size]
245
+ print(embeder_output2)
246
+ print(embeder_output2.shape) # torch.Size([1, 8, 768])
247
+
248
+ # check the difference between the two sentences in embedding space
249
+ # embeder_output1[0, :, :] shape: [max_words_num, hidden_size]
250
+ # embeder_output2[0, :, :] shape: [max_words_num, hidden_size]
251
+ # error = torch.max(torch.abs(embeder_output1[0, :, :] - embeder_output2[0, :, :]), dim=-1)
252
+ error = torch.abs(embeder_output1 - embeder_output2)
253
+ print(error)
254
+ print("Embedding distance between the two sentences: ")
255
+ print(f"String1: {string1}")
256
+ print(f"String2: {string2}")
257
+ print(torch.mean(error))
258
+ exit()
Dataloader/dataLoader.py ADDED
@@ -0,0 +1,1473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import json
4
+ import SimpleITK as sitk
5
+ import numpy as np
6
+ from skimage.transform import rescale, resize, downscale_local_mean
7
+ # from torchvision.transforms import v2
8
+ import sys
9
+ sys.path.append('./')
10
+ from Dataloader.dataloader_utils import *
11
+ import random
12
+
13
+ # add your mapping files here
14
+ # mapping_files = {
15
+ # 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
16
+ # 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
17
+ # # 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_1/nifti_mappings.json',
18
+ # }
19
+
20
+
21
+ mapping_files = {
22
+ 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
23
+ 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
24
+ 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
25
+ 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
26
+ 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
27
+ # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
28
+ 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
29
+ 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
30
+ 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
31
+ 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
32
+ 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
33
+ 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
34
+ 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
35
+ 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
36
+ }
37
+
38
+ CLAMP_RANGE = [-400, 400] # default clamp range for the images
39
+
40
+ indivi_ROI_list = ['abdomen','arm','brain','hand','head','leg','neck','pelvis','skeleton','thorax']
41
+
42
+ def reverse_axis_order(arr):
43
+ """SimpleITK to NumPy axis order conversion."""
44
+ # For 3D or 4D arrays, this is just a fast view, not a copy.
45
+ return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
46
+
47
+ def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high'):
48
+ """Sample a random value from a uniform distribution with multiple orders.
49
+
50
+ Args:
51
+ high (float): Upper bound of the uniform distribution.
52
+ low (float): Lower bound of the uniform distribution.
53
+ order_num (int): Number of times to sample.
54
+ type (str): 'high' or 'low', determines the sampling direction.
55
+
56
+ Returns:
57
+ sample_value (float): The sampled value after multiple orders.
58
+
59
+ Notes:
60
+ - If type is 'high', samples are drawn iteratively from [low, high], each time using the previous sample as the new lower bound.
61
+ - If type is 'low', samples are drawn iteratively from [low, high], each time using the previous sample as the new upper bound.
62
+ - If order_num is 0, returns the low value.
63
+ - If order_num is 1, returns a single random value from the uniform distribution.
64
+ - If order_num is 2, returns a value from a linear distribution.
65
+ - If order_num is 3, returns a value from a quadratic distribution.
66
+ """
67
+ if type == 'high':
68
+ sample_value = low
69
+ for _ in range(order_num):
70
+ sample_value = np.random.uniform(low=sample_value, high=high)
71
+ elif type == 'low':
72
+ sample_value = high
73
+ for _ in range(order_num):
74
+ sample_value = np.random.uniform(low, high=sample_value)
75
+ return sample_value
76
+
77
+ class OminiDataset(object):
78
+ """Base class for OmniMorph datasets."""
79
+ def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
80
+
81
+ # self.mappings = mapping_files
82
+ self.ALLdata = self.combine_data(mappings = mapping_files)
83
+ self.out_sz = out_sz
84
+ self.reverse_axis_order = reverse_axis_order
85
+ self.min_dim = min_dim
86
+ self.clamp_range = clamp_range
87
+ self.min_crop_ratio = min_crop_ratio
88
+ self.transform = transform
89
+ self.ndims = 3
90
+
91
+ def get_ALLdata(self):
92
+ return self.ALLdata
93
+
94
+ def get_all_ROI(self):
95
+ # Get all the ROI options. and remove the reduntant ones
96
+ ROIs = []
97
+ # ALLdata_filtered = data
98
+ for k in self.ALLdata_filtered.keys():
99
+ ROIs.append(self.ALLdata[k]['ROI'])
100
+ ROIs = set(ROIs)
101
+ return ROIs
102
+
103
+ def get_filter_ROIs(self,keep_single_roi=False):
104
+ ALLdata_filtered = self.ALLdata_filtered.copy()
105
+ # if keep_single_roi == True:
106
+ # for k in self.ALLdata_filtered.keys():
107
+ # if '-' in self.ALLdata_filtered[k]['ROI']:
108
+ # del ALLdata_filtered[k]
109
+ # d = {k: v for k, v in ALLdata_filtered.items() if v['ROI'] in self.ROIs}
110
+ for k in ALLdata_filtered.keys():
111
+ if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
112
+ del ALLdata_filtered[k]
113
+ return ALLdata_filtered
114
+
115
+ def combine_data(self, mappings = mapping_files):
116
+ ALLdata = {}
117
+ for j in mappings.keys():
118
+ with open(mappings[j], 'r') as f:
119
+ mappings_tmp = json.load(f)
120
+ ALLdata.update(mappings_tmp)
121
+ return ALLdata
122
+
123
+ def get_3D_volume(self, volume, select_channel = None):
124
+ # Get a 3D volume from the 4D volume, sometime the input image may have 4 dimensions
125
+ if self.reverse_axis_order:
126
+ volume = reverse_axis_order(volume)
127
+ if volume.ndim == 4:
128
+ if select_channel is None:
129
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
130
+ volume = volume[:, :, :, select_channel]
131
+ return volume
132
+
133
+ def get_filter_mindim(self):
134
+ # Filter out images with dimensions less than min_dim
135
+ # Top priority is to filter out images with dimensions less than min_dim
136
+ ALLdata = self.ALLdata.copy()
137
+ for k in self.ALLdata.keys():
138
+ if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
139
+ del ALLdata[k]
140
+ return ALLdata
141
+
142
+ def normalize(self, volume, eps=1e-7):
143
+ # Normalize the image (0-1)
144
+ volume = volume.astype(np.float64)
145
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
146
+ return volume
147
+
148
+ def random_crop_3d(self, volume, crop_size=None):
149
+ # Fast random crop with optional padding using NumPy
150
+ d, h, w = volume.shape
151
+ if crop_size is None:
152
+ crop_size = self.out_sz
153
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
154
+
155
+ # Only pad if needed (avoid np.pad if not necessary)
156
+ pad_d = max(0, crop_d - d)
157
+ pad_h = max(0, crop_h - h)
158
+ pad_w = max(0, crop_w - w)
159
+ if pad_d or pad_h or pad_w:
160
+ pad_width = (
161
+ (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
162
+ (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
163
+ (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
164
+ )
165
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
166
+ d, h, w = volume.shape
167
+
168
+ # Crop indices
169
+ start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
170
+ start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
171
+ start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
172
+
173
+ # Use NumPy slicing (very fast)
174
+ return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
175
+
176
+ class OminiDataset_v1(Dataset):
177
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.2, reverse_axis_order = False):
178
+ self.mappings = mapping_files
179
+ self.ALLdata = self.combine_data()
180
+ self.out_sz = out_sz
181
+ self.reverse_axis_order = reverse_axis_order
182
+ self.min_crop_ratio = min_crop_ratio
183
+ self.crop_ratio_sample_order = 2
184
+ self.transform = transform
185
+ self.clamp_range = clamp_range
186
+ self.ndims = 3
187
+ # Start you filtering here
188
+ self.ALLdata_filtered = self.get_filter_mindim()
189
+
190
+
191
+ # self.min_dim = self.find_min_dim()
192
+
193
+ def find_min_dim(self):
194
+ # Find the minimum dimension of the images
195
+ min_dim = 100000
196
+ for k in self.ALLdata.keys():
197
+ value = self.ALLdata[k]
198
+ if min(value['Size']) < min_dim:
199
+ min_dim = min(value['Size'])
200
+ return min_dim
201
+
202
+ def random_crop_3d(self, volume, crop_size=None):
203
+ # Fast random crop with optional padding using NumPy
204
+ d, h, w = volume.shape
205
+ if crop_size is None:
206
+ crop_size = self.out_sz
207
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
208
+
209
+ # Only pad if needed (avoid np.pad if not necessary)
210
+ pad_d = max(0, crop_d - d)
211
+ pad_h = max(0, crop_h - h)
212
+ pad_w = max(0, crop_w - w)
213
+ if pad_d or pad_h or pad_w:
214
+ pad_width = (
215
+ (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
216
+ (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
217
+ (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
218
+ )
219
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
220
+ d, h, w = volume.shape
221
+
222
+ # Crop indices
223
+ start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
224
+ start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
225
+ start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
226
+
227
+ # Use NumPy slicing (very fast)
228
+ return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
229
+
230
+ def get_ALLdata(self):
231
+ # Return all data
232
+ return self.ALLdata
233
+
234
+ def get_3D_volume(self, volume, select_channel = None):
235
+ if self.reverse_axis_order:
236
+ volume = reverse_axis_order(volume)
237
+ if volume.ndim == 4:
238
+ if select_channel is None:
239
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
240
+ volume = volume[:, :, :, select_channel]
241
+ # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
242
+ return volume
243
+
244
+ def get_filter_ROI(self, key_word):
245
+ # Filter out images with a key word
246
+ ALLdata = self.ALLdata.copy()
247
+ for k in self.ALLdata.keys():
248
+ if key_word not in k["ROI"]:
249
+ del ALLdata[k]
250
+ return ALLdata
251
+
252
+ def get_filter_mindim(self):
253
+ # Filter out images with dimensions less than min_dim
254
+ # Top priority is to filter out images with dimensions less than min_dim
255
+ ALLdata = self.ALLdata.copy()
256
+ for k in self.ALLdata.keys():
257
+ if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
258
+ del ALLdata[k]
259
+ return ALLdata
260
+
261
+ def combine_data(self):
262
+ ALLdata = {}
263
+ for j in self.mappings.keys():
264
+ with open(self.mappings[j], 'r') as f:
265
+ mappings = json.load(f)
266
+ ALLdata.update(mappings)
267
+ return ALLdata
268
+
269
+ def __len__(self):
270
+ return len(self.ALLdata_filtered.keys())
271
+
272
+ def normalize(self, volume, eps=1e-7):
273
+ # Normalize the image (0-1)
274
+ volume = volume.astype(np.float64)
275
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
276
+ return volume
277
+
278
+ def __getitem__(self, idx):
279
+ key = list(self.ALLdata_filtered.keys())[idx]
280
+ if 0:
281
+ print(key)
282
+ volume = sitk.ReadImage(key)
283
+ volume = sitk.GetArrayFromImage(volume)
284
+ # if volume.ndim == 4:
285
+ volume = self.get_3D_volume(volume)
286
+
287
+ if self.clamp_range is not None:
288
+ modality = self.ALLdata_filtered[key].get("Modality", None)
289
+ if modality == "CT":
290
+ volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
291
+ volume = self.normalize(volume)
292
+
293
+ if self.min_crop_ratio is not None:
294
+ # print(f'before volume_shape: {volume.shape}')
295
+ # crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
296
+ crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
297
+ # crop_size = int(min(volume.shape) * crop_ratio)
298
+ crop_size = int(max(volume.shape) * crop_ratio)
299
+ volume = self.random_crop_3d(volume, crop_size)
300
+ volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
301
+
302
+
303
+ else:
304
+ volume = self.random_crop_3d(volume, self.out_sz)
305
+ volume = volume[None, :, :, :]
306
+
307
+ if self.transform is not None:
308
+ return self.transform(volume)
309
+
310
+ return volume
311
+
312
+ class OMDataset_indiv(Dataset):
313
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, reverse_axis_order = False):
314
+ # self.mappings = mapping_files
315
+ self.ALLdata = self.combine_data(mappings=mapping_files)
316
+ self.out_sz = out_sz
317
+ self.max_sz = out_sz*8
318
+ self.reverse_axis_order = reverse_axis_order
319
+ self.min_crop_ratio = min_crop_ratio
320
+ self.crop_ratio_sample_order = 2
321
+ self.transform = transform
322
+ self.clamp_range = clamp_range
323
+ self.ndims = 3
324
+
325
+ # Start you filtering here
326
+ # print(f"Filtering data with out_sz: {self.out_sz}, min_crop_ratio: {min_crop_ratio}")
327
+ print(f"Diffusion mode: Total data size before filtering: {len(self.ALLdata)}")
328
+ self.ALLdata_filtered = self.get_filter_mindim()
329
+ print(f"Diffusion mode: Filtered data size: {len(self.ALLdata_filtered)}")
330
+
331
+
332
+ # self.min_dim = self.find_min_dim()
333
+
334
+ def find_min_dim(self):
335
+ # Find the minimum dimension of the images
336
+ min_dim = 100000
337
+ for k in self.ALLdata.keys():
338
+ value = self.ALLdata[k]
339
+ if min(value['Size']) < min_dim:
340
+ min_dim = min(value['Size'])
341
+ return min_dim
342
+
343
+ def random_crop_3d(self, volume, crop_size=None):
344
+ # Fast random crop with optional padding using NumPy
345
+ d, h, w = volume.shape
346
+ if crop_size is None:
347
+ crop_size = self.out_sz
348
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
349
+
350
+ # Only pad if needed (avoid np.pad if not necessary)
351
+ pad_d = max(0, crop_d - d)
352
+ pad_h = max(0, crop_h - h)
353
+ pad_w = max(0, crop_w - w)
354
+ if pad_d or pad_h or pad_w:
355
+ pad_width = (
356
+ (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
357
+ (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
358
+ (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
359
+ )
360
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
361
+ d, h, w = volume.shape
362
+
363
+ # Crop indices
364
+ start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
365
+ start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
366
+ start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
367
+
368
+ # Use NumPy slicing (very fast)
369
+ return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
370
+
371
+ def get_ALLdata(self):
372
+ # Return all data
373
+ return self.ALLdata
374
+
375
+ def get_3D_volume(self, volume, select_channel = None):
376
+ if self.reverse_axis_order:
377
+ volume = reverse_axis_order(volume)
378
+ if volume.ndim == 4:
379
+ if select_channel is None:
380
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
381
+ volume = volume[:, :, :, select_channel]
382
+ # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
383
+ return volume
384
+
385
+ def get_filter_ROI(self, key_word):
386
+ # Filter out images with a key word
387
+ ALLdata = self.ALLdata.copy()
388
+ for k in self.ALLdata.keys():
389
+ if key_word not in k["ROI"]:
390
+ del ALLdata[k]
391
+ return ALLdata
392
+
393
+ def get_filter_mindim(self):
394
+ # Filter out images with dimensions less than min_dim
395
+ # Top priority is to filter out images with dimensions less than min_dim
396
+ ALLdata = self.ALLdata.copy()
397
+ for k in self.ALLdata.keys():
398
+ if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
399
+ del ALLdata[k]
400
+ return ALLdata
401
+
402
+ def combine_data(self, mappings = mapping_files):
403
+ ALLdata = {}
404
+ for j in mappings.keys():
405
+ with open(mappings[j], 'r') as f:
406
+ mappings_tmp = json.load(f)
407
+ ALLdata.update(mappings_tmp)
408
+ return ALLdata
409
+
410
+ def __len__(self):
411
+ return len(self.ALLdata_filtered.keys())
412
+
413
+ def normalize(self, volume, eps=1e-7):
414
+ # Normalize the image (0-1)
415
+ volume = volume.astype(np.float64)
416
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
417
+ return volume
418
+
419
+ def __getitem__(self, idx):
420
+ key = list(self.ALLdata_filtered.keys())[idx]
421
+ embd = self.ALLdata_filtered[key]['embd']
422
+ embd = np.array(embd, dtype=np.float32)
423
+
424
+ if 0:
425
+ print(key)
426
+ volume = sitk.ReadImage(key)
427
+ volume = sitk.GetArrayFromImage(volume)
428
+ # if volume.ndim == 4:
429
+ volume = self.get_3D_volume(volume)
430
+
431
+ if self.clamp_range is not None:
432
+ modality = self.ALLdata_filtered[key].get("Modality", None)
433
+ if modality == "CT":
434
+ volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
435
+ volume = self.normalize(volume)
436
+
437
+ if self.min_crop_ratio is not None:
438
+ # print(f'before volume_shape: {volume.shape}')
439
+ # crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
440
+ crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
441
+ # crop_size = int(min(volume.shape) * crop_ratio)
442
+ crop_size = int(max(volume.shape) * crop_ratio)
443
+ crop_size = min(crop_size, self.max_sz)
444
+ volume = self.random_crop_3d(volume, crop_size)
445
+ volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
446
+
447
+
448
+ else:
449
+ volume = self.random_crop_3d(volume, self.out_sz)
450
+ volume = volume[None, :, :, :]
451
+
452
+ if self.transform is not None:
453
+ return self.transform(volume)
454
+
455
+ return [volume, embd]
456
+
457
+ class OminiDataset_paired(Dataset):
458
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.9, ROIs = None, modality = None, reverse_axis_order = False):
459
+ # self.mappings = mapping_files
460
+ self.ALLdata = self.combine_data(mappings=mapping_files)
461
+ self.out_sz = out_sz
462
+ self.sz_range = get_sizeRange_dict()
463
+ self.min_dim_ratio = 0.5
464
+ self.reverse_axis_order = reverse_axis_order
465
+ self.min_crop_ratio = min_crop_ratio
466
+ self.transform = transform
467
+ self.clamp_range = clamp_range
468
+ self.ndims = 3
469
+ # Start you filtering here
470
+ # print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
471
+ self.ALLdata_filtered = self.get_filter_mindim()
472
+ # print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
473
+ self.ALLdata_filtered = self.get_filter_modality(modality)
474
+ # print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
475
+ if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
476
+ self.ROIs = self.get_all_ROI()
477
+ else:
478
+ self.ROIs = ROIs
479
+ self.ALLdata_filtered = self.get_filter_ROIs()
480
+ # print(f"Number of images after ROI filtering: {len(self.ALLdata_filtered.keys())}")
481
+ # filtering ends here
482
+
483
+
484
+
485
+ def combine_data(self, mappings = mapping_files):
486
+ ALLdata = {}
487
+ for j in mappings.keys():
488
+ with open(mappings[j], 'r') as f:
489
+ mappings_tmp = json.load(f)
490
+ ALLdata.update(mappings_tmp)
491
+ return ALLdata
492
+
493
+ def normalize(self, volume, eps=1e-7):
494
+ # Normalize the image (0-1)
495
+ volume = volume.astype(np.float64)
496
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
497
+ return volume
498
+
499
+ def random_crop_3d(self, volume, crop_size=None):
500
+ # Fast random crop with optional padding using NumPy
501
+ d, h, w = volume.shape
502
+ if crop_size is None:
503
+ crop_size = self.out_sz
504
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
505
+
506
+ # Only pad if needed (avoid np.pad if not necessary)
507
+ pad_d = max(0, crop_d - d)
508
+ pad_h = max(0, crop_h - h)
509
+ pad_w = max(0, crop_w - w)
510
+ if pad_d or pad_h or pad_w:
511
+ pad_width = (
512
+ (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
513
+ (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
514
+ (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
515
+ )
516
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
517
+ d, h, w = volume.shape
518
+
519
+ # Crop indices
520
+ start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
521
+ start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
522
+ start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
523
+
524
+ # Use NumPy slicing (very fast)
525
+ return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
526
+
527
+ # def random_crop_3d(self, volume, crop_size=None):
528
+ # # Randomly crop the image
529
+ # d, h, w = volume.shape
530
+ # if crop_size is None:
531
+ # crop_size = self.out_sz
532
+ # crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
533
+
534
+ # if crop_d > d or crop_h > h or crop_w > w:
535
+ # raise ValueError("Crop size must be smaller than the original array size")
536
+
537
+ # start_d = np.random.randint(0, d - crop_d + 1)
538
+ # start_h = np.random.randint(0, h - crop_h + 1)
539
+ # start_w = np.random.randint(0, w - crop_w + 1)
540
+
541
+ # cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
542
+
543
+ # return cropped_array
544
+
545
+ def get_all_ROI(self):
546
+ # Get all the ROI options. and remove the reduntant ones
547
+ ROIs = []
548
+ for k in self.ALLdata_filtered.keys():
549
+ ROIs.append(self.ALLdata[k]['ROI'])
550
+ ROIs = set(ROIs)
551
+ return ROIs
552
+
553
+ def find_min_dim(self):
554
+ # Find the minimum dimension of the images
555
+ min_dim = 100000
556
+ for k in self.ALLdata.keys():
557
+ value = self.ALLdata[k]
558
+ if min(value['Size']) < min_dim:
559
+ min_dim = min(value['Size'])
560
+ return min_dim
561
+
562
+ def get_ALLdata(self):
563
+ # Return all data
564
+ return self.ALLdata
565
+
566
+ def get_filter_modality(self, key_words=None):
567
+ ALLdata_filtered = self.ALLdata_filtered.copy()
568
+ if key_words is not None:
569
+ for k in self.ALLdata_filtered.keys():
570
+ if ALLdata_filtered[k]["Modality"] not in key_words:
571
+ del ALLdata_filtered[k]
572
+ return ALLdata_filtered
573
+
574
+ def get_filter_ROI(self, key_word):
575
+ # Filter out images with a key word
576
+ ALLdata_filtered = self.ALLdata_filtered.copy()
577
+ for k in self.ALLdata_filtered.keys():
578
+ if key_word not in k["ROI"]:
579
+ del ALLdata_filtered[k]
580
+ return ALLdata_filtered
581
+
582
+ def get_key_by_ROI(self, key_word):
583
+ # Get all the keys with a key word
584
+ keys = []
585
+ for k in self.ALLdata_filtered.keys():
586
+ if key_word == self.ALLdata_filtered[k]["ROI"]:
587
+ keys.append(k)
588
+ return keys
589
+
590
+ def get_filter_ROIs(self):
591
+ ALLdata_filtered = self.ALLdata_filtered.copy()
592
+ for k in self.ALLdata_filtered.keys():
593
+ if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
594
+ del ALLdata_filtered[k]
595
+ return ALLdata_filtered
596
+
597
+ def get_3D_volume(self, volume, select_channel = None):
598
+ if self.reverse_axis_order:
599
+ volume = reverse_axis_order(volume)
600
+ if volume.ndim == 4:
601
+ if select_channel is None:
602
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
603
+ volume = volume[:, :, :, select_channel]
604
+ return volume
605
+
606
+ def get_filter_mindim(self):
607
+ # Filter out images with dimensions less than min_dim
608
+ # Top priority is to filter out images with dimensions less than min_dim
609
+ ALLdata = self.ALLdata.copy()
610
+ for k in self.ALLdata.keys():
611
+ img_sz = self.ALLdata[k]['Size'][:self.ndims]
612
+ del_flag = False
613
+ del_flag = del_flag or min(img_sz) < self.out_sz
614
+ # print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
615
+ # print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
616
+ del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
617
+ del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
618
+ # del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
619
+ if del_flag:
620
+ del ALLdata[k]
621
+ return ALLdata
622
+
623
+
624
+
625
+ def __getitem__(self,idx):
626
+ key = list(self.ALLdata_filtered.keys())[idx]
627
+ volume_A = sitk.ReadImage(key)
628
+ volume_A = sitk.GetArrayFromImage(volume_A)
629
+
630
+ paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
631
+ paired_key = random.choice(paired_keys)
632
+
633
+ volume_B = sitk.ReadImage(paired_key)
634
+ volume_B = sitk.GetArrayFromImage(volume_B)
635
+
636
+ # if volume_A.ndim == 4 or volume_B.ndim == 4:
637
+ volume_A = self.get_3D_volume(volume_A)
638
+ volume_B = self.get_3D_volume(volume_B)
639
+
640
+ if self.clamp_range is not None:
641
+ modality = self.ALLdata_filtered[key].get("Modality", None)
642
+ if modality == "CT":
643
+ volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
644
+ volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
645
+ volume_A = self.normalize(volume_A)
646
+ volume_B = self.normalize(volume_B)
647
+
648
+ if self.min_crop_ratio is not None:
649
+
650
+ # print(f'before volume_shape: {volume.shape}')
651
+ crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
652
+ crop_size_A = int(min(volume_A.shape) * crop_ratio)
653
+ crop_size_B = int(min(volume_B.shape) * crop_ratio)
654
+ # crop_size_A = int(max(volume_A.shape) * crop_ratio)
655
+ # crop_size_B = int(max(volume_B.shape) * crop_ratio)
656
+ volume_A = self.random_crop_3d(volume_A, crop_size_A)
657
+ volume_B = self.random_crop_3d(volume_B, crop_size_B)
658
+ volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
659
+ volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
660
+
661
+ else:
662
+ volume_A = self.random_crop_3d(volume_A, self.out_sz)
663
+ volume_B = self.random_crop_3d(volume_B, self.out_sz)
664
+ volume_A = volume_A[None, :, :, :]
665
+ volume_B = volume_B[None, :, :, :]
666
+
667
+ if self.transform is not None:
668
+ return self.transform(volume_A), self.transform(volume_B)
669
+
670
+ # print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
671
+ return volume_A, volume_B
672
+
673
+ def __len__(self):
674
+ return len(self.ALLdata_filtered.keys())
675
+
676
+ class OMDataset_pair(Dataset):
677
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = indivi_ROI_list, modality = None, reverse_axis_order = False):
678
+ # self.mappings = mapping_files
679
+ self.ALLdata = self.combine_data(mappings=mapping_files)
680
+ self.out_sz = out_sz
681
+ self.max_sz = out_sz*8
682
+ self.sz_range = get_sizeRange_dict()
683
+ self.min_dim_ratio = 0.7
684
+ self.reverse_axis_order = reverse_axis_order
685
+ self.min_crop_ratio = min_crop_ratio
686
+ self.transform = transform
687
+ self.clamp_range = clamp_range
688
+ self.ndims = 3
689
+ # Start you filtering here
690
+ # print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
691
+ print(f"Registration mode: Total data size before filtering: {len(self.ALLdata)}")
692
+
693
+ self.ALLdata_filtered = self.get_filter_mindim()
694
+ # print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
695
+ self.ALLdata_filtered = self.get_filter_modality(modality)
696
+ # print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
697
+ if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
698
+ self.ROIs = self.get_all_ROI()
699
+ else:
700
+ self.ROIs = ROIs
701
+ self.ALLdata_filtered = self.get_filter_ROIs()
702
+ print(f"Registration mode: Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
703
+ # filtering ends here
704
+
705
+
706
+
707
+ def combine_data(self, mappings = mapping_files):
708
+ ALLdata = {}
709
+ for j in mappings.keys():
710
+ with open(mappings[j], 'r') as f:
711
+ mappings_tmp = json.load(f)
712
+ ALLdata.update(mappings_tmp)
713
+ return ALLdata
714
+
715
+ def normalize(self, volume, eps=1e-7):
716
+ # Normalize the image (0-1)
717
+ volume = volume.astype(np.float64)
718
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
719
+ return volume
720
+
721
+ def random_crop_3d(self, volume, crop_size=None):
722
+ # Fast random crop with optional padding using NumPy
723
+ d, h, w = volume.shape
724
+ if crop_size is None:
725
+ crop_size = self.out_sz
726
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
727
+
728
+ # Only pad if needed (avoid np.pad if not necessary)
729
+ pad_d = max(0, crop_d - d)
730
+ pad_h = max(0, crop_h - h)
731
+ pad_w = max(0, crop_w - w)
732
+ if pad_d or pad_h or pad_w:
733
+ pad_width = (
734
+ (np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
735
+ (np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
736
+ (np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
737
+ )
738
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
739
+ d, h, w = volume.shape
740
+
741
+ # Crop indices
742
+ start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
743
+ start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
744
+ start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
745
+
746
+ # Use NumPy slicing (very fast)
747
+ return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
748
+
749
+ # def random_crop_3d(self, volume, crop_size=None):
750
+ # # Randomly crop the image
751
+ # d, h, w = volume.shape
752
+ # if crop_size is None:
753
+ # crop_size = self.out_sz
754
+ # crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
755
+
756
+ # if crop_d > d or crop_h > h or crop_w > w:
757
+ # raise ValueError("Crop size must be smaller than the original array size")
758
+
759
+ # start_d = np.random.randint(0, d - crop_d + 1)
760
+ # start_h = np.random.randint(0, h - crop_h + 1)
761
+ # start_w = np.random.randint(0, w - crop_w + 1)
762
+
763
+ # cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
764
+
765
+ # return cropped_array
766
+
767
+ def get_all_ROI(self):
768
+ # Get all the ROI options. and remove the reduntant ones
769
+ ROIs = []
770
+ for k in self.ALLdata_filtered.keys():
771
+ ROIs.append(self.ALLdata[k]['ROI'])
772
+ ROIs = set(ROIs)
773
+ return ROIs
774
+
775
+ def find_min_dim(self):
776
+ # Find the minimum dimension of the images
777
+ min_dim = 100000
778
+ for k in self.ALLdata.keys():
779
+ value = self.ALLdata[k]
780
+ if min(value['Size']) < min_dim:
781
+ min_dim = min(value['Size'])
782
+ return min_dim
783
+
784
+ def get_ALLdata(self):
785
+ # Return all data
786
+ return self.ALLdata
787
+
788
+ def get_filter_modality(self, key_words=None):
789
+ ALLdata_filtered = self.ALLdata_filtered.copy()
790
+ if key_words is not None:
791
+ for k in self.ALLdata_filtered.keys():
792
+ if ALLdata_filtered[k]["Modality"] not in key_words:
793
+ del ALLdata_filtered[k]
794
+ return ALLdata_filtered
795
+
796
+ def get_filter_ROI(self, key_word):
797
+ # Filter out images with a key word
798
+ ALLdata_filtered = self.ALLdata_filtered.copy()
799
+ for k in self.ALLdata_filtered.keys():
800
+ if key_word not in k["ROI"]:
801
+ del ALLdata_filtered[k]
802
+ return ALLdata_filtered
803
+
804
+ def get_key_by_ROI(self, key_word):
805
+ # Get all the keys with a key word
806
+ keys = []
807
+ for k in self.ALLdata_filtered.keys():
808
+ if key_word == self.ALLdata_filtered[k]["ROI"]:
809
+ keys.append(k)
810
+ return keys
811
+
812
+ def filter_keys_by_xx(self, key_word, keys=None, term="ROI"):
813
+ # Filter out images with a key word
814
+ filtered_keys = []
815
+ if keys is None:
816
+ keys = self.ALLdata_filtered.keys()
817
+ for k in keys:
818
+ value = self.ALLdata_filtered[k].get(term, None)
819
+ if value is not None and key_word == value:
820
+ filtered_keys.append(k)
821
+ return filtered_keys
822
+
823
+ def get_filter_ROIs(self):
824
+ ALLdata_filtered = self.ALLdata_filtered.copy()
825
+ for k in self.ALLdata_filtered.keys():
826
+ if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
827
+ del ALLdata_filtered[k]
828
+ return ALLdata_filtered
829
+
830
+ def get_3D_volume(self, volume, select_channel = None):
831
+ if self.reverse_axis_order:
832
+ volume = reverse_axis_order(volume)
833
+ if volume.ndim == 4:
834
+ if select_channel is None:
835
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
836
+ volume = volume[:, :, :, select_channel]
837
+ return volume
838
+
839
+ def get_filter_mindim(self):
840
+ # Filter out images with dimensions less than min_dim
841
+ # Top priority is to filter out images with dimensions less than min_dim
842
+ ALLdata = self.ALLdata.copy()
843
+ for k in self.ALLdata.keys():
844
+ img_sz = self.ALLdata[k]['Size'][:self.ndims]
845
+ del_flag = False
846
+ del_flag = del_flag or min(img_sz) < self.out_sz
847
+ # print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
848
+ # print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
849
+ del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
850
+ del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
851
+ # del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
852
+ if del_flag:
853
+ del ALLdata[k]
854
+ return ALLdata
855
+
856
+
857
+
858
+ def __getitem__(self,idx):
859
+ key = list(self.ALLdata_filtered.keys())[idx]
860
+ volume_A = sitk.ReadImage(key)
861
+ volume_A = sitk.GetArrayFromImage(volume_A)
862
+
863
+ embd_A = self.ALLdata_filtered[key]['embd']
864
+ embd_A = np.array(embd_A, dtype=np.float32)
865
+
866
+ all_keys = list(self.ALLdata_filtered.keys())
867
+ paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['ROI'], all_keys, term="ROI")
868
+ paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['Modality'], paired_keys, term="Modality")
869
+ # paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
870
+
871
+ paired_key = random.choice(paired_keys)
872
+
873
+ print(f"Key: {key}, Paired Key: {paired_key}")
874
+ print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
875
+
876
+
877
+ volume_B = sitk.ReadImage(paired_key)
878
+ volume_B = sitk.GetArrayFromImage(volume_B)
879
+
880
+ embd_B = self.ALLdata_filtered[paired_key]['embd']
881
+ embd_B = np.array(embd_B, dtype=np.float32)
882
+
883
+ # if volume_A.ndim == 4 or volume_B.ndim == 4:
884
+ volume_A = self.get_3D_volume(volume_A)
885
+ volume_B = self.get_3D_volume(volume_B)
886
+
887
+ if self.clamp_range is not None:
888
+ modality = self.ALLdata_filtered[key].get("Modality", None)
889
+ if modality == "CT":
890
+ volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
891
+ volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
892
+ volume_A = self.normalize(volume_A)
893
+ volume_B = self.normalize(volume_B)
894
+
895
+ if self.min_crop_ratio is not None:
896
+
897
+ # print(f'before volume_shape: {volume.shape}')
898
+ crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
899
+ # crop_size_A = int(min(volume_A.shape) * crop_ratio)
900
+ # crop_size_B = int(min(volume_B.shape) * crop_ratio)
901
+ crop_size_A = int(max(volume_A.shape) * crop_ratio)
902
+ crop_size_B = int(max(volume_B.shape) * crop_ratio)
903
+ crop_size_A = min(crop_size_A, self.max_sz)
904
+ crop_size_B = min(crop_size_B, self.max_sz)
905
+ volume_A = self.random_crop_3d(volume_A, crop_size_A)
906
+ volume_B = self.random_crop_3d(volume_B, crop_size_B)
907
+ volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
908
+ volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
909
+
910
+ else:
911
+ volume_A = self.random_crop_3d(volume_A, self.out_sz)
912
+ volume_B = self.random_crop_3d(volume_B, self.out_sz)
913
+ volume_A = volume_A[None, :, :, :]
914
+ volume_B = volume_B[None, :, :, :]
915
+
916
+
917
+ if self.transform is not None:
918
+ return self.transform(volume_A), self.transform(volume_B)
919
+
920
+ # print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
921
+ return [volume_A, volume_B, embd_A, embd_B]
922
+
923
+ def __len__(self):
924
+ return len(self.ALLdata_filtered.keys())
925
+
926
+ class OminiDataset_paired_inf(object):
927
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, ROIs = None):
928
+ # self.mappings = mapping_files
929
+ self.ALLdata = self.combine_data(mappings=mapping_files)
930
+ self.out_sz = out_sz
931
+ self.min_crop_ratio = min_crop_ratio
932
+ self.transform = transform
933
+ self.clamp_range = clamp_range
934
+ self.ndims = 3
935
+ # Start you filtering here:
936
+ # filter out images with dimensions less than min_dim
937
+ self.ALLdata_filtered = self.get_filter_mindim()
938
+ # filter out images with ROIs that are not in the provided ROIs
939
+ if ROIs is None:
940
+ self.ROIs = self.get_all_ROI()
941
+ else:
942
+ self.ROIs = ROIs
943
+ self.ALLdata_filtered = self.get_filter_ROIs()
944
+ # filtering ends here
945
+
946
+ self.roi_scan_mapping = self.build_ROI_scan_mapping()
947
+ self.keys_dist, self.total = self.get_keys_dist()
948
+
949
+
950
+
951
+
952
+ def get_all_ROI(self):
953
+ # Get all the ROI options. and remove the reduntant ones
954
+ ROIs = []
955
+ for k in self.ALLdata_filtered.keys():
956
+ ROIs.append(self.ALLdata[k]['ROI'])
957
+ ROIs = set(ROIs)
958
+ return ROIs
959
+
960
+ def get_ALLdata(self):
961
+ # Return all data
962
+ return self.ALLdata
963
+
964
+ def combine_data(self, mappings = mapping_files):
965
+ ALLdata = {}
966
+ for j in mappings.keys():
967
+ with open(mappings[j], 'r') as f:
968
+ mappings_tmp = json.load(f)
969
+ ALLdata.update(mappings_tmp)
970
+ return ALLdata
971
+
972
+ def __len__(self):
973
+ return len(self.ALLdata_filtered.keys())
974
+
975
+ def random_crop_3d(self, volume, crop_size=None):
976
+ # Randomly crop the image
977
+ d, h, w = volume.shape
978
+ if crop_size is None:
979
+ crop_size = self.out_sz
980
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
981
+
982
+ if crop_d > d or crop_h > h or crop_w > w:
983
+ raise ValueError("Crop size must be smaller than the original array size")
984
+
985
+ start_d = np.random.randint(0, d - crop_d + 1)
986
+ start_h = np.random.randint(0, h - crop_h + 1)
987
+ start_w = np.random.randint(0, w - crop_w + 1)
988
+
989
+ cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
990
+
991
+ return cropped_array
992
+
993
+ def normalize(self, volume, eps=1e-7):
994
+ # Normalize the image (0-1)
995
+ volume = volume.astype(np.float64)
996
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
997
+ return volume
998
+
999
+ def get_3D_volume(self, volume, select_channel = None):
1000
+ volume = reverse_axis_order(volume)
1001
+ if volume.ndim == 4:
1002
+ if select_channel is None:
1003
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
1004
+ volume = volume[:, :, :, select_channel]
1005
+ return volume
1006
+
1007
+ def get_filter_mindim(self):
1008
+ # Filter out images with dimensions less than min_dim
1009
+ # Top priority is to filter out images with dimensions less than min_dim
1010
+ ALLdata = self.ALLdata.copy()
1011
+ for k in self.ALLdata.keys():
1012
+ if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
1013
+ del ALLdata[k]
1014
+ return ALLdata
1015
+
1016
+ def get_filter_ROI(self, key_word):
1017
+ # Filter out images with a key word
1018
+ ALLdata_filtered = self.ALLdata_filtered.copy()
1019
+ for k in self.ALLdata_filtered.keys():
1020
+ if key_word not in k["ROI"]:
1021
+ del ALLdata_filtered[k]
1022
+ return ALLdata_filtered
1023
+
1024
+
1025
+ def get_filter_ROIs(self):
1026
+ ALLdata_filtered = self.ALLdata_filtered.copy()
1027
+ for k in self.ALLdata_filtered.keys():
1028
+ if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
1029
+ del ALLdata_filtered[k]
1030
+ return ALLdata_filtered
1031
+
1032
+ def get_keys_dist(self):
1033
+ ROIs = self.get_all_ROI()
1034
+ keys_dist = {}
1035
+ total = 0
1036
+ for item in self.ALLdata_filtered.keys():
1037
+ if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
1038
+ keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
1039
+ keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
1040
+
1041
+ return keys_dist, total
1042
+
1043
+ def build_ROI_scan_mapping(self):
1044
+ # Build a mapping of ROIs to scans
1045
+ ROI_scan_mapping = {}
1046
+ for item in self.ALLdata_filtered.keys():
1047
+ if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
1048
+ ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
1049
+ ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
1050
+ return ROI_scan_mapping
1051
+
1052
+
1053
+ def get_random_2_items(self, mode = 'uniform'):
1054
+ # Get a random pair of items from the dataset with the same ROI
1055
+ if mode == 'uniform':
1056
+ idx = random.randint(0, len(self.keys_dist.keys()) - 1)
1057
+ key = list(self.keys_dist.keys())[idx]
1058
+ path_1 = random.choice(self.roi_scan_mapping[key])
1059
+ path_2 = random.choice(self.roi_scan_mapping[key])
1060
+
1061
+ volume_A = sitk.ReadImage(path_1)
1062
+ volume_A = sitk.GetArrayFromImage(volume_A)
1063
+
1064
+ volume_B = sitk.ReadImage(path_2)
1065
+ volume_B = sitk.GetArrayFromImage(volume_B)
1066
+
1067
+ if self.clamp_range is not None:
1068
+ modality = self.ALLdata_filtered[key].get("Modality", None)
1069
+ if modality == "CT":
1070
+ volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
1071
+ volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
1072
+ volume_A = self.normalize(volume_A)
1073
+ volume_B = self.normalize(volume_B)
1074
+
1075
+ if self.min_crop_ratio is not None:
1076
+ crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
1077
+ crop_size_A = int(min(volume_A.shape) * crop_ratio)
1078
+ crop_size_B = int(min(volume_B.shape) * crop_ratio)
1079
+ volume_A = self.random_crop_3d(volume_A, crop_size_A)
1080
+ volume_B = self.random_crop_3d(volume_B, crop_size_B)
1081
+ volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
1082
+ volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
1083
+ else:
1084
+ volume_A = self.radndom_crop_3d(volume_A, self.out_sz)
1085
+ volume_B = self.radndom_crop_3d(volume_B, self.out_sz)
1086
+ volume_A = volume_A[None, :, :, :]
1087
+ volume_B = volume_B[None, :, :, :]
1088
+ if self.transform is not None:
1089
+ return self.transform(volume_A), self.transform(volume_B)
1090
+ return volume_A, volume_B
1091
+
1092
+ elif mode == 'original':
1093
+ pass
1094
+
1095
+ def build_batch(self, batch_size = 2):
1096
+ batch_1 = []
1097
+ batch_2 = []
1098
+ for i in range(batch_size):
1099
+ V_a, V_b = self.get_random_2_items()
1100
+ batch_1.append(V_a)
1101
+ batch_2.append(V_b)
1102
+ return np.array(batch_1), np.array(batch_2)
1103
+
1104
+ class OminiDataset_inference_w_all(object):
1105
+ def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = None, label_key = ['brain'], task_key = 'segmentation', database = None, select_channels_dict = {}):
1106
+ self.mappings = mapping_files
1107
+ # database=['MSD', 'TotalSegmentor']
1108
+ if database is not None:
1109
+ self.mappings = {db: self.mappings[db] for db in database if db in self.mappings}
1110
+ # select_channels_dict={
1111
+ # "ImgDict":["ed","es"]
1112
+ # }
1113
+ self.select_channels_dict = select_channels_dict
1114
+ self.ALLdata = self.combine_data(mappings=self.mappings)
1115
+ self.out_sz = out_sz
1116
+ self.label_key = label_key
1117
+ self.min_crop_ratio = min_crop_ratio
1118
+ self.transform = transform
1119
+ self.clamp_range = clamp_range
1120
+ self.ndims = 3
1121
+ self.is_reverse_axis_order = True # for inference, always reverse axis order (nifty is reverse order than numpy)
1122
+
1123
+ # Start you filtering here:
1124
+ # self.ALLdata_filtered = self.ALLdata.copy()
1125
+ # filter out images with dimensions less than min_dim
1126
+ self.ALLdata_filtered = self.get_filter_mindim()
1127
+ # filter out images with ROIs that are not in the provided ROIs
1128
+ if ROIs is None:
1129
+ self.ROIs = self.get_all_ROI()
1130
+ else:
1131
+ self.ROIs = ROIs
1132
+ self.ALLdata_filtered = self.get_filter_ROIs()
1133
+ self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
1134
+ # filtering ends here
1135
+
1136
+ self.roi_scan_mapping = self.build_ROI_scan_mapping()
1137
+ self.keys_dist, self.total = self.get_keys_dist()
1138
+
1139
+
1140
+
1141
+ def get_all_ROI(self):
1142
+ # Get all the ROI options. and remove the reduntant ones
1143
+ ROIs = []
1144
+ for k in self.ALLdata_filtered.keys():
1145
+ ROIs.append(self.ALLdata[k]['ROI'])
1146
+ ROIs = set(ROIs)
1147
+ return ROIs
1148
+
1149
+ def get_keys_dist(self):
1150
+ ROIs = self.get_all_ROI()
1151
+ keys_dist = {}
1152
+ total = 0
1153
+ for item in self.ALLdata_filtered.keys():
1154
+ if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
1155
+ keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
1156
+ keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
1157
+
1158
+ return keys_dist, total
1159
+
1160
+ def build_ROI_scan_mapping(self):
1161
+ # Build a mapping of ROIs to scans
1162
+ ROI_scan_mapping = {}
1163
+ for item in self.ALLdata_filtered.keys():
1164
+ if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
1165
+ ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
1166
+ ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
1167
+ return ROI_scan_mapping
1168
+
1169
+ def get_3D_volume(self, volume, select_channel = None):
1170
+ volume = reverse_axis_order(volume) if self.is_reverse_axis_order else volume
1171
+ if volume.ndim == 4:
1172
+ if select_channel is None:
1173
+ select_channel = np.random.randint(0, volume.shape[3] - 1)
1174
+ volume = volume[:, :, :, select_channel]
1175
+ # print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
1176
+ return volume
1177
+
1178
+ def get_filter_mindim(self):
1179
+ # Filter out images with dimensions less than min_dim
1180
+ # Top priority is to filter out images with dimensions less than min_dim
1181
+ ALLdata = self.ALLdata.copy()
1182
+ for k in self.ALLdata.keys():
1183
+ if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
1184
+ del ALLdata[k]
1185
+ return ALLdata
1186
+
1187
+ def find_min_dim(self):
1188
+ # Find the minimum dimension of the images
1189
+ min_dim = 100000
1190
+ for k in self.ALLdata.keys():
1191
+ value = self.ALLdata[k]
1192
+ if min(value['Size']) < min_dim:
1193
+ min_dim = min(value['Size'])
1194
+ return min_dim
1195
+
1196
+ # def combine_data(self):
1197
+ # ALLdata = {}
1198
+ # for j in self.mappings.keys():
1199
+ # with open(self.mappings[j], 'r') as f:
1200
+ # mappings = json.load(f)
1201
+ # ALLdata.update(mappings)
1202
+ # return ALLdata
1203
+
1204
+ def combine_data(self, mappings = mapping_files):
1205
+ ALLdata = {}
1206
+ for j in mappings.keys():
1207
+ with open(mappings[j], 'r') as f:
1208
+ mappings_tmp = json.load(f)
1209
+ ALLdata.update(mappings_tmp)
1210
+ return ALLdata
1211
+
1212
+ def normalize(self, volume, eps=1e-7):
1213
+ # Normalize the image (0-1)
1214
+ volume = volume.astype(np.float64)
1215
+ volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
1216
+ return volume
1217
+
1218
+ def get_key_by_ROI(self, key_word):
1219
+ # Get all the keys with a key word
1220
+ keys = []
1221
+ for k in self.ALLdata_filtered.keys():
1222
+ if key_word == self.ALLdata_filtered[k]["ROI"]:
1223
+ keys.append(k)
1224
+ return keys
1225
+
1226
+ def get_filter_task(self, task_key = 'segmentation'):
1227
+ # Filter out images with task type that are not in the provided labels_path
1228
+ ALLdata_filtered = self.ALLdata_filtered.copy()
1229
+ for k in self.ALLdata_filtered.keys():
1230
+ if 'Label_path' not in self.ALLdata_filtered[k] or task_key not in self.ALLdata_filtered[k]['Label_path']:
1231
+ del ALLdata_filtered[k]
1232
+ Warning(f"Label path not found for {k} with task key {task_key}. This image will be removed from the dataset.")
1233
+ return ALLdata_filtered
1234
+
1235
+ def get_filter_labels(self, task_key='segmentation', label_keys=['heart']):
1236
+ # Filter out images where 'Label_path' does not contain any of the label_keys for the given task_key
1237
+ ALLdata_filtered = self.ALLdata_filtered.copy()
1238
+ keys_to_remove = []
1239
+ for k in list(ALLdata_filtered.keys()):
1240
+ label_path = ALLdata_filtered[k].get('Label_path', {})
1241
+ task_labels = label_path.get(task_key, {})
1242
+ # Check if any label_keys are present in task_labels
1243
+ # print(f"Checking {k} for task key {task_labels.keys()} with label keys {label_keys}")
1244
+ has_any_label = any((tk in label_keys) for tk in task_labels.keys())
1245
+ # print(f"Has any label: {has_any_label}")
1246
+ if not has_any_label:
1247
+ keys_to_remove.append(k)
1248
+ # print(f"Label path not found for {k} with task key {task_key} and label keys {label_keys}. This image will be removed from the dataset.")
1249
+ for k in keys_to_remove:
1250
+ del ALLdata_filtered[k]
1251
+ return ALLdata_filtered
1252
+
1253
+ def get_random_pad_crop_params(self, volume_shape, crop_size=None, random=True):
1254
+ # Get random padding and cropping parameters for a given shape
1255
+ d, h, w = volume_shape[:3]
1256
+ if crop_size is None:
1257
+ crop_size = self.out_sz
1258
+ crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
1259
+
1260
+ # Calculate padding
1261
+ pad_width = []
1262
+ for size, crop in zip((d, h, w), (crop_d, crop_h, crop_w)):
1263
+ if crop > size:
1264
+ total_pad = crop - size
1265
+ pad_before = np.random.randint(0, total_pad + 1)
1266
+ pad_after = total_pad - pad_before
1267
+ pad_width.append((pad_before, pad_after))
1268
+ else:
1269
+ pad_width.append((0, 0))
1270
+
1271
+ # Update shape after padding
1272
+ d_p, h_p, w_p = d + pad_width[0][0] + pad_width[0][1], h + pad_width[1][0] + pad_width[1][1], w + pad_width[2][0] + pad_width[2][1]
1273
+
1274
+ if random:
1275
+ # Calculate cropping start indices (random crop)
1276
+ start_d = np.random.randint(0, d_p - crop_d + 1) if d_p > crop_d else 0
1277
+ start_h = np.random.randint(0, h_p - crop_h + 1) if h_p > crop_h else 0
1278
+ start_w = np.random.randint(0, w_p - crop_w + 1) if w_p > crop_w else 0
1279
+ else:
1280
+ # Calculate cropping start indices (center crop)
1281
+ start_d = max((d_p - crop_d) // 2, 0)
1282
+ start_h = max((h_p - crop_h) // 2, 0)
1283
+ start_w = max((w_p - crop_w) // 2, 0)
1284
+
1285
+ crop_slices = (start_d, start_h, start_w, crop_d, crop_h, crop_w)
1286
+ return pad_width, crop_slices
1287
+
1288
+ def apply_pad_crop(self, volume, pad_width, crop_slices):
1289
+ # Apply padding and cropping to the volume
1290
+ if any(pad != (0, 0) for pad in pad_width):
1291
+ volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
1292
+ start_d, start_h, start_w, crop_d, crop_h, crop_w = crop_slices
1293
+ cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
1294
+ return cropped_array
1295
+
1296
+ def get_filter_ROIs(self):
1297
+ ALLdata_filtered = self.ALLdata_filtered.copy()
1298
+ for k in self.ALLdata_filtered.keys():
1299
+ if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
1300
+ del ALLdata_filtered[k]
1301
+ return ALLdata_filtered
1302
+
1303
+ def get_channel_ids(self, key):
1304
+ """
1305
+ Get the indices where ImgDict values match the selected channels (e.g., 'ed', 'es').
1306
+
1307
+ Returns:
1308
+ list: List of integer indices matching the selected channels
1309
+ """
1310
+ img_dict = self.ALLdata_filtered[key].get("ImgDict", {})
1311
+ selected_values = self.select_channels_dict.get("ImgDict", [])
1312
+ # Build reverse mapping: value -> index
1313
+ value_to_idx = {value: int(idx) for idx, value in img_dict.items()}
1314
+
1315
+ # Get indices in the order of selected_values
1316
+ indices = [
1317
+ value_to_idx[val] for val in selected_values
1318
+ if val in value_to_idx
1319
+ ]
1320
+ return indices
1321
+ # return sorted(indices)
1322
+
1323
+ def __len__(self):
1324
+ return len(self.ALLdata_filtered.keys())
1325
+
1326
+ def __getitem__(self, idx):
1327
+ key = list(self.ALLdata_filtered.keys())[idx]
1328
+ return_dict = dict()
1329
+
1330
+ print(f"Processing key: {key}")
1331
+
1332
+ volume = sitk.ReadImage(key)
1333
+ volume = sitk.GetArrayFromImage(volume)
1334
+
1335
+ if volume.ndim == 4:
1336
+ channel_ids = self.get_channel_ids(key)
1337
+ if len(channel_ids) == 0:
1338
+ # warning message that this key has no matching channels
1339
+ Warning(f"No matching channels found for key: {key} with ImgDict: {self.ALLdata_filtered[key].get('ImgDict', {})} and selected channels: {self.select_channels_dict.get('ImgDict', [])}. Using random channel.")
1340
+ channel_id = None
1341
+ else:
1342
+ channel_id=channel_ids[0]
1343
+
1344
+ volume = self.get_3D_volume(volume, select_channel = channel_id)
1345
+
1346
+ if self.clamp_range is not None:
1347
+ modality = self.ALLdata_filtered[key].get("Modality", None)
1348
+ if modality == "CT":
1349
+ volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
1350
+ volume = self.normalize(volume)
1351
+
1352
+ crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
1353
+
1354
+ crop_size = int(max(volume.shape) * crop_ratio)
1355
+ pad_width, crop_slices = self.get_random_pad_crop_params(volume.shape, crop_size)
1356
+ # print(f"Pad width: {pad_width}, Crop slices: {crop_slices}, Original shape: {volume.shape}")
1357
+ volume = self.apply_pad_crop(volume, pad_width, crop_slices)
1358
+
1359
+ label_dict = dict()
1360
+ if 'Label_path' in self.ALLdata_filtered[key]:
1361
+ for lk in self.label_key:
1362
+ if lk in self.ALLdata_filtered[key]['Label_path']['segmentation'].keys():
1363
+ label = sitk.ReadImage(self.ALLdata_filtered[key]['Label_path']['segmentation'][lk])
1364
+ label = sitk.GetArrayFromImage(label)
1365
+ # print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
1366
+ label = reverse_axis_order(label) if self.is_reverse_axis_order else label
1367
+
1368
+ # print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
1369
+ if label.ndim > self.ndims:
1370
+ if len(channel_ids) != 0:
1371
+ label = label[...,channel_ids] # assuming channel last
1372
+ pad_width_lab = pad_width + [(0,0)]*(label.ndim - self.ndims)
1373
+ # print(f"Label with channels, pad_width_lab: {pad_width_lab}")
1374
+ else:
1375
+ pad_width_lab = pad_width
1376
+ label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
1377
+ # print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
1378
+ label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
1379
+ if label.ndim > self.ndims:
1380
+ if self.ndims==3:
1381
+ label_dict[lk] = np.transpose(label_dict[lk], (3,0,1,2)) # assuming channel last
1382
+ elif self.ndims==4:
1383
+ label_dict[lk] = np.transpose(label_dict[lk], (4,0,1,2,3)) # assuming channel last
1384
+ # print(f"After resize, label shape: {label_dict[lk].shape}, key: {key}, label key: {lk}")
1385
+ else:
1386
+ label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
1387
+ Warning(f"Label path not found for {key} with label key {lk}.")
1388
+ label_dict[lk] = label_dict[lk][None, :, :, :] if label_dict[lk].ndim == 3 else label_dict[lk]
1389
+ else:
1390
+ for lk in self.label_key:
1391
+ label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
1392
+ Warning(f"Label path not found for {key} with label key {lk}.")
1393
+ label_dict[lk] = label_dict[lk][None, :, :, :]
1394
+
1395
+ volume =resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
1396
+ # return_dict['labels'] = label_dict
1397
+ return_dict['labels'] = np.concatenate([v for v in label_dict.values()], axis=1)
1398
+
1399
+ return_dict['img'] = volume[None, :, :, :]
1400
+ return_dict['label_channels'] = list(self.select_channels_dict.get("ImgDict", []))
1401
+ return return_dict
1402
+
1403
+
1404
+ class OminiDataset_bertembd(OminiDataset):
1405
+ def __init__(self,
1406
+ out_sz = 128,
1407
+ transform=None,
1408
+ clamp_range = CLAMP_RANGE,
1409
+ min_crop_ratio = 0.85,
1410
+ ROIs = None,
1411
+ modality = None,
1412
+ reverse_axis_order = False,
1413
+ min_dim = 3,
1414
+ mapping_files = mapping_files):
1415
+ super().init(out_sz = out_sz,
1416
+ transform = transform,
1417
+ clamp_range = clamp_range,
1418
+ min_crop_ratio = min_crop_ratio,
1419
+ ROIs = ROIs,
1420
+ modality = modality,
1421
+ reverse_axis_order = reverse_axis_order,
1422
+ min_dim = min_dim,
1423
+ mapping_files=mapping_files)
1424
+ # start you filtering here
1425
+ self.ALLdata_filtered = self.get_filter_mindim()
1426
+ if ROIs is None:
1427
+ # if no ROIs are provided, get all the ROIs from filtered data
1428
+ self.ROIs = self.get_all_ROI()
1429
+ else:
1430
+ self.ROIs = ROIs
1431
+ self.ALLdata_filtered = self.get_filter_ROIs()
1432
+ # self.ALLdata_filtered = self.filter_embd()
1433
+ # self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
1434
+ # end your filtering here
1435
+ def __getitem__(self, idx):
1436
+ key = list(self.ALLdata_filtered.keys())[idx]
1437
+ embd = self.ALLdata_filtered[key]['embd']
1438
+ if 0:
1439
+ print(key)
1440
+
1441
+ volume = sitk.ReadImage(key)
1442
+ volume = sitk.GetArrayFromImage(volume)
1443
+ volume = self.get_3D_volume(volume)
1444
+
1445
+ if self.clamp_range is not None:
1446
+ modality = self.ALLdata_filtered[key].get("Modality", None)
1447
+ if modality == "CT":
1448
+ volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
1449
+ volume = self.normalize(volume)
1450
+
1451
+ if self.min_crop_ratio is not None:
1452
+ crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
1453
+ crop_size = int(max(volume.shape) * crop_ratio)
1454
+ volume = self.random_crop_3d(volume, crop_size)
1455
+ volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
1456
+ else:
1457
+ volume = self.random_crop_3d(volume, self.out_sz)
1458
+ volume = volume[None, :, :, :]
1459
+
1460
+ if self.transform is not None:
1461
+ return self.transform(volume)
1462
+
1463
+ return volume,np.array(embd)
1464
+
1465
+ def __len__(self):
1466
+ return len(self.ALLdata_filtered.keys())
1467
+
1468
+ def filter_embd(self):
1469
+ for k in self.ALLdata_filtered.keys():
1470
+ if 'BERT_embedding_keys' not in self.ALLdata_filtered[k]['Metadata']:
1471
+ del self.ALLdata_filtered[k]
1472
+ return self.ALLdata_filtered
1473
+
Dataloader/dataloader0.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from torchvision import datasets, transforms
6
+ import nibabel as nib
7
+ from skimage.transform import rescale, resize, downscale_local_mean
8
+ from scipy.ndimage import zoom
9
+ import numpy as np
10
+ # import SimpleITK as sitk
11
+
12
+ # print(os.getcwd())
13
+ import sys
14
+ sys.path.append('./')
15
+ from Dataloader.dataloader_utils import *
16
+
17
+ EPS = 1e-7
18
+
19
+ def get_dataloader(data_name='cmr',mode='train'):
20
+ if data_name=='cmr':
21
+ if mode=='train':
22
+ dataloader=CMR_loader
23
+ elif mode =='aug':
24
+ dataloader=CMR_tgt_loader
25
+ else:
26
+ print('mode not exist')
27
+ elif data_name=='lct':
28
+ if mode=='train':
29
+ dataloader=LCT_loader
30
+ elif mode =='aug':
31
+ dataloader=LCT_tgt_loader
32
+ else:
33
+ print('mode not exist')
34
+ else:
35
+ print('dataloader not exist')
36
+ return dataloader
37
+
38
+ class LCT_loader(Dataset):
39
+ def __init__(self, data_root_path = f'Data/Src_data/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
40
+ # def __init__(self, data_root_path = '/home/data/jzheng/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
41
+ self.files = [data_root_path + f for f in os.listdir(data_root_path) if f.endswith('.npy')]
42
+ self.transforms = transforms
43
+ self.noise_scale=noise_scale
44
+ self.d_p = data_root_path
45
+
46
+ def __getitem__(self, item):
47
+ array = np.load(self.files[item])
48
+ if 'process' not in self.d_p:
49
+ array = (array - array.min()) / (array.max() - array.min() + EPS) # Normalize to 0 to 1
50
+ array = array[None,:,:,:] # add a channel to array make it (‘C’,H,W,Z)
51
+ if self.transforms != None:
52
+ array = self.transforms(array)
53
+ # print(array.shape)
54
+ return array, array, item # -> (B, C, H, W, Z)
55
+ # return array, array # -> (B, C, H, W, Z)
56
+
57
+ def __len__(self):
58
+ return len(self.files)
59
+
60
+ class LCT_tgt_loader(Dataset):
61
+ def __init__(self, data_root_path = "Data/Tgt_data/lct/",noise_scale=0.0, patient_index = None):
62
+ self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
63
+ self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
64
+
65
+ self.files_tr.sort()
66
+ self.files_gt.sort()
67
+
68
+ self.transforms = transforms
69
+ self.noise_scale=noise_scale
70
+
71
+ def __getitem__(self, item):
72
+ img_nib = nib.load(self.files_tr[item])
73
+ mask_nib = nib.load(self.files_gt[item])
74
+
75
+ image = img_nib.get_fdata()
76
+ mask = mask_nib.get_fdata()
77
+
78
+ image = image[None,:,:,:]
79
+ mask = mask[None,:,:,:]
80
+
81
+ print(self.files_tr[item],self.files_gt[item])
82
+
83
+ return image, mask, item
84
+
85
+
86
+
87
+ def __len__(self):
88
+ assert len(self.files_gt) == len(self.files_tr)
89
+ return len(self.files_gt)
90
+
91
+ class LCT_seg(Dataset):
92
+ def __init__(self, data_root_path = "/home/data/jzheng/CTLung_processed/testset/modality_0001/",noise_scale=0.0, patient_index = None):
93
+ self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
94
+ self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
95
+
96
+ self.files_tr.sort()
97
+ self.files_gt.sort()
98
+
99
+ self.transforms = transforms
100
+ self.noise_scale=noise_scale
101
+
102
+ def __getitem__(self, item):
103
+ img_nib = nib.load(self.files_tr[item])
104
+ mask_nib = nib.load(self.files_gt[item])
105
+
106
+ image = img_nib.get_fdata()
107
+ mask = mask_nib.get_fdata()
108
+
109
+ image = image[None,:,:,:]
110
+ mask = mask[None,:,:,:]
111
+
112
+ print(self.files_tr[item],self.files_gt[item])
113
+
114
+ return image, mask, item
115
+
116
+
117
+
118
+ def __len__(self):
119
+ assert len(self.files_gt) == len(self.files_tr)
120
+ return len(self.files_gt)
121
+
122
+ class CMR_loader_preprocess(Dataset):
123
+ # This is for pre_processing for CMR. not use for training model
124
+ def __init__(self, data_path = 'Data/CTLung_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
125
+ # def __init__(self, data_path = '/home/data/jzheng/CMR_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
126
+ self.d_p = data_path
127
+ self.target_res = target_res
128
+ self.files = [self.d_p + x for x in os.listdir(self.d_p)]
129
+ self.transforms = transforms
130
+ self.noise_scale=noise_scale
131
+
132
+ def __getitem__(self, item):
133
+ array = nib.load(self.files[item]).get_fdata()
134
+ array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
135
+ array = array[None, :, :]
136
+ array = remove_background(array) # jzheng 20240228
137
+ array = (array - array.min()) / (array.max() - array.min() + EPS)
138
+
139
+ if self.noise_scale > 0:
140
+ array = thresh_img(array,[0,self.noise_scale])
141
+ array = array * (np.random.normal(1, self.noise_scale*2))
142
+
143
+ if self.transforms != None:
144
+ array = self.transforms(array)
145
+ return array, self.files[item]
146
+
147
+ def __len__(self):
148
+ return len(self.files)
149
+
150
+ class CMR_loader(Dataset):
151
+ # niff format size is (H,W) for CMR
152
+ # CMR_processed_rmbg_resize means the niif image has been gone throught rmbg and resize offline to make trainig fast
153
+ def __init__(self, data_path = f'Data/Src_data/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
154
+ # def __init__(self, data_path = '/home/data/jzheng/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
155
+ self.d_p = data_path
156
+ self.ndims = 2
157
+ self.target_res = target_res
158
+ self.files = [self.d_p + x for x in os.listdir(self.d_p)]
159
+ self.transforms = transforms
160
+ # self.get_transform()
161
+ self.noise_scale=noise_scale
162
+ self.preprocessed='resize' in data_path
163
+
164
+ def __getitem__(self, item):
165
+ array = nib.load(self.files[item]).get_fdata()
166
+ if not self.preprocessed:
167
+ array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
168
+ array = array[None, :, :]
169
+ if not self.preprocessed:
170
+ array = remove_background(array) # jzheng 20240228
171
+ array = (array - array.min()) / (array.max() - array.min() + EPS)
172
+
173
+ # if self.noise_scale > 0:
174
+ # array = thresh_img(array,[0,self.noise_scale])
175
+ # array = array * (np.random.normal(1, self.noise_scale*2)) + np.random.normal(0, self.noise_scale*2)
176
+
177
+ if self.transforms != None:
178
+ array = self.transforms(array)
179
+ return array, array, item
180
+
181
+ def __len__(self):
182
+ return len(self.files)
183
+
184
+ def get_transform(self,degrees=np.pi,translate=0.125):
185
+ # self.transforms = torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
186
+ self.transforms = torchvision.transforms.Compose([
187
+ # torchvision.transforms.Resize((hyp_parameters['img_size'], hyp_parameters['img_size'])),
188
+ torchvision.transforms.ToTensor(),
189
+ torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
190
+ # torchvision.transforms.ToTensor(),
191
+ # torchvision.transforms.Normalize(0.5, 0.5)
192
+ # Lambda(lambda x: (x - 0.5) * 2)
193
+ ])
194
+ return
195
+
196
+ class CMR_tgt_loader(Dataset):
197
+ def __init__(self,
198
+ data_path = 'Data/Tgt_data/cmr/',
199
+ # gt_path = '/home/data/jzheng/acdc/train_gt/',
200
+ target_res = (256,256),
201
+ is_3d = False,
202
+ patient_index = [],
203
+ ):
204
+
205
+ # parameter initialize
206
+ self.d_p = os.path.join(data_path,'Tr','')
207
+ self.gt_p = os.path.join(data_path,'Gt','')
208
+ self.img_files = os.listdir(self.d_p)
209
+ self.gt_files = os.listdir(self.gt_p)
210
+ self.p_indice = patient_index
211
+ self.target_res_2d = target_res
212
+ self.img_files.sort()
213
+ self.gt_files.sort()
214
+ self.img_samples = []
215
+ self.gt_samples = []
216
+ self.p_id = []
217
+
218
+ if len(self.p_indice) == 0:
219
+ self.p_indice = [x for x in range(1,101)]
220
+ # build patient-to-file correspondence
221
+ p2f = {}
222
+ assert len(self.gt_files) == len(self.img_files)
223
+ print(self.p_indice)
224
+ for i in self.p_indice:
225
+ for gt_f, img_f in zip(self.gt_files, self.img_files):
226
+ pf_id = gt_f.split('_')[0]
227
+ pf_id = pf_id[-3:]
228
+ if i == int(pf_id):
229
+ img_volume = nib.load(self.d_p + img_f).get_fdata()
230
+ gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
231
+ assert img_volume.shape == gt_volume.shape
232
+ depth = img_volume.shape[2]
233
+ for si in range(depth):
234
+ img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
235
+ img = (img - img.min()) / (img.max() - img.min() + EPS)
236
+
237
+ gt = gt_volume[:, :, si]
238
+
239
+ gt_1_index = gt == 1
240
+ gt_2_index = gt == 2
241
+ gt_3_index = gt == 3
242
+ gt_4_index = gt == 4
243
+
244
+ gt_1 = gt * gt_1_index
245
+ gt_2 = gt * gt_2_index
246
+ gt_3 = gt * gt_3_index
247
+ gt_4 = gt * gt_4_index
248
+
249
+
250
+ gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
251
+ gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
252
+ gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
253
+ gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
254
+
255
+
256
+ self.img_samples.append(img[np.newaxis, :, :])
257
+ self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
258
+ self.p_id.append(i)
259
+
260
+
261
+ def __getitem__(self, item):
262
+
263
+ return self.img_samples[item], self.gt_samples[item], self.p_id[item]
264
+
265
+
266
+ def __len__(self):
267
+
268
+ assert len(self.img_samples) == len(self.gt_samples)
269
+ return len(self.img_samples)
270
+
271
+ class acdc_seg(Dataset):
272
+ def __init__(self,
273
+ data_path = '/home/data/jzheng/acdc/train_images/',
274
+ gt_path = '/home/data/jzheng/acdc/train_gt/',
275
+ target_res = (256,256),
276
+ is_3d = False,
277
+ patient_index = [],
278
+ ):
279
+
280
+ # parameter initialize
281
+ self.d_p = data_path
282
+ self.gt_p = gt_path
283
+ self.img_files = os.listdir(self.d_p)
284
+ self.gt_files = os.listdir(self.gt_p)
285
+ self.p_indice = patient_index
286
+ self.target_res_2d = target_res
287
+ self.img_files.sort()
288
+ self.gt_files.sort()
289
+ self.img_samples = []
290
+ self.gt_samples = []
291
+ self.p_id = []
292
+
293
+ if len(self.p_indice) == 0:
294
+ self.p_indice = [x for x in range(1,101)]
295
+ # build patient-to-file correspondence
296
+ p2f = {}
297
+ assert len(self.gt_files) == len(self.img_files)
298
+ print(self.p_indice)
299
+ for i in self.p_indice:
300
+ for gt_f, img_f in zip(self.gt_files, self.img_files):
301
+ pf_id = gt_f.split('_')[0]
302
+ pf_id = pf_id[-3:]
303
+ if i == int(pf_id):
304
+ img_volume = nib.load(self.d_p + img_f).get_fdata()
305
+ gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
306
+ assert img_volume.shape == gt_volume.shape
307
+ depth = img_volume.shape[2]
308
+ for si in range(depth):
309
+ img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
310
+ img = (img - img.min()) / (img.max() - img.min() + EPS)
311
+
312
+ gt = gt_volume[:, :, si]
313
+
314
+ gt_1_index = gt == 1
315
+ gt_2_index = gt == 2
316
+ gt_3_index = gt == 3
317
+ gt_4_index = gt == 4
318
+
319
+ gt_1 = gt * gt_1_index
320
+ gt_2 = gt * gt_2_index
321
+ gt_3 = gt * gt_3_index
322
+ gt_4 = gt * gt_4_index
323
+
324
+
325
+ gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
326
+ gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
327
+ gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
328
+ gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
329
+
330
+
331
+ self.img_samples.append(img[np.newaxis, :, :])
332
+ self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
333
+ self.p_id.append(i)
334
+
335
+
336
+ def __getitem__(self, item):
337
+
338
+ return self.img_samples[item], self.gt_samples[item], self.p_id[item]
339
+
340
+
341
+ def __len__(self):
342
+
343
+ assert len(self.img_samples) == len(self.gt_samples)
344
+ return len(self.img_samples)
345
+
346
+ class acdc_gan(Dataset):
347
+ def __init__(self,
348
+ train_path = '/home/data/jzheng/acdc/images/',
349
+ target_res = (32, 256, 256),
350
+ is_3d = False,
351
+ transforms = None
352
+ ):
353
+ self.t_p = train_path
354
+ self.files = os.listdir(self.t_p)
355
+ self.sample_list_2d = []
356
+ self.is_3d = is_3d
357
+ self.target_res = target_res
358
+ self.res_2d = (target_res[1], target_res[2])
359
+ self.transforms = transforms
360
+
361
+ if self.is_3d == False:
362
+ for f in self.files:
363
+ img = nib.load(self.t_p + f).get_fdata()
364
+ depth = img.shape[2]
365
+ f_i = int(round(depth*0.1))
366
+ b_i = int(round(depth*0.9))
367
+ interval_slice = img[:, :, f_i:b_i]
368
+ for ii in range(interval_slice.shape[2]):
369
+ single_slice = interval_slice[:,:,ii]
370
+ single_slice = resize(single_slice, self.res_2d, anti_aliasing=True, preserve_range=True)
371
+ single_slice = (single_slice - single_slice.min()) / ( single_slice.max() - single_slice.min() + EPS)
372
+ self.sample_list_2d.append(single_slice[None,:,:])
373
+
374
+
375
+ def __len__(self):
376
+ if self.is_3d == False:
377
+ return len(self.sample_list_2d)
378
+ else:
379
+ return len(self.files )
380
+
381
+ def __getitem__(self, index):
382
+ if self.is_3d == False:
383
+ return self.sample_list_2d[index], self.sample_list_2d[index]
384
+ for f in self.files:
385
+ img = nib.load(self.t_p + f).get_fdata()
386
+ target_d_ratio = self.target_res[0] / img.shape[2]
387
+ target_w_ratio = self.target_res[1] / img.shape[0]
388
+ target_h_ratio = self.target_res[2] / img.shape[1]
389
+
390
+ resize_img = zoom(img, (target_w_ratio, target_h_ratio, target_d_ratio))
391
+
392
+ resize_img = np.swapaxes(resize_img, 0, 2)
393
+ resize_img = np.swapaxes(resize_img, 1, 2)
394
+ resize_img = (resize_img - resize_img.min()) / (resize_img.max() - resize_img.min() + EPS)
395
+ if transforms != None:
396
+ resize_img = self.transforms(resize_img)
397
+ return resize_img, resize_img
398
+
399
+ class acdc_gan_single_slice(Dataset):
400
+ def __init__(self, train_path = '/well/papiez/shared/ACDC/clean_training/images/'):
401
+ self.t_p = train_path
402
+ self.files = os.listdir(self.t_p)
403
+
404
+ def __len__(self):
405
+ return len(self.files)
406
+
407
+ def __getitem__(self, index):
408
+ img = self.files[index]
409
+ img = nib.load(self.t_p + img).get_fdata()
410
+ depth = img.shape[2]
411
+ mid_d = int(depth/2)
412
+ mid_slice = img[:,:,mid_d]
413
+ mid_slice = resize(mid_slice, (128, 128), anti_aliasing=True, preserve_range=True)
414
+ mid_slice = (mid_slice-mid_slice.min())/(mid_slice.max()-mid_slice.min()+EPS)
415
+ # print(mid_slice.max(),mid_slice.min())
416
+
417
+ return mid_slice, mid_slice
418
+
419
+
420
+
421
+
Dataloader/dataloader_tester.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataLoader import *
2
+ import torchvision.transforms as tf
3
+ import SimpleITK as sitk
4
+ import os
5
+
6
+
7
+ transform = tf.Compose([
8
+ tf.ToTensor(), # Convert image to tensor
9
+ ])
10
+
11
+ mapping_files_bert = {
12
+ # 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
13
+ # 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
14
+ 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
15
+ }
16
+
17
+ if __name__ == "__main__":
18
+ # dataset = OminiDataset_v1(transform=None)
19
+ # datasetp = OminiDataset_paired(transform=None)
20
+ # dataset = OminiDataset_paired_inf(transform=None)
21
+ # dataset = OminiDataset_inference_w_all(transform=None)
22
+ # dataset = OminiDataset_bertembd(transform=None,mapping_files=mapping_files_bert)
23
+ dataset = OminiDataset(transform=None)
24
+
25
+
26
+
27
+
28
+ # print(dataset.get_keys_dist())
29
+ # print(len(dataset))
30
+ # print(dataset.build_batch().shape)
31
+ # exit()
32
+ dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
33
+
34
+ for i, data in enumerate(dataloader):
35
+ print(data[1])
36
+ exit()
37
+
38
+
39
+ # print(dataset.get_ALLdata())
Dataloader/dataloader_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ # from torch import nn, optim
4
+ # from torch.autograd.variable import Variable
5
+ # from torchvision import transforms, datasets
6
+ # from torchvision.utils import save_image
7
+ # import torch.nn.functional as F
8
+ # import scipy.ndimage as spimg
9
+ # import pyquaternion as quater
10
+ # import random
11
+ import numpy as np
12
+ from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, generate_binary_structure
13
+ import pydicom
14
+ from scipy.ndimage import zoom
15
+ from einops import rearrange, reduce, repeat
16
+
17
+ def get_sizeRange_dict(roi=''):
18
+ """
19
+ Returns a dictionary with size ranges for different regions of interest (ROIs).
20
+ If a specific ROI is provided, returns the size range for that ROI.
21
+ If no ROI is provided, returns the entire dictionary.
22
+ Args:
23
+ roi (str): The region of interest for which to get the size range.
24
+ Returns:
25
+ dict or list: A dictionary with size ranges for all ROIs, or a list with the size range for the specified ROI.
26
+ """
27
+ # Define the size ranges for different ROIs
28
+ # The values are in the format [min_size, max_size]
29
+ # The sizes are in mm for the minimum and maximum dimensions
30
+ sizeRange_dict = {
31
+ 'whole-body': [420, 2048],
32
+ 'neck-thorax-abdomen-pelvis-leg': [400, 2048],
33
+ 'neck-thorax-abdomen-pelvis': [380, 2048],
34
+ 'thorax-abdomen-pelvis-leg': [360, 2048],
35
+ 'neck-thorax-abdomen': [320, 1024],
36
+ 'head-neck-thorax-abdomen': [360, 2048],
37
+ 'head-neck-thorax': [340, 1024],
38
+ 'thorax-abdomen-pelvis': [340, 1024],
39
+ 'abdomen-pelvis-leg': [320, 1024],
40
+ 'neck-thorax': [220, 1024],
41
+ 'thorax-abdomen': [260, 1024],
42
+ 'abdomen-pelvis': [260, 1024],
43
+ 'pelvis-leg': [240, 1024],
44
+ 'head-neck': [240, 1024],
45
+ 'head': [150, 1024],
46
+ 'brain': [128, 1024],
47
+ 'neck': [140, 1024],
48
+ 'abdomen': [240, 1024],
49
+ 'pelvis': [220, 1024],
50
+ 'thorax': [220, 1024],
51
+ 'arm': [140, 1024],
52
+ 'hand': [140, 1024],
53
+ 'leg': [160, 1024],
54
+ 'skeleton': [130, 1024],
55
+ }
56
+ if roi in sizeRange_dict:
57
+ return sizeRange_dict[roi]
58
+ else:
59
+ return sizeRange_dict
60
+
61
+
62
+ def remove_background(img,replace_value=None,num_bin=256,dim_ch=0,sigma=None):
63
+ # common_value1,common_value2=[], []
64
+ # if replace_value is None:
65
+ if dim_ch is None:
66
+ dim_ch=0
67
+ img=np.expand_dims(img,axis=dim_ch)
68
+ ims = np.split(img,img.shape[dim_ch],axis=dim_ch)
69
+ # ims =[img]
70
+ ims = [np.squeeze(im,axis=dim_ch) for im in ims]
71
+ msk1 = np.ones_like(ims[0])
72
+ for im in ims:
73
+ if num_bin>0:
74
+ flatten_im=im.flatten()
75
+ hist, bins = np.histogram(flatten_im,bins=range(num_bin))
76
+ # common_value1.append(np.argmax(hist))
77
+ common_value1 = np.argmax(hist)
78
+ # hist[common_value1] = -10**5
79
+ msk1[im!=common_value1] = 0
80
+ # common_value2 = np.argmax(hist)
81
+ if sigma is not None and sigma > 0:
82
+ # struct=generate_binary_structure()
83
+ msk1 = binary_dilation(msk1,iterations=int(sigma*4)).astype(float)
84
+ msk0 = binary_erosion(1-msk1,iterations=int(sigma*4)).astype(float)
85
+ msk_blur = gaussian_filter(msk0, sigma=sigma*4,truncate=sigma//4, mode='nearest')
86
+ # msk_blur = msk0
87
+ for id, im in enumerate(ims):
88
+ if replace_value is None:
89
+ # a=im[np.logical_not(msk1)]
90
+ # replace_value[id] = np.min(im[np.logical_not(msk1)])
91
+ replace_v=np.min(im[np.logical_not(msk1)])
92
+ else:
93
+ replace_v=replace_value[id]
94
+ # im[msk1==1] = replace_v
95
+ if sigma is not None and sigma>0:
96
+ im_blur=im
97
+ im_blur[msk1==1]=replace_v
98
+ im_blur = gaussian_filter(im_blur, sigma=sigma*4,truncate=sigma//4, mode='nearest')
99
+ # im[msk1==1] = im_blur[msk1==1]
100
+ im=im*(msk_blur) + im_blur*(1-msk_blur)
101
+ else:
102
+ im[msk1 == 1] = replace_v
103
+ # print(im.shape)
104
+ ims[id]=im
105
+ return np.stack(ims,axis=dim_ch)
106
+
107
+ def thresh_img(img,thresh = None,EPS = 10**-7):
108
+
109
+ if isinstance(thresh,list):
110
+ threshold=np.random.uniform(thresh[0],thresh[1])
111
+ upbound=1-np.random.uniform(thresh[0],thresh[1])-threshold
112
+ else:
113
+ threshold=thresh
114
+ if threshold is not None:
115
+ # img=img-threshold
116
+ # img=np.where(img>=0,img,0)
117
+ # img = np.maximum(img-threshold,0)
118
+ # img = torch.maximum(img - threshold,torch.tensor(0.))
119
+ if isinstance(img,list):
120
+ device=img[0].device
121
+ for i in range(len(img)):
122
+ img[i] = torch.clamp(img[i]-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
123
+ else:
124
+ device=img.device
125
+ img = torch.clamp(img-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
126
+ # return (img - img.min()) / (img.max() - img.min() + EPS)
127
+ return img
128
+
129
+ def clamp_img_tensor(img,clamp = [None,None]):
130
+ device=img.device
131
+ if clamp[0] is not None and clamp[1] is not None:
132
+ img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device),max=torch.tensor(clamp[1]).to(device))
133
+ else:
134
+ if clamp[0] is not None:
135
+ img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device))
136
+ if clamp[1] is not None:
137
+ img = torch.clamp(img, max=torch.tensor(clamp[1]).to(device))
138
+ return img
139
+
140
+ def read_CT_volume(folder_path,target_res = 128):
141
+ # read CT into a (128x128x128) cube and pad the insufficient dimension
142
+
143
+ dicom_slices = []
144
+ # Iterate over each file in the folder
145
+ for filename in sorted(os.listdir(folder_path), reverse=True):
146
+ if filename.endswith(".dcm"): # Check if the file is a DICOM file
147
+ file_path = os.path.join(folder_path, filename)
148
+
149
+ # Read the DICOM file
150
+ dicom_data = pydicom.dcmread(file_path)
151
+
152
+ # Append DICOM pixel data to the list
153
+ dicom_slices.append(dicom_data.pixel_array)
154
+
155
+ # Convert the list of slices to a numpy array
156
+
157
+ dicom_slices = np.array(dicom_slices)
158
+ dicome_volume = rearrange(dicom_slices, 'z h w -> h w z')
159
+
160
+ # Get spatial information from the first DICOM file
161
+ first_dicom = pydicom.dcmread(os.path.join(folder_path, os.listdir(folder_path)[0]))
162
+ slice_thickness = first_dicom.SliceThickness
163
+ pixel_spacing = first_dicom.PixelSpacing
164
+
165
+ # Get the scaling ratio for each dim
166
+ h_axis_ratio = pixel_spacing[0]
167
+ w_axis_ratio = pixel_spacing[1]
168
+ z_axis_ratio = slice_thickness
169
+
170
+ # find the longest dim that need to rescale
171
+ longest_axis = max([h_axis_ratio*dicome_volume.shape[0], w_axis_ratio*dicome_volume.shape[1],z_axis_ratio*dicome_volume.shape[2]])
172
+ c_factor = longest_axis/target_res
173
+ # print((h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
174
+ resized_volume = zoom(dicome_volume, (h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
175
+ # print('resize', resized_volume.shape)
176
+
177
+
178
+ max_dim_size = max(resized_volume.shape)
179
+
180
+ # Calculate padding for each dimension
181
+ padding_h = max_dim_size - resized_volume.shape[0]
182
+ padding_w = max_dim_size - resized_volume.shape[1]
183
+ padding_z = max_dim_size - resized_volume.shape[2]
184
+
185
+ pad_depth = (padding_z // 2, padding_z - padding_z // 2)
186
+ pad_height = (padding_h // 2, padding_h - padding_h // 2)
187
+ pad_width = (padding_w // 2, padding_w - padding_w // 2)
188
+
189
+ # Pad the array symmetrically
190
+ padded_resized_volume = np.pad(resized_volume, (pad_height, pad_width, pad_depth), mode='constant')
191
+
192
+ return padded_resized_volume, slice_thickness, pixel_spacing
193
+
Dataloader/embding_gen.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import json
4
+ import SimpleITK as sitk
5
+ import numpy as np
6
+ from skimage.transform import rescale, resize, downscale_local_mean
7
+ # from torchvision.transforms import v2
8
+ import sys
9
+ from bert_helper import *
10
+ sys.path.append('./')
11
+ from Dataloader.dataloader_utils import *
12
+ import random
13
+
14
+
15
+
16
+ mapping_files = {
17
+ # 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
18
+ # 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
19
+ # 'Kaggle_osic': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/Kaggle_osic_new/nifti_mappings.json',
20
+ # 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_test/nifti_mappings.json',
21
+ # 'MnMs': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MnMs/nifti_mappings.json',
22
+ # 'Brats2019': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2019/nifti_mappings.json',
23
+ # 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
24
+ # 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
25
+ # 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
26
+ 'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
27
+ # 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
28
+ # 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
29
+ # 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
30
+ # 'AbdomenCT1k':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenCT1k/nifti_mappings.json',
31
+
32
+ }
33
+ save_paths = {
34
+ 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
35
+ 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
36
+ 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
37
+ 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
38
+ 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
39
+ 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
40
+ 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
41
+ 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
42
+ 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
43
+ 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
44
+ 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
45
+ 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
46
+ 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
47
+ 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
48
+ }
49
+ query = {
50
+ 'MSD': ['description'],
51
+ 'TotalSegmentor': ['age','gender'],
52
+ 'Kaggle_osic': ['Age','Sex','Smoke_Status','Weeks','FVC','Percent'],
53
+ 'CancerImageArchive':['Series_Description', 'Study_Description', 'Manufacturer'],
54
+ 'MnMs': ['Age','Sex','Height','Weight'],
55
+ 'Brats2019': ['Age', 'Grade', 'Survival','ResectionStatus'],
56
+ 'Brats2020': ['Age', 'Grade', 'Survival','ResectionStatus'],
57
+ 'Brats2021': ['Age', 'Grade', 'Survival','ResectionStatus'],
58
+ 'OASIS_1': ['Age', 'M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
59
+ 'OASIS_2': ['Age', 'Group','M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
60
+ 'PSMA-FDG-PET-CT-LESION':['Study Description', 'diagnosis','age','sex',"pet_radionuclide",'ct_contrast_agent'],
61
+ 'PSMA-CT':[],
62
+ 'AbdomenAtlas':[],
63
+ 'AbdomenCT1k':[],
64
+ }
65
+ add_text = {
66
+ 'MSD': {},
67
+ 'TotalSegmentor': {},
68
+ 'Kaggle_osic': {'description': 'pulmonary fibrosis progression'},
69
+ 'CancerImageArchive': {},
70
+ 'MnMs': {},
71
+ 'Brats2019': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
72
+ 'Brats2020': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
73
+ 'Brats2021': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
74
+ 'OASIS_1': {},
75
+ 'OASIS_2': {},
76
+ 'PSMA-CT':{'description': 'melanoma patients'},
77
+ 'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
78
+ 'AbdomenAtlas':{},
79
+ 'AbdomenCT1k':{},
80
+ }
81
+
82
+
83
+ # bert intialization
84
+ model_name = '/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased'
85
+ reduce_method = 'mean'
86
+ max_words_num = 32 # max number of words in the caption > 2
87
+ # max_words_num = 64 # max number of words in the caption > 2
88
+
89
+ embeder, tokenizer = get_frozen_embeder(model_name)
90
+ def embed_str_filter(str_input, filter_words=['segmentation', 'registration']):
91
+ '''
92
+ Filter out specific words from the input string.
93
+ '''
94
+ for word in filter_words:
95
+ str_input = str_input.replace(word, '')
96
+ return str_input
97
+
98
+ for dataset in mapping_files.keys():
99
+ jsn_path = mapping_files[dataset]
100
+
101
+ with open(jsn_path, 'r') as f:
102
+ embd_json = json.load(f)
103
+ for key in embd_json.keys():
104
+ embd_json_temp = {}
105
+
106
+
107
+ embd_json_temp['Modality'] = embd_json[key]['Modality']
108
+ embd_json_temp['ROI'] = embd_json[key]['ROI']
109
+
110
+
111
+ query_key = query[dataset]
112
+
113
+ meta_data = embd_json[key]['Metadata']
114
+ for q in query_key:
115
+ if q in meta_data:
116
+ embd_json_temp[q] = meta_data[q]
117
+ else:
118
+ embd_json_temp[q] = 'N/A'
119
+ for q in add_text[dataset].keys():
120
+ if q in embd_json_temp:
121
+ embd_json_temp[q] += ', ' + add_text[dataset][q]
122
+ else:
123
+ embd_json_temp[q] = add_text[dataset][q]
124
+ emdb_str = str(embd_json_temp)[1:-1].lower()
125
+ embd_str = replace_text(emdb_str, get_synonyms_dict(None))
126
+ embd_str = embed_str_filter(embd_str)
127
+
128
+ print(f'embd_json_temp: {str(embd_json_temp)}')
129
+ print(f'embd_str: {embd_str}')
130
+ print(f'words_num: {len(embd_str.split())}')
131
+ assert(len(embd_str.split()) <= max_words_num), f'Too many words in the caption: {embd_str}'
132
+
133
+ embd = str2emb(embd_str, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
134
+ print(embd)
135
+ embd_json[key]['embd'] = embd.tolist()[0]
136
+ embd_json[key]['embd_key'] = embd_str
137
+
138
+ # exit()
139
+
140
+ new_jsn_path = save_paths[dataset]
141
+ with open(new_jsn_path, 'w') as f:
142
+ json.dump(embd_json, f, indent=4)
143
+
144
+
145
+
146
+
147
+
148
+
149
+
Dataloader/nifty_mappings/AbdomenAtlas_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:303c3fb7388e7b3b01cb6f494c3ac3f542da98487039e5b2415786ac4af58ba0
3
+ size 179457573
Dataloader/nifty_mappings/AbdomenCT1k_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0abaaa1013fdafe3fae6d5544746a66d8b20892ceb3cf9141a125113984e8350
3
+ size 37315918
Dataloader/nifty_mappings/Brats2019_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c5b80fc861484d36d8d6e0f97c404e2c321ee965cc1556a868205f5937d24fe
3
+ size 12126490
Dataloader/nifty_mappings/Brats2020_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de345c6a66a4f33552aacbb961cd034ac488500ff5d48810579055f0543162dc
3
+ size 17743015
Dataloader/nifty_mappings/Brats2021_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4990a7031d6ac91e1c33e6db046dddf234f67dd8edecd07691675945b9d00af5
3
+ size 44722001
Dataloader/nifty_mappings/CIA_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98cbd21d3d5b7f5fb84091705fbbfcd0f8f26cb26ff4b34ffcf546cf1cedb48a
3
+ size 32744567
Dataloader/nifty_mappings/Kaggle_osic_mappings.json ADDED
The diff for this file is too large to render. See raw diff
 
Dataloader/nifty_mappings/MSD_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1ab13c61cd6829f088ee92bff4ce12a0f0e19fc9367682291fbd9717b149e83
3
+ size 92620864
Dataloader/nifty_mappings/MnMs_mappings.json ADDED
The diff for this file is too large to render. See raw diff
 
Dataloader/nifty_mappings/OASIS_1_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8784bff1bb5c9ba08fccc8ca9776f3f26c9b2993c1c446ef17d5ba1dd2bda490
3
+ size 15609846
Dataloader/nifty_mappings/OASIS_2_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f88910a0846e056b0d4caacd6e6ebfebde52b537828756e217d9a6c6343177c
3
+ size 13396017
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3c8729df59b6e9771fa791c5fe1cd7636e83a3c17109613984cdce0d92eefdc
3
+ size 11700732
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:922363b739e1f14243731ea283ee730bc55724a27360d2f28f32b01b23ede5d9
3
+ size 48425273
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c36ba45053fea97244c259af0151ddb02e8281fce8c8f439cc88733bd71d668f
3
+ size 67962146
Diffusion/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import Diffusion
2
+ from . import diffuser
3
+ from . import networks
4
+ from . import losses
5
+
6
+ import sys
7
+ sys.path.append('./Diffusion')
8
+ sys.path.append('./')
Diffusion/diffuser.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import numpy as np
4
+ from torch.nn.utils.stateless import functional_call
5
+
6
+ import Diffusion.utils_diff as utils
7
+ from Diffusion.networks import *
8
+ # from networks import *
9
+
10
+ import random
11
+
12
+ EPS = 1e-8
13
+
14
+
15
+
16
+ class DeformDDPM(nn.Module):
17
+ def __init__(
18
+ self,
19
+ network,
20
+ n_steps=50,
21
+ beta_schedule_fn = None,
22
+ device='cpu',
23
+ image_chw=(1, 28, 28),
24
+ batch_size = 1,
25
+ img_pad_mode = "zeros",
26
+ ddf_pad_mode="border",
27
+ padding_mode="border",
28
+ v_scale = 0.008/256,
29
+ resample_mode=None,
30
+ ):
31
+ super(DeformDDPM, self).__init__()
32
+ self.rec_num=2
33
+ self.ndims=len(image_chw)-1
34
+ self.n_steps = n_steps
35
+ self.v_scale = v_scale
36
+ self.device = device
37
+ self.msk_noise_scale = torch.tensor(0)
38
+
39
+ # print('================')
40
+ # print("device:",device)
41
+ # if device == 'cpu':
42
+ # print("num_device: 1")
43
+ # else:
44
+ # print("num_device:", torch.cuda.device_count())
45
+ # print('================')
46
+
47
+ self.num_device = torch.cuda.device_count()
48
+
49
+ self.batch_size = batch_size #//self.num_device
50
+ self.img_pad_mode = img_pad_mode
51
+ self.ddf_pad_mode = ddf_pad_mode
52
+ self.padding_mode = padding_mode
53
+ self.resample_mode = resample_mode
54
+ self.image_chw = image_chw
55
+ self.network = network#.to(self.device)
56
+ self.ddf_stn_full = STN(
57
+ img_sz = self.image_chw[1],
58
+ ndims = self.ndims,
59
+ padding_mode = self.padding_mode,
60
+ device = self.device,
61
+ )
62
+ self._DDF_Encoder_init()
63
+ self.copy_opt = nn.Identity()
64
+ return
65
+
66
+ def get_stn(self):
67
+ return self.img_stn, self.ddf_stn_full
68
+
69
+ def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
70
+ if ctl_sz is None:
71
+ ctl_sz = self.image_chw[1] // ctl_ratio
72
+ self.ctl_sz=ctl_sz
73
+ self.img_sz=self.image_chw[1]
74
+ self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
75
+ self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
76
+ self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
77
+
78
+ def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
79
+ rec_num = 1
80
+ mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
81
+ mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
82
+ # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
83
+ # mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
84
+ # mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
85
+ mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
86
+ mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
87
+ # print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
88
+ return rec_num,mul_num_ddf,mul_num_dvf
89
+
90
+ # def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3):
91
+ # # high: tensor of shape (...), low: int or tensor broadcastable to high
92
+ # sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone()
93
+ # for _ in range(order_num):
94
+ # # For each element, sample in [sample_num, high]
95
+ # # torch.randint requires scalar low/high, so we use elementwise sampling
96
+ # rand_shape = high.shape
97
+ # # Clamp sample_num to be <= high
98
+ # sample_num = torch.minimum(sample_num, high)
99
+ # # Generate random numbers for each element
100
+ # rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device)
101
+ # for idx in np.ndindex(rand_shape):
102
+ # l = sample_num[idx].item()
103
+ # h = high[idx].item()
104
+ # if l >= h:
105
+ # rand[idx] = l
106
+ # else:
107
+ # rand[idx] = torch.randint(l, h + 1, (1,), device=high.device)
108
+ # sample_num = rand.to(high.dtype)
109
+ # return sample_num
110
+
111
+ def _get_random_ddf(self,img,t):
112
+ rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
113
+ ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
114
+ warped_img = self.img_stn(img,ddf_forward)
115
+ return warped_img, dvf_forward,ddf_forward
116
+
117
+ def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
118
+ dvf=0
119
+ if self.img_sz is None:
120
+ self.img_sz=max(ctl_szs)
121
+ if 1 in ctl_szs:
122
+ dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
123
+ dvf = dvf + dvf_rot
124
+ for ctl_sz in ctl_szs:
125
+ _v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
126
+ # temp>>
127
+ if ctl_sz <= 2:
128
+ _v_scale = _v_scale/2
129
+ # temp<<
130
+ dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
131
+ dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
132
+ dvf=dvf+dvf_comp
133
+ return dvf
134
+
135
+ def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
136
+ sample_value = low
137
+ for _ in range(order_num):
138
+ sample_value = np.random.uniform(low=sample_value, high=high)
139
+ return sample_value
140
+
141
+ def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
142
+ crop_rate=2
143
+ for _ in range(self.ndims+1):
144
+ mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
145
+ # v_scale = v_scale *crop_rate
146
+ ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
147
+ if ddf0 is not None:
148
+ ddf=ddf0
149
+ else:
150
+ ddf = torch.zeros(ctl_ddf_sz) * 0
151
+ dddf = torch.zeros(ctl_ddf_sz) * 0
152
+ scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine
153
+ # scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine
154
+ # scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine
155
+ ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
156
+
157
+ for i in range(rec_num):
158
+ # Randomly select 5 elements from ctl_szs (if there are at least 5)
159
+ if len(ctl_szs_all) > select_num:
160
+ ctl_szs = random.sample(ctl_szs_all, select_num)
161
+ dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
162
+ # if True:
163
+ if noise_ratio==0:
164
+ dvf0=dvf
165
+ else:
166
+ dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
167
+ # print([num.shape for num in mul_num])
168
+ for j in range(torch.max(mul_num[0]).item()):
169
+ flag = [(n>j).int().to(self.device) for n in mul_num]
170
+ ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
171
+ dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
172
+
173
+ ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
174
+ # ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
175
+ if self.ndims==2:
176
+ ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
177
+ else:
178
+ ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
179
+ # if rec_num==1:
180
+ if True:
181
+ dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
182
+ # dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
183
+ if self.ndims == 2:
184
+ dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
185
+ else:
186
+ dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
187
+ return ddf,dddf
188
+ else:
189
+ return ddf
190
+
191
+ def create_noise_map(self, img, noise_type='gaussian', noise_ratio=0.2):
192
+ if noise_type == 'gaussian':
193
+ noise_map = torch.randn_like(img) * noise_ratio
194
+ elif noise_type == 'uniform':
195
+ noise_map = torch.rand_like(img) # 0-1
196
+ elif noise_type == 'binary':
197
+ noise_map = torch.bernoulli(torch.rand_like(img))
198
+ else:
199
+ noise_map = torch.zeros_like(img)
200
+ noise_map = noise_map.to(img.device)
201
+ return noise_map
202
+
203
+ def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
204
+ noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
205
+ return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
206
+
207
+ def apply_noise(self, img, noise_map=None, apply_mask=None):
208
+ return img * apply_mask + noise_map * (1-apply_mask)
209
+
210
+ def downsample(self, img, down_ratio_range=[1./32,1]):
211
+ down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
212
+ # print(down_ratio)
213
+ down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
214
+ # print(down_img)
215
+ # return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio)
216
+ return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy
217
+
218
+ def get_slice_mask(self, img, slice_num_range=[0,32]):
219
+ slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
220
+ mask = torch.zeros_like(img)
221
+ sample_ratio = 0
222
+ for i in range(self.ndims):
223
+ slice_num = random.randint(slice_num_range[0], slice_num_range[1])
224
+ slice_idx = random.sample(range(self.image_chw[1]), slice_num)
225
+ transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
226
+ for idx in slice_idx:
227
+ mask[..., idx] = 1
228
+ mask = mask.permute(*transpose_list)
229
+ # sample_ratio += slice_num / self.image_chw[1] / self.ndims
230
+ sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy
231
+
232
+ # print(mask)
233
+ # print("sample_ratio:", sample_ratio)
234
+ return mask, sample_ratio
235
+
236
+ def project(self, img):
237
+ proj_img = torch.zeros_like(img)
238
+ rand_bourn = np.random.randint(0, 2, size=[self.ndims])
239
+ proj_dim_num = np.sum(rand_bourn)
240
+ for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
241
+ if pflag:
242
+ proj_img += torch.mean(img, dim=i, keepdim=True)
243
+ # print("projecting dim:", i)
244
+ return proj_img/(proj_dim_num+EPS), proj_dim_num
245
+
246
+ def proc_cond_img(self, img, proc_type=None):
247
+ # Remove torch.no_grad() since most operations are not differentiable anyway
248
+ proc_img = img.clone().detach()
249
+ if proc_type is None:
250
+ # Heavily bias towards 'uncon' for efficiency
251
+ proc_type = random.choices(
252
+ # ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
253
+ # weights=[1, 1, 1, 1, 1, 1, 3], k=1
254
+ ['adding', 'independ', 'downsample', 'slice', 'none', 'uncon'],
255
+ weights=[1, 1, 1, 1, 1, 3], k=1
256
+ )[0]
257
+ mask = torch.tensor(1, device=img.device)
258
+ cond_ratio = torch.tensor(1., device=img.device)
259
+ self.msk_noise_scale = torch.tensor(0, device=img.device)
260
+ noise_type = random.choice(['gaussian', 'uniform', 'none'])
261
+ # Precompute noise_map only if needed
262
+ noise_map = None
263
+ if proc_type not in ['none', None, '']:
264
+ if proc_type == 'uncon':
265
+ noise_map = self.create_noise_map(img, noise_type=noise_type)
266
+ proc_img = noise_map
267
+ mask = torch.tensor(0, device=img.device)
268
+ cond_ratio = torch.tensor(0, device=img.device)
269
+ return proc_img, mask, cond_ratio
270
+ if proc_type in ['adding', 'independ', 'slice']:
271
+ # self.msk_noise_scale = 0
272
+ noise_map = self.create_noise_map(img, noise_type=noise_type)
273
+ if proc_type == 'adding':
274
+ proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
275
+ cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
276
+ elif proc_type == 'independ':
277
+ mask = self.create_noise_map(img, noise_type='binary')
278
+ if self.msk_noise_scale == 0:
279
+ proc_img = img * mask
280
+ else:
281
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
282
+ with torch.no_grad():
283
+ cond_ratio = mask.float().mean()
284
+ elif proc_type == 'downsample':
285
+ # proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
286
+ proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
287
+ cond_ratio = torch.tensor(down_ratio, device=img.device)
288
+ elif proc_type == 'slice':
289
+ slice_num_max = random.randint(1, 64)
290
+ slice_num_max = random.randint(1, slice_num_max)
291
+ mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
292
+ if self.msk_noise_scale == 0:
293
+ proc_img = img * mask
294
+ else:
295
+ proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
296
+ cond_ratio = torch.tensor(sample_ratio, device=img.device)
297
+ elif proc_type == 'project':
298
+ proc_img, proj_num = self.project(proc_img)
299
+ cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
300
+ # cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy
301
+ return proc_img, mask, cond_ratio
302
+
303
+ def diffuse(self, x_0, t):
304
+ t=torch.tensor(t)
305
+ # img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t)
306
+ # return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn
307
+ return self._get_random_ddf(img = x_0, t = t)
308
+
309
+
310
+ def recover(self, x, y, t,rec_num=2, text=None):
311
+ if isinstance(t, list):
312
+ t=[torch.tensor(t0) for t0 in t]
313
+ t=[t0.to(x.device) for t0 in t]
314
+ else:
315
+ t=torch.tensor(t)
316
+ t.to(x.device)
317
+ if rec_num is None:
318
+ rec_num = self.rec_num
319
+ return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
320
+
321
+ def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
322
+ """
323
+ use detach to recover:
324
+ - but not include no_grad
325
+ """
326
+ if isinstance(t, list):
327
+ t = [torch.tensor(t0, device=x.device) for t0 in t]
328
+ else:
329
+ t = torch.tensor(t, device=x.device)
330
+
331
+ if rec_num is None:
332
+ rec_num = self.rec_num
333
+
334
+ # params = {k: v.detach() for k, v in self.network.named_parameters()}
335
+ # buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer
336
+ # # functional_call require position args,here kwargs doesnot work, so:
337
+ # def _forward(module, kw):
338
+ # return module(**kw)
339
+ # # functional_call(module, ...) can only pass args/kwargs to module.forward
340
+ # # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs)
341
+ # return functional_call(
342
+ # self.network,
343
+ # (params, buffers),
344
+ # args=(),
345
+ # kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
346
+ # )
347
+
348
+ # 1) param detached
349
+ params = {k: v.detach() for k, v in self.network.named_parameters()}
350
+ # 2) buffers keeps unchanged
351
+ buffers = dict(self.network.named_buffers())
352
+
353
+ # 3) old version of PyTorch doesnot support passing params and buffers together
354
+ params_and_buffers = {}
355
+ params_and_buffers.update(params)
356
+ params_and_buffers.update(buffers)
357
+ return functional_call(
358
+ self.network,
359
+ params_and_buffers,
360
+ (),
361
+ kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
362
+ )
363
+
364
+
365
+ def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
366
+ if mask is None:
367
+ mask = 1
368
+ # org_imgs=self.copy_opt(x0)
369
+ if cond_imgs is None:
370
+ cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
371
+ noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
372
+ if isinstance(self.network,DefRec_MutAttnNet):
373
+ t = [t] * 1
374
+ return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
375
+
376
+ def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
377
+ if T is not None:
378
+ return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
379
+ else:
380
+ return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
381
+ # if mask is None:
382
+ # mask = 1
383
+ # cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs)
384
+ # noisy_imgs, dvf_I, _ = self.diffuse(x0, t)
385
+ # if isinstance(self.network, DefRec_MutAttnNet):
386
+ # t = [t] * 1
387
+ # return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I
388
+
389
+ def diff_recover(self,
390
+ img_org,
391
+ msk_org=None,
392
+ T=[None,None],
393
+ ddf_rand=None,
394
+ v_scale = None,
395
+ t_save=None,
396
+ cond_imgs=None,
397
+ proc_type=None,
398
+ text=None,
399
+ ):
400
+ if cond_imgs is None:
401
+ cond_imgs = img_org.clone().detach()
402
+ # if proc_type is not None:
403
+ cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
404
+ if ddf_rand is None:
405
+ if v_scale is not None:
406
+ self.v_scale=v_scale
407
+ self._DDF_Encoder_init()
408
+ if T[0] is None or T[0] == 0:
409
+ img_diff = img_org.clone().detach()
410
+ ddf_rand = torch.zeros_like(img_diff)
411
+ else:
412
+ img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
413
+ else:
414
+ img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
415
+ ddf_comp = ddf_rand.clone().detach()
416
+ img_rec = img_diff.clone().detach()
417
+ if msk_org is not None:
418
+ msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
419
+ else:
420
+ msk_diff = None
421
+ msk_rec = msk_diff.clone().detach() if msk_org is not None else None
422
+ img_save=[]
423
+ msk_save=[]
424
+
425
+ if isinstance(self.network,DefRec_MutAttnNet):
426
+ # Denosing image via list of t
427
+ t_list = list(range(T[1]-1, -1, -1))
428
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
429
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
430
+ img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
431
+ if msk_org is not None:
432
+ msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
433
+ else:
434
+ # Denosing image
435
+ if isinstance(T[-1], int):
436
+ time_steps = range(T[-1] - 1, -1, -1)
437
+ trainable_iterations =[]
438
+ else:
439
+ time_steps = T[-1]
440
+
441
+ # # Randomly select k iterations to make their parameters trainable
442
+ # win_len = 2 # Number of iterations to make trainable
443
+ # if len(time_steps) <= win_len:
444
+ # win_start = 0
445
+ # else:
446
+ # win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
447
+ # win_end = win_start + win_len - 1
448
+
449
+ k=2
450
+ # trainable_iterations = time_steps[win_start: win_start + win_len]
451
+ # trainable_iterations = random.sample(time_steps, k)
452
+ trainable_iterations = time_steps[-1:-k-1:-1]
453
+ # print(time_steps)
454
+ # print("trainable_iterations:", trainable_iterations)
455
+ for i in time_steps:
456
+ t = torch.tensor(np.array([i])).to(self.device)
457
+
458
+ if i in trainable_iterations:
459
+ # Make parameters trainable for this iteration
460
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
461
+ else:
462
+ # Freeze parameters for this iteration using torch.no_grad()
463
+ with torch.no_grad():
464
+ pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
465
+ # for idx, i in enumerate(time_steps):
466
+ # t = torch.tensor(np.array([i])).to(self.device)
467
+ # if idx < win_start:
468
+ # # just no_grad
469
+ # with torch.no_grad():
470
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
471
+ # elif win_start <= idx <= win_end:
472
+ # # normal update
473
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
474
+ # else:
475
+ # # freeze params but keep grad for input
476
+ # pre_dvf_I = self.recover_frozen_params_but_grad_input(
477
+ # x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text
478
+ # )
479
+
480
+ ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
481
+ # Apply to image
482
+ img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
483
+ if msk_org is not None:
484
+ msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
485
+ if t_save is not None:
486
+ if i in t_save:
487
+ img_save.append(img_rec)
488
+ if msk_org is not None:
489
+ msk_save.append(msk_rec)
490
+
491
+ # for i in time_steps:
492
+ # t = torch.tensor(np.array([i])).to(self.device)
493
+ # pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None)
494
+ # ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
495
+ # # apply to image
496
+ # img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
497
+ # if msk_org is not None:
498
+ # msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp)
499
+ # if t_save is not None:
500
+ # if i in t_save:
501
+ # img_save.append(img_rec)
502
+ # if msk_org is not None:
503
+ # msk_save.append(msk_rec)
504
+ # print(torch.max(torch.abs(ddf_comp)))
505
+ # print(torch.max(torch.abs(ddf_rand)))
506
+
507
+ return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
508
+
509
+ if __name__ == "__main__":
510
+ H, W = 8, 8
511
+ deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
512
+ # img = torch.zeros([1, 1, H, W])
513
+ img = torch.randn([1, 1, H, W])
514
+ t = 1
515
+ rec_num = 2
516
+ # proc_type = 'adding'
517
+ # proc_type = 'independ'
518
+ # proc_type = 'downsample'
519
+ proc_type = 'slice'
520
+ # proc_type = 'project'
521
+ # proc_type = 'none'
522
+ print(img)
523
+ cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
524
+ print(cond_imgs)
525
+ # img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type)
526
+ # print(img_rec.shape, dvf_I.shape)
527
+
528
+ # proc_type = 'adding'
529
+ # ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)
530
+
531
+
Diffusion/losses.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ losses for DRDM
3
+ """
4
+
5
+ import numpy as np
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ EPS=1e-7
12
+
13
+ # eps_scale = 10e-5
14
+ # eps_scale = 10e-4
15
+ # eps_scale = 1e-4
16
+ eps_scale = 1e-5
17
+
18
+
19
+ class LMSE(torch.nn.Module):
20
+ """
21
+ Labeled Mean Square Error (LMSE)
22
+ """
23
+
24
+ def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
25
+ super(LMSE, self).__init__()
26
+ self.eps = eps
27
+ self.relate_eps = relate_eps
28
+ self.ndims = 3
29
+ self.smooth = smooth
30
+ self.win = win
31
+ # Set window size
32
+ if self.win is None:
33
+ self.win = [5] * self.ndims
34
+ if smooth:
35
+ self.kernels = self._build_kernel(std=0.0)
36
+
37
+ def _build_kernel(self, std=0.0):
38
+ if std == 0.0:
39
+ return torch.ones([1, 1, *self.win])
40
+ else:
41
+ tail = int(np.ceil(std)) * 3
42
+ k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
43
+ kernel = k / torch.sum(k)
44
+ kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
45
+ # print(kernel.item)
46
+ return kernel.unsqueeze(0).unsqueeze(0)
47
+
48
+ def forward(self, I, J, label=None):
49
+ """
50
+ Computes the labeled mean squared error between I and J (ref).
51
+ If label is provided, computes the MSE only over the labeled regions.
52
+ """
53
+ padding = [(w-1) // 2 for w in self.win]
54
+ if self.smooth:
55
+ I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
56
+ J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
57
+ mse = (I - J) ** 2
58
+ if self.relate_eps is not None:
59
+ mse = mse/((J**2) + self.relate_eps)
60
+ if label is not None:
61
+ label = label.float()
62
+ mse = mse * label
63
+ mse_sum = torch.sum(mse, dim=(2, 3, 4))
64
+ label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
65
+ loss = torch.mean(mse_sum / label_sum)
66
+ else:
67
+ loss = torch.mean(mse)
68
+ return loss
69
+
70
+ class LNCC(torch.nn.Module):
71
+ """
72
+ Local (over window) normalized cross-correlation (LNCC)
73
+ """
74
+
75
+ def __init__(self, win=None, num_ch=1, eps=1e-6, central=True, smooth=True):
76
+ super(LNCC, self).__init__()
77
+ self.scale = 2e0
78
+ self.win = win
79
+ self.eps = eps
80
+ self.central = central
81
+ self.ndims = 3
82
+ self.strides = [1] * (self.ndims + 2)
83
+ self.smooth = smooth
84
+
85
+ # Set window size
86
+ if self.win is None:
87
+ self.win = [9] * self.ndims
88
+ self.padding = [(w-1) // 2 for w in self.win]
89
+
90
+ if smooth:
91
+ self.kernels = self._build_kernel(std=0.45)
92
+ self.sum_filt = self._build_kernel(std=0.0)
93
+
94
+ def _build_kernel(self, std=0.0):
95
+ if std == 0.0:
96
+ return torch.ones([1, 1, *self.win])/np.prod(self.win)
97
+ else:
98
+ self.tail = int(np.ceil(std)) * 2
99
+ k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
100
+ kernel = k / torch.sum(k)
101
+ # print(kernel)
102
+ kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
103
+ # kernel = kernel * np.prod(self.win)
104
+ # print('Gaussian kernel created with std:', std)
105
+ # print('Kernel sum:', torch.sum(kernel))
106
+
107
+ return kernel.unsqueeze(0).unsqueeze(0)
108
+
109
+ def lncc(self, I, J, label=None):
110
+ self.sum_filt = self.sum_filt.to(I.device)
111
+
112
+ if self.smooth:
113
+ self.kernels = self.kernels.to(I.device)
114
+ I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
115
+ J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
116
+
117
+ # if self.central:
118
+ # I = I - torch.mean(I, dim=(2, 3, 4), keepdim=True)
119
+ # J = J - torch.mean(J, dim=(2, 3, 4), keepdim=True)
120
+ # Compute CC squares
121
+ I2 = I * I
122
+ J2 = J * J
123
+ IJ = I * J
124
+
125
+ if self.central:
126
+ # Compute local sums via convolution
127
+ I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
128
+ J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
129
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
130
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
131
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
132
+
133
+ # Compute cross-correlation
134
+ win_size = np.prod(self.win)
135
+ # print('Window size:', win_size)
136
+ # u_I = I_sum / win_size
137
+ # u_J = J_sum / win_size
138
+ # cross = IJ_sum - ((I_sum * J_sum) / win_size)
139
+ # I_var = I2_sum - ((I_sum * I_sum) / win_size)
140
+ # J_var = J2_sum - ((J_sum * J_sum) / win_size)
141
+ cross = IJ_sum - (I_sum * J_sum)
142
+ I_var = I2_sum - (I_sum * I_sum)
143
+ J_var = J2_sum - (J_sum * J_sum)
144
+ else:
145
+ # if 1:
146
+ # Compute local sums via convolution
147
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
148
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
149
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
150
+
151
+ cross = IJ_sum
152
+ I_var = I2_sum
153
+ J_var = J2_sum
154
+
155
+ # cc = (cross * cross) / (I_var * J_var + self.eps)
156
+ cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
157
+ if label is not None:
158
+ label = label.float()
159
+ cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
160
+
161
+ return torch.mean(cc)
162
+
163
+ def forward(self, I, J, label=None):
164
+ return -self.lncc(I*self.scale, J*self.scale, label=label)
165
+
166
+
167
+
168
+ class NCC(torch.nn.Module):
169
+ # def __init__(self, eps_scale=10e-7,img_sz=256):
170
+ def __init__(self, eps_scale=10e-5,img_sz=256):
171
+ super(NCC, self).__init__()
172
+ self.eps_scale=eps_scale#*img_sz/256
173
+ # self.scale=10e4
174
+ self.scale=1e2
175
+
176
+ def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
177
+ if ddf_stn is None:
178
+ trm_pred=pred
179
+ else:
180
+ trm_pred=-ddf_stn(pred, inv_lab)
181
+ trm_pred = self.scale * trm_pred
182
+ inv_lab = self.scale * inv_lab
183
+ if mask is None:
184
+ loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
185
+ else:
186
+ batch_size = inv_lab.shape[0]
187
+ loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
188
+ return loss_gen
189
+
190
+ class MRSE(torch.nn.Module):
191
+ def __init__(self, eps_scale=eps_scale,img_sz=256):
192
+ super(MRSE, self).__init__()
193
+ self.eps_scale=eps_scale#*img_sz/256
194
+ self.scale = 10e1
195
+
196
+ def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
197
+ if ddf_stn is None:
198
+ trm_pred=pred
199
+ else:
200
+ trm_pred=-ddf_stn(pred, inv_lab)
201
+ trm_pred = self.scale * trm_pred
202
+ inv_lab = self.scale * inv_lab
203
+ if mask is None:
204
+ loss_gen = torch.mean(
205
+ torch.sum(torch.square(trm_pred + inv_lab), dim=1)
206
+ / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
207
+ )
208
+ else:
209
+ batch_size = inv_lab.shape[0]
210
+ loss_gen = torch.sum(
211
+ torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
212
+ / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
213
+ )/torch.sum(mask)/batch_size
214
+ return loss_gen/1
215
+
216
+ class RMSE(torch.nn.Module):
217
+ def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
218
+ super(RMSE, self).__init__()
219
+ self.eps_scale=eps_scale#*img_sz/256
220
+ self.ndims=ndims
221
+
222
+ def forward(self,pred,inv_lab=None,ddf_stn=None):
223
+ if ddf_stn is None:
224
+ trm_pred=pred
225
+ else:
226
+ trm_pred=-ddf_stn(pred, inv_lab)
227
+ loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
228
+ dim=list(range(1, 1 + self.ndims))) / (
229
+ torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
230
+ return loss_gen
231
+ # loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS))
232
+
233
+
234
+ class Grad(torch.nn.Module):
235
+ """
236
+ N-D gradient loss
237
+ """
238
+
239
+ def __init__(self, penalty=['l1'],ndims=2, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
240
+ super(Grad, self).__init__()
241
+ self.penalty = penalty
242
+ self.eps = eps
243
+ self.outrange_weight = outrange_weight
244
+ self.detj_weight=detj_weight
245
+ self.apear_scale = apear_scale
246
+ self.ndims=ndims
247
+ self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
248
+ self.act = torch.nn.ReLU(inplace=False)
249
+ self.dist=dist
250
+ self.sign=sign
251
+ self.waive_thresh=waive_thresh
252
+
253
+ def _diffs(self, y,dist=None):
254
+ if dist is None:
255
+ dist=self.dist
256
+ # vol_shape = y.size()[2:]
257
+ # vol_shape = y.get_shape().as_list()[1:-1]
258
+ # ndims = len(vol_shape)
259
+
260
+ df = [None] * self.ndims
261
+ for i in range(self.ndims):
262
+ d = i + 2
263
+ # permute dimensions to put the ith dimension first
264
+ r = [d, *range(d), *range(d + 1, self.ndims + 2)]
265
+ yp = y.permute(r)
266
+ dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
267
+
268
+ # permute back
269
+ # note: this might not be necessary for this loss specifically,
270
+ # since the results are just summed over anyway.
271
+ r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
272
+ df[i] = dfi.permute(r)
273
+ return df
274
+
275
+ def _eq_diffs(self, y,dist=None):
276
+ if dist is None:
277
+ dist=self.dist
278
+ # vol_shape = y.get_shape().as_list()[1:-1]
279
+ vol_shape = y.size()[2:]
280
+ ndims = len(vol_shape)
281
+ pad = [0, 0] * (ndims + 1) +[dist, 0]
282
+ pad1 = [0, 0] * (ndims + 1) +[0, dist]
283
+ # df = [None, None] * ndims
284
+ df = [None] * ndims
285
+ for i in range(ndims):
286
+ d = i + 2
287
+ r=[d, *range(d), *range(d + 1, ndims + 2)]
288
+ ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
289
+ yt = y.permute(r)
290
+ dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
291
+ df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
292
+ # df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
293
+ # df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri)
294
+ y.permute(ri)
295
+ return df
296
+
297
+ def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
298
+ if dist is None:
299
+ dist=self.dist
300
+ vol_shape = y.size()[2:]
301
+ ndims = len(vol_shape)
302
+ df = [None] * ndims
303
+
304
+ for i in range(ndims):
305
+ d = i + 2
306
+ r=[d, *range(d), *range(d + 1, ndims + 2)]
307
+ ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
308
+ yt = y.permute(r)
309
+ wt = w.permute(r)
310
+ dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
311
+ df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
312
+ y.permute(ri)
313
+ w.permute(ri)
314
+ return df
315
+
316
+ def _outl_dist(self, y,range_thresh=0.2):
317
+ self.device = y.device
318
+ vol_shape = y.size()[2:]
319
+ self.max_sz=self.max_sz.to(self.device)
320
+ act=torch.nn.ReLU(inplace=True)
321
+ loss=0.
322
+ for i in range(self.ndims):
323
+ d = i + 2
324
+ # permute dimensions to put the ith dimension first
325
+ r = [d, *range(d), *range(d + 1, self.ndims + 2)]
326
+ ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
327
+ yt = y.permute(r)
328
+ loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
329
+ # loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh)))
330
+ y.permute(ri)
331
+ return loss/self.ndims
332
+
333
+ def _center_dist(self, y):
334
+ self.device = y.device
335
+ vol_shape = y.size()[2:]
336
+ self.max_sz=self.max_sz.to(self.device)
337
+ select_loc = [s // 2 for s in vol_shape]
338
+ if self.ndims==3:
339
+ # return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))
340
+ return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
341
+ elif self.ndims == 2:
342
+ # return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))
343
+ return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
344
+
345
+
346
+ # def _eval_detJ(self, disp=None, weight=None):
347
+ # weight = 1
348
+ # if self.ndims==3:
349
+ # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + (
350
+ # disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + (
351
+ # disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - (
352
+ # disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - (
353
+ # disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - (
354
+ # disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...])
355
+ # elif self.ndims==2:
356
+ # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...])
357
+
358
+ # return detj * weight
359
+
360
+ def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
361
+ """
362
+ disp: list length ndims
363
+ disp[i] is derivative wrt spatial dim i (forward diff),
364
+ tensor shape [B, C=ndims, ...]
365
+ add_identity: True if y_pred is displacement u and phi=x+u
366
+ spacing: voxel spacing (or 1.0). If you care about physical units,
367
+ divide derivatives by spacing (and dist). Sign won't change.
368
+ """
369
+ # Optional scaling (won't affect sign as long as spacing>0)
370
+ if spacing != 1.0:
371
+ disp = [d / spacing for d in disp]
372
+
373
+ if self.ndims == 2:
374
+ dux_dx = disp[0][:, 0, ...]
375
+ duy_dx = disp[0][:, 1, ...]
376
+ dux_dy = disp[1][:, 0, ...]
377
+ duy_dy = disp[1][:, 1, ...]
378
+
379
+ if add_identity:
380
+ j11 = 1.0 + dux_dx
381
+ j22 = 1.0 + duy_dy
382
+ else:
383
+ j11 = dux_dx
384
+ j22 = duy_dy
385
+
386
+ detj = j11 * j22 - dux_dy * duy_dx
387
+ return detj
388
+
389
+ elif self.ndims == 3:
390
+ dux_dx = disp[0][:, 0, ...]
391
+ duy_dx = disp[0][:, 1, ...]
392
+ duz_dx = disp[0][:, 2, ...]
393
+
394
+ dux_dy = disp[1][:, 0, ...]
395
+ duy_dy = disp[1][:, 1, ...]
396
+ duz_dy = disp[1][:, 2, ...]
397
+
398
+ dux_dz = disp[2][:, 0, ...]
399
+ duy_dz = disp[2][:, 1, ...]
400
+ duz_dz = disp[2][:, 2, ...]
401
+
402
+ if add_identity:
403
+ j11 = 1.0 + dux_dx
404
+ j22 = 1.0 + duy_dy
405
+ j33 = 1.0 + duz_dz
406
+ else:
407
+ j11 = dux_dx
408
+ j22 = duy_dy
409
+ j33 = duz_dz
410
+
411
+ j12 = dux_dy; j13 = dux_dz
412
+ j21 = duy_dx; j23 = duy_dz
413
+ j31 = duz_dx; j32 = duz_dy
414
+
415
+ detj = (
416
+ j11 * (j22 * j33 - j23 * j32)
417
+ - j12 * (j21 * j33 - j23 * j31)
418
+ + j13 * (j21 * j32 - j22 * j31)
419
+ )
420
+ return detj
421
+
422
+ else:
423
+ raise ValueError(f"Unsupported ndims={self.ndims}")
424
+
425
+
426
+ def forward(self, y_pred=None,x_in=None, img=None, msk=None):
427
+ reg_loss = 0
428
+ act=torch.nn.ReLU(inplace=True)
429
+
430
+ dg = 1
431
+ if img is not None:
432
+ dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
433
+ if msk is not None:
434
+ dg = dg * msk
435
+
436
+ if 'l1' in self.penalty:
437
+ df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
438
+ reg_loss += sum(df) / len(df)
439
+
440
+ if 'l2' in self.penalty:
441
+ df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
442
+ reg_loss += torch.sqrt(sum(df) / len(df))
443
+
444
+ if 'negdetj' in self.penalty:
445
+ df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0])
446
+ reg_loss += 0.5*df
447
+ if 'range' in self.penalty:
448
+ reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+
449
+ if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
450
+ mean_dim=list(range(1, self.ndims + 2))
451
+ dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
452
+ dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
453
+
454
+ y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
455
+ x_in = x_in if isinstance(x_in,list) else [x_in]
456
+ if 'std' in self.penalty:
457
+ reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
458
+ if 'param' in self.penalty:
459
+ for id, d in enumerate(self.dist):
460
+ df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
461
+ reg_loss += 1 * (df) / len(self.dist)
462
+
463
+ if 'detj' in self.penalty:
464
+ df = torch.mean(torch.abs(
465
+ torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
466
+ reg_loss += 0.5*df
467
+
468
+ return reg_loss
469
+
470
+
471
+ def avg_std_skew_kurt(array,ndims=2):
472
+ dim = list(range(2, ndims + 2))
473
+ mean = torch.mean(array,dim=dim)
474
+ diffs = array - mean
475
+ var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
476
+ std = torch.pow(var, 0.5)
477
+ zscores = diffs / std
478
+ skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
479
+ kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
480
+ return [mean,std,skews,kurtoses]
481
+
482
+ def grad_std(array,ndims=2):
483
+ dim = list(range(2, ndims + 2))
484
+ array=torch.clamp(array,min=-0.8,max=0.8)
485
+ dim0=list(range(1,ndims+2))
486
+ std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
487
+ return std
488
+
489
+ def avg_std(array,ndims=2):
490
+ dim = list(range(2, ndims + 2))
491
+ return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
492
+
493
+
494
+ if __name__ == "__main__":
495
+ # ndims=2
496
+ # dist=[16,32]
497
+ # ddf = torch.rand(1,2,128,128)
498
+ # # ddf[:,:,0,:]=ddf[:,:,0,:]-1
499
+ # # ddf[:,:,1,:]=ddf[:,:,1,:]+1
500
+ # # ddf[:,:,0,0]=ddf[:,:,0,0] -1
501
+ # # ddf[:,:,1,1]=ddf[:,:,1,1] +1
502
+ # # ddf[:,0,0,1]=ddf[:,0,0,1] +1
503
+ # # ddf[:,1,0,1]=ddf[:,1,0,1] -1
504
+ # # ddf[:,0,0,1]=ddf[:,0,0,1] -1
505
+ # # ddf[:,1,0,1]=ddf[:,1,0,1] +1
506
+ # # ddf[:,1,1,0]=ddf[:,1,1,0] -1
507
+ # # ddf[:,0,1,0]=ddf[:,0,1,0] +1
508
+ # ddf=ddf
509
+ # img = torch.rand(1,1,128,128)
510
+ # x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
511
+ # x_in=[torch.tensor(x_in).type(torch.float32),0.]
512
+
513
+ # Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
514
+ # loss_detj = Loss_detj(ddf,x_in,img)
515
+ # print(loss_detj)
516
+
517
+ size = 128
518
+ smooth = True
519
+ # smooth = False
520
+ img3d = torch.empty(1,1,size,size,size).uniform_(0,1)
521
+ img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
522
+ # img3d_t = img3d.clone().detach()
523
+ # img3d_t = torch.zeros_like(img3d)
524
+ translation = 2
525
+ start = 0
526
+ end = 32
527
+ # img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
528
+ # img3d_t[:,:,:,translation:,translation:] = img3d[:,:,:,:size-translation,:size-translation]
529
+ img3d_t[:,:,:,:,translation:] = img3d[:,:,:,:,:size-translation]
530
+ # img3d_t[:,:,start:end,start:end,start:end] = img3d[:,:,start+translation:end+translation,start+translation:end+translation,start+translation:end+translation]
531
+ img3d_t = img3d_t
532
+ loss_ncc = LNCC(smooth=smooth,central=True)
533
+ loss_sim = loss_ncc(img3d, img3d_t)
534
+ print(loss_sim)
Diffusion/losses_ncc0.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ losses for DRDM
3
+ """
4
+
5
+ import numpy as np
6
+ import sys
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ EPS=1e-7
12
+
13
+ # eps_scale = 10e-5
14
+ # eps_scale = 10e-4
15
+ # eps_scale = 1e-4
16
+ eps_scale = 1e-5
17
+
18
+
19
+
20
+ class LMSE(torch.nn.Module):
21
+ """
22
+ Labeled Mean Square Error (LMSE)
23
+ """
24
+
25
+ def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
26
+ super(LMSE, self).__init__()
27
+ self.eps = eps
28
+ self.relate_eps = relate_eps
29
+ self.ndims = 3
30
+ self.smooth = smooth
31
+ self.win = win
32
+ # Set window size
33
+ if self.win is None:
34
+ self.win = [5] * self.ndims
35
+ if smooth:
36
+ self.kernels = self._build_kernel(std=0.0)
37
+
38
+ def _build_kernel(self, std=0.0):
39
+ if std == 0.0:
40
+ return torch.ones([1, 1, *self.win])
41
+ else:
42
+ tail = int(np.ceil(std)) * 3
43
+ k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
44
+ kernel = k / torch.sum(k)
45
+ kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
46
+ return kernel.unsqueeze(0).unsqueeze(0)
47
+
48
+ def forward(self, I, J, label=None):
49
+ """
50
+ Computes the labeled mean squared error between I and J (ref).
51
+ If label is provided, computes the MSE only over the labeled regions.
52
+ """
53
+ padding = [(w-1) // 2 for w in self.win]
54
+ if self.smooth:
55
+ I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
56
+ J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
57
+ mse = (I - J) ** 2
58
+ if self.relate_eps is not None:
59
+ mse = mse/((J**2) + self.relate_eps)
60
+ if label is not None:
61
+ label = label.float()
62
+ mse = mse * label
63
+ mse_sum = torch.sum(mse, dim=(2, 3, 4))
64
+ label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
65
+ loss = torch.mean(mse_sum / label_sum)
66
+ else:
67
+ loss = torch.mean(mse)
68
+ return loss
69
+
70
+ class LNCC(torch.nn.Module):
71
+ """
72
+ Local (over window) normalized cross-correlation (LNCC)
73
+ """
74
+
75
+ def __init__(self, win=None, num_ch=1, eps=1e-7, central=True, smooth=False):
76
+ super(LNCC, self).__init__()
77
+ self.win = win
78
+ self.eps = eps
79
+ self.central = central
80
+ self.ndims = 3
81
+ self.strides = [1] * (self.ndims + 2)
82
+ self.smooth = smooth
83
+
84
+ # Set window size
85
+ if self.win is None:
86
+ self.win = [11] * self.ndims
87
+
88
+ if smooth:
89
+ self.kernels = self._build_kernel(std=0.5)
90
+ self.sum_filt = self._build_kernel(std=0.0)
91
+
92
+ def _build_kernel(self, std=0.0):
93
+ if std == 0.0:
94
+ return torch.ones([1, 1, *self.win])
95
+ else:
96
+ tail = int(np.ceil(std)) * 3
97
+ k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
98
+ kernel = k / torch.sum(k)
99
+ kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
100
+ return kernel.unsqueeze(0).unsqueeze(0)
101
+
102
+ def lncc(self, I, J, label=None):
103
+ self.sum_filt = self.sum_filt.to(I.device)
104
+ padding = [(w-1) // 2 for w in self.win]
105
+
106
+ if self.smooth:
107
+ I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
108
+ J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
109
+
110
+ # Compute CC squares
111
+ I2 = I * I
112
+ J2 = J * J
113
+ IJ = I * J
114
+
115
+ if self.central:
116
+ # Compute local sums via convolution
117
+ I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=padding)
118
+ J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=padding)
119
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
120
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
121
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
122
+
123
+ # Compute cross-correlation
124
+ win_size = np.prod(self.win)
125
+ cross = IJ_sum - (I_sum * J_sum) / win_size
126
+ I_var = I2_sum - (I_sum * I_sum) / win_size
127
+ J_var = J2_sum - (J_sum * J_sum) / win_size
128
+ else:
129
+ # Compute local sums via convolution
130
+ I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
131
+ J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
132
+ IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
133
+
134
+ cross = IJ_sum
135
+ I_var = I2_sum
136
+ J_var = J2_sum
137
+
138
+ cc = (cross * cross) / (I_var * J_var + self.eps)
139
+ if label is not None:
140
+ label = label.float()
141
+ cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
142
+
143
+ return torch.mean(cc)
144
+
145
+ def forward(self, I, J, label=None):
146
+ return -self.lncc(I, J, label=label)
147
+
148
+
149
+
150
+ class NCC(torch.nn.Module):
151
+ # def __init__(self, eps_scale=10e-7,img_sz=256):
152
+ def __init__(self, eps_scale=10e-5,img_sz=256):
153
+ super(NCC, self).__init__()
154
+ self.eps_scale=eps_scale#*img_sz/256
155
+ # self.scale=10e4
156
+ self.scale=1e2
157
+
158
+ def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
159
+ if ddf_stn is None:
160
+ trm_pred=pred
161
+ else:
162
+ trm_pred=-ddf_stn(pred, inv_lab)
163
+ trm_pred = self.scale * trm_pred
164
+ inv_lab = self.scale * inv_lab
165
+ if mask is None:
166
+ loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
167
+ else:
168
+ batch_size = inv_lab.shape[0]
169
+ loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
170
+ return loss_gen
171
+
172
+ class MRSE(torch.nn.Module):
173
+ def __init__(self, eps_scale=eps_scale,img_sz=256):
174
+ super(MRSE, self).__init__()
175
+ self.eps_scale=eps_scale#*img_sz/256
176
+ self.scale = 10e1
177
+
178
+ def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
179
+ if ddf_stn is None:
180
+ trm_pred=pred
181
+ else:
182
+ trm_pred=-ddf_stn(pred, inv_lab)
183
+ trm_pred = self.scale * trm_pred
184
+ inv_lab = self.scale * inv_lab
185
+ if mask is None:
186
+ loss_gen = torch.mean(
187
+ torch.sum(torch.square(trm_pred + inv_lab), dim=1)
188
+ / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
189
+ )
190
+ else:
191
+ batch_size = inv_lab.shape[0]
192
+ loss_gen = torch.sum(
193
+ torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
194
+ / (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
195
+ )/torch.sum(mask)/batch_size
196
+ return loss_gen/1
197
+
198
+ class RMSE(torch.nn.Module):
199
+ def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
200
+ super(RMSE, self).__init__()
201
+ self.eps_scale=eps_scale#*img_sz/256
202
+ self.ndims=ndims
203
+
204
+ def forward(self,pred,inv_lab=None,ddf_stn=None):
205
+ if ddf_stn is None:
206
+ trm_pred=pred
207
+ else:
208
+ trm_pred=-ddf_stn(pred, inv_lab)
209
+ loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
210
+ dim=list(range(1, 1 + self.ndims))) / (
211
+ torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
212
+ return loss_gen
213
+ # loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS))
214
+
215
+ class Grad(torch.nn.Module):
216
+ """
217
+ N-D gradient loss
218
+ """
219
+
220
+ def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
221
+ super(Grad, self).__init__()
222
+ self.penalty = penalty
223
+ self.eps = eps
224
+ self.outrange_weight = outrange_weight
225
+ self.detj_weight=detj_weight
226
+ self.apear_scale = apear_scale
227
+ self.ndims=ndims
228
+ self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
229
+ self.act = torch.nn.ReLU(inplace=False)
230
+ self.dist=dist
231
+ self.sign=sign
232
+ self.waive_thresh=waive_thresh
233
+
234
+ def _diffs(self, y,dist=None):
235
+ if dist is None:
236
+ dist=self.dist
237
+ # vol_shape = y.size()[2:]
238
+ # vol_shape = y.get_shape().as_list()[1:-1]
239
+ # ndims = len(vol_shape)
240
+
241
+ df = [None] * self.ndims
242
+ for i in range(self.ndims):
243
+ d = i + 2
244
+ # permute dimensions to put the ith dimension first
245
+ r = [d, *range(d), *range(d + 1, self.ndims + 2)]
246
+ yp = y.permute(r)
247
+ dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
248
+
249
+ # permute back
250
+ # note: this might not be necessary for this loss specifically,
251
+ # since the results are just summed over anyway.
252
+ r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
253
+ df[i] = dfi.permute(r)
254
+ return df
255
+
256
+ def _eq_diffs(self, y,dist=None):
257
+ if dist is None:
258
+ dist=self.dist
259
+ # vol_shape = y.get_shape().as_list()[1:-1]
260
+ vol_shape = y.size()[2:]
261
+ ndims = len(vol_shape)
262
+ pad = [0, 0] * (ndims + 1) +[dist, 0]
263
+ pad1 = [0, 0] * (ndims + 1) +[0, dist]
264
+ # df = [None, None] * ndims
265
+ df = [None] * ndims
266
+ for i in range(ndims):
267
+ d = i + 2
268
+ r=[d, *range(d), *range(d + 1, ndims + 2)]
269
+ ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
270
+ yt = y.permute(r)
271
+ dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
272
+ df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
273
+ # df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
274
+ # df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri)
275
+ y.permute(ri)
276
+ return df
277
+
278
+ def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
279
+ if dist is None:
280
+ dist=self.dist
281
+ vol_shape = y.size()[2:]
282
+ ndims = len(vol_shape)
283
+ df = [None] * ndims
284
+
285
+ for i in range(ndims):
286
+ d = i + 2
287
+ r=[d, *range(d), *range(d + 1, ndims + 2)]
288
+ ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
289
+ yt = y.permute(r)
290
+ wt = w.permute(r)
291
+ dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
292
+ df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
293
+ y.permute(ri)
294
+ w.permute(ri)
295
+ return df
296
+
297
+ def _outl_dist(self, y,range_thresh=0.2):
298
+ self.device = y.device
299
+ vol_shape = y.size()[2:]
300
+ self.max_sz=self.max_sz.to(self.device)
301
+ act=torch.nn.ReLU(inplace=True)
302
+ loss=0.
303
+ for i in range(self.ndims):
304
+ d = i + 2
305
+ # permute dimensions to put the ith dimension first
306
+ r = [d, *range(d), *range(d + 1, self.ndims + 2)]
307
+ ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
308
+ yt = y.permute(r)
309
+ loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
310
+ # loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh)))
311
+ y.permute(ri)
312
+ return loss/self.ndims
313
+
314
+ def _center_dist(self, y):
315
+ self.device = y.device
316
+ vol_shape = y.size()[2:]
317
+ self.max_sz=self.max_sz.to(self.device)
318
+ select_loc = [s // 2 for s in vol_shape]
319
+ if self.ndims==3:
320
+ # return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))
321
+ return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
322
+ elif self.ndims == 2:
323
+ # return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))
324
+ return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
325
+
326
+
327
+ # def _eval_detJ(self, disp=None, weight=None):
328
+ # weight = 1
329
+ # if self.ndims==3:
330
+ # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + (
331
+ # disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + (
332
+ # disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - (
333
+ # disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - (
334
+ # disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - (
335
+ # disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...])
336
+ # elif self.ndims==2:
337
+ # detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...])
338
+
339
+ # return detj * weight
340
+
341
+ def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
342
+ """
343
+ disp: list length ndims
344
+ disp[i] is derivative wrt spatial dim i (forward diff),
345
+ tensor shape [B, C=ndims, ...]
346
+ add_identity: True if y_pred is displacement u and phi=x+u
347
+ spacing: voxel spacing (or 1.0). If you care about physical units,
348
+ divide derivatives by spacing (and dist). Sign won't change.
349
+ """
350
+ # Optional scaling (won't affect sign as long as spacing>0)
351
+ if spacing != 1.0:
352
+ disp = [d / spacing for d in disp]
353
+
354
+ if self.ndims == 2:
355
+ dux_dx = disp[0][:, 0, ...]
356
+ duy_dx = disp[0][:, 1, ...]
357
+ dux_dy = disp[1][:, 0, ...]
358
+ duy_dy = disp[1][:, 1, ...]
359
+
360
+ if add_identity:
361
+ j11 = 1.0 + dux_dx
362
+ j22 = 1.0 + duy_dy
363
+ else:
364
+ j11 = dux_dx
365
+ j22 = duy_dy
366
+
367
+ detj = j11 * j22 - dux_dy * duy_dx
368
+ return detj
369
+
370
+ elif self.ndims == 3:
371
+ dux_dx = disp[0][:, 0, ...]
372
+ duy_dx = disp[0][:, 1, ...]
373
+ duz_dx = disp[0][:, 2, ...]
374
+
375
+ dux_dy = disp[1][:, 0, ...]
376
+ duy_dy = disp[1][:, 1, ...]
377
+ duz_dy = disp[1][:, 2, ...]
378
+
379
+ dux_dz = disp[2][:, 0, ...]
380
+ duy_dz = disp[2][:, 1, ...]
381
+ duz_dz = disp[2][:, 2, ...]
382
+
383
+ if add_identity:
384
+ j11 = 1.0 + dux_dx
385
+ j22 = 1.0 + duy_dy
386
+ j33 = 1.0 + duz_dz
387
+ else:
388
+ j11 = dux_dx
389
+ j22 = duy_dy
390
+ j33 = duz_dz
391
+
392
+ j12 = dux_dy; j13 = dux_dz
393
+ j21 = duy_dx; j23 = duy_dz
394
+ j31 = duz_dx; j32 = duz_dy
395
+
396
+ detj = (
397
+ j11 * (j22 * j33 - j23 * j32)
398
+ - j12 * (j21 * j33 - j23 * j31)
399
+ + j13 * (j21 * j32 - j22 * j31)
400
+ )
401
+ return detj
402
+
403
+ else:
404
+ raise ValueError(f"Unsupported ndims={self.ndims}")
405
+
406
+
407
+ def forward(self, y_pred=None,x_in=None, img=None, msk=None):
408
+ reg_loss = 0
409
+ act=torch.nn.ReLU(inplace=True)
410
+
411
+ dg = 1
412
+ if img is not None:
413
+ dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
414
+ if msk is not None:
415
+ dg = dg * msk
416
+
417
+ if 'l1' in self.penalty:
418
+ df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
419
+ reg_loss += sum(df) / len(df)
420
+
421
+ if 'l2' in self.penalty:
422
+ df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
423
+ reg_loss += torch.sqrt(sum(df) / len(df))
424
+
425
+ if 'negdetj' in self.penalty:
426
+ df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0])
427
+ reg_loss += 0.5*df
428
+ if 'range' in self.penalty:
429
+ reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+
430
+ if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
431
+ mean_dim=list(range(1, self.ndims + 2))
432
+ dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
433
+ dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
434
+
435
+ y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
436
+ x_in = x_in if isinstance(x_in,list) else [x_in]
437
+ if 'std' in self.penalty:
438
+ reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
439
+ if 'param' in self.penalty:
440
+ for id, d in enumerate(self.dist):
441
+ df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
442
+ reg_loss += 1 * (df) / len(self.dist)
443
+
444
+ if 'detj' in self.penalty:
445
+ df = torch.mean(torch.abs(
446
+ torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
447
+ reg_loss += 0.5*df
448
+
449
+ return reg_loss
450
+
451
+
452
+ def avg_std_skew_kurt(array,ndims=2):
453
+ dim = list(range(2, ndims + 2))
454
+ mean = torch.mean(array,dim=dim)
455
+ diffs = array - mean
456
+ var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
457
+ std = torch.pow(var, 0.5)
458
+ zscores = diffs / std
459
+ skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
460
+ kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
461
+ return [mean,std,skews,kurtoses]
462
+
463
+ def grad_std(array,ndims=2):
464
+ dim = list(range(2, ndims + 2))
465
+ array=torch.clamp(array,min=-0.8,max=0.8)
466
+ dim0=list(range(1,ndims+2))
467
+ std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
468
+ return std
469
+
470
+ def avg_std(array,ndims=2):
471
+ dim = list(range(2, ndims + 2))
472
+ return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
473
+
474
+
475
+ if __name__ == "__main__":
476
+ ndims=2
477
+ dist=[16,32]
478
+ ddf = torch.rand(1,2,128,128)
479
+ # ddf[:,:,0,:]=ddf[:,:,0,:]-1
480
+ # ddf[:,:,1,:]=ddf[:,:,1,:]+1
481
+ # ddf[:,:,0,0]=ddf[:,:,0,0] -1
482
+ # ddf[:,:,1,1]=ddf[:,:,1,1] +1
483
+ # ddf[:,0,0,1]=ddf[:,0,0,1] +1
484
+ # ddf[:,1,0,1]=ddf[:,1,0,1] -1
485
+ # ddf[:,0,0,1]=ddf[:,0,0,1] -1
486
+ # ddf[:,1,0,1]=ddf[:,1,0,1] +1
487
+ # ddf[:,1,1,0]=ddf[:,1,1,0] -1
488
+ # ddf[:,0,1,0]=ddf[:,0,1,0] +1
489
+ ddf=ddf
490
+ img = torch.rand(1,1,128,128)
491
+ x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
492
+ x_in=[torch.tensor(x_in).type(torch.float32),0.]
493
+
494
+ Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
495
+ loss_detj = Loss_detj(ddf,x_in,img)
496
+ print(loss_detj)
Diffusion/networks.py ADDED
@@ -0,0 +1,1167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+
7
+ def get_net(name="recresnet"):
8
+ name = name.lower()
9
+ if name == "recresacnet":
10
+ net = RecResACNet
11
+ elif name == "recmutattnnet":
12
+ net = RecMutAttnNet
13
+ elif name == "recmutattnnet0":
14
+ net = RecMutAttnNet0
15
+ elif name == "recmutattnnet1":
16
+ net = RecMutAttnNet1
17
+ elif name == "defrecmutattnnet":
18
+ net = DefRec_MutAttnNet
19
+ elif name == "recmutattnnet_contrastive":
20
+ net = RecMutAttnNet_contrastive
21
+ else:
22
+ net = None
23
+ return net
24
+
25
+
26
+
27
+ def sinusoidal_embedding(n, d):
28
+ # Returns the standard positional embedding
29
+ embedding = torch.zeros(n, d)
30
+ wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
31
+ wk = wk.reshape((1, d))
32
+ t = torch.arange(n).reshape((n, 1))
33
+ embedding[:,::2] = torch.sin(t * wk[:,::2])
34
+ embedding[:,1::2] = torch.cos(t * wk[:,::2])
35
+ return embedding
36
+
37
+ class AtrousBlock(nn.Module):
38
+ def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
39
+ super(AtrousBlock, self).__init__()
40
+ # if 0 not in shape:
41
+ if normalize:
42
+ # print(shape)
43
+ # self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024
44
+ norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024
45
+ self.ln = norm(out_c,affine=True)
46
+ else:
47
+ self.ln = nn.Identity()
48
+ Conv=getattr(nn,'Conv%dd' % ndims)
49
+ if in_c!=out_c:
50
+ self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None
51
+ else:
52
+ self.conv0 = None
53
+ self.convs = nn.ModuleList([
54
+ Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
55
+ if ar>0 else Conv(out_c, out_c, 1, 1, 0)
56
+ for ar in atrous_rates
57
+ ])
58
+ # self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding)
59
+ # self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding)
60
+ self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
61
+ # self.activation = nn.ReLU() if activation is None else activation
62
+ # self.activation = nn.ReLU()
63
+ self.normalize = normalize
64
+
65
+ def forward(self, x):
66
+ if self.conv0 is not None:
67
+ x = self.conv0(x) #if self.conv0 is not None else x
68
+ x = self.ln(x) if self.normalize else x # jzheng 15/03/2024
69
+ out=nn.Identity()(x)
70
+ for conv in self.convs:
71
+ out = self.activation(out)
72
+ out = conv(out)
73
+ return self.activation(out+x)
74
+
75
+ # ==============================================
76
+ # Unconditional Network
77
+ # ==============================================
78
+
79
+ class RecResACNet(nn.Module):
80
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
81
+ super(RecResACNet, self).__init__()
82
+
83
+ self.dimension = ndims
84
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
85
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
86
+
87
+ # Sinusoidal embedding
88
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
89
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
90
+ self.time_embed.requires_grad_(False)
91
+
92
+ # First half
93
+ self.te1 = self._make_te(time_emb_dim, 1)
94
+ self.b1 = nn.Sequential(
95
+ AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
96
+ AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
97
+ AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
98
+
99
+ )
100
+ self.down1 = self.Conv(10, 10, 4, 2, 1)
101
+
102
+ self.te2 = self._make_te(time_emb_dim, 10)
103
+ self.b2 = nn.Sequential(
104
+ AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
105
+ AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
106
+ AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
107
+ )
108
+ self.down2 = self.Conv(20, 20, 4, 2, 1)
109
+
110
+ self.te3 = self._make_te(time_emb_dim, 20)
111
+ self.b3 = nn.Sequential(
112
+ AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
113
+ AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
114
+ AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
115
+ )
116
+ self.down3 = self.Conv(40, 40, 4, 2, 1)
117
+
118
+ # Bottleneck
119
+ self.te_mid = self._make_te(time_emb_dim, 40)
120
+ self.b_mid = nn.Sequential(
121
+ AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
122
+ AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
123
+ AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
124
+ )
125
+
126
+ # Second half
127
+ self.up1 = self.ConvT(40, 40, 4, 2, 1)
128
+
129
+ self.te4 = self._make_te(time_emb_dim, 80)
130
+ self.b4 = nn.Sequential(
131
+ AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
132
+ AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
133
+ AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
134
+ )
135
+
136
+ self.up2 = self.ConvT(20, 20, 4, 2, 1)
137
+ self.te5 = self._make_te(time_emb_dim, 40)
138
+ self.b5 = nn.Sequential(
139
+ AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
140
+ AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
141
+ AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
142
+ )
143
+
144
+ self.up3 = self.ConvT(10, 10, 4, 2, 1)
145
+ self.te_out = self._make_te(time_emb_dim, 20)
146
+ self.b_out = nn.Sequential(
147
+ AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
148
+ AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
149
+ AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
150
+ )
151
+
152
+ self.conv_out = self.Conv(10, ndims, 3, 1, 1)
153
+
154
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
155
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
156
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
157
+ zip(sample_coords, max_sz)], 1)
158
+
159
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
160
+ ref = self.ref_grid if ref is None else ref
161
+ img_sz = self.max_sz if img_sz is None else img_sz
162
+ # resample_mode = 'bicubic'
163
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
164
+ # padding_mode = "border"
165
+
166
+ if True:
167
+ # return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
168
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
169
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
170
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
171
+ align_corners=True)
172
+
173
+ def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
174
+ #
175
+ self.device = x.device
176
+ # [h, w] = x.size()[2:]
177
+ img_sz = x.size()[2:]
178
+ n = x.size()[0]
179
+ self.max_sz = [img_sz[0]] * self.dimension
180
+ ts_emb_shape=[n,-1]+[1]*self.dimension
181
+ # [h,w]=img_sz
182
+ # self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2])
183
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
184
+ # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0),
185
+ # [1, 2, h, w]).to(self.device)
186
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
187
+ [1, self.dimension]+list(img_sz)).to(self.device)
188
+ img = x
189
+
190
+ # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
191
+ t = self.time_embed(t)
192
+
193
+ for rec_id in range(rec_num):
194
+ out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28)
195
+ out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14)
196
+ out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7)
197
+
198
+ out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3)
199
+
200
+ out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
201
+ out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7)
202
+
203
+ out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
204
+ out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14)
205
+
206
+ out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
207
+ out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28)
208
+
209
+ out = self.conv_out(out)
210
+
211
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
212
+ if rec_id == 0:
213
+ ddf = ddf_one
214
+ else:
215
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
216
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
217
+
218
+ return ddf
219
+
220
+ def _make_te(self, dim_in, dim_out):
221
+ # make time embedding
222
+
223
+ return nn.Sequential(
224
+ nn.Linear(dim_in, dim_out),
225
+ # nn.SiLU(),
226
+ nn.ReLU(),
227
+ nn.Linear(dim_out, dim_out)
228
+ )
229
+
230
+ # ==============================================
231
+ # Conditional Network
232
+ # ==============================================
233
+
234
+ class cross_attn(nn.Module):
235
+ def __init__(self, q, k, v, ndims=2):
236
+ self.q = q
237
+ self.k = k
238
+ self.v = v
239
+ self.ndims = ndims
240
+ self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
241
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
242
+ self.softmax = nn.Softmax(dim=-1)
243
+ self.gamma = nn.Parameter(torch.zeros(1))
244
+
245
+ def forward(self, x, y):
246
+ q = self.q(x)
247
+ k = self.k(y)
248
+ v = self.v(y)
249
+ attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
250
+ out = torch.matmul(attn, v)
251
+ return out
252
+
253
+ class DefRec_MutAttnNet(nn.Module):
254
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
255
+ super(DefRec_MutAttnNet, self).__init__()
256
+
257
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
258
+ # self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
259
+ self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
260
+ self.conditional_input = conditional_input
261
+ self.num_heads = num_heads
262
+ self.text_feat_chn = text_feat_chn
263
+
264
+ self.dimension = ndims
265
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
266
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
267
+ self.copy = nn.Identity()
268
+ # Sinusoidal embedding
269
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
270
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
271
+ self.time_embed.requires_grad_(False)
272
+ self.hier_num = len(self.feat_channels) - 1
273
+ self.down_layers = nn.ModuleList()
274
+ self.up_layers = nn.ModuleList()
275
+ self.ted_layers = nn.ModuleList()
276
+ self.teu_layers = nn.ModuleList()
277
+ self.block_down = nn.ModuleList()
278
+ self.block_up = nn.ModuleList()
279
+ if self.conditional_input:
280
+ self.block_down_cond = nn.ModuleList()
281
+ self.fuse_conv0 = nn.ModuleList()
282
+ # self.fuse_conv1 = nn.ModuleList()
283
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
284
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
285
+ self.global_maxpool = Global_Maxpool(1)
286
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
287
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
288
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
289
+ self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
290
+ self.img_res = [res]*self.dimension
291
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
292
+ [1, self.dimension]+list(self.img_res))
293
+
294
+ for i in range(1, self.hier_num + 1):
295
+ j=-i
296
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
297
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
298
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
299
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
300
+ self.block_down.append(nn.Sequential(
301
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
302
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
303
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
304
+ ))
305
+ if self.conditional_input:
306
+ self.block_down_cond.append(nn.Sequential(
307
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
308
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
309
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
310
+ ))
311
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
312
+ # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
313
+ if i==self.hier_num:
314
+ k=j
315
+ else:
316
+ k=j-1
317
+ self.block_up.append(nn.Sequential(
318
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
319
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
320
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
321
+ ))
322
+
323
+ # Bottleneck
324
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
325
+ self.b_mid = nn.Sequential(
326
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
327
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
328
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
329
+ )
330
+
331
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
332
+
333
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
334
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
335
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
336
+ zip(sample_coords, max_sz)], 1)
337
+
338
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
339
+ ref = self.ref_grid if ref is None else ref
340
+ img_sz = self.max_sz if img_sz is None else img_sz
341
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
342
+
343
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
344
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
345
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
346
+ align_corners=True)
347
+
348
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
349
+ self.device = x.device
350
+ img_sz = x.size()[2:]
351
+ n = x.size()[0]
352
+ self.max_sz = [img_sz[0]] * self.dimension
353
+ ts_emb_shape=[n,-1]+[1]*self.dimension
354
+
355
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
356
+ if list(img_sz) != self.img_res:
357
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
358
+ # print(img_sz, self.img_res)
359
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
360
+ [1, self.dimension]+list(img_sz))
361
+ self.ref_grid = self.ref_grid.to(self.device)
362
+
363
+ img = x
364
+ if self.conditional_input:
365
+ tgt = y
366
+ # encode the conditional input
367
+ tgt_down_list = []
368
+ for i in range(self.hier_num):
369
+ # out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
370
+ if self.conditional_input:
371
+ tgt = self.block_down_cond[i](tgt)
372
+ tgt_down_list.append(self.copy(tgt))
373
+ tgt = self.down_layers[i](tgt)
374
+ tgt_mid = self.copy(tgt)
375
+ tgt_shape = tgt_mid.shape
376
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
377
+ tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
378
+
379
+ t = [t0.to(self.device) for t0 in t]
380
+ t = [t0 for _ in range(rec_num) for t0 in t]
381
+ for rec_id,time in enumerate(t):
382
+ t_emb = self.time_embed(time)
383
+
384
+ # for rec_id in range(rec_num):
385
+ # if self.conditional_input:
386
+ # tgt = y
387
+ enc_list = []
388
+ out = img
389
+ for i in range(self.hier_num):
390
+ out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
391
+ if self.conditional_input:
392
+ # tgt = self.block_down_cond[i](tgt)
393
+ out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
394
+ # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
395
+ enc_list.append(out)
396
+ out = self.down_layers[i](out)
397
+ # if self.conditional_input:
398
+ # tgt = self.down_layers[i](tgt)
399
+
400
+
401
+ out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
402
+ if self.conditional_input:
403
+ # out += self.attn_layer(out, tgt, tgt)[0]
404
+ out_shape = out.shape
405
+ # tgt_shape = tgt.shape
406
+ # # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
407
+ # tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
408
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
409
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
410
+ out = out + out_attn
411
+
412
+ if self.conditional_input:
413
+ if text is None:
414
+ text = self.text
415
+ text = text.to(self.device)
416
+ out_txt = self.img2txt(out) + text
417
+ out_txt = self.txt_proc(out_txt)
418
+ out_txt = self.txt2img(out_txt)
419
+ out = out + out_txt
420
+
421
+ for i in range(self.hier_num):
422
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
423
+ out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
424
+
425
+ out = self.conv_out(out)/128
426
+
427
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
428
+ if rec_id == 0:
429
+ ddf = ddf_one
430
+ else:
431
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
432
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
433
+
434
+ return ddf
435
+
436
+ def _make_te(self, dim_in, dim_out):
437
+ return nn.Sequential(
438
+ nn.Linear(dim_in, dim_out),
439
+ nn.ReLU(),
440
+ nn.Linear(dim_out, dim_out)
441
+ )
442
+
443
+ class RecMutAttnNet1(nn.Module):
444
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
445
+ super(RecMutAttnNet1, self).__init__()
446
+
447
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
448
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
449
+ self.conditional_input = conditional_input
450
+ self.num_heads = num_heads
451
+ self.text_feat_chn = text_feat_chn
452
+
453
+ self.dimension = ndims
454
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
455
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
456
+
457
+ # Sinusoidal embedding
458
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
459
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
460
+ self.time_embed.requires_grad_(False)
461
+ self.hier_num = len(self.feat_channels) - 1
462
+ self.down_layers = nn.ModuleList()
463
+ self.up_layers = nn.ModuleList()
464
+ self.ted_layers = nn.ModuleList()
465
+ self.teu_layers = nn.ModuleList()
466
+ self.block_down = nn.ModuleList()
467
+ if self.conditional_input:
468
+ self.block_down_cond = nn.ModuleList()
469
+ self.fuse_conv0 = nn.ModuleList()
470
+ self.fuse_conv1 = nn.ModuleList()
471
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
472
+
473
+ self.block_up = nn.ModuleList()
474
+
475
+ for i in range(1, self.hier_num + 1):
476
+ j=-i
477
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
478
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
479
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
480
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
481
+ self.block_down.append(nn.Sequential(
482
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
483
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
484
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
485
+ ))
486
+ if self.conditional_input:
487
+ self.block_down_cond.append(nn.Sequential(
488
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
489
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
490
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
491
+ ))
492
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
493
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
494
+ if i==self.hier_num:
495
+ k=j
496
+ else:
497
+ k=j-1
498
+ self.block_up.append(nn.Sequential(
499
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
500
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
501
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
502
+ ))
503
+
504
+ # Bottleneck
505
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
506
+ self.b_mid = nn.Sequential(
507
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
508
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
509
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
510
+ )
511
+
512
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
513
+
514
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
515
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
516
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
517
+ zip(sample_coords, max_sz)], 1)
518
+
519
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
520
+ ref = self.ref_grid if ref is None else ref
521
+ img_sz = self.max_sz if img_sz is None else img_sz
522
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
523
+
524
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
525
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
526
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
527
+ align_corners=True)
528
+
529
+ def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
530
+ self.device = x.device
531
+ img_sz = x.size()[2:]
532
+ n = x.size()[0]
533
+ self.max_sz = [img_sz[0]] * self.dimension
534
+ ts_emb_shape=[n,-1]+[1]*self.dimension
535
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
536
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
537
+ [1, self.dimension]+list(img_sz)).to(self.device)
538
+ img = x
539
+ t = self.time_embed(t)
540
+
541
+ for rec_id in range(rec_num):
542
+ if self.conditional_input:
543
+ tgt = y
544
+ enc_list = []
545
+ out = img
546
+ for i in range(self.hier_num):
547
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
548
+ if self.conditional_input:
549
+ tgt = self.block_down_cond[i](tgt)
550
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
551
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
552
+ enc_list.append(out)
553
+ out = self.down_layers[i](out)
554
+ if self.conditional_input:
555
+ tgt = self.down_layers[i](tgt)
556
+
557
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
558
+ if self.conditional_input:
559
+ # out += self.attn_layer(out, tgt, tgt)[0]
560
+ out_shape = out.shape
561
+ tgt_shape = tgt.shape
562
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
563
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
564
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
565
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
566
+ out = out + out_attn
567
+
568
+ for i in range(self.hier_num):
569
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
570
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
571
+
572
+ out = self.conv_out(out)/128
573
+
574
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
575
+ if rec_id == 0:
576
+ ddf = ddf_one
577
+ else:
578
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
579
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
580
+
581
+ return ddf
582
+
583
+ def _make_te(self, dim_in, dim_out):
584
+ return nn.Sequential(
585
+ nn.Linear(dim_in, dim_out),
586
+ nn.ReLU(),
587
+ nn.Linear(dim_out, dim_out)
588
+ )
589
+
590
+ class RecMutAttnNet(nn.Module):
591
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
592
+ super(RecMutAttnNet, self).__init__()
593
+
594
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
595
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
596
+ self.conditional_input = conditional_input
597
+ self.num_heads = num_heads
598
+ self.text_feat_chn = text_feat_chn
599
+
600
+ self.dimension = ndims
601
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
602
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
603
+
604
+ # Sinusoidal embedding
605
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
606
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
607
+ self.time_embed.requires_grad_(False)
608
+ self.hier_num = len(self.feat_channels) - 1
609
+ self.down_layers = nn.ModuleList()
610
+ self.up_layers = nn.ModuleList()
611
+ self.ted_layers = nn.ModuleList()
612
+ self.teu_layers = nn.ModuleList()
613
+ self.block_down = nn.ModuleList()
614
+ self.block_up = nn.ModuleList()
615
+ if self.conditional_input:
616
+ self.block_down_cond = nn.ModuleList()
617
+ self.fuse_conv0 = nn.ModuleList()
618
+ self.fuse_conv1 = nn.ModuleList()
619
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
620
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
621
+ self.global_maxpool = Global_Maxpool(1)
622
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
623
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
624
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
625
+ self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
626
+ self.img_res = [res]*self.dimension
627
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
628
+ [1, self.dimension]+list(self.img_res))
629
+
630
+ for i in range(1, self.hier_num + 1):
631
+ j=-i
632
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
633
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
634
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
635
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
636
+ self.block_down.append(nn.Sequential(
637
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
638
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
639
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
640
+ ))
641
+ if self.conditional_input:
642
+ self.block_down_cond.append(nn.Sequential(
643
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
644
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
645
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
646
+ ))
647
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
648
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
649
+ if i==self.hier_num:
650
+ k=j
651
+ else:
652
+ k=j-1
653
+ self.block_up.append(nn.Sequential(
654
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
655
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
656
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
657
+ ))
658
+
659
+ # Bottleneck
660
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
661
+ self.b_mid = nn.Sequential(
662
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
663
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
664
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
665
+ )
666
+
667
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
668
+
669
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
670
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
671
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
672
+ zip(sample_coords, max_sz)], 1)
673
+
674
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
675
+ ref = self.ref_grid if ref is None else ref
676
+ img_sz = self.max_sz if img_sz is None else img_sz
677
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
678
+
679
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
680
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
681
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
682
+ align_corners=True)
683
+
684
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
685
+ self.device = x.device
686
+ img_sz = x.size()[2:]
687
+ n = x.size()[0]
688
+ self.max_sz = [img_sz[0]] * self.dimension
689
+ ts_emb_shape=[n,-1]+[1]*self.dimension
690
+
691
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
692
+ if list(img_sz) != self.img_res:
693
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
694
+ # print(img_sz, self.img_res)
695
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
696
+ [1, self.dimension]+list(img_sz))
697
+ self.ref_grid = self.ref_grid.to(self.device)
698
+
699
+ img = x
700
+ t = self.time_embed(t)
701
+
702
+ for rec_id in range(rec_num):
703
+ if self.conditional_input:
704
+ tgt = y
705
+ enc_list = []
706
+ out = img
707
+ for i in range(self.hier_num):
708
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
709
+ if self.conditional_input:
710
+ tgt = self.block_down_cond[i](tgt)
711
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
712
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
713
+ enc_list.append(out)
714
+ out = self.down_layers[i](out)
715
+ if self.conditional_input:
716
+ tgt = self.down_layers[i](tgt)
717
+
718
+
719
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
720
+ if self.conditional_input:
721
+ # out += self.attn_layer(out, tgt, tgt)[0]
722
+ out_shape = out.shape
723
+ tgt_shape = tgt.shape
724
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
725
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
726
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
727
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
728
+ out = out + out_attn
729
+
730
+ if self.conditional_input:
731
+ if text is None:
732
+ text = self.text
733
+ text = text.to(self.device)
734
+ text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
735
+ out_txt = self.img2txt(out) + text
736
+ out_txt = self.txt_proc(out_txt)
737
+ out_txt = self.txt2img(out_txt)
738
+ out = out + out_txt
739
+
740
+ for i in range(self.hier_num):
741
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
742
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
743
+
744
+ out = self.conv_out(out)/128
745
+
746
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
747
+ if rec_id == 0:
748
+ ddf = ddf_one
749
+ else:
750
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
751
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
752
+
753
+ return ddf
754
+
755
+ def _make_te(self, dim_in, dim_out):
756
+ return nn.Sequential(
757
+ nn.Linear(dim_in, dim_out),
758
+ nn.ReLU(),
759
+ nn.Linear(dim_out, dim_out)
760
+ )
761
+
762
+ class RecMutAttnNet_contrastive(nn.Module):
763
+ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
764
+ super(RecMutAttnNet_contrastive, self).__init__()
765
+
766
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
767
+ self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
768
+ self.conditional_input = conditional_input
769
+ self.num_heads = num_heads
770
+ self.text_feat_chn = text_feat_chn
771
+
772
+ self.dimension = ndims
773
+ self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
774
+ self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
775
+
776
+ # Sinusoidal embedding
777
+ self.time_embed = nn.Embedding(n_steps, time_emb_dim)
778
+ self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
779
+ self.time_embed.requires_grad_(False)
780
+ self.hier_num = len(self.feat_channels) - 1
781
+ self.down_layers = nn.ModuleList()
782
+ self.up_layers = nn.ModuleList()
783
+ self.ted_layers = nn.ModuleList()
784
+ self.teu_layers = nn.ModuleList()
785
+ self.block_down = nn.ModuleList()
786
+ self.block_up = nn.ModuleList()
787
+ if self.conditional_input:
788
+ self.block_down_cond = nn.ModuleList()
789
+ self.fuse_conv0 = nn.ModuleList()
790
+ self.fuse_conv1 = nn.ModuleList()
791
+ self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
792
+ Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
793
+ self.global_maxpool = Global_Maxpool(1)
794
+ self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
795
+ self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
796
+ self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
797
+ self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
798
+ self.img_res = [res]*self.dimension
799
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
800
+ [1, self.dimension]+list(self.img_res))
801
+
802
+ for i in range(1, self.hier_num + 1):
803
+ j=-i
804
+ self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
805
+ self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
806
+ self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
807
+ self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
808
+ self.block_down.append(nn.Sequential(
809
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
810
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
811
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
812
+ ))
813
+ if self.conditional_input:
814
+ self.block_down_cond.append(nn.Sequential(
815
+ AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
816
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
817
+ AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
818
+ ))
819
+ self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
820
+ self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
821
+ if i==self.hier_num:
822
+ k=j
823
+ else:
824
+ k=j-1
825
+ self.block_up.append(nn.Sequential(
826
+ AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
827
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
828
+ AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
829
+ ))
830
+
831
+ # Bottleneck
832
+ self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
833
+ self.b_mid = nn.Sequential(
834
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
835
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
836
+ AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
837
+ )
838
+
839
+ self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
840
+
841
+ def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
842
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
843
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
844
+ zip(sample_coords, max_sz)], 1)
845
+
846
+ def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
847
+ ref = self.ref_grid if ref is None else ref
848
+ img_sz = self.max_sz if img_sz is None else img_sz
849
+ resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
850
+
851
+ return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
852
+ np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
853
+ [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
854
+ align_corners=True)
855
+
856
+ def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
857
+ self.device = x.device
858
+ img_sz = x.size()[2:]
859
+ n = x.size()[0]
860
+ self.max_sz = [img_sz[0]] * self.dimension
861
+ ts_emb_shape=[n,-1]+[1]*self.dimension
862
+
863
+ self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
864
+ if list(img_sz) != self.img_res:
865
+ # print ("Reinitialize the ref_grid to match the model's input image size.")
866
+ # print(img_sz, self.img_res)
867
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
868
+ [1, self.dimension]+list(img_sz))
869
+ self.ref_grid = self.ref_grid.to(self.device)
870
+
871
+ img = x
872
+ t = self.time_embed(t)
873
+
874
+ for rec_id in range(rec_num):
875
+ if self.conditional_input:
876
+ tgt = y
877
+ enc_list = []
878
+ out = img
879
+ for i in range(self.hier_num):
880
+ out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
881
+ if self.conditional_input:
882
+ tgt = self.block_down_cond[i](tgt)
883
+ out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
884
+ tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
885
+ enc_list.append(out)
886
+ out = self.down_layers[i](out)
887
+ if self.conditional_input:
888
+ tgt = self.down_layers[i](tgt)
889
+
890
+
891
+ out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
892
+ if self.conditional_input:
893
+ # out += self.attn_layer(out, tgt, tgt)[0]
894
+ out_shape = out.shape
895
+ tgt_shape = tgt.shape
896
+ # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
897
+ tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
898
+ out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
899
+ out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
900
+ out = out + out_attn
901
+
902
+ if self.conditional_input:
903
+ if text is None:
904
+ text = self.text
905
+ text = text.to(self.device)
906
+ text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
907
+ img_embd = self.global_maxpool(self.img2txt(out)).view(n, -1) # [B, 1024]
908
+ out_txt = self.img2txt(out) + text
909
+ out_txt = self.txt_proc(out_txt)
910
+ out_txt = self.txt2img(out_txt)
911
+ out = out + out_txt
912
+
913
+ for i in range(self.hier_num):
914
+ out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
915
+ out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
916
+
917
+ out = self.conv_out(out)/128
918
+
919
+ ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
920
+ if rec_id == 0:
921
+ ddf = ddf_one
922
+ else:
923
+ ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
924
+ img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
925
+
926
+ return ddf, img_embd
927
+
928
+ def _make_te(self, dim_in, dim_out):
929
+ return nn.Sequential(
930
+ nn.Linear(dim_in, dim_out),
931
+ nn.ReLU(),
932
+ nn.Linear(dim_out, dim_out)
933
+ )
934
+ # class RecMutAttnNet(nn.Module):
935
+ # def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
936
+ # super(RecMutAttnNet, self).__init__()
937
+
938
+ # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
939
+ # self.conditional_input = conditional_input
940
+
941
+ # self.dimension = ndims
942
+ # self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
943
+ # self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
944
+
945
+ # # Sinusoidal embedding
946
+ # self.time_embed = nn.Embedding(n_steps, time_emb_dim)
947
+ # self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
948
+ # self.time_embed.requires_grad_(False)
949
+ # self.hier_num = len(self.feat_channels) - 1
950
+ # self.down_layers = nn.ModuleList()
951
+ # self.up_layers = nn.ModuleList()
952
+ # self.ted_layers = nn.ModuleList()
953
+ # self.teu_layers = nn.ModuleList()
954
+ # self.block_down = nn.ModuleList()
955
+ # if self.conditional_input:
956
+ # self.block_down_cond = nn.ModuleList()
957
+ # self.fuse_conv0 = nn.ModuleList()
958
+ # self.fuse_conv1 = nn.ModuleList()
959
+ # self.block_up = nn.ModuleList()
960
+
961
+ # for i in range(1, self.hier_num + 1):
962
+ # j=-i
963
+ # self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
964
+ # self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
965
+ # self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
966
+ # self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
967
+ # self.block_down.append(nn.Sequential(
968
+ # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
969
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
970
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
971
+ # ))
972
+ # if self.conditional_input:
973
+ # self.block_down_cond.append(nn.Sequential(
974
+ # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
975
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
976
+ # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
977
+ # ))
978
+ # self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
979
+ # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
980
+ # if i==self.hier_num:
981
+ # k=j
982
+ # else:
983
+ # k=j-1
984
+ # self.block_up.append(nn.Sequential(
985
+ # AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
986
+ # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
987
+ # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
988
+ # ))
989
+
990
+ # # Bottleneck
991
+ # self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
992
+ # self.b_mid = nn.Sequential(
993
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
994
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
995
+ # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
996
+ # )
997
+
998
+ # self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
999
+
1000
+ # def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
1001
+ # sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1002
+ # return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
1003
+ # zip(sample_coords, max_sz)], 1)
1004
+
1005
+ # def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
1006
+ # ref = self.ref_grid if ref is None else ref
1007
+ # img_sz = self.max_sz if img_sz is None else img_sz
1008
+ # resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
1009
+
1010
+ # return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
1011
+ # np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
1012
+ # [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
1013
+ # align_corners=True)
1014
+
1015
+ # def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
1016
+ # self.device = x.device
1017
+ # img_sz = x.size()[2:]
1018
+ # n = x.size()[0]
1019
+ # self.max_sz = [img_sz[0]] * self.dimension
1020
+ # ts_emb_shape=[n,-1]+[1]*self.dimension
1021
+ # self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
1022
+ # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
1023
+ # [1, self.dimension]+list(img_sz)).to(self.device)
1024
+ # img = x
1025
+ # t = self.time_embed(t)
1026
+
1027
+ # for rec_id in range(rec_num):
1028
+ # if self.conditional_input:
1029
+ # tgt = y
1030
+ # enc_list = []
1031
+ # out = img
1032
+ # for i in range(self.hier_num):
1033
+ # out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
1034
+ # if self.conditional_input:
1035
+ # tgt = self.block_down_cond[i](tgt)
1036
+ # out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
1037
+ # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
1038
+ # enc_list.append(out)
1039
+ # out = self.down_layers[i](out)
1040
+ # if self.conditional_input:
1041
+ # tgt = self.down_layers[i](tgt)
1042
+
1043
+ # out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
1044
+ # if self.conditional_input:
1045
+ # out = out + tgt
1046
+
1047
+ # for i in range(self.hier_num):
1048
+ # out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
1049
+ # out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
1050
+
1051
+ # out = self.conv_out(out)/128
1052
+
1053
+ # ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
1054
+ # if rec_id == 0:
1055
+ # ddf = ddf_one
1056
+ # else:
1057
+ # ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
1058
+ # img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
1059
+
1060
+ # return ddf
1061
+
1062
+ # def _make_te(self, dim_in, dim_out):
1063
+ # return nn.Sequential(
1064
+ # nn.Linear(dim_in, dim_out),
1065
+ # nn.ReLU(),
1066
+ # nn.Linear(dim_out, dim_out)
1067
+ # )
1068
+ # ==============================================
1069
+ # Layers
1070
+ # ==============================================
1071
+
1072
+
1073
+ def ddf_multiplier(dvf,mul_num=10,stn=None):
1074
+ ddf=dvf
1075
+ for i in range(mul_num):
1076
+ ddf = dvf + stn(ddf, dvf)
1077
+ return ddf
1078
+
1079
+
1080
+ def composite(ddfs,stn=None):
1081
+ if stn is None:
1082
+ stn = STN(device=ddfs[0].device,padding_mode="border")
1083
+ comp_ddf=ddfs[0]
1084
+ for i in range(1,len(ddfs)):
1085
+ comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
1086
+ return comp_ddf
1087
+
1088
+ class STN(nn.Module):
1089
+ def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
1090
+ super(STN, self).__init__()
1091
+ self.ndims=ndims
1092
+ self.img_sz=[img_sz]*ndims
1093
+ # self.img_sz=img_sz
1094
+ self.device = device
1095
+ self.padding_mode = padding_mode
1096
+ # max_sz=[128]*self.ndims
1097
+ max_sz=[img_sz]*self.ndims
1098
+ # max_sz=img_sz
1099
+ # max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz)
1100
+ # self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device)
1101
+ self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
1102
+ self.resample_mode=resample_mode
1103
+ if self.img_sz is not None:
1104
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
1105
+ [1, self.ndims] + self.img_sz).to(self.device)
1106
+ return
1107
+ def max_limit(self, sample_coords0, plus=0., minus=1.):
1108
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1109
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
1110
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
1111
+ zip(sample_coords, self.max_sz)], 1)
1112
+
1113
+ def boundary_limit(self, sample_coords0, plus=0., minus=1.):
1114
+
1115
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
1116
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
1117
+ return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
1118
+ zip(sample_coords, self.max_sz, self.ref_grid)], 1)
1119
+
1120
+ def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
1121
+ # print(vol.device, ddf.device)
1122
+ # print(self.device)
1123
+ # print('===================')
1124
+ device = ddf.device
1125
+
1126
+ ref = self.ref_grid if ref is None else ref
1127
+ if img_sz is None:
1128
+ img_sz = self.max_sz
1129
+ else:
1130
+ img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
1131
+ # resample_mode = 'bicubic'
1132
+ if self.resample_mode is None:
1133
+ resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear'
1134
+ else:
1135
+ resample_mode=self.resample_mode
1136
+ # padding_mode = "border"
1137
+ # print(ddf.shape, ref.shape)
1138
+ return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
1139
+ [0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
1140
+ padding_mode=padding_mode,
1141
+ align_corners=True)
1142
+
1143
+ def forward(self,x,ddf):
1144
+ self.device = x.device if self.device is None else self.device
1145
+ if self.img_sz is None:
1146
+ self.img_sz = list(x.size()[2:]).to(self.device)
1147
+ self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
1148
+ resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
1149
+ return resampled_x
1150
+
1151
+ if __name__ == '__main__':
1152
+ ndims = 3
1153
+ res = 128
1154
+ x = torch.rand([1, 1] + [res]*ndims)
1155
+ t = torch.randint(0, 1000, (1,))
1156
+ text = torch.rand([1, 1024] + [1]*ndims)
1157
+ model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
1158
+ y = model(x, x, t, text=text)
1159
+ print("Ouput shape", y.shape)
1160
+
1161
+ # Total parameters
1162
+ total_params = sum(p.numel() for p in model.parameters())
1163
+ # Trainable parameters only
1164
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
1165
+
1166
+ print(f"Total parameters: {total_params}")
1167
+ print(f"Trainable parameters: {trainable_params}")
Diffusion/utils_diff.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch import nn, optim
5
+ from torch.autograd.variable import Variable
6
+ from torchvision import transforms, datasets
7
+ from torchvision.utils import save_image
8
+ import torch.nn.functional as F
9
+ import scipy.ndimage as spimg
10
+ import pyquaternion as quater
11
+ import random
12
+ import numpy as np
13
+ import math
14
+ from typing import Optional, Tuple, List
15
+ # from data_loader.acdc_dataloader import acdc_gan
16
+
17
+ # from Adaptive_Motion_Generator.Dataloader.Archive.acdc_dataloader import *
18
+
19
+ def get_barcode(index=[],header=['Patient','Slice','AugImg','NoiseStep'],digit=[4,6,4,4],split='_'):
20
+ # Patient0001_Slice0001_NosieImg0001_NoiseStep0070
21
+ barcode_str=''
22
+ header=header.copy()
23
+ digit=digit.copy()
24
+ if len(index)<3:
25
+ header[2] = 'ORG'
26
+ header[3] = 'NA'
27
+ digit[2] = 0
28
+ digit[3] = 0
29
+ index +=['','']
30
+
31
+ for id, h in enumerate(header):
32
+ barcode_str+=h+str(index[id]).zfill(digit[id])+split
33
+ return barcode_str[:-1]
34
+
35
+ class RandomResizedCrop3D(nn.Module):
36
+ """Crop a random portion of a 3D volume and resize it to a given size.
37
+
38
+ Args:
39
+ size (tuple of int): Expected output size of the crop, for each dimension (D, H, W).
40
+ scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
41
+ before resizing. The scale is defined with respect to the volume of the original image.
42
+ ratio (tuple of float): Lower and upper bounds for the random aspect ratio of the crop, before resizing.
43
+ interpolation (str): Desired interpolation mode ('trilinear' or 'nearest').
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ size: Tuple[int, int, int],
49
+ scale=(0.6, 1.0),
50
+ ratio=(0.5, 1.5),
51
+ interpolation='trilinear'
52
+ ):
53
+ super().__init__()
54
+ self.size = size
55
+ self.scale = scale
56
+ self.ratio = ratio
57
+ self.interpolation = interpolation
58
+
59
+ @staticmethod
60
+ def get_params(img: torch.Tensor, rand_scale: float, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int, int, int]:
61
+ """Get parameters for `crop` for a random sized crop.
62
+
63
+ Args:
64
+ img (Tensor): Input image.
65
+ scale (list): Range of scale of the origin size cropped.
66
+ ratio (list): Range of aspect ratio of the origin aspect ratio cropped.
67
+
68
+ Returns:
69
+ tuple: params (i, j, k, d, h, w) to be passed to `crop` for a random sized crop.
70
+ """
71
+ img_sz = np.array(list(img.size())[2:])
72
+ crop_sz = (img_sz * rand_scale).astype(np.int32) #[int(s*rand_scale) for s in img_sz]
73
+ start_id = np.random.randint(0, img_sz - crop_sz + 1, size=(img_sz.size,))
74
+ return start_id.tolist()+crop_sz.tolist()
75
+
76
+ # volume = depth * height * width
77
+ #
78
+ # log_ratio = torch.log(torch.tensor(ratio))
79
+ # for _ in range(10):
80
+ # target_volume = volume * torch.empty(1).uniform_(*scale).item()
81
+ # aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
82
+ #
83
+ # w = int(round(math.sqrt(target_volume * aspect_ratio)))
84
+ # h = int(round(math.sqrt(target_volume / aspect_ratio)))
85
+ # d = int(round(math.sqrt(target_volume / (w * h))))
86
+ #
87
+ # if 0 < w <= width and 0 < h <= height and 0 < d <= depth:
88
+ # i = torch.randint(0, depth - d + 1, size=(1,)).item()
89
+ # j = torch.randint(0, height - h + 1, size=(1,)).item()
90
+ # k = torch.randint(0, width - w + 1, size=(1,)).item()
91
+ # return i, j, k, d, h, w
92
+ #
93
+ # # Fallback to central crop
94
+ # return (depth - d) // 2, (height - h) // 2, (width - w) // 2, d, h, w
95
+
96
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
97
+ """Apply the RandomResizedCrop transformation.
98
+
99
+ Args:
100
+ img (Tensor): Input 3D image.
101
+
102
+ Returns:
103
+ Tensor: Cropped and resized image.
104
+ """
105
+ rand_scale = np.random.uniform(self.scale[0], self.scale[1])
106
+ [i, j, k, d, h, w] = self.get_params(img,rand_scale, self.scale, self.ratio)
107
+ # print(i, j, k, d, h, w)
108
+ img_cropped = img[:, :, i:i + d, j:j + h, k:k + w]
109
+ # print(img_cropped.shape)
110
+ img_resized = F.interpolate(img_cropped, size=self.size, mode=self.interpolation,
111
+ align_corners=False if self.interpolation == 'trilinear' else None)
112
+ return img_resized#.squeeze(0)
113
+
114
+ def __repr__(self) -> str:
115
+ return f"{self.__class__.__name__}(size={self.size}, scale={self.scale}, ratio={self.ratio}, interpolation={self.interpolation})"
116
+
117
+ def random_permute(X, select_dims=[-1,-2],include_flip=True):
118
+ axes=list(range(X[0].ndim))
119
+ selected_axes = [axes[i] for i in select_dims]
120
+ random.shuffle(selected_axes)
121
+ for i, dim in enumerate(select_dims):
122
+ axes[dim] = selected_axes[i]
123
+ if include_flip and random.choice([True,False]):
124
+ # X = [np.flip(x, axis=dim) for x in X]
125
+ X = [torch.flip(x, [dim]) for x in X]
126
+ # return [np.transpose(x,axes) for x in X]
127
+ return [x.permute(axes) for x in X]
128
+
129
+ # def thresh_img(img,thresh = None,EPS = 10**-7):
130
+ # threshold0 = np.random.uniform(thresh[0], thresh[1])
131
+ # threshold1 = np.random.uniform(thresh[0], thresh[1])
132
+ # scale =
133
+ # if threshold is not None:
134
+ # # img=img-threshold
135
+ # # img=np.where(img>=0,img,0)
136
+ # # img = np.maximum(img-threshold,0)
137
+ # img = torch.maximum(img - threshold,torch.tensor(0.))
138
+ # # return (img - img.min()) / (img.max() - img.min() + EPS)
139
+ # return img
140
+
141
+ def get_transformer(degrees=180,translate=0.125,ndims=2,prob=0.8,fill=0.,img_sz=None):
142
+ prob_crop=0. if img_sz==None else 0.8
143
+ # prob_crop=0. if len(img_sz)==2 else 0.8
144
+
145
+ if img_sz==None or len(img_sz)==2:
146
+ return torchvision.transforms.Compose([
147
+ torchvision.transforms.RandomApply([
148
+ torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims, fill=fill,
149
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
150
+ ],prob),
151
+ # torchvision.transforms.RandomApply([
152
+ # torchvision.transforms.RandomResizedCrop(size=img_sz),
153
+ # ], prob_crop),
154
+ torchvision.transforms.RandomVerticalFlip(p=0.5),
155
+ torchvision.transforms.RandomAutocontrast(p=0.5),
156
+ ])
157
+ else:
158
+ return torchvision.transforms.Compose([
159
+ torchvision.transforms.RandomApply([
160
+ torchvision.transforms.RandomResizedCrop(size=img_sz) if len(img_sz) == 2 else RandomResizedCrop3D(
161
+ size=img_sz),
162
+ ], prob_crop),
163
+ ])
164
+
165
+
166
+ def get_random_affine_transformer(degrees=180,translate=0.125,ndims=2):
167
+ return torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
168
+
169
+ def channel_merge_acdc(img):
170
+ # input: a torch tensor (C,H,W)
171
+ ch = img.shape[0]
172
+ output = np.zeros((img.shape[1], img.shape[2]))
173
+ # output[img[2,:,:] == 1] = 1
174
+ for i in range(ch):
175
+ output= output + img[i]
176
+ return output
177
+
178
+ def img_crop(img, crop_rate=2, img_sz=[256,256]):
179
+ ndims=len(img_sz)
180
+ crop = [np.random.randint(0.*imgs, 1. * imgs)//crop_rate for imgs in img_sz]
181
+ crop = [crop, [1 * imgs//crop_rate - c for imgs, c in zip(img_sz, crop)]]
182
+ if ndims==2:
183
+ return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]: img_sz[1] - crop[1][1]]
184
+ else:
185
+ return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]:img_sz[1] - crop[1][1], crop[0][2]: img_sz[2] - crop[1][2]]
186
+
187
+
188
+ def boundary_limit(sample_coords0, max_sz, plus=0., minus=1.):
189
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
190
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
191
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) for x, sz in
192
+ zip(sample_coords, max_sz)], 1)
193
+
194
+
195
+ def resample(vol, ddf, ref=None, img_sz=None,max_sz=[128,128],ndims=2):
196
+ device = vol.device
197
+ img_sz = vol.size()[2:]
198
+ ndims=len(img_sz)
199
+ if ndims==2:
200
+ [h,w]=img_sz
201
+ img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=device), [1, 1, 1, ndims])
202
+ ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0), [1, ndims,h, w ])
203
+ elif ndims==3:
204
+ [h, w, d] = img_sz
205
+ img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2., (d-1)/2], device=device), [1, 1, 1, 1, ndims])
206
+ ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w), torch.arange(end=d)]), 0), [1, ndims,h, w, d])
207
+ # ref_grid.to(device)
208
+ # img_shape.to(device)
209
+ # ddf.to(device)
210
+ # ref = self.ref_grid if ref is None else ref
211
+ # img_sz = self.img_sz if img_sz is None else img_sz
212
+ resample_mode = 'bilinear'
213
+ # padding_mode = "border"
214
+ padding_mode = "zeros"
215
+
216
+ # img_sz = np.reshape(img_sz, [1] *(ndims+1)+[ndims])
217
+ # if ndims==2:
218
+ if True:
219
+ re=[0]+list(range(2,ndims+2))+[1]
220
+ # re=list(range(ndims+2))
221
+ # print((torch.flip((ddf.to(device) + ref_grid.permute(re))/ img_shape - 1, dims=[-1])).tolist())
222
+ return F.grid_sample(vol, torch.flip((ddf + ref_grid.permute(re).to(device))/ img_shape - 1, dims=[-1]).type(torch.float32).to(device), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
223
+ #
224
+ # return F.grid_sample(vol, torch.flip(
225
+ # torch.permute(ddf * torch.Tensor(np.reshape(np.array(max_sz), [1, 1, 1, ndims])) + ref_grid,
226
+ # [0, 2, 3, 1]) / img_shape - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
227
+ # align_corners=True)
228
+
229
+ def random_resample(vol,deform_scale=32.):
230
+ vol_size=vol.size()
231
+ device=vol.device
232
+ ndims = len(vol_size)-2
233
+ img_size=[s for s in vol_size[2:]]
234
+ if ndims==2:
235
+ img_size=img_size+[16]
236
+ # ddf,_,_=random_ddf(vol_size[0],img_size)
237
+ _,_,ddf=random_ddf(vol_size[0],img_size,ndims=ndims,range_gauss=deform_scale)
238
+ ddf=Variable(torch.tensor(ddf,dtype=torch.float32)).to(device)
239
+ if ndims==2:
240
+ return resample(vol,ddf[...,8,:ndims])
241
+ else:
242
+ return resample(vol, ddf[..., :ndims])
243
+
244
+ def get_random_deformed_mask(msk_shape, deform_scale=32.,apply_possibility=0.75):
245
+ msk = torch.ones([1, 1]+list(msk_shape),dtype=torch.float32)
246
+ if random.uniform(0,1) < apply_possibility:
247
+ return random_resample(msk, deform_scale=deform_scale)
248
+ else:
249
+ return msk
250
+
251
+ # grid option
252
+ def get_tranf_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],transl=[[0,0,0]]):
253
+ return np.concatenate([get_rot_mat(grid_size, vec=vec, ang=ang),transl],-1)
254
+
255
+
256
+ def get_rot_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],ndims=3):
257
+ vec = np.array(vec)
258
+ ang = np.array(ang)
259
+ batch_num = ang.shape[0]
260
+ return np.reshape(vecang2rotmats(vec, ang), [batch_num] + [ndims*(ndims)])
261
+
262
+ def random_mat(batch_sz, img_sz, num_class=2,pn_spline=20, pn_gauss=10, range_spline=2., range_gauss=48, spread_range=[5., 24.],
263
+ transl_range=32., rot_range=np.pi / 2):
264
+ scale=4
265
+ ndims=3
266
+ vec=np.reshape(np.random.uniform(-1., 1., [batch_sz,1, ndims])+np.random.uniform(-.1, .1, [batch_sz,num_class, ndims]),[batch_sz*num_class, ndims])
267
+ ang=np.reshape(np.random.uniform(-rot_range, rot_range, [batch_sz,1])+np.random.uniform(-rot_range/scale, rot_range/scale, [batch_sz,num_class]),[batch_sz*num_class])
268
+ transl=np.reshape(np.random.uniform(-transl_range, transl_range, [batch_sz,1,ndims])+np.random.uniform(-transl_range/scale, transl_range/scale, [batch_sz,num_class,ndims]),[batch_sz*num_class,ndims])
269
+ return np.reshape(np.concatenate([get_rot_mat(img_sz, vec=vec, ang=ang),transl],-1),[batch_sz,num_class,4,3])
270
+
271
+ # return np.reshape(get_tranf_mat(img_sz, vec=np.random.uniform(-1., 1., [batch_sz*num_class, 3]), ang=np.random.uniform(-rot_range, rot_range, [batch_sz*num_class]),transl=np.random.uniform(-transl_range, transl_range, [batch_sz*num_class,3])),[batch_sz,num_class,4,3])
272
+
273
+ def random_ddf(batch_sz, img_sz, pn_spline=20, pn_gauss=10, range_spline=1., range_gauss=16., spread_range=[16., 64.],
274
+ transl_range=0., rot_range=np.pi / 1,ndims=3):
275
+ rand_ang=np.random.uniform(-rot_range, rot_range, [batch_sz])
276
+ # rand_ang = np.random.randint(-4, 4, [batch_sz])*rot_range
277
+
278
+ if ndims==3:
279
+ rot_df = get_rot_ddf(img_sz, vec=np.random.uniform(-1., 1., [batch_sz, 3]),
280
+ ang=rand_ang)
281
+ else:
282
+ rot_df = get_rot_ddf(img_sz, vec=np.concatenate([np.zeros([batch_sz, 2]),np.ones([batch_sz, 1])],-1),
283
+ ang=rand_ang)
284
+ ndims = 3
285
+ # rot_df = +np.random.uniform(-1., 1., [batch_sz, ndims,ndims])
286
+ # ddf0=np.stack([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range)\
287
+ # +generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline)\
288
+ # +np.random.uniform(-transl_range,transl_range,[3]) for i in range(batch_sz)],axis=0)\
289
+ # +rot_df
290
+ if range_gauss>0:
291
+ ddf0 = np.tile([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range) \
292
+ # + generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline) \
293
+ + np.random.uniform(-transl_range, transl_range, [ndims])], [batch_sz, 1, 1, 1, 1]) \
294
+ + rot_df
295
+ else:
296
+ ddf0 = rot_df
297
+
298
+ def boundary_replicate(sample_coords, input_size, padding=5):
299
+ return np.stack(
300
+ [np.maximum(np.minimum(sample_coords[..., i], input_size[i] - 1 + padding), 0 - padding) for i in
301
+ range(len(input_size))], axis=-1), \
302
+ np.prod([((sample_coords[..., i] < input_size[i]) * (sample_coords[..., i] >= 0)) for i in
303
+ range(len(input_size))], axis=0)
304
+
305
+ ref = get_reference_grid(img_sz)
306
+ cf1, ind = boundary_replicate(ddf0 + ref, img_sz)
307
+ return cf1 - ref, np.expand_dims(ind, -1), rot_df
308
+
309
+
310
+ def generate_random_gaussian_ddf(img_sz, pn=30, range_sz=5, spread_std=[0.1, 1.]):
311
+ x = np.floor(np.random.uniform(range_sz / 2., img_sz[0] - range_sz / 2., [1, pn])).astype('int')
312
+ y = np.floor(np.random.uniform(range_sz / 2., img_sz[1] - range_sz / 2., [1, pn])).astype('int')
313
+ z = np.floor(np.random.uniform(range_sz / 2., img_sz[2] - range_sz / 2., [1, pn])).astype('int')
314
+
315
+ odf = np.random.uniform(-range_sz, range_sz, [pn, 3])
316
+ vol = np.zeros([img_sz[0], img_sz[1], img_sz[2], 3])
317
+ vol[x, y, z] = odf
318
+
319
+ return spimg.gaussian_filter(vol, np.random.uniform(spread_std[0], spread_std[1]))
320
+
321
+
322
+ def get_rot_ddf(grid_size, vec=[[0., 0., 1.]], ang=[[0.]]):
323
+ vec = np.array(vec)
324
+ ang = np.array(ang)
325
+ batch_num = ang.shape[0]
326
+ ref_grids = get_reference_grid(grid_size,
327
+ bias_scale=1.)
328
+ # a=vecang2rotmats(vec, ang)
329
+ return np.reshape(np.matmul(np.reshape(np.tile(ref_grids, [batch_num, 1, 1, 1, 1]), [batch_num, -1, 3]),
330
+ vecang2rotmats(vec, ang)), [batch_num] + grid_size + [3]) - ref_grids
331
+
332
+
333
+ def get_reference_grid(grid_size, bias_scale=0.):
334
+ return np.stack(np.meshgrid(
335
+ [i for i in range(grid_size[0])],
336
+ [j for j in range(grid_size[1])],
337
+ [k for k in range(grid_size[2])],
338
+ indexing='ij'), axis=-1).astype('float') - bias_scale * (np.array(grid_size) - 1) / 2.
339
+
340
+
341
+ def resample_linear(inputs, ddf=None, sample_coords=None,random_boundary=True):
342
+ if random_boundary:
343
+ random_factor = np.random.uniform(0., 1.)
344
+ min_val = np.min(inputs)
345
+ inputs[:, 0, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, 0, :, :]
346
+ inputs[:, -1, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, -1, :, :]
347
+ inputs[:, :, 0, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, 0, :]
348
+ inputs[:, :, -1, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, -1, :]
349
+ inputs[:, :, :, 0] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, 0]
350
+ inputs[:, :, :, -1] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, -1]
351
+
352
+ input_size = inputs.shape[1:4]
353
+ sample_coords = get_reference_grid(input_size) + ddf if sample_coords is None else sample_coords
354
+ spatial_rank = 3 # inputs.ndim - 2
355
+ xy = [sample_coords[..., i] for i in
356
+ range(sample_coords.shape[-1])] # tf.unstack(sample_coords, axis=len(sample_coords.shape)-1)
357
+ index_voxel_coords = [np.floor(x) for x in xy]
358
+
359
+ def boundary_replicate(sample_coords0, input_size0, plus=0):
360
+ return np.maximum(np.minimum(sample_coords0, input_size0 - 2 + plus), 0 + plus)
361
+
362
+ def boundary_replicate_float(sample_coords0, input_size0, plus=0.):
363
+ return np.maximum(np.minimum(sample_coords0, input_size0 - 1 + plus), 0 + plus)
364
+
365
+ xy = [boundary_replicate_float(x.astype('float32'), input_size[idx]) for idx, x in enumerate(xy)]
366
+ spatial_coords = [boundary_replicate(x.astype('int32'), input_size[idx])
367
+ for idx, x in enumerate(index_voxel_coords)]
368
+ spatial_coords_plus1 = [boundary_replicate((x + 1).astype('int32'), input_size[idx], 1)
369
+ for idx, x in enumerate(index_voxel_coords)]
370
+
371
+ weight = [np.expand_dims(x - i.astype('float32'), -1) for x, i in zip(xy, spatial_coords)]
372
+ weight_c = [np.expand_dims(i.astype('float32') - x, -1) for x, i in zip(xy, spatial_coords_plus1)]
373
+
374
+ sz = list(spatial_coords[0].shape)
375
+ batch_coords = np.tile(np.reshape(range(sz[0]), [sz[0]] + [1] * (len(sz) - 1)), [1] + sz[1:])
376
+ sc = (spatial_coords, spatial_coords_plus1)
377
+ binary_codes = [[int(c) for c in format(i, '0%ib' % spatial_rank)] for i in range(2 ** spatial_rank)]
378
+
379
+ make_sample = lambda bc: inputs[batch_coords, sc[bc[0]][0], sc[bc[1]][1], sc[bc[2]][
380
+ 2], ...] # tf.gather_nd(inputs, np.stack([batch_coords] + [sc[c][i] for i, c in enumerate(bc)], -1))
381
+ samples = [make_sample(bc) for bc in binary_codes]
382
+
383
+ def pyramid_combination(samples0, weight0, weight_c0):
384
+ if len(weight0) == 1:
385
+ return samples0[0] * weight_c0[0] + samples0[1] * weight0[0]
386
+ else:
387
+ return pyramid_combination(samples0[::2], weight0[:-1], weight_c0[:-1]) * weight_c0[-1] + \
388
+ pyramid_combination(samples0[1::2], weight0[:-1], weight_c0[:-1]) * weight0[-1]
389
+
390
+ return pyramid_combination(samples, weight, weight_c)
391
+
392
+
393
+ def vecang2rotmats(vec, ang):
394
+ return np.stack([np.reshape(vecang2rotmat(vec[i, ...], ang[i, ...]), [3, 3]) for i in range(len(vec))], 0)
395
+
396
+
397
+ def vecang2rotmat(vec, ang):
398
+ q = quater.Quaternion(axis=vec, angle=ang)
399
+ return q.rotation_matrix
400
+
401
+
402
+ def images_to_vectors(images):
403
+ return images.view(images.size(0), 16384).to(device)
404
+
405
+ def vectors_to_images(vectors):
406
+ return vectors.view(vectors.size(0), 1, 128, 128).to(device)
407
+
408
+ def noise(size):
409
+ n = Variable(torch.randn(size, 100)).to(device)
410
+ return n
411
+
412
+ def ones_target(size):
413
+ data = Variable(torch.ones(size, 1)).to(device)
414
+ return data
415
+
416
+ def zeros_target(size):
417
+ data = Variable(torch.zeros(size, 1)).to(device)
418
+ return data
419
+
420
+
421
+ def eval_detJ_lab(disp=None,vol1=None,vol2=None,thresh=0.5):
422
+ ndims=disp.ndim-2
423
+ if vol1 ==None or thresh==None:
424
+ label=1
425
+ else:
426
+ label=vol1>thresh
427
+ label=label*(spimg.laplace(label) < 0.1)
428
+ rescale_factor=2
429
+ label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
430
+
431
+ # disp = disp.permute([0, *range(2,ndims+2), 1])
432
+ # print(disp.shape)
433
+ disp = np.transpose(disp, [0, *range(2,ndims+2), 1])
434
+ # Jacob=np.stack(np.gradient(disp,axis=[-4,-3,-2]),-1)
435
+ Jacob=np.stack(np.gradient(disp,axis=[*range(1,ndims+1)]),-1)
436
+ for ii in range(ndims):
437
+ Jacob[..., ii, ii] = Jacob[..., ii, ii] + 1
438
+ # Jacob[..., 0, 0] = Jacob[..., 0, 0] + 1
439
+ # Jacob[..., 1, 1] = Jacob[..., 1, 1] + 1
440
+ # Jacob[..., 2, 2] = Jacob[..., 2, 2] + 1
441
+ return np.sum((np.linalg.det(Jacob)<0)*label)
442
+
443
+ def eval_def_mag(disp=None,vol1=None,vol2=None,thresh=0.5):
444
+ ndims=3
445
+ # if vol1 ==None or thresh==None:
446
+ # label=1
447
+ # else:
448
+ # label=vol1>thresh
449
+ # label=label*(spimg.laplace(label) < 0.1)
450
+ # rescale_factor=2
451
+ # label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
452
+ mag=np.sqrt(np.sum(np.square(disp),axis=1))
453
+ sz=mag.shape
454
+ max_mag=np.mean(np.max(np.reshape(mag,[sz[0],-1]),axis=-1))
455
+ avg_mag=np.mean(mag)
456
+ return [avg_mag,max_mag]
457
+
458
+
459
+
460
+ def print_memory_usage(tag=""):
461
+ print(f"[{tag}] Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB | Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
462
+
463
+
464
+ if __name__ == "__main__":
465
+ vol_shape=[4,1,64,64]
466
+
467
+ vol=np.random.uniform(-1,1,vol_shape)
468
+ vol=Variable(torch.tensor(vol,dtype=torch.float32))
469
+ vol_res=random_resample(vol)
470
+ vol_crop=img_crop(vol_res)
471
+
472
+ mask = get_random_deformed_mask(vol.shape[2:])
473
+
474
+ print(mask)
475
+
476
+ # print(vol.tolist())
477
+ # print(vol_res.tolist())
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
OM_aug.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision.utils import save_image
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import Adam
7
+ from torchvision.utils import make_grid
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import random
12
+ import os
13
+ import utils
14
+ from Dataloader.dataloader0 import get_dataloader
15
+ from Dataloader.dataLoader import *
16
+
17
+ from torchvision.utils import save_image
18
+ from einops import rearrange, reduce, repeat
19
+ # import matplotlib.image
20
+ import numpy as np
21
+ import nibabel as nib
22
+ from tqdm import tqdm
23
+ import yaml
24
+ import argparse
25
+
26
+ EPS = 10e-8
27
+
28
+ parser = argparse.ArgumentParser()
29
+
30
+ parser.add_argument(
31
+ "--config",
32
+ "-C",
33
+ help="Path for the config file",
34
+ type=str,
35
+ default="Config/config_cmr.yaml",
36
+ # default="Config/config_lct.yaml",
37
+ required=False,
38
+ )
39
+ args = parser.parse_args()
40
+ #=======================================================================================================================
41
+
42
+ # config_path = 'Config/config_cmr.yaml'
43
+ # config_path = 'Config/config_lct.yaml'
44
+
45
+ # Load the YAML file into a dictionary
46
+ with open(args.config, 'r') as file:
47
+ hyp_parameters = yaml.safe_load(file)
48
+ print(hyp_parameters)
49
+ # hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
50
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
51
+ os.makedirs(hyp_parameters["aug_img_savepath"])
52
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
53
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
54
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
55
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
56
+ print(hyp_parameters["aug_img_savepath"])
57
+
58
+ hyp_parameters['batchsize'] = 1
59
+
60
+
61
+ # =======================================================================================================================
62
+ select_channels_dict={}
63
+ # min_crop_ratio = 0.5
64
+ min_crop_ratio = 0.9
65
+
66
+ # label_keys = ['heart']
67
+ # label_keys = ['brain']
68
+ # label_keys = ['pancreas']
69
+ # label_keys = ['spleen']
70
+ # label_keys = ['liver']
71
+ # database = ['MSD']
72
+ label_keys = ['heart']
73
+ database = ['MnMs']
74
+ # subtype = "ed" # 'ed' or 'es' for MnMs
75
+ subtype = "es" # 'ed' or 'es' for MnMs
76
+ hyp_parameters["aug_img_savepath"]=f"Data/Aug_data/mnms_{subtype}/img/"
77
+ hyp_parameters["aug_msk_savepath"]=f"Data/Aug_data/mnms_{subtype}/msk/"
78
+ hyp_parameters["aug_ddf_savepath"]=f"Data/Aug_data/mnms_{subtype}/ddf/"
79
+ select_channels_dict={
80
+ "ImgDict":[subtype]
81
+ }
82
+
83
+ # dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
84
+ dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database, select_channels_dict=select_channels_dict)
85
+ Infer_Loader = DataLoader(
86
+ dataset,
87
+ batch_size=hyp_parameters['batchsize'],
88
+ shuffle=False
89
+ )
90
+ # =======================================================================================================================
91
+
92
+ # Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
93
+ # transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
94
+ # dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
95
+ # train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
96
+
97
+
98
+
99
+ epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
100
+ model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
101
+ model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
102
+
103
+
104
+
105
+ Net = get_net(hyp_parameters["net_name"])
106
+
107
+ Deformddpm = DeformDDPM(
108
+ network=Net(n_steps = hyp_parameters["timesteps"],
109
+ ndims = hyp_parameters["ndims"],
110
+ num_input_chn = hyp_parameters["num_input_chn"],
111
+ res = hyp_parameters['img_size']
112
+ ),
113
+ n_steps = hyp_parameters["timesteps"],
114
+ image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
115
+ device = hyp_parameters["device"],
116
+ batch_size = hyp_parameters["batchsize"],
117
+ img_pad_mode = hyp_parameters["img_pad_mode"],
118
+ ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
119
+ padding_mode = hyp_parameters["padding_mode"],
120
+ v_scale = hyp_parameters["v_scale"],
121
+ resample_mode = hyp_parameters["resample_mode"],
122
+ )
123
+ Deformddpm.to(hyp_parameters["device"])
124
+
125
+ ddf_stn = STN(
126
+ img_sz = hyp_parameters["img_size"],
127
+ ndims = hyp_parameters["ndims"],
128
+ padding_mode = hyp_parameters['padding_mode'],
129
+ device = hyp_parameters["device"],
130
+ )
131
+ ddf_stn.to(hyp_parameters["device"])
132
+
133
+ print("Loading model from:", model_save_path)
134
+ # Deformddpm.load_state_dict(torch.load(model_save_path))
135
+ checkpoint = torch.load(model_save_path)
136
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
137
+ Deformddpm.eval()
138
+
139
+ os.makedirs(hyp_parameters['aug_img_savepath'], exist_ok=True)
140
+ os.makedirs(hyp_parameters['aug_msk_savepath'], exist_ok=True)
141
+ os.makedirs(hyp_parameters['aug_ddf_savepath'], exist_ok=True)
142
+
143
+ print("total num of image:", len(Infer_Loader))
144
+ for e, d in tqdm(enumerate(Infer_Loader)):
145
+ # if e<1:
146
+ # continue
147
+ # img, mask, pid = d
148
+ # img = d
149
+ # mask = d
150
+ img = d['img']
151
+ mask = d['labels']
152
+ label_str = str(d['label_channels'])
153
+ # mask = np.concatenate([v for v in d['labels'].values()], axis=1)
154
+ # print('img shape:', img.shape, 'mask shape:', mask.shape)
155
+
156
+ # pid = pid.cpu().detach().numpy()
157
+ # pid = pid[0]
158
+ pid = e
159
+
160
+ print('Processing to patient:', pid, ' image:',e)
161
+
162
+
163
+ img = img.type(torch.float32)
164
+ img = img.to(hyp_parameters["device"])
165
+ image_original = img.cpu().detach().numpy()
166
+
167
+ mask = mask.type(torch.float32)
168
+ mask = mask.to(hyp_parameters["device"])
169
+ mask_original = mask.cpu().detach().numpy()
170
+ # print(pid, image_original.shape, mask_original.max())
171
+
172
+
173
+ # if hyp_parameters["ndims"] == 2:
174
+ # nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
175
+ # nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
176
+ # elif hyp_parameters["ndims"] == 3:
177
+ # nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
178
+ # nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
179
+ nifti_img = utils.converet_to_nibabel(image_original,ndims=hyp_parameters["ndims"])
180
+ nifti_mask = utils.converet_to_nibabel(mask_original,ndims=hyp_parameters["ndims"])
181
+
182
+ # Saving original (undeformed image)
183
+ # CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
184
+ # Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
185
+ nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
186
+
187
+ # Saving original (undeformed image)
188
+ # CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
189
+ # Lung CT: ...
190
+ nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
191
+
192
+
193
+ noise_step = hyp_parameters["start_noise_step"]
194
+ with torch.no_grad():
195
+ for im in range(hyp_parameters["aug_coe"]):
196
+ # # Permute
197
+ # if hyp_parameters["ndims"] == 2:
198
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
199
+ # elif hyp_parameters["ndims"] == 3:
200
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
201
+
202
+ print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
203
+
204
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,msk_org=mask,T=[noise_step,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
205
+
206
+ denoise_imgs = img_rec.cpu().detach().numpy()
207
+ denoise_msks = msk_rec.cpu().detach().numpy()
208
+ noisy_imgs_np = img_diff.cpu().detach().numpy()
209
+ noisy_msks_np = msk_diff.cpu().detach().numpy()
210
+
211
+ # if hyp_parameters["ndims"] == 2:
212
+ # nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
213
+ # nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
214
+ # nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
215
+ # nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
216
+ # elif hyp_parameters["ndims"] == 3:
217
+ # nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
218
+ # nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
219
+ # nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
220
+ # nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4)) ###
221
+ nifti_img_aug = utils.converet_to_nibabel(denoise_imgs,ndims=hyp_parameters["ndims"])
222
+ nifti_mask_aug = utils.converet_to_nibabel(denoise_msks,ndims=hyp_parameters["ndims"])
223
+ nifti_img = utils.converet_to_nibabel(noisy_imgs_np,ndims=hyp_parameters["ndims"])
224
+ nifti_mask = utils.converet_to_nibabel(noisy_msks_np,ndims=hyp_parameters["ndims"])
225
+
226
+ nib.save(nifti_img_aug, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
227
+ nib.save(nifti_mask_aug, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
228
+
229
+ # Saving noisy image to nifti
230
+ # CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
231
+ # Lung CT: ...
232
+ nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
233
+ nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
234
+
235
+
236
+ if (im - hyp_parameters["start_noise_step"])%2 == 0:
237
+ noise_step = noise_step + hyp_parameters["noise_step"]
238
+ # break # for testing
239
+ if e >= 0:
240
+ exit()
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+
250
+
251
+
252
+
253
+
254
+
OM_aug_highres.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision.utils import save_image
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import Adam
7
+ from torchvision.utils import make_grid
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import random
12
+ import os
13
+ import utils
14
+ from Dataloader.dataloader0 import get_dataloader
15
+ from Dataloader.dataLoader import *
16
+
17
+ from torchvision.utils import save_image
18
+ from einops import rearrange, reduce, repeat
19
+ # import matplotlib.image
20
+ import numpy as np
21
+ import nibabel as nib
22
+ from tqdm import tqdm
23
+ import yaml
24
+ import argparse
25
+
26
+ EPS = 10e-8
27
+
28
+ parser = argparse.ArgumentParser()
29
+
30
+ parser.add_argument(
31
+ "--config",
32
+ "-C",
33
+ help="Path for the config file",
34
+ type=str,
35
+ default="Config/config_cmr.yaml",
36
+ # default="Config/config_lct.yaml",
37
+ required=False,
38
+ )
39
+ args = parser.parse_args()
40
+ #=======================================================================================================================
41
+
42
+ # config_path = 'Config/config_cmr.yaml'
43
+ # config_path = 'Config/config_lct.yaml'
44
+
45
+ # Load the YAML file into a dictionary
46
+ with open(args.config, 'r') as file:
47
+ hyp_parameters = yaml.safe_load(file)
48
+ print(hyp_parameters)
49
+ # hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
50
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
51
+ os.makedirs(hyp_parameters["aug_img_savepath"])
52
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
53
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
54
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
55
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
56
+ print(hyp_parameters["aug_img_savepath"])
57
+
58
+ hyp_parameters['batchsize'] = 1
59
+
60
+
61
+ # =======================================================================================================================
62
+ # min_crop_ratio = 0.5
63
+ min_crop_ratio = 0.9
64
+
65
+ # label_keys = ['heart']
66
+ # label_keys = ['brain']
67
+ label_keys = ['pancreas']
68
+ database = ['MSD']
69
+
70
+ # dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
71
+ dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
72
+ Infer_Loader = DataLoader(
73
+ dataset,
74
+ batch_size=hyp_parameters['batchsize'],
75
+ shuffle=False
76
+ )
77
+ # =======================================================================================================================
78
+
79
+ # Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
80
+ # transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
81
+ # dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
82
+ # train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
83
+
84
+
85
+
86
+ epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
87
+ model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
88
+ model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
89
+
90
+
91
+
92
+ Net = get_net(hyp_parameters["net_name"])
93
+
94
+ Deformddpm = DeformDDPM(
95
+ network=Net(n_steps = hyp_parameters["timesteps"],
96
+ ndims = hyp_parameters["ndims"],
97
+ num_input_chn = hyp_parameters["num_input_chn"],
98
+ res = hyp_parameters['img_size']
99
+ ),
100
+ n_steps = hyp_parameters["timesteps"],
101
+ image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
102
+ device = hyp_parameters["device"],
103
+ batch_size = hyp_parameters["batchsize"],
104
+ img_pad_mode = hyp_parameters["img_pad_mode"],
105
+ ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
106
+ padding_mode = hyp_parameters["padding_mode"],
107
+ v_scale = hyp_parameters["v_scale"],
108
+ resample_mode = hyp_parameters["resample_mode"],
109
+ )
110
+ Deformddpm.to(hyp_parameters["device"])
111
+
112
+ ddf_stn = STN(
113
+ img_sz = hyp_parameters["img_size"],
114
+ ndims = hyp_parameters["ndims"],
115
+ padding_mode = hyp_parameters['padding_mode'],
116
+ device = hyp_parameters["device"],
117
+ )
118
+ ddf_stn.to(hyp_parameters["device"])
119
+
120
+ print("Loading model from:", model_save_path)
121
+ # Deformddpm.load_state_dict(torch.load(model_save_path))
122
+ checkpoint = torch.load(model_save_path)
123
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
124
+ Deformddpm.eval()
125
+
126
+ os.makedirs(hyp_parameters['aug_img_savepath'], exist_ok=True)
127
+ os.makedirs(hyp_parameters['aug_msk_savepath'], exist_ok=True)
128
+ os.makedirs(hyp_parameters['aug_ddf_savepath'], exist_ok=True)
129
+
130
+ print("total num of image:", len(Infer_Loader))
131
+ for e, d in tqdm(enumerate(Infer_Loader)):
132
+
133
+ # img, mask, pid = d
134
+ # img = d
135
+ # mask = d
136
+ img = d['img']
137
+ mask = d['labels']
138
+ # mask = np.concatenate([v for v in d['labels'].values()], axis=1)
139
+ # print('img shape:', img.shape, 'mask shape:', mask.shape)
140
+
141
+ # pid = pid.cpu().detach().numpy()
142
+ # pid = pid[0]
143
+ pid = e
144
+
145
+ print('Processing to patient:', pid, ' image:',e)
146
+
147
+
148
+ img = img.type(torch.float32)
149
+ img = img.to(hyp_parameters["device"])
150
+ image_original = img.cpu().detach().numpy()
151
+
152
+ mask = mask.type(torch.float32)
153
+ mask = mask.to(hyp_parameters["device"])
154
+ mask_original = mask.cpu().detach().numpy()
155
+ # print(pid, image_original.shape, mask_original.max())
156
+
157
+
158
+ if hyp_parameters["ndims"] == 2:
159
+ nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
160
+ nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
161
+ elif hyp_parameters["ndims"] == 3:
162
+ nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
163
+ nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
164
+
165
+ # Saving original (undeformed image)
166
+ # CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
167
+ # Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
168
+ nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
169
+
170
+ # Saving original (undeformed image)
171
+ # CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
172
+ # Lung CT: ...
173
+ nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
174
+
175
+
176
+ noise_step = hyp_parameters["start_noise_step"]
177
+ with torch.no_grad():
178
+ for im in range(hyp_parameters["aug_coe"]):
179
+ # # Permute
180
+ # if hyp_parameters["ndims"] == 2:
181
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
182
+ # elif hyp_parameters["ndims"] == 3:
183
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
184
+
185
+ print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
186
+
187
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,msk_org=mask,T=[noise_step,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
188
+
189
+ denoise_imgs = img_rec.cpu().detach().numpy()
190
+ denoise_msks = msk_rec.cpu().detach().numpy()
191
+ noisy_imgs_np = img_diff.cpu().detach().numpy()
192
+ noisy_msks_np = msk_diff.cpu().detach().numpy()
193
+
194
+ if hyp_parameters["ndims"] == 2:
195
+ nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
196
+ nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
197
+ nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
198
+ nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
199
+ elif hyp_parameters["ndims"] == 3:
200
+ nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
201
+ nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
202
+ nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
203
+ nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
204
+
205
+ nib.save(nifti_img_aug, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
206
+ nib.save(nifti_mask_aug, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
207
+
208
+ # Saving noisy image to nifti
209
+ # CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
210
+ # Lung CT: ...
211
+ nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
212
+ nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
213
+
214
+
215
+ if (im - hyp_parameters["start_noise_step"])%2 == 0:
216
+ noise_step = noise_step + hyp_parameters["noise_step"]
217
+ # break # for testing
218
+ # if e > 5:
219
+ # break
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+
233
+
OM_contrastive.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.optim import Adam
4
+ from torch.utils.data import DataLoader
5
+ from Diffusion.networks import get_net
6
+ from Dataloader.dataLoader import *
7
+ import argparse
8
+ import yaml
9
+ import os
10
+ import time
11
+ import swanlab
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml")
15
+ args = parser.parse_args()
16
+
17
+ with open(args.config, 'r') as file:
18
+ hyp = yaml.safe_load(file)
19
+
20
+ # Setup
21
+ device = torch.device(hyp['device'] if torch.cuda.is_available() else 'cpu')
22
+ data_name = hyp['data_name']
23
+ net_name = hyp['net_name']
24
+ ndims = hyp['ndims']
25
+ img_size = hyp['img_size']
26
+ model_save_path = os.path.join('Models', f'{data_name}_{net_name}/')
27
+ os.makedirs(model_save_path, exist_ok=True)
28
+
29
+ # SwanLab
30
+ swanlab.init(project="OM", config=hyp)
31
+
32
+ # Model
33
+ Net = get_net(net_name)
34
+ model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device)
35
+ optimizer = Adam(model.parameters(), lr=hyp['lr'])
36
+
37
+ # Data
38
+ dataset = OMDataset_indiv(out_sz=img_size, transform=None)
39
+ train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True)
40
+
41
+ # Training
42
+ print('start training...')
43
+ for epoch in range(hyp['epoch']):
44
+ epoch_loss = 0.0
45
+
46
+ for i, (volume, embd) in enumerate(train_loader):
47
+ t0 = time.time()
48
+ volume = volume.float().to(device)
49
+ embd = embd.to(device) # [B, 1024] GT text embedding
50
+ t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device)
51
+
52
+ _, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024]
53
+
54
+ # Cosine similarity loss: align img_embd with GT text embedding
55
+ loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean()
56
+ swanlab.log({"loss": loss.item()})
57
+
58
+ optimizer.zero_grad()
59
+ loss.backward()
60
+ optimizer.step()
61
+ epoch_loss += loss.item()
62
+ t1 = time.time()
63
+ dt = t1 - t0
64
+ swanlab.log({"Time(mins)/batch": dt/60})
65
+ avg_loss = epoch_loss / max(len(train_loader), 1)
66
+ print(f"Epoch {epoch:04d} | Loss: {avg_loss:.6f}")
67
+ swanlab.log({"Avg Loss/epoch": avg_loss})
68
+
69
+ # if epoch % hyp['epoch_per_save'] == 0:
70
+ # save_path = model_save_path + str(epoch).rjust(6, '0') + f'_{data_name}_{net_name}.pth'
71
+ # torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, save_path)
72
+ # print(f"Saved: {save_path}")
OM_reg.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn
4
+ from torchvision.utils import save_image
5
+ from torch.utils.data import DataLoader
6
+ from torch.optim import Adam
7
+ from torchvision.utils import make_grid
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import random
12
+ import os
13
+ import utils
14
+ from Dataloader.dataloader0 import get_dataloader
15
+ from Dataloader.dataLoader import *
16
+
17
+ from torchvision.utils import save_image
18
+ from einops import rearrange, reduce, repeat
19
+ # import matplotlib.image
20
+ import numpy as np
21
+ import nibabel as nib
22
+ from tqdm import tqdm
23
+ import yaml
24
+ import argparse
25
+
26
+ EPS = 10e-8
27
+
28
+ parser = argparse.ArgumentParser()
29
+
30
+ parser.add_argument(
31
+ "--config",
32
+ "-C",
33
+ help="Path for the config file",
34
+ type=str,
35
+ default="Config/config_cmr.yaml",
36
+ # default="Config/config_lct.yaml",
37
+ required=False,
38
+ )
39
+ args = parser.parse_args()
40
+ #=======================================================================================================================
41
+
42
+ # config_path = 'Config/config_cmr.yaml'
43
+ # config_path = 'Config/config_lct.yaml'
44
+
45
+ # Load the YAML file into a dictionary
46
+ with open(args.config, 'r') as file:
47
+ hyp_parameters = yaml.safe_load(file)
48
+ print(hyp_parameters)
49
+ # hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
50
+ if not os.path.exists(hyp_parameters["aug_img_savepath"]):
51
+ os.makedirs(hyp_parameters["aug_img_savepath"])
52
+ if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
53
+ os.makedirs(hyp_parameters["aug_msk_savepath"])
54
+ if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
55
+ os.makedirs(hyp_parameters["aug_ddf_savepath"])
56
+ print(hyp_parameters["aug_img_savepath"])
57
+
58
+ hyp_parameters['batchsize'] = 1
59
+
60
+
61
+ # =======================================================================================================================
62
+ # min_crop_ratio = 0.5
63
+ min_crop_ratio = 0.9
64
+
65
+ # dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
66
+ # Infer_Loader = DataLoader(
67
+ # dataset,
68
+ # batch_size=hyp_parameters['batchsize'],
69
+ # shuffle=False
70
+ # )
71
+
72
+ # label_keys = ['heart']
73
+ label_keys = ['brain']
74
+ # label_keys = ['pancreas']
75
+ database = ['MSD']
76
+
77
+ dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
78
+ Infer_Loader = DataLoader(
79
+ dataset,
80
+ batch_size=hyp_parameters['batchsize'],
81
+ shuffle=False
82
+ )
83
+ # =======================================================================================================================
84
+
85
+ # Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
86
+ # transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
87
+ # dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
88
+ # train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
89
+
90
+
91
+
92
+ epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
93
+ model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
94
+ model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
95
+
96
+
97
+
98
+ Net = get_net(hyp_parameters["net_name"])
99
+
100
+ Deformddpm = DeformDDPM(
101
+ network=Net(n_steps = hyp_parameters["timesteps"],
102
+ ndims = hyp_parameters["ndims"],
103
+ num_input_chn = hyp_parameters["num_input_chn"],
104
+ res = hyp_parameters['img_size']
105
+ ),
106
+ n_steps = hyp_parameters["timesteps"],
107
+ image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
108
+ device = hyp_parameters["device"],
109
+ batch_size = hyp_parameters["batchsize"],
110
+ img_pad_mode = hyp_parameters["img_pad_mode"],
111
+ ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
112
+ padding_mode = hyp_parameters["padding_mode"],
113
+ v_scale = hyp_parameters["v_scale"],
114
+ resample_mode = hyp_parameters["resample_mode"],
115
+ )
116
+ Deformddpm.to(hyp_parameters["device"])
117
+
118
+ ddf_stn = STN(
119
+ img_sz = hyp_parameters["img_size"],
120
+ ndims = hyp_parameters["ndims"],
121
+ padding_mode = hyp_parameters['padding_mode'],
122
+ device = hyp_parameters["device"],
123
+ )
124
+ ddf_stn.to(hyp_parameters["device"])
125
+
126
+ print("Loading model from:", model_save_path)
127
+ # Deformddpm.load_state_dict(torch.load(model_save_path))
128
+ checkpoint = torch.load(model_save_path)
129
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
130
+ Deformddpm.eval()
131
+
132
+ os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
133
+ os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
134
+ os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
135
+
136
+ print("total num of image:", len(Infer_Loader))
137
+ for e, d in tqdm(enumerate(Infer_Loader)):
138
+ # for e, d in enumerate(Infer_Loader):
139
+ # img, mask, pid = d
140
+ # img = d
141
+ # mask = d
142
+ img = d['img']
143
+ mask = d['labels']
144
+
145
+ # pid = pid.cpu().detach().numpy()
146
+ # pid = pid[0]
147
+ pid = e
148
+
149
+ print('Processing to patient:', pid, ' image:',e)
150
+
151
+ img = img.to(hyp_parameters["device"])
152
+ img = img.type(torch.float32)
153
+ image_original = img.cpu().detach().numpy()
154
+ #
155
+ #
156
+ if e <= 0:
157
+ target_img = img.clone().detach() # save the first image as target image for conditioning
158
+
159
+ mask = mask.to(hyp_parameters["device"])
160
+ mask = mask.type(torch.float32)
161
+ mask_original = mask.cpu().detach().numpy()
162
+ # print(pid, image_original.shape, mask_original.max())
163
+
164
+
165
+ if hyp_parameters["ndims"] == 2:
166
+ nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
167
+ nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
168
+ elif hyp_parameters["ndims"] == 3:
169
+ nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
170
+ nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
171
+
172
+ # Saving original (undeformed image)
173
+ # CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
174
+ # Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
175
+ nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
176
+
177
+ # Saving original (undeformed image)
178
+ # CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
179
+ # Lung CT: ...
180
+ nib.save(nifti_img, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
181
+
182
+
183
+ noise_step = hyp_parameters["start_noise_step"]
184
+ with torch.no_grad():
185
+ for im in range(1):
186
+ # # Permute
187
+ # if hyp_parameters["ndims"] == 2:
188
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
189
+ # elif hyp_parameters["ndims"] == 3:
190
+ # [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
191
+
192
+ print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
193
+
194
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,cond_imgs=target_img.clone().detach(),msk_org=mask,T=[None,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
195
+
196
+ denoise_imgs = img_rec.cpu().detach().numpy()
197
+ denoise_msks = msk_rec.cpu().detach().numpy()
198
+ noisy_imgs_np = img_diff.cpu().detach().numpy()
199
+ noisy_msks_np = msk_diff.cpu().detach().numpy()
200
+
201
+ if hyp_parameters["ndims"] == 2:
202
+ nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
203
+ nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
204
+ nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
205
+ nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
206
+ elif hyp_parameters["ndims"] == 3:
207
+ nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
208
+ nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
209
+ nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
210
+ nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
211
+
212
+ nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
213
+ nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
214
+
215
+ # Saving noisy image to nifti
216
+ # CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
217
+ # Lung CT: ...
218
+ nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
219
+ nib.save(nifti_mask, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
220
+
221
+
222
+ if (im - hyp_parameters["start_noise_step"])%2 == 0:
223
+ noise_step = noise_step + hyp_parameters["noise_step"]
224
+ # break # for testing
225
+ if e > 5:
226
+ break
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+
235
+
236
+
237
+
238
+
239
+
240
+
OM_train.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import torchvision
5
+ from torch import nn
6
+ from torchvision.utils import save_image
7
+ from torch.utils.data import DataLoader
8
+
9
+ from torch.optim import Adam, SGD
10
+ from Diffusion.diffuser import DeformDDPM
11
+ from Diffusion.networks import get_net, STN
12
+ from torchvision.transforms import Lambda
13
+ import Diffusion.losses as losses
14
+ import random
15
+ import glob
16
+ import numpy as np
17
+ import utils
18
+
19
+ from Dataloader.dataloader0 import get_dataloader
20
+ from Dataloader.dataLoader import *
21
+
22
+ from Dataloader.dataloader_utils import thresh_img
23
+ import yaml
24
+ import argparse
25
+
26
+ ####################
27
+ import torch.multiprocessing as mp
28
+ from torch.utils.data.distributed import DistributedSampler
29
+ from torch.nn.parallel import DistributedDataParallel as DDP
30
+ import torch.distributed as dist
31
+ # from torch.distributed import init_process_group
32
+ ###############
33
+ def ddp_setup(rank, world_size):
34
+ """
35
+ Args:
36
+ rank: Unique identifier of each process
37
+ world_size: Total number of processes
38
+ """
39
+ os.environ["MASTER_ADDR"] = "localhost"
40
+ os.environ["MASTER_PORT"] = "12355"
41
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
42
+ torch.cuda.set_device(rank)
43
+
44
+ use_distributed = True
45
+ # use_distributed = False
46
+
47
+ EPS = 1e-5
48
+
49
+ parser = argparse.ArgumentParser()
50
+
51
+ # config_file_path = 'Config/config_cmr.yaml'
52
+ parser.add_argument(
53
+ "--config",
54
+ "-C",
55
+ help="Path for the config file",
56
+ type=str,
57
+ # default="Config/config_cmr.yaml",
58
+ # default="Config/config_lct.yaml",
59
+ default="Config/config_all.yaml",
60
+ required=False,
61
+ )
62
+ args = parser.parse_args()
63
+ #=======================================================================================================================
64
+
65
+
66
+
67
+ def main_train(rank=0,world_size=1):
68
+ if use_distributed:
69
+ ddp_setup(rank,world_size)
70
+ gpu_id = rank
71
+
72
+ # Load the YAML file into a dictionary
73
+ with open(args.config, 'r') as file:
74
+ hyp_parameters = yaml.safe_load(file)
75
+ print(hyp_parameters)
76
+
77
+ # epoch_per_save=10
78
+ epoch_per_save=hyp_parameters['epoch_per_save']
79
+
80
+ data_name=hyp_parameters['data_name']
81
+ net_name = hyp_parameters['net_name']
82
+
83
+ Net=get_net(net_name)
84
+
85
+ suffix_pth=f'_{data_name}_{net_name}.pth'
86
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
87
+ model_dir=model_save_path
88
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
89
+
90
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
91
+
92
+ # tsfm = torchvision.transforms.Compose([
93
+ # torchvision.transforms.ToTensor(),
94
+ # ])
95
+
96
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
97
+ # train_loader = DataLoader(
98
+ # dataset,
99
+ # batch_size=hyp_parameters['batchsize'],
100
+ # # shuffle=False,
101
+ # shuffle=True,
102
+ # drop_last=True,
103
+ # )
104
+
105
+ dataset = OminiDataset_v1(transform=None)
106
+ train_loader = DataLoader(
107
+ dataset,
108
+ batch_size=hyp_parameters['batchsize'],
109
+ shuffle=True,
110
+ drop_last=True,
111
+ )
112
+
113
+
114
+
115
+ Deformddpm = DeformDDPM(
116
+ network=Net(
117
+ n_steps=hyp_parameters["timesteps"],
118
+ ndims=hyp_parameters["ndims"],
119
+ num_input_chn = hyp_parameters["num_input_chn"],
120
+ res = hyp_parameters['img_size']
121
+ ),
122
+ n_steps=hyp_parameters["timesteps"],
123
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
124
+ device=hyp_parameters["device"],
125
+ batch_size=hyp_parameters["batchsize"],
126
+ img_pad_mode=hyp_parameters["img_pad_mode"],
127
+ v_scale=hyp_parameters["v_scale"],
128
+ )
129
+
130
+
131
+ ddf_stn = STN(
132
+ img_sz=hyp_parameters["img_size"],
133
+ ndims=hyp_parameters["ndims"],
134
+ # padding_mode="zeros",
135
+ padding_mode=hyp_parameters["padding_mode"],
136
+ device=hyp_parameters["device"],
137
+ )
138
+
139
+
140
+ if use_distributed:
141
+ Deformddpm.to(rank)
142
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
143
+ ddf_stn.to(rank)
144
+ else:
145
+ Deformddpm.to(hyp_parameters["device"])
146
+ ddf_stn.to(hyp_parameters["device"])
147
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
148
+
149
+
150
+ # mse = nn.MSELoss()
151
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
152
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
153
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
154
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
155
+
156
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
157
+ # hyp_parameters["lr"]=0.00000001
158
+ # # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.95)
159
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
160
+
161
+ # # LR scheduler ----- YHM
162
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
163
+
164
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
165
+
166
+ # check for existing models
167
+ if not os.path.exists(model_dir):
168
+ os.makedirs(model_dir, exist_ok=True)
169
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
170
+ model_files.sort()
171
+ if model_files:
172
+ if gpu_id == 0:
173
+ print(model_files)
174
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
175
+ else:
176
+ initial_epoch = 0
177
+
178
+ if gpu_id == 0:
179
+ print('len_train_data: ',len(dataset))
180
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
181
+
182
+ epoch_loss_tot = 0.0
183
+ epoch_loss_gen_d = 0.0
184
+ epoch_loss_gen_a = 0.0
185
+ epoch_loss_reg = 0.0
186
+ # Set model inside to train model
187
+ Deformddpm.train()
188
+
189
+ for step, batch in enumerate(train_loader):
190
+ # for step, batch in enumerate(train_loader_omni):
191
+ # x0, _ = batch
192
+ x0 = batch # for omni dataset
193
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
194
+
195
+ n = x0.size()[0] # batch_size -> n
196
+ x0 = x0.to(hyp_parameters["device"])
197
+
198
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
199
+
200
+ # random deformation + rotation
201
+ if hyp_parameters["ndims"]>2:
202
+ if np.random.uniform(0,1)<0.6:
203
+ x0 = utils.random_resample(x0, deform_scale=0)
204
+ x0 = transformer(x0)
205
+ if hyp_parameters['noise_scale']>0:
206
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
207
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
208
+
209
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
210
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
211
+ hyp_parameters["device"]
212
+ ) # pick up a seq of rand number from 0 to 'timestep'
213
+
214
+
215
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, mask=blind_mask) # forward diffusion process
216
+
217
+ loss_tot=0
218
+
219
+ loss_ddf = loss_reg(pre_dvf_I)
220
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
221
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
222
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
223
+
224
+ loss_tot += 1. * loss_gen_d + 1. * loss_gen_a
225
+ loss_tot += 1.0 * loss_ddf
226
+ optimizer.zero_grad()
227
+ loss_tot.backward()
228
+ optimizer.step()
229
+
230
+ epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
231
+ epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
232
+ epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
233
+ epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
234
+ # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
235
+
236
+ # break # FOR TESTING
237
+
238
+ if gpu_id == 0:
239
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
240
+
241
+ # # LR schedular step ----- YHM
242
+ # scheduler.step()
243
+
244
+ if 0 == epoch % epoch_per_save:
245
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
246
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
247
+ # break # FOR TESTING
248
+ if not use_distributed:
249
+ print(f"saved in {save_dir}")
250
+ # torch.save(Deformddpm.state_dict(), save_dir)
251
+ torch.save({
252
+ 'model_state_dict': Deformddpm.state_dict(),
253
+ 'optimizer_state_dict': optimizer.state_dict(),
254
+ 'epoch': epoch
255
+ }, save_dir)
256
+ elif gpu_id == 0:
257
+ print(f"saved in {save_dir}")
258
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
259
+ torch.save({
260
+ 'model_state_dict': Deformddpm.module.state_dict(),
261
+ 'optimizer_state_dict': optimizer.state_dict(),
262
+ 'epoch': epoch
263
+ }, save_dir)
264
+
265
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
266
+
267
+ if gpu_id == 0:
268
+ # if 0:
269
+ utils.print_memory_usage("Before Loading Model")
270
+ if 1:
271
+ gc.collect()
272
+ torch.cuda.empty_cache()
273
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
274
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
275
+ checkpoint = torch.load(model_file)
276
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
277
+ if use_distributed:
278
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
279
+ else:
280
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
281
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
282
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
283
+
284
+ if use_distributed:
285
+ # Broadcast model weights from rank 0 to all other GPUs
286
+ dist.barrier()
287
+ for param in Deformddpm.parameters():
288
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
289
+ dist.barrier()
290
+ for param_group in optimizer.param_groups:
291
+ for param in param_group['params']:
292
+ if param.grad is not None:
293
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
294
+
295
+ # initial_epoch = checkpoint['epoch'] + 1
296
+ # get the epoch number from the filename and add 1 to set as initial_epoch
297
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
298
+
299
+ return initial_epoch, Deformddpm, optimizer
300
+
301
+
302
+
303
+ if __name__ == "__main__":
304
+ if use_distributed:
305
+ world_size = torch.cuda.device_count()
306
+ print(f"Distributed GPU number = {world_size}")
307
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
308
+ else:
309
+ main_train(0,1)
OM_train_2modes.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import torchvision
5
+ from torch import nn
6
+ from torchvision.utils import save_image
7
+ from torch.utils.data import DataLoader
8
+
9
+ from torch.optim import Adam, SGD
10
+ from Diffusion.diffuser import DeformDDPM
11
+ from Diffusion.networks import get_net, STN
12
+ from torchvision.transforms import Lambda
13
+ import Diffusion.losses as losses
14
+ import random
15
+ import glob
16
+ import numpy as np
17
+ import utils
18
+ from tqdm import tqdm
19
+
20
+ from Dataloader.dataloader0 import get_dataloader
21
+ from Dataloader.dataLoader import *
22
+
23
+ from Dataloader.dataloader_utils import thresh_img
24
+ import yaml
25
+ import argparse
26
+
27
+ ####################
28
+ import torch.multiprocessing as mp
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ import torch.distributed as dist
32
+ # from torch.distributed import init_process_group
33
+ ###############
34
+ def ddp_setup(rank, world_size):
35
+ """
36
+ Args:
37
+ rank: Unique identifier of each process
38
+ world_size: Total number of processes
39
+ """
40
+ os.environ["MASTER_ADDR"] = "localhost"
41
+ os.environ["MASTER_PORT"] = "12355"
42
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
43
+ torch.cuda.set_device(rank)
44
+
45
+ use_distributed = True
46
+ # use_distributed = False
47
+
48
+ EPS = 1e-5
49
+ MSK_EPS = 0.01
50
+ TEXT_EMBED_PROB = 0.7
51
+ AUG_RESAMPLE_PROB = 0.6
52
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 30] # [ang, dist, reg]
53
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
54
+ # LOSS_WEIGHTS_REGIST = [10.0, 1.0, 1.0] # [imgsim, imgmse, ddf]
55
+ # LOSS_WEIGHTS_REGIST = [2.0, 0.1, 1e3] # [imgsim, imgmse, ddf]
56
+ LOSS_WEIGHTS_REGIST = [2.0, 0.1, 256] # [imgsim, imgmse, ddf]
57
+
58
+ # AUG_PERMUTE_PROB = 0.35
59
+
60
+ parser = argparse.ArgumentParser()
61
+
62
+ # config_file_path = 'Config/config_cmr.yaml'
63
+ parser.add_argument(
64
+ "--config",
65
+ "-C",
66
+ help="Path for the config file",
67
+ type=str,
68
+ # default="Config/config_cmr.yaml",
69
+ # default="Config/config_lct.yaml",
70
+ default="Config/config_all.yaml",
71
+ required=False,
72
+ )
73
+ args = parser.parse_args()
74
+ #=======================================================================================================================
75
+
76
+
77
+
78
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
79
+ if use_distributed:
80
+ ddp_setup(rank,world_size)
81
+
82
+ if torch.distributed.is_initialized():
83
+ print(f"World size: {torch.distributed.get_world_size()}")
84
+ print(f"Communication backend: {torch.distributed.get_backend()}")
85
+ gpu_id = rank
86
+
87
+ # Load the YAML file into a dictionary
88
+ with open(args.config, 'r') as file:
89
+ hyp_parameters = yaml.safe_load(file)
90
+ print(hyp_parameters)
91
+
92
+ # epoch_per_save=10
93
+ epoch_per_save=hyp_parameters['epoch_per_save']
94
+
95
+ data_name=hyp_parameters['data_name']
96
+ net_name = hyp_parameters['net_name']
97
+
98
+ Net=get_net(net_name)
99
+
100
+ suffix_pth=f'_{data_name}_{net_name}.pth'
101
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
102
+ model_dir=model_save_path
103
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
104
+
105
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
106
+
107
+ # tsfm = torchvision.transforms.Compose([
108
+ # torchvision.transforms.ToTensor(),
109
+ # ])
110
+
111
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
112
+ # train_loader = DataLoader(
113
+ # dataset,
114
+ # batch_size=hyp_parameters['batchsize'],
115
+ # # shuffle=False,
116
+ # shuffle=True,
117
+ # drop_last=True,
118
+ # )
119
+
120
+ # dataset = OminiDataset_v1(transform=None)
121
+ dataset = OMDataset_indiv(transform=None)
122
+ train_loader = DataLoader(
123
+ dataset,
124
+ batch_size=hyp_parameters['batchsize'],
125
+ shuffle=True,
126
+ drop_last=True,
127
+ )
128
+
129
+ # datasetp = OminiDataset_paired(transform=None)
130
+ datasetp = OMDataset_pair(transform=None)
131
+ train_loader_p = DataLoader(
132
+ datasetp,
133
+ batch_size=hyp_parameters['batchsize']//2,
134
+ shuffle=True,
135
+ drop_last=True,
136
+ )
137
+
138
+
139
+
140
+ Deformddpm = DeformDDPM(
141
+ network=Net(
142
+ n_steps=hyp_parameters["timesteps"],
143
+ ndims=hyp_parameters["ndims"],
144
+ num_input_chn = hyp_parameters["num_input_chn"],
145
+ res = hyp_parameters['img_size']
146
+ ),
147
+ n_steps=hyp_parameters["timesteps"],
148
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
149
+ device=hyp_parameters["device"],
150
+ batch_size=hyp_parameters["batchsize"],
151
+ img_pad_mode=hyp_parameters["img_pad_mode"],
152
+ v_scale=hyp_parameters["v_scale"],
153
+ )
154
+
155
+
156
+ ddf_stn = STN(
157
+ img_sz=hyp_parameters["img_size"],
158
+ ndims=hyp_parameters["ndims"],
159
+ # padding_mode="zeros",
160
+ padding_mode=hyp_parameters["padding_mode"],
161
+ device=hyp_parameters["device"],
162
+ )
163
+
164
+
165
+ if use_distributed:
166
+ Deformddpm.to(rank)
167
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
168
+ ddf_stn.to(rank)
169
+ else:
170
+ Deformddpm.to(hyp_parameters["device"])
171
+ ddf_stn.to(hyp_parameters["device"])
172
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
173
+
174
+
175
+ # mse = nn.MSELoss()
176
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
177
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
178
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
179
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
180
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
181
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
182
+ loss_imgsim = losses.LNCC()
183
+ loss_imgmse = losses.LMSE()
184
+
185
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
186
+ # hyp_parameters["lr"]=0.00000001
187
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
188
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
189
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
190
+
191
+ # # LR scheduler ----- YHM
192
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
193
+
194
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
195
+
196
+ # check for existing models
197
+ if not os.path.exists(model_dir):
198
+ os.makedirs(model_dir, exist_ok=True)
199
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
200
+ model_files.sort()
201
+ if model_files:
202
+ if gpu_id == 0:
203
+ print(model_files)
204
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
205
+ else:
206
+ initial_epoch = 0
207
+
208
+ if gpu_id == 0:
209
+ print('len_train_data: ',len(dataset))
210
+ # Training loop
211
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
212
+
213
+ epoch_loss_tot = 0.0
214
+ epoch_loss_gen_d = 0.0
215
+ epoch_loss_gen_a = 0.0
216
+ epoch_loss_reg = 0.0
217
+ epoch_loss_regist = 0.0
218
+ epoch_loss_imgsim = 0.0
219
+ epoch_loss_imgmse = 0.0
220
+ epoch_loss_ddfreg = 0.0
221
+ # Set model inside to train model
222
+ Deformddpm.train()
223
+
224
+ loss_nan_step = 0 # yu: count the number of nan loss steps
225
+
226
+ total = min(len(train_loader), len(train_loader_p))
227
+ for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
228
+ # for step, batch in tqdm(enumerate(train_loader)):
229
+ # for step, batch in tqdm(enumerate(train_loader)):
230
+
231
+ # for step, batch in enumerate(train_loader_omni):
232
+ # x0, _ = batch
233
+
234
+
235
+ # ==========================================================================
236
+ # diffusion train on single image
237
+
238
+ # x0 = batch # for omni dataset
239
+ [x0,embd] = batch # for om dataset
240
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
241
+ # print('embd:', embd.shape)
242
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
243
+ embd = embd.to(hyp_parameters["device"]).type(torch.float32)
244
+ else:
245
+ embd = None
246
+
247
+
248
+
249
+ n = x0.size()[0] # batch_size -> n
250
+ x0 = x0.to(hyp_parameters["device"])
251
+
252
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
253
+
254
+ # random deformation + rotation
255
+ if hyp_parameters["ndims"]>2:
256
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
257
+ x0 = utils.random_resample(x0, deform_scale=0)
258
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
259
+ else:
260
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
261
+ x0 = transformer(x0)
262
+ if hyp_parameters['noise_scale']>0:
263
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
264
+ x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
265
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
266
+
267
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
268
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
269
+ hyp_parameters["device"]
270
+ ) # pick up a seq of rand number from 0 to 'timestep'
271
+
272
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
273
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
274
+ # print('proc_type:', proc_type)
275
+ cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
276
+
277
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
278
+
279
+ loss_tot=0
280
+
281
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
282
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
283
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
284
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
285
+
286
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
287
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
288
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
289
+
290
+ # >> JZ: print nan in x0
291
+ if torch.isnan(x0).any():
292
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
293
+ # >> JZ: print loss of ddf
294
+ if loss_ddf>0.001:
295
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
296
+ # yu: check if loss_tot==nan or inf
297
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
298
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
299
+ loss_nan_step += 1
300
+ continue
301
+ if loss_nan_step > 5:
302
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
303
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
304
+
305
+
306
+ optimizer.zero_grad()
307
+ loss_tot.backward()
308
+ optimizer.step()
309
+
310
+ epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
311
+ epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
312
+ epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
313
+ epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
314
+
315
+ # print(loss_gen_a.item())
316
+ # if 0:
317
+ # if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
318
+ if step%train_mode_ratio == 0:
319
+ # ==========================================================================
320
+ # registration train on paired images
321
+ # x1, y1 = next(iter(train_loader_p))
322
+ # [x1, y1, _, embd_y] = next(iter(train_loader_p))
323
+ [x1, y1, _, embd_y] = batch_p
324
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
325
+ # embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
326
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
327
+ else:
328
+ # embd_x = None
329
+ embd_y = None
330
+
331
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
332
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
333
+ n = x1.size()[0] # batch_size -> n
334
+ # random deformation + rotation
335
+ # if hyp_parameters["ndims"]>2:
336
+ # if np.random.uniform(0,1)<0.6:
337
+ # x1 = utils.random_resample(x1, deform_scale=0)
338
+ # y1 = utils.random_resample(y1, deform_scale=0)
339
+ x1 = transformer(x1)
340
+ y1 = transformer(y1)
341
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
342
+ if hyp_parameters['noise_scale']>0:
343
+ [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
344
+ random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
345
+ random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
346
+ x1 = x1 * random_scale + random_shift
347
+ y1 = y1 * random_scale + random_shift
348
+ # x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
349
+ # x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
350
+ # y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
351
+ # y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
352
+ # # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
353
+ # t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
354
+ # hyp_parameters["device"]
355
+ # ) # pick up a seq of rand number from 0 to 'timestep'
356
+
357
+
358
+ # scale_regist = np.random.uniform(0.2,0.25)
359
+ # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
360
+ scale_regist = np.random.uniform(0.05,0.7)
361
+ T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), 16), reverse=True)
362
+ # scale_regist = np.random.uniform(0.4,1.)
363
+ # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
364
+ # scale_regist = np.random.uniform(0.6,1.)
365
+ # init_T = int(hyp_parameters["timesteps"] * scale_regist)
366
+ # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
367
+
368
+ T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
369
+
370
+ # print('T_regist:', T_regist)
371
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
372
+ proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'none'])
373
+ # proc_type = random.choice(['project'])
374
+ y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
375
+ # msk_tgt = msk_tgt + MSK_EPS
376
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
377
+ # loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
378
+ # loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
379
+ # loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
380
+ loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
381
+ loss_mse = loss_imgmse(img_rec, y1, label=(y1>0.0)) # calculate loss for the registration process
382
+ loss_ddf1 = loss_reg1(ddf_comp,img=y1) # calculate loss for the registration process
383
+
384
+ loss_regist = 0
385
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
386
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
387
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
388
+ # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
389
+ # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
390
+
391
+ # >> JZ: print nan in x0
392
+ if torch.isnan(x0).any():
393
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
394
+
395
+
396
+
397
+ # >> JZ: print loss of ddf
398
+ if loss_ddf1>0.001:
399
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
400
+ # # Print gradients for each parameter
401
+ # for name, param in Deformddpm.named_parameters():
402
+ # if param.grad is not None:
403
+ # print(f"Gradient for {name}: {param.grad.norm()}")
404
+ # else:
405
+ # print(f"Gradient for {name}: None")
406
+
407
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
408
+ optimizer.zero_grad()
409
+ loss_regist.backward()
410
+
411
+
412
+
413
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
414
+ optimizer.step()
415
+
416
+ epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
417
+ epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
418
+ epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
419
+ epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
420
+
421
+
422
+ print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
423
+ print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
424
+ # >> JZ: if loss_imgsim is zero
425
+ if loss_sim.item()>-0.001:
426
+ print(f"*** Zero image similarity loss at epoch {epoch}, step {step}.")
427
+ def save_niftiimage(tensor, filename):
428
+ import nibabel as nib
429
+ import numpy as np
430
+ array = tensor.squeeze().cpu().detach().numpy()
431
+ nifti_img = nib.Nifti1Image(array, affine=np.eye(4))
432
+ nib.save(nifti_img, filename)
433
+ # save the x1 and y1 images for debugging
434
+ save_path = os.path.join('/home/data/Github/OmniMorph/Log/error_files',f"debug_images_epoch{epoch}_step{step}/")
435
+ os.makedirs(save_path, exist_ok=True)
436
+ save_niftiimage(img_rec, os.path.join(save_path, 'img_rec.nii.gz'))
437
+ save_niftiimage(x1, os.path.join(save_path, 'x1.nii.gz'))
438
+ save_niftiimage(y1, os.path.join(save_path, 'y1.nii.gz'))
439
+ save_niftiimage(y1_proc, os.path.join(save_path, 'y1_proc.nii.gz'))
440
+ exit()
441
+ # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
442
+
443
+ # break # FOR TESTING
444
+ # else:
445
+ # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
446
+ # pass
447
+
448
+ if 1:
449
+ # if gpu_id == 0:
450
+ print('==================')
451
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
452
+ print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
453
+ print('==================')
454
+ # # LR schedular step ----- YHM
455
+ # scheduler.step()
456
+
457
+ if 0 == epoch % epoch_per_save:
458
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
459
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
460
+ # break # FOR TESTING
461
+ if not use_distributed:
462
+ print(f"saved in {save_dir}")
463
+ # torch.save(Deformddpm.state_dict(), save_dir)
464
+ torch.save({
465
+ 'model_state_dict': Deformddpm.state_dict(),
466
+ 'optimizer_state_dict': optimizer.state_dict(),
467
+ 'epoch': epoch
468
+ }, save_dir)
469
+ elif gpu_id == 0:
470
+ print(f"saved in {save_dir}")
471
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
472
+ torch.save({
473
+ 'model_state_dict': Deformddpm.module.state_dict(),
474
+ 'optimizer_state_dict': optimizer.state_dict(),
475
+ 'epoch': epoch
476
+ }, save_dir)
477
+
478
+ # Resource cleanup at the end of training
479
+ torch.cuda.empty_cache()
480
+ gc.collect()
481
+ if use_distributed and dist.is_initialized():
482
+ dist.destroy_process_group()
483
+
484
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
485
+
486
+ if gpu_id == 0:
487
+ # if 0:
488
+ utils.print_memory_usage("Before Loading Model")
489
+ if 1:
490
+ gc.collect()
491
+ torch.cuda.empty_cache()
492
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
493
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
494
+ checkpoint = torch.load(model_file)
495
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
496
+ if use_distributed:
497
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
498
+ else:
499
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
500
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
501
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
502
+
503
+ if use_distributed:
504
+ # Broadcast model weights from rank 0 to all other GPUs
505
+ dist.barrier()
506
+ for param in Deformddpm.parameters():
507
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
508
+ dist.barrier()
509
+ for param_group in optimizer.param_groups:
510
+ for param in param_group['params']:
511
+ if param.grad is not None:
512
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
513
+
514
+ # initial_epoch = checkpoint['epoch'] + 1
515
+ # get the epoch number from the filename and add 1 to set as initial_epoch
516
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
517
+
518
+ return initial_epoch, Deformddpm, optimizer
519
+
520
+
521
+
522
+ if __name__ == "__main__":
523
+ if use_distributed:
524
+ world_size = torch.cuda.device_count()
525
+ print(f"Distributed GPU number = {world_size}")
526
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
527
+ else:
528
+ main_train(0,1)
OM_train_3modes.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import torchvision
5
+ from torch import nn
6
+ from torchvision.utils import save_image
7
+ from torch.utils.data import DataLoader
8
+
9
+ from torch.optim import Adam, SGD
10
+ from Diffusion.diffuser import DeformDDPM
11
+ from Diffusion.networks import get_net, STN
12
+ from torchvision.transforms import Lambda
13
+ import Diffusion.losses as losses
14
+ import random
15
+ import glob
16
+ import numpy as np
17
+ import utils
18
+ from tqdm import tqdm
19
+
20
+ from Dataloader.dataloader0 import get_dataloader
21
+ from Dataloader.dataLoader import *
22
+
23
+ from Dataloader.dataloader_utils import thresh_img
24
+ import yaml
25
+ import argparse
26
+
27
+ ####################
28
+ import torch.multiprocessing as mp
29
+ from torch.utils.data.distributed import DistributedSampler
30
+ from torch.nn.parallel import DistributedDataParallel as DDP
31
+ import torch.distributed as dist
32
+ # from torch.distributed import init_process_group
33
+ ###############
34
+ def ddp_setup(rank, world_size):
35
+ """
36
+ Args:
37
+ rank: Unique identifier of each process
38
+ world_size: Total number of processes
39
+ """
40
+ os.environ["MASTER_ADDR"] = "localhost"
41
+ os.environ["MASTER_PORT"] = "12355"
42
+ dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
43
+ torch.cuda.set_device(rank)
44
+
45
+ use_distributed = True
46
+ # use_distributed = False
47
+
48
+ EPS = 1e-5
49
+ MSK_EPS = 0.01
50
+ TEXT_EMBED_PROB = 0.7
51
+ AUG_RESAMPLE_PROB = 0.6
52
+ LOSS_WEIGHTS_DIFF = [2.0, 1.0, 3.0] # [ang, dist, reg]
53
+ # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
54
+ LOSS_WEIGHTS_REGIST = [1.0, 0.2, 1e3] # [imgsim, imgmse, ddf]
55
+
56
+ # AUG_PERMUTE_PROB = 0.35
57
+
58
+ parser = argparse.ArgumentParser()
59
+
60
+ # config_file_path = 'Config/config_cmr.yaml'
61
+ parser.add_argument(
62
+ "--config",
63
+ "-C",
64
+ help="Path for the config file",
65
+ type=str,
66
+ # default="Config/config_cmr.yaml",
67
+ # default="Config/config_lct.yaml",
68
+ default="Config/config_all.yaml",
69
+ required=False,
70
+ )
71
+ args = parser.parse_args()
72
+ #=======================================================================================================================
73
+
74
+
75
+
76
+ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
77
+ if use_distributed:
78
+ ddp_setup(rank,world_size)
79
+
80
+ if torch.distributed.is_initialized():
81
+ print(f"World size: {torch.distributed.get_world_size()}")
82
+ print(f"Communication backend: {torch.distributed.get_backend()}")
83
+ gpu_id = rank
84
+
85
+ # Load the YAML file into a dictionary
86
+ with open(args.config, 'r') as file:
87
+ hyp_parameters = yaml.safe_load(file)
88
+ print(hyp_parameters)
89
+
90
+ # epoch_per_save=10
91
+ epoch_per_save=hyp_parameters['epoch_per_save']
92
+
93
+ data_name=hyp_parameters['data_name']
94
+ net_name = hyp_parameters['net_name']
95
+
96
+ Net=get_net(net_name)
97
+
98
+ suffix_pth=f'_{data_name}_{net_name}.pth'
99
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
100
+ model_dir=model_save_path
101
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
102
+
103
+ # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
104
+
105
+ # tsfm = torchvision.transforms.Compose([
106
+ # torchvision.transforms.ToTensor(),
107
+ # ])
108
+
109
+ # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
110
+ # train_loader = DataLoader(
111
+ # dataset,
112
+ # batch_size=hyp_parameters['batchsize'],
113
+ # # shuffle=False,
114
+ # shuffle=True,
115
+ # drop_last=True,
116
+ # )
117
+
118
+ # dataset = OminiDataset_v1(transform=None)
119
+ dataset = OMDataset_indiv(transform=None)
120
+ train_loader = DataLoader(
121
+ dataset,
122
+ batch_size=hyp_parameters['batchsize'],
123
+ shuffle=True,
124
+ drop_last=True,
125
+ )
126
+
127
+ # datasetp = OminiDataset_paired(transform=None)
128
+ datasetp = OMDataset_pair(transform=None)
129
+ train_loader_p = DataLoader(
130
+ datasetp,
131
+ batch_size=hyp_parameters['batchsize']//2,
132
+ shuffle=True,
133
+ drop_last=True,
134
+ )
135
+
136
+
137
+
138
+ Deformddpm = DeformDDPM(
139
+ network=Net(
140
+ n_steps=hyp_parameters["timesteps"],
141
+ ndims=hyp_parameters["ndims"],
142
+ num_input_chn = hyp_parameters["num_input_chn"],
143
+ res = hyp_parameters['img_size']
144
+ ),
145
+ n_steps=hyp_parameters["timesteps"],
146
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
147
+ device=hyp_parameters["device"],
148
+ batch_size=hyp_parameters["batchsize"],
149
+ img_pad_mode=hyp_parameters["img_pad_mode"],
150
+ v_scale=hyp_parameters["v_scale"],
151
+ )
152
+
153
+
154
+ ddf_stn = STN(
155
+ img_sz=hyp_parameters["img_size"],
156
+ ndims=hyp_parameters["ndims"],
157
+ # padding_mode="zeros",
158
+ padding_mode=hyp_parameters["padding_mode"],
159
+ device=hyp_parameters["device"],
160
+ )
161
+
162
+
163
+ if use_distributed:
164
+ Deformddpm.to(rank)
165
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
166
+ ddf_stn.to(rank)
167
+ else:
168
+ Deformddpm.to(hyp_parameters["device"])
169
+ ddf_stn.to(hyp_parameters["device"])
170
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
171
+
172
+
173
+ # mse = nn.MSELoss()
174
+ # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
175
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e2)
176
+ loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e2)
177
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
178
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
179
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
180
+ loss_imgsim = losses.LNCC()
181
+ loss_imgmse = losses.LMSE()
182
+
183
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
184
+ # hyp_parameters["lr"]=0.00000001
185
+ # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
186
+ # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
187
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
188
+
189
+ # # LR scheduler ----- YHM
190
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
191
+
192
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
193
+
194
+ # check for existing models
195
+ if not os.path.exists(model_dir):
196
+ os.makedirs(model_dir, exist_ok=True)
197
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
198
+ model_files.sort()
199
+ if model_files:
200
+ if gpu_id == 0:
201
+ print(model_files)
202
+ initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
203
+ else:
204
+ initial_epoch = 0
205
+
206
+ if gpu_id == 0:
207
+ print('len_train_data: ',len(dataset))
208
+ # Training loop
209
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
210
+
211
+ epoch_loss_tot = 0.0
212
+ epoch_loss_gen_d = 0.0
213
+ epoch_loss_gen_a = 0.0
214
+ epoch_loss_reg = 0.0
215
+ epoch_loss_regist = 0.0
216
+ epoch_loss_imgsim = 0.0
217
+ epoch_loss_imgmse = 0.0
218
+ epoch_loss_ddfreg = 0.0
219
+ # Set model inside to train model
220
+ Deformddpm.train()
221
+
222
+ loss_nan_step = 0 # yu: count the number of nan loss steps
223
+
224
+ for step, batch in tqdm(enumerate(train_loader)):
225
+ # for step, batch in tqdm(enumerate(train_loader)):
226
+
227
+ # for step, batch in enumerate(train_loader_omni):
228
+ # x0, _ = batch
229
+
230
+
231
+ # ==========================================================================
232
+ # diffusion train on single image
233
+
234
+ # x0 = batch # for omni dataset
235
+ [x0,embd] = batch # for om dataset
236
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
237
+ # print('embd:', embd.shape)
238
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
239
+ embd = embd.to(hyp_parameters["device"]).type(torch.float32)
240
+ else:
241
+ embd = None
242
+
243
+
244
+
245
+ n = x0.size()[0] # batch_size -> n
246
+ x0 = x0.to(hyp_parameters["device"])
247
+
248
+ blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
249
+
250
+ # random deformation + rotation
251
+ if hyp_parameters["ndims"]>2:
252
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
253
+ x0 = utils.random_resample(x0, deform_scale=0)
254
+ # elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
255
+ else:
256
+ [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
257
+ x0 = transformer(x0)
258
+ if hyp_parameters['noise_scale']>0:
259
+ if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
260
+ x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
261
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
262
+
263
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
264
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
265
+ hyp_parameters["device"]
266
+ ) # pick up a seq of rand number from 0 to 'timestep'
267
+
268
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
269
+ proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
270
+ # print('proc_type:', proc_type)
271
+ cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
272
+
273
+ pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
274
+
275
+ loss_tot=0
276
+
277
+ loss_ddf = loss_reg(pre_dvf_I,img=x0)
278
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
279
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
280
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
281
+
282
+ loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
283
+ loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
284
+ loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
285
+
286
+ # >> JZ: print nan in x0
287
+ if torch.isnan(x0).any():
288
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
289
+ # >> JZ: print loss of ddf
290
+ if loss_ddf>0.001:
291
+ print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
292
+ # yu: check if loss_tot==nan or inf
293
+ if torch.isnan(loss_tot) or torch.isinf(loss_tot):
294
+ print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
295
+ loss_nan_step += 1
296
+ continue
297
+ if loss_nan_step > 5:
298
+ print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
299
+ raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
300
+
301
+
302
+ optimizer.zero_grad()
303
+ loss_tot.backward()
304
+ optimizer.step()
305
+
306
+ epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
307
+ epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
308
+ epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
309
+ epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
310
+
311
+ # print(loss_gen_a.item())
312
+ # if 0:
313
+ # if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
314
+ if step%train_mode_ratio == 0:
315
+ # ==========================================================================
316
+ # registration train on paired images
317
+ # x1, y1 = next(iter(train_loader_p))
318
+ [x1, y1, _, embd_y] = next(iter(train_loader_p))
319
+ if np.random.uniform(0,1)<TEXT_EMBED_PROB:
320
+ # embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
321
+ embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
322
+ else:
323
+ # embd_x = None
324
+ embd_y = None
325
+
326
+ x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
327
+ y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
328
+ n = x1.size()[0] # batch_size -> n
329
+ # random deformation + rotation
330
+ # if hyp_parameters["ndims"]>2:
331
+ # if np.random.uniform(0,1)<0.6:
332
+ # x1 = utils.random_resample(x1, deform_scale=0)
333
+ # y1 = utils.random_resample(y1, deform_scale=0)
334
+ x1 = transformer(x1)
335
+ y1 = transformer(y1)
336
+ [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
337
+ if hyp_parameters['noise_scale']>0:
338
+ x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
339
+ x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
340
+ y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
341
+ y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
342
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
343
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
344
+ hyp_parameters["device"]
345
+ ) # pick up a seq of rand number from 0 to 'timestep'
346
+
347
+
348
+ scale_regist = np.random.uniform(0.6,1.)
349
+ T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
350
+ # scale_regist = np.random.uniform(0.4,1.)
351
+ # T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
352
+ # scale_regist = np.random.uniform(0.6,1.)
353
+ # init_T = int(hyp_parameters["timesteps"] * scale_regist)
354
+ # T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
355
+
356
+ T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
357
+
358
+ # print('T_regist:', T_regist)
359
+ # proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
360
+ proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'none'])
361
+ # proc_type = random.choice(['project'])
362
+ y1, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
363
+ msk_tgt = msk_tgt + MSK_EPS
364
+ [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
365
+ loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
366
+ loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
367
+ loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
368
+
369
+ loss_regist = 0
370
+ loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
371
+ loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
372
+ loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
373
+ # print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
374
+ # print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
375
+
376
+ # >> JZ: print nan in x0
377
+ if torch.isnan(x0).any():
378
+ print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
379
+ # >> JZ: print loss of ddf
380
+ if loss_ddf1>0.001:
381
+ print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
382
+
383
+ loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
384
+ optimizer.zero_grad()
385
+ loss_regist.backward()
386
+
387
+ # # Print gradients for each parameter
388
+ # for name, param in Deformddpm.named_parameters():
389
+ # if param.grad is not None:
390
+ # print(f"Gradient for {name}: {param.grad.norm()}")
391
+ # else:
392
+ # print(f"Gradient for {name}: None")
393
+
394
+ torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
395
+ optimizer.step()
396
+
397
+ epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
398
+ epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
399
+ epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
400
+ epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
401
+
402
+
403
+
404
+ # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
405
+
406
+ # break # FOR TESTING
407
+ # else:
408
+ # print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
409
+ # pass
410
+
411
+ if 1:
412
+ # if gpu_id == 0:
413
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
414
+ print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
415
+
416
+ # # LR schedular step ----- YHM
417
+ # scheduler.step()
418
+
419
+ if 0 == epoch % epoch_per_save:
420
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
421
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
422
+ # break # FOR TESTING
423
+ if not use_distributed:
424
+ print(f"saved in {save_dir}")
425
+ # torch.save(Deformddpm.state_dict(), save_dir)
426
+ torch.save({
427
+ 'model_state_dict': Deformddpm.state_dict(),
428
+ 'optimizer_state_dict': optimizer.state_dict(),
429
+ 'epoch': epoch
430
+ }, save_dir)
431
+ elif gpu_id == 0:
432
+ print(f"saved in {save_dir}")
433
+ # torch.save(Deformddpm.module.state_dict(), save_dir)
434
+ torch.save({
435
+ 'model_state_dict': Deformddpm.module.state_dict(),
436
+ 'optimizer_state_dict': optimizer.state_dict(),
437
+ 'epoch': epoch
438
+ }, save_dir)
439
+
440
+ # Resource cleanup at the end of training
441
+ torch.cuda.empty_cache()
442
+ gc.collect()
443
+ if use_distributed and dist.is_initialized():
444
+ dist.destroy_process_group()
445
+
446
+ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
447
+
448
+ if gpu_id == 0:
449
+ # if 0:
450
+ utils.print_memory_usage("Before Loading Model")
451
+ if 1:
452
+ gc.collect()
453
+ torch.cuda.empty_cache()
454
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
455
+ # Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
456
+ checkpoint = torch.load(model_file)
457
+ # checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
458
+ if use_distributed:
459
+ Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
460
+ else:
461
+ Deformddpm.load_state_dict(checkpoint['model_state_dict'])
462
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
463
+ utils.print_memory_usage("After Loading Checkpoint on GPU")
464
+
465
+ if use_distributed:
466
+ # Broadcast model weights from rank 0 to all other GPUs
467
+ dist.barrier()
468
+ for param in Deformddpm.parameters():
469
+ dist.broadcast(param.data, src=0) # Synchronize model across ranks
470
+ dist.barrier()
471
+ for param_group in optimizer.param_groups:
472
+ for param in param_group['params']:
473
+ if param.grad is not None:
474
+ dist.broadcast(param.grad, src=0) # Sync optimizer gradients
475
+
476
+ # initial_epoch = checkpoint['epoch'] + 1
477
+ # get the epoch number from the filename and add 1 to set as initial_epoch
478
+ initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
479
+
480
+ return initial_epoch, Deformddpm, optimizer
481
+
482
+
483
+
484
+ if __name__ == "__main__":
485
+ if use_distributed:
486
+ world_size = torch.cuda.device_count()
487
+ print(f"Distributed GPU number = {world_size}")
488
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
489
+ else:
490
+ main_train(0,1)
OM_train_uncon.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch import nn
5
+ from torchvision.utils import save_image
6
+ from torch.utils.data import DataLoader
7
+ from torch.optim import Adam, SGD
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import Diffusion.losses as losses
12
+ import random
13
+ import glob
14
+ import numpy as np
15
+ import utils
16
+
17
+ from Dataloader.dataloader0 import get_dataloader
18
+
19
+ from Dataloader.dataloader_utils import thresh_img
20
+ import yaml
21
+ import argparse
22
+
23
+ ####################
24
+ import torch.multiprocessing as mp
25
+ from torch.utils.data.distributed import DistributedSampler
26
+ from torch.nn.parallel import DistributedDataParallel as DDP
27
+ from torch.distributed import init_process_group, destroy_process_group
28
+ ###############
29
+ def ddp_setup(rank, world_size):
30
+ """
31
+ Args:
32
+ rank: Unique identifier of each process
33
+ world_size: Total number of processes
34
+ """
35
+ os.environ["MASTER_ADDR"] = "localhost"
36
+ os.environ["MASTER_PORT"] = "12355"
37
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
38
+ torch.cuda.set_device(rank)
39
+
40
+ use_parallel=False
41
+ use_distributed = False
42
+
43
+ EPS = 1e-5
44
+
45
+ parser = argparse.ArgumentParser()
46
+
47
+ # config_file_path = 'Config/config_cmr.yaml'
48
+ parser.add_argument(
49
+ "--config",
50
+ "-C",
51
+ help="Path for the config file",
52
+ type=str,
53
+ default="Config/config_cmr.yaml",
54
+ # default="Config/config_lct.yaml",
55
+ required=False,
56
+ )
57
+ args = parser.parse_args()
58
+ #=======================================================================================================================
59
+
60
+
61
+
62
+ def main_train(rank,world_size):
63
+
64
+ ddp_setup(rank,world_size)
65
+ gpu_id = rank
66
+
67
+ # Load the YAML file into a dictionary
68
+ with open(args.config, 'r') as file:
69
+ hyp_parameters = yaml.safe_load(file)
70
+ print(hyp_parameters)
71
+
72
+
73
+
74
+ # epoch_per_save=10
75
+ epoch_per_save=hyp_parameters['epoch_per_save']
76
+
77
+ data_name=hyp_parameters['data_name']
78
+ net_name = hyp_parameters['net_name']
79
+
80
+ Net=get_net(net_name)
81
+
82
+ suffix_pth=f'_{data_name}_{net_name}.pth'
83
+ model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
84
+ model_dir=model_save_path
85
+ transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
86
+ Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
87
+
88
+ tsfm = torchvision.transforms.Compose([
89
+ torchvision.transforms.ToTensor(),
90
+ ])
91
+
92
+
93
+ dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
94
+ train_loader = DataLoader(
95
+ dataset,
96
+ batch_size=hyp_parameters['batchsize'],
97
+ # shuffle=False,
98
+ shuffle=True,
99
+ drop_last=True,
100
+ )
101
+
102
+
103
+
104
+ Deformddpm = DeformDDPM(
105
+ network=Net(n_steps=hyp_parameters["timesteps"], ndims=hyp_parameters["ndims"], num_input_chn=1),
106
+ n_steps=hyp_parameters["timesteps"],
107
+ image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
108
+ device=hyp_parameters["device"],
109
+ batch_size=hyp_parameters["batchsize"],
110
+ img_pad_mode=hyp_parameters["img_pad_mode"],
111
+ v_scale=hyp_parameters["v_scale"],
112
+ )
113
+
114
+
115
+ ddf_stn = STN(
116
+ img_sz=hyp_parameters["img_size"],
117
+ ndims=hyp_parameters["ndims"],
118
+ # padding_mode="zeros",
119
+ padding_mode=hyp_parameters["padding_mode"],
120
+ device=hyp_parameters["device"],
121
+ )
122
+
123
+ # Deformddpm.to(hyp_parameters["device"])
124
+ # ddf_stn.to(hyp_parameters["device"])
125
+
126
+ # if use_distributed:
127
+ # torch.distributed.init_process_group(backend='nccl')
128
+ # Deformddpm = nn.parallel.DistributedDataParallel(Deformddpm, device_ids=[torch.cuda.current_device()])
129
+ # ddf_stn = nn.parallel.DistributedDataParallel(ddf_stn, device_ids=[torch.cuda.current_device()])
130
+ # elif use_parallel:
131
+ # Deformddpm = nn.DataParallel(Deformddpm)
132
+ # ddf_stn = nn.DataParallel(ddf_stn)
133
+
134
+ Deformddpm.to(rank)
135
+ Deformddpm = DDP(Deformddpm, device_ids=[rank])
136
+ ddf_stn.to(rank)
137
+ # ddf_stn = DDP(ddf_stn, device_ids=[rank])
138
+
139
+
140
+ # mse = nn.MSELoss()
141
+ loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
142
+ loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
143
+ # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
144
+ loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
145
+
146
+ optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
147
+ # hyp_parameters["lr"]=0.00000001
148
+ # # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.95)
149
+ # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
150
+
151
+ # # LR scheduler ----- YHM
152
+ # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
153
+
154
+ # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
155
+
156
+ # check for existing models
157
+ if not os.path.exists(model_dir):
158
+ os.makedirs(model_dir, exist_ok=True)
159
+ model_files = glob.glob(os.path.join(model_dir, "*.pth"))
160
+ model_files.sort()
161
+ print(model_files)
162
+ if model_files:
163
+ # if there are any model files, load the most recent one
164
+ latest_model_file = model_files[-1]
165
+ # Deformddpm.network.load_state_dict(torch.load(latest_model_file))
166
+ if use_parallel:
167
+ Deformddpm.module.load_state_dict(torch.load(latest_model_file), strict=False)
168
+ else:
169
+ Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
170
+ # get the epoch number from the filename and add 1 to set as initial_epoch
171
+ initial_epoch = int(os.path.basename(latest_model_file).split('.')[0][:6]) + 1
172
+ else:
173
+ initial_epoch = 0
174
+ print('len_train_data: ',len(dataset))
175
+ for epoch in range(initial_epoch,hyp_parameters["epoch"]):
176
+
177
+ epoch_loss_tot = 0.0
178
+ epoch_loss_gen_d = 0.0
179
+ epoch_loss_gen_a = 0.0
180
+ epoch_loss_reg = 0.0
181
+ # Set model inside to train model
182
+ Deformddpm.train()
183
+
184
+ for step, batch in enumerate(train_loader):
185
+ # x0, _ = batch
186
+ x0, _, _ = batch
187
+ x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
188
+
189
+ n = x0.size()[0] # batch_size -> n
190
+ x0 = x0.to(hyp_parameters["device"])
191
+ # random deformation + rotation
192
+ if hyp_parameters["ndims"]>2:
193
+ if np.random.uniform(0,1)<0.6:
194
+ x0 = utils.random_resample(x0, deform_scale=0)
195
+ x0 = transformer(x0)
196
+ if hyp_parameters['noise_scale']>0:
197
+ x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
198
+ x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
199
+
200
+ # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
201
+ t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
202
+ hyp_parameters["device"]
203
+ ) # pick up a seq of rand number from 0 to 'timestep'
204
+
205
+
206
+ if use_parallel:
207
+ # # noisy_imgs, dvf_I = ddf_enc(img= x0, t)
208
+ # noisy_imgs, dvf_I,_ = Deformddpm.module.diffuse(x0, t)
209
+ # pre_dvf_I = Deformddpm.backward(noisy_imgs, t.reshape(16, -1))
210
+ pre_dvf_I, _ = Deformddpm.module(x0, t)
211
+ else:
212
+ # # noisy_imgs, dvf_I = ddf_enc(img= x0, t)
213
+ # noisy_imgs, dvf_I,_ = Deformddpm.diffuse(x0, t)
214
+ # pre_dvf_I = Deformddpm.backward(noisy_imgs, t.reshape(16, -1))
215
+ pre_dvf_I,dvf_I = Deformddpm(x0, t)
216
+
217
+ loss_tot=0
218
+
219
+ loss_ddf = loss_reg(pre_dvf_I)
220
+ trm_pred = ddf_stn(pre_dvf_I, dvf_I)
221
+ loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
222
+ loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
223
+
224
+ loss_tot += 1.0 * loss_gen_d + 1.0 * loss_gen_a
225
+ loss_tot +=10 * loss_ddf
226
+ optimizer.zero_grad()
227
+ loss_tot.backward()
228
+ optimizer.step()
229
+
230
+ epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
231
+ epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
232
+ epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
233
+ epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
234
+ # print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
235
+
236
+ if gpu_id == 0:
237
+ print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
238
+
239
+ # # LR schedular step ----- YHM
240
+ # scheduler.step()
241
+
242
+ if 0 == epoch % epoch_per_save:
243
+ save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
244
+ if os.path.exists(model_save_path):
245
+ print(f"saved in {save_dir}")
246
+ else:
247
+ os.makedirs(os.path.dirname(model_save_path))
248
+ # break # FOR TESTING
249
+ if use_parallel:
250
+ torch.save(Deformddpm.module.state_dict(), save_dir)
251
+ elif gpu_id == 0:
252
+ torch.save(Deformddpm.module.state_dict(), save_dir)
253
+
254
+
255
+ if __name__ == "__main__":
256
+ world_size = torch.cuda.device_count()
257
+ print(f"world size = {world_size}")
258
+ mp.spawn(main_train,args = (world_size,),nprocs = world_size)
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OmniMorph: Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on conditional Deformation-Recovery Diffusion Model
2
+
3
+ ## Environment
4
+ ```
5
+ conda activate torch
6
+ conda deactivate
7
+ ```
8
+ source /home/data/Github/OmniMorph/ominenv/bin/activate
9
+
10
+ ## Masking CUDA
11
+ CUDA_VISIBLE_DEVICES=0,1,3 python ...
bash_infer.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ source /home/data/jzheng/Adaptive_Motion_Generator-master/pipenv/bin/activate
3
+
4
+ export CUDA_VISIBLE_DEVICES=2
5
+ # export CUDA_VISIBLE_DEVICES=0
6
+
7
+ # python -u OM_aug.py -C Config/config_om.yaml
8
+ # python -u OM_reg.py -C Config/config_om.yaml
9
+ nohup python -u OM_aug.py -C Config/config_om.yaml > aug_log.txt 2>&1 &
bash_train.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ source /home/data/jzheng/Adaptive_Motion_Generator-master/pipenv/bin/activate
3
+
4
+ export CUDA_VISIBLE_DEVICES=3
5
+ # export CUDA_VISIBLE_DEVICES=1,3
6
+ # export CUDA_VISIBLE_DEVICES=1,2,3
7
+ # # python -u OM_train.py -C Config/config_lct.yaml
8
+ # nohup python -u OM_train.py -C Config/config_lct.yaml > train_log.txt 2>&1 &
9
+
10
+ # python -u OM_train_2modes.py -C Config/config_om.yaml
11
+ nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &
12
+ # nohup python -u OM_train.py -C Config/config_om.yaml > train_log.txt 2>&1 &
dataloader_tester.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch import nn
5
+ from torchvision.utils import save_image
6
+ from torch.utils.data import DataLoader
7
+ from torch.optim import Adam, SGD
8
+ from Diffusion.diffuser import DeformDDPM
9
+ from Diffusion.networks import get_net, STN
10
+ from torchvision.transforms import Lambda
11
+ import Diffusion.losses as losses
12
+ import random
13
+ import glob
14
+ import numpy as np
15
+ import utils
16
+
17
+ from Dataloader.dataloader0 import get_dataloader
18
+ from Dataloader.dataLoader import *
19
+ from Dataloader.dataloader_utils import thresh_img
20
+ import yaml
21
+ import argparse
22
+
23
+ tsfm = torchvision.transforms.Compose(
24
+ [
25
+ torchvision.transforms.ToTensor(),
26
+ ]
27
+ )
28
+ Data_Loader=get_dataloader(data_name = 'lct', mode='train')
29
+
30
+ dataset = Data_Loader(
31
+ target_res=[128] * 3,
32
+ transforms=None,
33
+ noise_scale=4.0e-05,
34
+ )
35
+ train_loader = DataLoader(
36
+ dataset,
37
+ batch_size=32,
38
+ # shuffle=False,
39
+ shuffle=True,
40
+ drop_last=True,
41
+ )
42
+
43
+
44
+ dataset2 = OminiDataset_v1(transform=None)
45
+ train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True)
46
+
47
+
48
+ dataset = OminiDataset_paired(transform=None, ROIs = ['leg'])
49
+ train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
50
+ # print(dataset.get_all_ROI())
51
+ # print(dataset.getitem())
52
+ # print(dataset.get_ALLdata())
53
+ # print(dataset.getitem(idx=11))
54
+ # exit()
55
+
56
+
57
+
58
+ for i, batch in enumerate(train_loader):
59
+ x0, x1 = batch
60
+ print(x0.shape,x1.shape)
61
+ print(x0.dtype,x1.dtype)
62
+ print(x0.min(),x0.max())
63
+ break
64
+ exit()
65
+
requirements.txt ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ certifi==2022.12.7
2
+ charset-normalizer==2.1.1
3
+ contourpy==1.1.1
4
+ cycler==0.12.1
5
+ einops==0.3.2
6
+ elasticdeform==0.5.0
7
+ filelock==3.16.1
8
+ fonttools==4.49.0
9
+ fsspec==2025.3.0
10
+ hausdorff==0.2.6
11
+ huggingface-hub==0.29.3
12
+ idna==3.4
13
+ imageio==2.34.0
14
+ importlib_metadata==7.1.0
15
+ importlib_resources==6.1.2
16
+ joblib==1.4.0
17
+ kiwisolver==1.4.5
18
+ lazy_loader==0.3
19
+ llvmlite==0.41.1
20
+ matplotlib==3.7.5
21
+ networkx==3.1
22
+ nibabel==5.1.0
23
+ nptyping==2.5.0
24
+ numba==0.58.1
25
+ numpy==1.24.1
26
+ opencv-python==4.9.0.80
27
+ packaging==23.2
28
+ pandas==2.0.3
29
+ pillow==10.2.0
30
+ pydicom==2.4.4
31
+ pynrrd==1.0.0
32
+ pyparsing==3.1.1
33
+ pyquaternion==0.9.9
34
+ python-dateutil==2.8.2
35
+ pytz==2025.2
36
+ PyWavelets==1.4.1
37
+ PyYAML==6.0.2
38
+ regex==2024.11.6
39
+ requests==2.28.1
40
+ safetensors==0.5.3
41
+ scikit-image==0.21.0
42
+ scikit-learn==1.3.2
43
+ scipy==1.9.3
44
+ SimpleITK==2.3.1
45
+ six==1.16.0
46
+ threadpoolctl==3.5.0
47
+ tifffile==2023.7.10
48
+ tokenizers==0.20.3
49
+ torch==1.12.1+cu113
50
+ torchaudio==0.12.1+cu113
51
+ torchvision==0.13.1+cu113
52
+ tqdm==4.66.2
53
+ transformers==4.46.3
54
+ typing_extensions==4.8.0
55
+ tzdata==2025.2
56
+ urllib3==1.26.13
57
+ zipp==3.17.0
utils.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ from torch import nn, optim
5
+ from torch.autograd.variable import Variable
6
+ from torchvision import transforms, datasets
7
+ from torchvision.utils import save_image
8
+ import torch.nn.functional as F
9
+ import scipy.ndimage as spimg
10
+ import pyquaternion as quater
11
+ import random
12
+ import numpy as np
13
+ import math
14
+ from typing import Optional, Tuple, List
15
+ import nibabel as nib
16
+ # from data_loader.acdc_dataloader import acdc_gan
17
+
18
+ # from Adaptive_Motion_Generator.Dataloader.Archive.acdc_dataloader import *
19
+
20
+ def get_barcode(index=[],header=['Patient','Slice','AugImg','NoiseStep'],digit=[4,6,4,4],split='_'):
21
+ # Patient0001_Slice0001_NosieImg0001_NoiseStep0070
22
+ barcode_str=''
23
+ header=header.copy()
24
+ digit=digit.copy()
25
+ if len(index)<3:
26
+ header[2] = 'ORG'
27
+ header[3] = 'NA'
28
+ digit[2] = 0
29
+ digit[3] = 0
30
+ index +=['','']
31
+
32
+ for id, h in enumerate(header):
33
+ barcode_str+=h+str(index[id]).zfill(digit[id])+split
34
+ return barcode_str[:-1]
35
+
36
+ class RandomResizedCrop3D(nn.Module):
37
+ """Crop a random portion of a 3D volume and resize it to a given size.
38
+
39
+ Args:
40
+ size (tuple of int): Expected output size of the crop, for each dimension (D, H, W).
41
+ scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
42
+ before resizing. The scale is defined with respect to the volume of the original image.
43
+ ratio (tuple of float): Lower and upper bounds for the random aspect ratio of the crop, before resizing.
44
+ interpolation (str): Desired interpolation mode ('trilinear' or 'nearest').
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ size: Tuple[int, int, int],
50
+ scale=(0.6, 1.0),
51
+ ratio=(0.5, 1.5),
52
+ interpolation='trilinear'
53
+ ):
54
+ super().__init__()
55
+ self.size = size
56
+ self.scale = scale
57
+ self.ratio = ratio
58
+ self.interpolation = interpolation
59
+
60
+ @staticmethod
61
+ def get_params(img: torch.Tensor, rand_scale: float, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int, int, int]:
62
+ """Get parameters for `crop` for a random sized crop.
63
+
64
+ Args:
65
+ img (Tensor): Input image.
66
+ scale (list): Range of scale of the origin size cropped.
67
+ ratio (list): Range of aspect ratio of the origin aspect ratio cropped.
68
+
69
+ Returns:
70
+ tuple: params (i, j, k, d, h, w) to be passed to `crop` for a random sized crop.
71
+ """
72
+ img_sz = np.array(list(img.size())[2:])
73
+ crop_sz = (img_sz * rand_scale).astype(np.int32) #[int(s*rand_scale) for s in img_sz]
74
+ start_id = np.random.randint(0, img_sz - crop_sz + 1, size=(img_sz.size,))
75
+ return start_id.tolist()+crop_sz.tolist()
76
+
77
+ # volume = depth * height * width
78
+ #
79
+ # log_ratio = torch.log(torch.tensor(ratio))
80
+ # for _ in range(10):
81
+ # target_volume = volume * torch.empty(1).uniform_(*scale).item()
82
+ # aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
83
+ #
84
+ # w = int(round(math.sqrt(target_volume * aspect_ratio)))
85
+ # h = int(round(math.sqrt(target_volume / aspect_ratio)))
86
+ # d = int(round(math.sqrt(target_volume / (w * h))))
87
+ #
88
+ # if 0 < w <= width and 0 < h <= height and 0 < d <= depth:
89
+ # i = torch.randint(0, depth - d + 1, size=(1,)).item()
90
+ # j = torch.randint(0, height - h + 1, size=(1,)).item()
91
+ # k = torch.randint(0, width - w + 1, size=(1,)).item()
92
+ # return i, j, k, d, h, w
93
+ #
94
+ # # Fallback to central crop
95
+ # return (depth - d) // 2, (height - h) // 2, (width - w) // 2, d, h, w
96
+
97
+ def forward(self, img: torch.Tensor) -> torch.Tensor:
98
+ """Apply the RandomResizedCrop transformation.
99
+
100
+ Args:
101
+ img (Tensor): Input 3D image.
102
+
103
+ Returns:
104
+ Tensor: Cropped and resized image.
105
+ """
106
+ rand_scale = np.random.uniform(self.scale[0], self.scale[1])
107
+ [i, j, k, d, h, w] = self.get_params(img,rand_scale, self.scale, self.ratio)
108
+ # print(i, j, k, d, h, w)
109
+ img_cropped = img[:, :, i:i + d, j:j + h, k:k + w]
110
+ # print(img_cropped.shape)
111
+ img_resized = F.interpolate(img_cropped, size=self.size, mode=self.interpolation,
112
+ align_corners=False if self.interpolation == 'trilinear' else None)
113
+ return img_resized#.squeeze(0)
114
+
115
+ def __repr__(self) -> str:
116
+ return f"{self.__class__.__name__}(size={self.size}, scale={self.scale}, ratio={self.ratio}, interpolation={self.interpolation})"
117
+
118
+ def random_permute(X, select_dims=[-1,-2],include_flip=True):
119
+ axes=list(range(X[0].ndim))
120
+ selected_axes = [axes[i] for i in select_dims]
121
+ random.shuffle(selected_axes)
122
+ for i, dim in enumerate(select_dims):
123
+ axes[dim] = selected_axes[i]
124
+ if include_flip and random.choice([True,False]):
125
+ # X = [np.flip(x, axis=dim) for x in X]
126
+ X = [torch.flip(x, [dim]) for x in X]
127
+ # return [np.transpose(x,axes) for x in X]
128
+ return [x.permute(axes) for x in X]
129
+
130
+ # def thresh_img(img,thresh = None,EPS = 10**-7):
131
+ # threshold0 = np.random.uniform(thresh[0], thresh[1])
132
+ # threshold1 = np.random.uniform(thresh[0], thresh[1])
133
+ # scale =
134
+ # if threshold is not None:
135
+ # # img=img-threshold
136
+ # # img=np.where(img>=0,img,0)
137
+ # # img = np.maximum(img-threshold,0)
138
+ # img = torch.maximum(img - threshold,torch.tensor(0.))
139
+ # # return (img - img.min()) / (img.max() - img.min() + EPS)
140
+ # return img
141
+
142
+ def get_transformer(degrees=180,translate=0.125,ndims=2,prob=0.8,fill=0.,img_sz=None):
143
+ prob_crop=0. if img_sz==None else 0.8
144
+ # prob_crop=0. if len(img_sz)==2 else 0.8
145
+
146
+ if img_sz==None or len(img_sz)==2:
147
+ return torchvision.transforms.Compose([
148
+ torchvision.transforms.RandomApply([
149
+ torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims, fill=fill,
150
+ interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
151
+ ],prob),
152
+ # torchvision.transforms.RandomApply([
153
+ # torchvision.transforms.RandomResizedCrop(size=img_sz),
154
+ # ], prob_crop),
155
+ torchvision.transforms.RandomVerticalFlip(p=0.5),
156
+ torchvision.transforms.RandomAutocontrast(p=0.5),
157
+ ])
158
+ else:
159
+ return torchvision.transforms.Compose([
160
+ torchvision.transforms.RandomApply([
161
+ torchvision.transforms.RandomResizedCrop(size=img_sz) if len(img_sz) == 2 else RandomResizedCrop3D(
162
+ size=img_sz),
163
+ ], prob_crop),
164
+ ])
165
+
166
+
167
+ def get_random_affine_transformer(degrees=180,translate=0.125,ndims=2):
168
+ return torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
169
+
170
+ def channel_merge_acdc(img):
171
+ # input: a torch tensor (C,H,W)
172
+ ch = img.shape[0]
173
+ output = np.zeros((img.shape[1], img.shape[2]))
174
+ # output[img[2,:,:] == 1] = 1
175
+ for i in range(ch):
176
+ output= output + img[i]
177
+ return output
178
+
179
+ def img_crop(img, crop_rate=2, img_sz=[256,256]):
180
+ ndims=len(img_sz)
181
+ crop = [np.random.randint(0.*imgs, 1. * imgs)//crop_rate for imgs in img_sz]
182
+ crop = [crop, [1 * imgs//crop_rate - c for imgs, c in zip(img_sz, crop)]]
183
+ if ndims==2:
184
+ return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]: img_sz[1] - crop[1][1]]
185
+ else:
186
+ return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]:img_sz[1] - crop[1][1], crop[0][2]: img_sz[2] - crop[1][2]]
187
+
188
+
189
+ def boundary_limit(sample_coords0, max_sz, plus=0., minus=1.):
190
+ sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
191
+ # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
192
+ return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) for x, sz in
193
+ zip(sample_coords, max_sz)], 1)
194
+
195
+
196
+ def resample(vol, ddf, ref=None, img_sz=None,max_sz=[128,128],ndims=2):
197
+ device = vol.device
198
+ img_sz = vol.size()[2:]
199
+ ndims=len(img_sz)
200
+ if ndims==2:
201
+ [h,w]=img_sz
202
+ img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=device), [1, 1, 1, ndims])
203
+ ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0), [1, ndims,h, w ])
204
+ elif ndims==3:
205
+ [h, w, d] = img_sz
206
+ img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2., (d-1)/2], device=device), [1, 1, 1, 1, ndims])
207
+ ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w), torch.arange(end=d)]), 0), [1, ndims,h, w, d])
208
+ # ref_grid.to(device)
209
+ # img_shape.to(device)
210
+ # ddf.to(device)
211
+ # ref = self.ref_grid if ref is None else ref
212
+ # img_sz = self.img_sz if img_sz is None else img_sz
213
+ resample_mode = 'bilinear'
214
+ # padding_mode = "border"
215
+ padding_mode = "zeros"
216
+
217
+ # img_sz = np.reshape(img_sz, [1] *(ndims+1)+[ndims])
218
+ # if ndims==2:
219
+ if True:
220
+ re=[0]+list(range(2,ndims+2))+[1]
221
+ # re=list(range(ndims+2))
222
+ # print((torch.flip((ddf.to(device) + ref_grid.permute(re))/ img_shape - 1, dims=[-1])).tolist())
223
+ return F.grid_sample(vol, torch.flip((ddf + ref_grid.permute(re).to(device))/ img_shape - 1, dims=[-1]).type(torch.float32).to(device), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
224
+ #
225
+ # return F.grid_sample(vol, torch.flip(
226
+ # torch.permute(ddf * torch.Tensor(np.reshape(np.array(max_sz), [1, 1, 1, ndims])) + ref_grid,
227
+ # [0, 2, 3, 1]) / img_shape - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
228
+ # align_corners=True)
229
+
230
+ def random_resample(vol,deform_scale=32.):
231
+ vol_size=vol.size()
232
+ device=vol.device
233
+ ndims = len(vol_size)-2
234
+ img_size=[s for s in vol_size[2:]]
235
+ if ndims==2:
236
+ img_size=img_size+[16]
237
+ # ddf,_,_=random_ddf(vol_size[0],img_size)
238
+ _,_,ddf=random_ddf(vol_size[0],img_size,ndims=ndims,range_gauss=deform_scale)
239
+ ddf=Variable(torch.tensor(ddf,dtype=torch.float32)).to(device)
240
+ if ndims==2:
241
+ return resample(vol,ddf[...,8,:ndims])
242
+ else:
243
+ return resample(vol, ddf[..., :ndims])
244
+
245
+ def get_random_deformed_mask(msk_shape, deform_scale=32.,apply_possibility=0.75):
246
+ msk = torch.ones([1, 1]+list(msk_shape),dtype=torch.float32)
247
+ if random.uniform(0,1) < apply_possibility:
248
+ return random_resample(msk, deform_scale=deform_scale)
249
+ else:
250
+ return msk
251
+
252
+ # grid option
253
+ def get_tranf_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],transl=[[0,0,0]]):
254
+ return np.concatenate([get_rot_mat(grid_size, vec=vec, ang=ang),transl],-1)
255
+
256
+
257
+ def get_rot_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],ndims=3):
258
+ vec = np.array(vec)
259
+ ang = np.array(ang)
260
+ batch_num = ang.shape[0]
261
+ return np.reshape(vecang2rotmats(vec, ang), [batch_num] + [ndims*(ndims)])
262
+
263
+ def random_mat(batch_sz, img_sz, num_class=2,pn_spline=20, pn_gauss=10, range_spline=2., range_gauss=48, spread_range=[5., 24.],
264
+ transl_range=32., rot_range=np.pi / 2):
265
+ scale=4
266
+ ndims=3
267
+ vec=np.reshape(np.random.uniform(-1., 1., [batch_sz,1, ndims])+np.random.uniform(-.1, .1, [batch_sz,num_class, ndims]),[batch_sz*num_class, ndims])
268
+ ang=np.reshape(np.random.uniform(-rot_range, rot_range, [batch_sz,1])+np.random.uniform(-rot_range/scale, rot_range/scale, [batch_sz,num_class]),[batch_sz*num_class])
269
+ transl=np.reshape(np.random.uniform(-transl_range, transl_range, [batch_sz,1,ndims])+np.random.uniform(-transl_range/scale, transl_range/scale, [batch_sz,num_class,ndims]),[batch_sz*num_class,ndims])
270
+ return np.reshape(np.concatenate([get_rot_mat(img_sz, vec=vec, ang=ang),transl],-1),[batch_sz,num_class,4,3])
271
+
272
+ # return np.reshape(get_tranf_mat(img_sz, vec=np.random.uniform(-1., 1., [batch_sz*num_class, 3]), ang=np.random.uniform(-rot_range, rot_range, [batch_sz*num_class]),transl=np.random.uniform(-transl_range, transl_range, [batch_sz*num_class,3])),[batch_sz,num_class,4,3])
273
+
274
+ def random_ddf(batch_sz, img_sz, pn_spline=20, pn_gauss=10, range_spline=1., range_gauss=16., spread_range=[16., 64.],
275
+ transl_range=0., rot_range=np.pi / 1,ndims=3):
276
+ rand_ang=np.random.uniform(-rot_range, rot_range, [batch_sz])
277
+ # rand_ang = np.random.randint(-4, 4, [batch_sz])*rot_range
278
+
279
+ if ndims==3:
280
+ rot_df = get_rot_ddf(img_sz, vec=np.random.uniform(-1., 1., [batch_sz, 3]),
281
+ ang=rand_ang)
282
+ else:
283
+ rot_df = get_rot_ddf(img_sz, vec=np.concatenate([np.zeros([batch_sz, 2]),np.ones([batch_sz, 1])],-1),
284
+ ang=rand_ang)
285
+ ndims = 3
286
+ # rot_df = +np.random.uniform(-1., 1., [batch_sz, ndims,ndims])
287
+ # ddf0=np.stack([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range)\
288
+ # +generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline)\
289
+ # +np.random.uniform(-transl_range,transl_range,[3]) for i in range(batch_sz)],axis=0)\
290
+ # +rot_df
291
+ if range_gauss>0:
292
+ ddf0 = np.tile([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range) \
293
+ # + generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline) \
294
+ + np.random.uniform(-transl_range, transl_range, [ndims])], [batch_sz, 1, 1, 1, 1]) \
295
+ + rot_df
296
+ else:
297
+ ddf0 = rot_df
298
+
299
+ def boundary_replicate(sample_coords, input_size, padding=5):
300
+ return np.stack(
301
+ [np.maximum(np.minimum(sample_coords[..., i], input_size[i] - 1 + padding), 0 - padding) for i in
302
+ range(len(input_size))], axis=-1), \
303
+ np.prod([((sample_coords[..., i] < input_size[i]) * (sample_coords[..., i] >= 0)) for i in
304
+ range(len(input_size))], axis=0)
305
+
306
+ ref = get_reference_grid(img_sz)
307
+ cf1, ind = boundary_replicate(ddf0 + ref, img_sz)
308
+ return cf1 - ref, np.expand_dims(ind, -1), rot_df
309
+
310
+
311
+ def generate_random_gaussian_ddf(img_sz, pn=30, range_sz=5, spread_std=[0.1, 1.]):
312
+ x = np.floor(np.random.uniform(range_sz / 2., img_sz[0] - range_sz / 2., [1, pn])).astype('int')
313
+ y = np.floor(np.random.uniform(range_sz / 2., img_sz[1] - range_sz / 2., [1, pn])).astype('int')
314
+ z = np.floor(np.random.uniform(range_sz / 2., img_sz[2] - range_sz / 2., [1, pn])).astype('int')
315
+
316
+ odf = np.random.uniform(-range_sz, range_sz, [pn, 3])
317
+ vol = np.zeros([img_sz[0], img_sz[1], img_sz[2], 3])
318
+ vol[x, y, z] = odf
319
+
320
+ return spimg.gaussian_filter(vol, np.random.uniform(spread_std[0], spread_std[1]))
321
+
322
+
323
+ def get_rot_ddf(grid_size, vec=[[0., 0., 1.]], ang=[[0.]]):
324
+ vec = np.array(vec)
325
+ ang = np.array(ang)
326
+ batch_num = ang.shape[0]
327
+ ref_grids = get_reference_grid(grid_size,
328
+ bias_scale=1.)
329
+ # a=vecang2rotmats(vec, ang)
330
+ return np.reshape(np.matmul(np.reshape(np.tile(ref_grids, [batch_num, 1, 1, 1, 1]), [batch_num, -1, 3]),
331
+ vecang2rotmats(vec, ang)), [batch_num] + grid_size + [3]) - ref_grids
332
+
333
+
334
+ def get_reference_grid(grid_size, bias_scale=0.):
335
+ return np.stack(np.meshgrid(
336
+ [i for i in range(grid_size[0])],
337
+ [j for j in range(grid_size[1])],
338
+ [k for k in range(grid_size[2])],
339
+ indexing='ij'), axis=-1).astype('float') - bias_scale * (np.array(grid_size) - 1) / 2.
340
+
341
+
342
+ def resample_linear(inputs, ddf=None, sample_coords=None,random_boundary=True):
343
+ if random_boundary:
344
+ random_factor = np.random.uniform(0., 1.)
345
+ min_val = np.min(inputs)
346
+ inputs[:, 0, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, 0, :, :]
347
+ inputs[:, -1, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, -1, :, :]
348
+ inputs[:, :, 0, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, 0, :]
349
+ inputs[:, :, -1, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, -1, :]
350
+ inputs[:, :, :, 0] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, 0]
351
+ inputs[:, :, :, -1] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, -1]
352
+
353
+ input_size = inputs.shape[1:4]
354
+ sample_coords = get_reference_grid(input_size) + ddf if sample_coords is None else sample_coords
355
+ spatial_rank = 3 # inputs.ndim - 2
356
+ xy = [sample_coords[..., i] for i in
357
+ range(sample_coords.shape[-1])] # tf.unstack(sample_coords, axis=len(sample_coords.shape)-1)
358
+ index_voxel_coords = [np.floor(x) for x in xy]
359
+
360
+ def boundary_replicate(sample_coords0, input_size0, plus=0):
361
+ return np.maximum(np.minimum(sample_coords0, input_size0 - 2 + plus), 0 + plus)
362
+
363
+ def boundary_replicate_float(sample_coords0, input_size0, plus=0.):
364
+ return np.maximum(np.minimum(sample_coords0, input_size0 - 1 + plus), 0 + plus)
365
+
366
+ xy = [boundary_replicate_float(x.astype('float32'), input_size[idx]) for idx, x in enumerate(xy)]
367
+ spatial_coords = [boundary_replicate(x.astype('int32'), input_size[idx])
368
+ for idx, x in enumerate(index_voxel_coords)]
369
+ spatial_coords_plus1 = [boundary_replicate((x + 1).astype('int32'), input_size[idx], 1)
370
+ for idx, x in enumerate(index_voxel_coords)]
371
+
372
+ weight = [np.expand_dims(x - i.astype('float32'), -1) for x, i in zip(xy, spatial_coords)]
373
+ weight_c = [np.expand_dims(i.astype('float32') - x, -1) for x, i in zip(xy, spatial_coords_plus1)]
374
+
375
+ sz = list(spatial_coords[0].shape)
376
+ batch_coords = np.tile(np.reshape(range(sz[0]), [sz[0]] + [1] * (len(sz) - 1)), [1] + sz[1:])
377
+ sc = (spatial_coords, spatial_coords_plus1)
378
+ binary_codes = [[int(c) for c in format(i, '0%ib' % spatial_rank)] for i in range(2 ** spatial_rank)]
379
+
380
+ make_sample = lambda bc: inputs[batch_coords, sc[bc[0]][0], sc[bc[1]][1], sc[bc[2]][
381
+ 2], ...] # tf.gather_nd(inputs, np.stack([batch_coords] + [sc[c][i] for i, c in enumerate(bc)], -1))
382
+ samples = [make_sample(bc) for bc in binary_codes]
383
+
384
+ def pyramid_combination(samples0, weight0, weight_c0):
385
+ if len(weight0) == 1:
386
+ return samples0[0] * weight_c0[0] + samples0[1] * weight0[0]
387
+ else:
388
+ return pyramid_combination(samples0[::2], weight0[:-1], weight_c0[:-1]) * weight_c0[-1] + \
389
+ pyramid_combination(samples0[1::2], weight0[:-1], weight_c0[:-1]) * weight0[-1]
390
+
391
+ return pyramid_combination(samples, weight, weight_c)
392
+
393
+
394
+ def vecang2rotmats(vec, ang):
395
+ return np.stack([np.reshape(vecang2rotmat(vec[i, ...], ang[i, ...]), [3, 3]) for i in range(len(vec))], 0)
396
+
397
+
398
+ def vecang2rotmat(vec, ang):
399
+ q = quater.Quaternion(axis=vec, angle=ang)
400
+ return q.rotation_matrix
401
+
402
+
403
+ def images_to_vectors(images):
404
+ return images.view(images.size(0), 16384).to(device)
405
+
406
+ def vectors_to_images(vectors):
407
+ return vectors.view(vectors.size(0), 1, 128, 128).to(device)
408
+
409
+ def noise(size):
410
+ n = Variable(torch.randn(size, 100)).to(device)
411
+ return n
412
+
413
+ def ones_target(size):
414
+ data = Variable(torch.ones(size, 1)).to(device)
415
+ return data
416
+
417
+ def zeros_target(size):
418
+ data = Variable(torch.zeros(size, 1)).to(device)
419
+ return data
420
+
421
+
422
+ def eval_detJ_lab(disp=None,vol1=None,vol2=None,thresh=0.5):
423
+ ndims=disp.ndim-2
424
+ if vol1 ==None or thresh==None:
425
+ label=1
426
+ else:
427
+ label=vol1>thresh
428
+ label=label*(spimg.laplace(label) < 0.1)
429
+ rescale_factor=2
430
+ label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
431
+
432
+ # disp = disp.permute([0, *range(2,ndims+2), 1])
433
+ # print(disp.shape)
434
+ disp = np.transpose(disp, [0, *range(2,ndims+2), 1])
435
+ # Jacob=np.stack(np.gradient(disp,axis=[-4,-3,-2]),-1)
436
+ Jacob=np.stack(np.gradient(disp,axis=[*range(1,ndims+1)]),-1)
437
+ for ii in range(ndims):
438
+ Jacob[..., ii, ii] = Jacob[..., ii, ii] + 1
439
+ # Jacob[..., 0, 0] = Jacob[..., 0, 0] + 1
440
+ # Jacob[..., 1, 1] = Jacob[..., 1, 1] + 1
441
+ # Jacob[..., 2, 2] = Jacob[..., 2, 2] + 1
442
+ return np.sum((np.linalg.det(Jacob)<0)*label)
443
+
444
+ def eval_def_mag(disp=None,vol1=None,vol2=None,thresh=0.5):
445
+ ndims=3
446
+ # if vol1 ==None or thresh==None:
447
+ # label=1
448
+ # else:
449
+ # label=vol1>thresh
450
+ # label=label*(spimg.laplace(label) < 0.1)
451
+ # rescale_factor=2
452
+ # label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
453
+ mag=np.sqrt(np.sum(np.square(disp),axis=1))
454
+ sz=mag.shape
455
+ max_mag=np.mean(np.max(np.reshape(mag,[sz[0],-1]),axis=-1))
456
+ avg_mag=np.mean(mag)
457
+ return [avg_mag,max_mag]
458
+
459
+
460
+ def converet_to_nibabel(vol_tensor,ndims=3):
461
+ if isinstance(vol_tensor, np.ndarray):
462
+ vol_np=vol_tensor
463
+ else:
464
+ vol_np=vol_tensor.cpu().numpy()
465
+ vol_np=vol_np.squeeze(0)
466
+ if ndims==3:
467
+ map_eyes = np.eye(4)
468
+ elif ndims==2:
469
+ map_eyes = np.eye(4)
470
+ map_eyes[2,2]=0
471
+
472
+ if vol_np.shape[0]==1:
473
+ vol_np=vol_np.squeeze(0)
474
+ elif vol_np.shape[0]>1:
475
+ # save as 4D volumes
476
+ # print(vol_np.shape)
477
+ vol_np=np.transpose(vol_np,[1,2,3,0])
478
+
479
+ return nib.Nifti1Image(vol_np, affine=map_eyes)
480
+
481
+ def print_memory_usage(tag=""):
482
+ print(f"[{tag}] Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB | Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
483
+
484
+
485
+ if __name__ == "__main__":
486
+ vol_shape=[4,1,64,64]
487
+
488
+ vol=np.random.uniform(-1,1,vol_shape)
489
+ vol=Variable(torch.tensor(vol,dtype=torch.float32))
490
+ vol_res=random_resample(vol)
491
+ vol_crop=img_crop(vol_res)
492
+
493
+ mask = get_random_deformed_mask(vol.shape[2:])
494
+
495
+ print(mask)
496
+
497
+ # print(vol.tolist())
498
+ # print(vol_res.tolist())