Initial upload: OmniMorph codebase
Browse files- .gitattributes +13 -0
- .gitignore +29 -0
- Config/config_cmr.yaml +29 -0
- Config/config_lct.yaml +31 -0
- Config/config_om.yaml +53 -0
- Config/config_om_contrastive.yaml +51 -0
- Dataloader/PSMA-CT_mappings.json +3 -0
- Dataloader/bert_helper.py +258 -0
- Dataloader/dataLoader.py +1473 -0
- Dataloader/dataloader0.py +421 -0
- Dataloader/dataloader_tester.py +39 -0
- Dataloader/dataloader_utils.py +193 -0
- Dataloader/embding_gen.py +149 -0
- Dataloader/nifty_mappings/AbdomenAtlas_mappings.json +3 -0
- Dataloader/nifty_mappings/AbdomenCT1k_mappings.json +3 -0
- Dataloader/nifty_mappings/Brats2019_mappings.json +3 -0
- Dataloader/nifty_mappings/Brats2020_mappings.json +3 -0
- Dataloader/nifty_mappings/Brats2021_mappings.json +3 -0
- Dataloader/nifty_mappings/CIA_mappings.json +3 -0
- Dataloader/nifty_mappings/Kaggle_osic_mappings.json +0 -0
- Dataloader/nifty_mappings/MSD_mappings.json +3 -0
- Dataloader/nifty_mappings/MnMs_mappings.json +0 -0
- Dataloader/nifty_mappings/OASIS_1_mappings.json +3 -0
- Dataloader/nifty_mappings/OASIS_2_mappings.json +3 -0
- Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json +3 -0
- Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json +3 -0
- Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json +3 -0
- Diffusion/__init__.py +8 -0
- Diffusion/diffuser.py +531 -0
- Diffusion/losses.py +534 -0
- Diffusion/losses_ncc0.py +496 -0
- Diffusion/networks.py +1167 -0
- Diffusion/utils_diff.py +477 -0
- LICENSE +201 -0
- OM_aug.py +254 -0
- OM_aug_highres.py +233 -0
- OM_contrastive.py +72 -0
- OM_reg.py +240 -0
- OM_train.py +309 -0
- OM_train_2modes.py +528 -0
- OM_train_3modes.py +490 -0
- OM_train_uncon.py +258 -0
- README.md +11 -0
- bash_infer.sh +9 -0
- bash_train.sh +12 -0
- dataloader_tester.py +65 -0
- requirements.txt +57 -0
- utils.py +498 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Dataloader/PSMA-CT_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Dataloader/nifty_mappings/AbdomenAtlas_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
Dataloader/nifty_mappings/AbdomenCT1k_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Dataloader/nifty_mappings/Brats2019_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
Dataloader/nifty_mappings/Brats2020_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
Dataloader/nifty_mappings/Brats2021_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
Dataloader/nifty_mappings/CIA_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
Dataloader/nifty_mappings/MSD_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
Dataloader/nifty_mappings/OASIS_1_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
Dataloader/nifty_mappings/OASIS_2_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model checkpoints
|
| 2 |
+
Models/
|
| 3 |
+
|
| 4 |
+
# Data files
|
| 5 |
+
Data/
|
| 6 |
+
|
| 7 |
+
# Python cache
|
| 8 |
+
__pycache__/
|
| 9 |
+
|
| 10 |
+
# Virtual environment
|
| 11 |
+
ominenv/
|
| 12 |
+
|
| 13 |
+
# External libraries
|
| 14 |
+
External/
|
| 15 |
+
|
| 16 |
+
# Logs
|
| 17 |
+
Log/
|
| 18 |
+
swanlog/
|
| 19 |
+
train_log.txt
|
| 20 |
+
aug_log.txt
|
| 21 |
+
|
| 22 |
+
# Reference implementation
|
| 23 |
+
def_diff_rec/
|
| 24 |
+
|
| 25 |
+
# IDE
|
| 26 |
+
.vscode/
|
| 27 |
+
|
| 28 |
+
# Misc
|
| 29 |
+
CLAUDE.md
|
Config/config_cmr.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: cmr
|
| 2 |
+
net_name: recresacnet
|
| 3 |
+
ndims: 2
|
| 4 |
+
img_size: 256
|
| 5 |
+
batchsize: 1
|
| 6 |
+
ddf_pad_mode: border
|
| 7 |
+
device: cuda
|
| 8 |
+
img_pad_mode: zeros
|
| 9 |
+
num_input_chn: 1
|
| 10 |
+
padding_mode: zeros
|
| 11 |
+
resample_mode: bicubic
|
| 12 |
+
timesteps: 80
|
| 13 |
+
v_scale: 4.0e-05
|
| 14 |
+
# =========================
|
| 15 |
+
# TRAINING SETTING
|
| 16 |
+
epoch: 10000
|
| 17 |
+
epoch_per_save: 1
|
| 18 |
+
lr: 0.0001
|
| 19 |
+
noise_scale: 0.1
|
| 20 |
+
# =========================
|
| 21 |
+
# AUGMENTATION SETTING
|
| 22 |
+
patients_list: []
|
| 23 |
+
model_id_str: '000000'
|
| 24 |
+
start_noise_step: 48
|
| 25 |
+
noise_step: 2
|
| 26 |
+
aug_coe: 32 # how many times each sample will be augmented
|
| 27 |
+
aug_img_savepath: Data/Aug_data/cmr/img/
|
| 28 |
+
aug_msk_savepath: Data/Aug_data/cmr/msk/
|
| 29 |
+
aug_ddf_savepath: Data/Aug_data/cmr/ddf/
|
Config/config_lct.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: lct
|
| 2 |
+
net_name: recmutattnnet
|
| 3 |
+
# net_name: recresacnet
|
| 4 |
+
ndims: 3
|
| 5 |
+
img_size: 128 #was 128
|
| 6 |
+
batchsize: 2
|
| 7 |
+
ddf_pad_mode: border
|
| 8 |
+
device: cuda
|
| 9 |
+
img_pad_mode: zeros
|
| 10 |
+
num_input_chn: 1
|
| 11 |
+
padding_mode: border
|
| 12 |
+
resample_mode: bilinear
|
| 13 |
+
timesteps: 80
|
| 14 |
+
v_scale: 4.0e-05
|
| 15 |
+
# =========================
|
| 16 |
+
# TRAINING SETTING
|
| 17 |
+
epoch: 10000
|
| 18 |
+
epoch_per_save: 1
|
| 19 |
+
lr: 0.00001
|
| 20 |
+
noise_scale: 0.1
|
| 21 |
+
# =========================
|
| 22 |
+
# AUGMENTATION SETTING
|
| 23 |
+
patients_list: []
|
| 24 |
+
model_id_str: '001157'
|
| 25 |
+
start_noise_step: 64
|
| 26 |
+
noise_step: 1
|
| 27 |
+
aug_coe: 32 # how many times each sample will be augmented
|
| 28 |
+
condition_type: 'project' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
| 29 |
+
aug_img_savepath: Data/Aug_data/lct/img/
|
| 30 |
+
aug_msk_savepath: Data/Aug_data/lct/msk/
|
| 31 |
+
aug_ddf_savepath: Data/Aug_data/lct/ddf/
|
Config/config_om.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: all
|
| 2 |
+
# net_name: recresacnet
|
| 3 |
+
net_name: recmutattnnet
|
| 4 |
+
# net_name: recmutattnnet1
|
| 5 |
+
# net_name: defrecmutattnnet
|
| 6 |
+
ndims: 3
|
| 7 |
+
img_size: 128
|
| 8 |
+
batchsize: 2
|
| 9 |
+
ddf_pad_mode: border
|
| 10 |
+
device: cuda
|
| 11 |
+
img_pad_mode: zeros
|
| 12 |
+
num_input_chn: 1
|
| 13 |
+
padding_mode: border
|
| 14 |
+
resample_mode: bilinear
|
| 15 |
+
timesteps: 80
|
| 16 |
+
v_scale: 5.0e-05
|
| 17 |
+
# =========================
|
| 18 |
+
# TRAINING SETTING
|
| 19 |
+
epoch: 10000
|
| 20 |
+
epoch_per_save: 1
|
| 21 |
+
lr: 0.00001
|
| 22 |
+
noise_scale: 0.1
|
| 23 |
+
# =========================
|
| 24 |
+
# AUGMENTATION SETTING
|
| 25 |
+
patients_list: []
|
| 26 |
+
# model_id_str: '000000'
|
| 27 |
+
# model_id_str: '000180' # before registration training
|
| 28 |
+
# model_id_str: '000353' # good augmentation results on msd
|
| 29 |
+
model_id_str: '000354' #
|
| 30 |
+
# model_id_str: '000157'
|
| 31 |
+
# model_id_str: '000171'
|
| 32 |
+
start_noise_step: 48 # starting from which noise step to add noise
|
| 33 |
+
noise_step: 1
|
| 34 |
+
aug_coe: 64 # how many times each sample will be augmented
|
| 35 |
+
# start_noise_step: 56 # starting from which noise step to add noise
|
| 36 |
+
# noise_step: 4
|
| 37 |
+
# aug_coe: 4 # how many times each sample will be augmented
|
| 38 |
+
condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
| 39 |
+
# aug_img_savepath: Data/Aug_data/totseg/img/
|
| 40 |
+
# aug_msk_savepath: Data/Aug_data/totseg/msk/
|
| 41 |
+
# aug_ddf_savepath: Data/Aug_data/totseg/ddf/
|
| 42 |
+
# aug_img_savepath: Data/Aug_data/om/img/
|
| 43 |
+
# aug_msk_savepath: Data/Aug_data/om/msk/
|
| 44 |
+
# aug_ddf_savepath: Data/Aug_data/om/ddf/
|
| 45 |
+
reg_img_savepath: Data/Reg_data/om/img/
|
| 46 |
+
reg_msk_savepath: Data/Reg_data/om/msk/
|
| 47 |
+
reg_ddf_savepath: Data/Reg_data/om/ddf/
|
| 48 |
+
# aug_img_savepath: Data/Aug_data/msd/img/
|
| 49 |
+
# aug_msk_savepath: Data/Aug_data/msd/msk/
|
| 50 |
+
# aug_ddf_savepath: Data/Aug_data/msd/ddf/
|
| 51 |
+
aug_img_savepath: Data/Aug_data/mnms/img/
|
| 52 |
+
aug_msk_savepath: Data/Aug_data/mnms/msk/
|
| 53 |
+
aug_ddf_savepath: Data/Aug_data/mnms/ddf/
|
Config/config_om_contrastive.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: all
|
| 2 |
+
# net_name: recresacnet
|
| 3 |
+
# net_name: recmutattnnet
|
| 4 |
+
net_name: recmutattnnet_contrastive
|
| 5 |
+
# net_name: recmutattnnet1
|
| 6 |
+
# net_name: defrecmutattnnet
|
| 7 |
+
ndims: 3
|
| 8 |
+
img_size: 128
|
| 9 |
+
batchsize: 1 #1 for testing
|
| 10 |
+
ddf_pad_mode: border
|
| 11 |
+
device: cuda
|
| 12 |
+
img_pad_mode: zeros
|
| 13 |
+
num_input_chn: 1
|
| 14 |
+
padding_mode: border
|
| 15 |
+
resample_mode: bilinear
|
| 16 |
+
timesteps: 80
|
| 17 |
+
v_scale: 5.0e-05
|
| 18 |
+
# =========================
|
| 19 |
+
# TRAINING SETTING
|
| 20 |
+
epoch: 10000
|
| 21 |
+
epoch_per_save: 1
|
| 22 |
+
lr: 0.00001
|
| 23 |
+
noise_scale: 0.1
|
| 24 |
+
# =========================
|
| 25 |
+
# AUGMENTATION SETTING
|
| 26 |
+
patients_list: []
|
| 27 |
+
# model_id_str: '000000'
|
| 28 |
+
# model_id_str: '000180' # before registration training
|
| 29 |
+
# model_id_str: '000353' # good augmentation results on msd
|
| 30 |
+
model_id_str: '000354' #
|
| 31 |
+
# model_id_str: '000157'
|
| 32 |
+
# model_id_str: '000171'
|
| 33 |
+
start_noise_step: 48 # starting from which noise step to add noise
|
| 34 |
+
noise_step: 1
|
| 35 |
+
aug_coe: 64 # how many times each sample will be augmented
|
| 36 |
+
# start_noise_step: 56 # starting from which noise step to add noise
|
| 37 |
+
# noise_step: 4
|
| 38 |
+
# aug_coe: 4 # how many times each sample will be augmented
|
| 39 |
+
condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
| 40 |
+
# aug_img_savepath: Data/Aug_data/totseg/img/
|
| 41 |
+
# aug_msk_savepath: Data/Aug_data/totseg/msk/
|
| 42 |
+
# aug_ddf_savepath: Data/Aug_data/totseg/ddf/
|
| 43 |
+
# aug_img_savepath: Data/Aug_data/om/img/
|
| 44 |
+
# aug_msk_savepath: Data/Aug_data/om/msk/
|
| 45 |
+
# aug_ddf_savepath: Data/Aug_data/om/ddf/
|
| 46 |
+
reg_img_savepath: Data/Reg_data/om/img/
|
| 47 |
+
reg_msk_savepath: Data/Reg_data/om/msk/
|
| 48 |
+
reg_ddf_savepath: Data/Reg_data/om/ddf/
|
| 49 |
+
aug_img_savepath: Data/Aug_data/msd/img/
|
| 50 |
+
aug_msk_savepath: Data/Aug_data/msd/msk/
|
| 51 |
+
aug_ddf_savepath: Data/Aug_data/msd/ddf/
|
Dataloader/PSMA-CT_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fbbdc9b4b48688a37c4f828eea2823820a1ee27f954d5987d8cbf3b67d6d9bf
|
| 3 |
+
size 179285490
|
Dataloader/bert_helper.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import math
|
| 4 |
+
from torch.nn import Tanh, BatchNorm1d
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch
|
| 8 |
+
from transformers import BertModel, BertForSequenceClassification
|
| 9 |
+
from transformers import BertTokenizer
|
| 10 |
+
from transformers import AutoTokenizer, AutoModel
|
| 11 |
+
|
| 12 |
+
from torch.utils.data import Dataset as Dataset_n
|
| 13 |
+
from torch.utils.data import DataLoader as DataLoader_n
|
| 14 |
+
from torch.utils.data import WeightedRandomSampler
|
| 15 |
+
|
| 16 |
+
def _freeze_bert(
|
| 17 |
+
bert_model: BertModel, freeze_bert=True, freeze_layer_count=-1
|
| 18 |
+
):
|
| 19 |
+
"""Freeze parameters in BertModel (in place)
|
| 20 |
+
Args:
|
| 21 |
+
bert_model: HuggingFace bert model
|
| 22 |
+
freeze_bert: Bool whether to freeze the bert model
|
| 23 |
+
freeze_layer_count: If freeze_bert, up to what layer to freeze.
|
| 24 |
+
Returns:
|
| 25 |
+
bert_model
|
| 26 |
+
"""
|
| 27 |
+
if freeze_bert:
|
| 28 |
+
# freeze the entire bert model
|
| 29 |
+
for param in bert_model.parameters():
|
| 30 |
+
param.requires_grad = False
|
| 31 |
+
else:
|
| 32 |
+
# freeze the embeddings
|
| 33 |
+
for param in bert_model.embeddings.parameters():
|
| 34 |
+
param.requires_grad = False
|
| 35 |
+
if freeze_layer_count != -1:
|
| 36 |
+
if freeze_layer_count > 0 :
|
| 37 |
+
# freeze layers in bert_model.encoder
|
| 38 |
+
for layer in bert_model.encoder.layer[:freeze_layer_count]:
|
| 39 |
+
for param in layer.parameters():
|
| 40 |
+
param.requires_grad = False
|
| 41 |
+
|
| 42 |
+
if freeze_layer_count < 0 :
|
| 43 |
+
# freeze layers in bert_model.encoder
|
| 44 |
+
for layer in bert_model.encoder.layer[freeze_layer_count:]:
|
| 45 |
+
for param in layer.parameters():
|
| 46 |
+
param.requires_grad = False
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
def get_frozen_embeder(key_word="bert-large-uncased"):
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained(key_word, do_lower_case=False)
|
| 51 |
+
model = AutoModel.from_pretrained(key_word)
|
| 52 |
+
|
| 53 |
+
_freeze_bert(model, freeze_bert=True, freeze_layer_count=None)
|
| 54 |
+
return model, tokenizer
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def str2emb(string, max_words_num=100, embeder=None, tokenizer=None, reduce_method='mean'):
|
| 58 |
+
string = string.lower()
|
| 59 |
+
str_token = tokenizer(string, return_tensors='pt', max_length=max_words_num,
|
| 60 |
+
padding='max_length', truncation=True)
|
| 61 |
+
embeder_output = embeder(**str_token)
|
| 62 |
+
if reduce_method == 'mean':
|
| 63 |
+
embeder_output = torch.mean(embeder_output.last_hidden_state, dim=1)
|
| 64 |
+
elif reduce_method == 'max':
|
| 65 |
+
embeder_output = torch.max(embeder_output.last_hidden_state, dim=1)[0]
|
| 66 |
+
else:
|
| 67 |
+
embeder_output = embeder_output.last_hidden_state
|
| 68 |
+
return embeder_output
|
| 69 |
+
|
| 70 |
+
def get_synonyms_dict(dict_type=None):
|
| 71 |
+
'''
|
| 72 |
+
Get the dictionary of synonyms for the specified dictionary type
|
| 73 |
+
'''
|
| 74 |
+
if dict_type == 'ROI':
|
| 75 |
+
dict_synonyms = {
|
| 76 |
+
'whole-body': ['whole-body', 'whole body', 'wholebody', 'whole body', 'whole-body', 'whole body', 'wholebody','polytrauma','head-neck-thorax-abdomen-pelvis-leg','head-neck-thorax-abdomen-pelvis'],
|
| 77 |
+
'neck-thorax-abdomen-pelvis-leg': ['neck-thorax-abdomen-pelvis-leg','neck-thx-abd-pelvis-leg', 'angiography neck-thx-abd-pelvis-leg', 'neck thorax abdomen pelvis leg', 'neck and thorax and abdomen and pelvis and leg', 'neck, thorax, abdomen, pelvis & leg', 'neck/thorax/abdomen/pelvis/leg', 'neck, thorax, abdomen, pelvis and leg', 'neck thorax abdomen pelvis leg'],
|
| 78 |
+
'neck-thorax-abdomen-pelvis': ['neck-thorax-abdomen-pelvis', 'neck-thx-abd-pelvis', 'neck thorax abdomen pelvis', 'neck and thorax and abdomen and pelvis', 'neck, thorax, abdomen & pelvis', 'neck/thorax/abdomen/pelvis', 'neck, thorax, abdomen and pelvis', 'neck thorax abdomen & pelvis'],
|
| 79 |
+
'thorax-abdomen-pelvis-leg': ['thorax-abdomen-pelvis-leg','thx-abd-pelvis-leg', 'angiography thx-abd-pelvis-leg', 'thorax abdomen pelvis leg', 'thorax and abdomen and pelvis and leg', 'thorax, abdomen, pelvis & leg', 'thorax/abdomen/pelvis/leg', 'thorax, abdomen, pelvis and leg', 'thorax abdomen pelvis leg'],
|
| 80 |
+
'neck-thorax-abdomen': ['neck-thorax-abdomen', 'neck-thorax-abdomen', 'neck thorax abdomen', 'neck and thorax and abdomen', 'neck, thorax, abdomen', 'neck/thorax/abdomen', 'neck, thorax, abdomen', 'neck thorax abdomen'],
|
| 81 |
+
'head-neck-thorax-abdomen': ['head-neck-thorax-abdomen', 'head-neck-thorax-abdomen', 'head neck thorax abdomen', 'head and neck and thorax and abdomen', 'head, neck, thorax, abdomen', 'head/thorax/abdomen', 'head, thorax, abdomen', 'head thorax abdomen'],
|
| 82 |
+
'head-neck-thorax': ['head-neck-thorax', 'head neck thorax', 'head and neck and thorax', 'head, neck, thorax', 'head/thorax', 'head, thorax', 'head thorax'],
|
| 83 |
+
'thorax-abdomen-pelvis': ['thorax-abdomen-pelvis', 'thx-abd-pelvis', 'polytrauma', 'thorax abdomen pelvis', 'thorax and abdomen and pelvis', 'thorax, abdomen & pelvis', 'thorax/abdomen/pelvis', 'thorax, abdomen and pelvis', 'thorax abdomen & pelvis'],
|
| 84 |
+
'abdomen-pelvis-leg': ['abdomen-pelvis-leg', 'angiography abdomen-pelvis-leg', 'abd-pelvis-leg', 'abdomen pelvis leg', 'abdomen and pelvis and leg', 'abdomen, pelvis & leg', 'abdomen/pelvis/leg', 'abdomen, pelvis, leg', 'abdomen pelvis leg'],
|
| 85 |
+
'neck-thorax': ['neck-thorax', 'neck thorax', 'neck and thorax', 'neck, thorax', 'thorax-neck', 'thorax neck', 'thorax and neck', 'thorax, neck','thorax/neck'],
|
| 86 |
+
'thorax-abdomen': ['thorax-abdomen', 'thorax abdomen', 'thorax and abdomen', 'thorax, abdomen', 'aortic valve'],
|
| 87 |
+
'abdomen-pelvis': ['abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis', 'abdomen-pelvis', 'abdomen pelvis', 'abdomen and pelvis', 'abdomen & pelvis', 'abdomen/pelvis'],
|
| 88 |
+
'pelvis-leg': ['pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg', 'pelvis-leg', 'pelvis leg', 'pelvis and leg', 'pelvis, leg', 'pelvis/leg'],
|
| 89 |
+
'head-neck': ['head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck', 'head-neck', 'head neck', 'head and neck', 'head, neck', 'head/neck'],
|
| 90 |
+
'abdomen': ['abdomen', 'abdominal', 'belly', 'stomach', 'tummy', 'gut', 'guts', 'viscera', 'bowels', 'intestines', 'gastrointestinal', 'digestive', 'peritoneum','gastric', 'liver', 'spleen', 'pancreas','kidney','lumbar','renal','hepatic','splenic','pancreatic','intervention'],
|
| 91 |
+
'thorax': ['chest', 'thorax', 'breast', 'lung', 'heart','heart-thorakale aorta', 'heart-thorakale', 'mediastinum', 'pleura', 'bronchus', 'bronchi', 'trachea', 'esophagus', 'diaphragm', 'rib', 'sternum', 'clavicle', 'scapula', 'axilla', 'armpit','breast biopsy','thoracic','mammary','caeiothoracic','mediastinal','pleural','bronchial','bronchial tree','tracheal','esophageal','diaphragmatic','costal','sternal','clavicular','scapular','axillary','axillar','cardiac','pericardial','pericardiac','pericardium'],
|
| 92 |
+
'head': ['head', 'headbasis', 'brain', 'skull', 'face','nose','ear','eye','mouth','jaw','cheek','chin','forehead','temporal','parietal','occipital','frontal','mandible','maxilla','mandibular','maxillary','nasal','orbital','orbita','ocular','auricular','otic','oral','buccal','labial','lingual','palatal'],
|
| 93 |
+
'neck': ['neck', 'throat', 'cervical', 'thyroid', 'trachea', 'larynx', 'pharynx', 'esophagus','pharyngeal','laryngeal','cervical','thyroid','trachea','esophagus','carotid','jugular'],
|
| 94 |
+
'hand': ['hand', 'finger', 'thumb', 'palm', 'wrist', 'knuckle', 'fingernail', 'phalanx', 'metacarpal', 'carpal', 'radius'],
|
| 95 |
+
'arm': ['arm', 'forearm', 'upper arm', 'bicep', 'tricep', 'brachium', 'brachial', 'humerus', 'radius', 'ulna', 'elbow', 'shoulder', 'armpit''clavicle', 'scapula', 'acromion', 'acromioclavicular'],
|
| 96 |
+
'leg': ['leg', 'felsenleg','thigh', 'calf', 'shin', 'knee', 'foot', 'ankle', 'toe', 'heel', 'sole', 'arch', 'instep', 'metatarsal', 'phalanx', 'tibia', 'fibula', 'femur', 'patella', 'kneecap','achilles tendon','achilles'],
|
| 97 |
+
'pelvis': ['pelvis', 'hip', 'groin', 'buttock', 'gluteus', 'gluteal', 'ischium', 'pubis', 'sacrum', 'coccyx', 'acetabulum', 'iliac', 'iliac crest', 'iliac spine', 'iliac wing', 'sacroiliac', 'sacroiliac joint', 'sacroiliac ligament', 'sacroiliac spine', 'ureter', 'bladder', 'urethra', 'prostate', 'testicle', 'ovary', 'uterus',],
|
| 98 |
+
'skeleton': ['skeleton','bone','spine', 'back', 'vertebra', 'sacrum', 'coccyx'],
|
| 99 |
+
}
|
| 100 |
+
elif dict_type == 'Label_tissue':
|
| 101 |
+
dict_synonyms = {
|
| 102 |
+
'liver': ['liver','hepatic'],
|
| 103 |
+
'spleen': ['spleen','splenic'],
|
| 104 |
+
'kidney': ['kidney','renal'],
|
| 105 |
+
'pancreas': ['pancreas','pancreatic'],
|
| 106 |
+
'stomach': ['stomach','gastric'],
|
| 107 |
+
'intestine': ['large intestine', 'small intestine','large bowel','small bowel'],
|
| 108 |
+
'gallbladder': ['gallbladder'],
|
| 109 |
+
'adrenal_gland': ['adrenal_gland','adrenal gland'],
|
| 110 |
+
'bladder': ['bladder'],
|
| 111 |
+
'prostate': ['prostate'],
|
| 112 |
+
'uterus': ['uterus'],
|
| 113 |
+
'ovary': ['ovary'],
|
| 114 |
+
'testicle': ['testicle'],
|
| 115 |
+
'lymph_node': ['lymph_node','lymph node'],
|
| 116 |
+
'bone': ['bone'],
|
| 117 |
+
'lung': ['lung'],
|
| 118 |
+
'heart': ['heart'],
|
| 119 |
+
'esophagus': ['esophagus'],
|
| 120 |
+
'muscle': ['muscle'],
|
| 121 |
+
'fat': ['fat'],
|
| 122 |
+
'skin': ['skin'],
|
| 123 |
+
'vessel': ['vessel'],
|
| 124 |
+
'tumor': ['tumor'],
|
| 125 |
+
'other': ['other']
|
| 126 |
+
}
|
| 127 |
+
elif dict_type == 'Task':
|
| 128 |
+
dict_synonyms = {
|
| 129 |
+
'segmentation': ['segmentation', 'seg', 'mask'],
|
| 130 |
+
'classification': ['classification', 'class', 'diagnosis','identify','identification'],
|
| 131 |
+
'localization': ['localization', 'locate', 'location', 'position'],
|
| 132 |
+
'registration': ['registration', 'register', 'align', 'alignment'],
|
| 133 |
+
'detection': ['detection', 'detect', 'find', 'locate'],
|
| 134 |
+
'quantification': ['quantification', 'quantify', 'measure', 'measurement'],
|
| 135 |
+
}
|
| 136 |
+
elif dict_type == 'Modality':
|
| 137 |
+
dict_synonyms = {
|
| 138 |
+
'CT': ['CT', 'computed tomography'],
|
| 139 |
+
'MRI': ['MRI', 'MR', 'magnetic resonance imaging'],
|
| 140 |
+
'PET': ['PET', 'positron emission tomography'],
|
| 141 |
+
'US': ['US', 'ultrasound'],
|
| 142 |
+
'X-ray': ['X-ray', 'radiography'],
|
| 143 |
+
'SPECT': ['SPECT', 'single-photon emission computed tomlogy'],
|
| 144 |
+
}
|
| 145 |
+
else:
|
| 146 |
+
dict_synonyms = {
|
| 147 |
+
'\'gender\'': ['\'gender\'', '\'sex\'', '\'M/F\'', '\'m/f\''],
|
| 148 |
+
'\'modality\'': ['\'modality\'', '\'modal\''],
|
| 149 |
+
'\'male\'': ['\'male\'', '\'m\''],
|
| 150 |
+
'\'female\'': ['\'female\'', '\'f\'','\'woman\''],
|
| 151 |
+
'\'high-grade glioma\'': ['\'high-grade glioma\'', '\'high grade glioma\'', '\'HGG\''],
|
| 152 |
+
'\'low-grade glioma\'': ['\'low-grade glioma\'', '\'low grade glioma\'', '\'LGG\''],
|
| 153 |
+
'\'atlas scaling factor\'': ['\'atlas scaling factor\'', '\'asf\''],
|
| 154 |
+
'\'age\'': ['\'age\'', '\'years\'', '\'year\'', '\'y/o\'', '\'y.o.\''],
|
| 155 |
+
'\'education\'': ['\'educ\'', '\'educat\'', '\'education\''],
|
| 156 |
+
'\'roi\'': ['\'roi\'', '\'region of interest\'', '\'region\''],
|
| 157 |
+
'\'mini-mental state examination\'': ['\'mini-mental state examination\'', '\'mmse\''],
|
| 158 |
+
'\'clinical dementia rating\'': ['\'clinical dementia rating\'', '\'cdr\''],
|
| 159 |
+
'\'socio-economic status\'': ['\'socio-economic status\'', '\'ses\''],
|
| 160 |
+
'\'unknown\'': ['\'unknown\'', '\'unkn\'', '\'not available\'', '\'nan\'', '\'n/a\'', '\'none\'', '\'n.a.\'', '\'not applicable\'','\'not specified\'', '\'unspecified\'', '\'not given\'', '\'null\''],
|
| 161 |
+
'': [' segmentation', '\'seg\'', '\'registration\''],
|
| 162 |
+
}
|
| 163 |
+
return dict_synonyms
|
| 164 |
+
|
| 165 |
+
def replace_text(text, dict_synonyms):
|
| 166 |
+
'''
|
| 167 |
+
Replace the text in the text with the standard term
|
| 168 |
+
'''
|
| 169 |
+
if isinstance(text, str):
|
| 170 |
+
for key, value in dict_synonyms.items():
|
| 171 |
+
for v in value:
|
| 172 |
+
if v.lower() in text.lower():
|
| 173 |
+
text = text.replace(v, key)
|
| 174 |
+
return text
|
| 175 |
+
elif isinstance(text, list):
|
| 176 |
+
text = [replace_text(t, dict_synonyms) for t in text]
|
| 177 |
+
elif isinstance(text, dict):
|
| 178 |
+
for key in text.keys():
|
| 179 |
+
# replace values in dict
|
| 180 |
+
text[key] = replace_text(text[key], dict_synonyms)
|
| 181 |
+
# replace keys in dict
|
| 182 |
+
for k in dict_synonyms.keys():
|
| 183 |
+
if k.lower() in key.lower():
|
| 184 |
+
text[dict_synonyms[k]] = text.pop(key)
|
| 185 |
+
return text
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def replace_synonyms(text, dict_synonyms):
|
| 189 |
+
'''
|
| 190 |
+
Replace the synonyms in the text with the standard term
|
| 191 |
+
'''
|
| 192 |
+
if isinstance(text,str):
|
| 193 |
+
for key, value in dict_synonyms.items():
|
| 194 |
+
for v in value:
|
| 195 |
+
if v.lower() in text.lower():
|
| 196 |
+
return key
|
| 197 |
+
Warning(f"Value {text} is not in the correct format")
|
| 198 |
+
elif isinstance(text,list):
|
| 199 |
+
text = [replace_synonyms(t, dict_synonyms) for t in text]
|
| 200 |
+
elif isinstance(text,dict):
|
| 201 |
+
for key in text.keys():
|
| 202 |
+
# replace values in dict
|
| 203 |
+
text[key] = replace_synonyms(text[key], dict_synonyms)
|
| 204 |
+
# replace keys in dict
|
| 205 |
+
for k in dict_synonyms.keys():
|
| 206 |
+
text[dict_synonyms[k]] = text.pop(key)
|
| 207 |
+
return text
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
# model_name = "bert-base-uncased"
|
| 211 |
+
# model_name = "bert-large-uncased"
|
| 212 |
+
model_name = "/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased"
|
| 213 |
+
# model_name = "Rostlab/prot_bert"
|
| 214 |
+
# model_name = "fspanda/Medical-Bio-BERT2"
|
| 215 |
+
# model_name = "GerMedBERT/medbert-512"
|
| 216 |
+
|
| 217 |
+
reduce_method = 'mean'
|
| 218 |
+
max_words_num = 32 # max number of words in the caption > 2
|
| 219 |
+
|
| 220 |
+
embeder, tokenizer = get_frozen_embeder(model_name)
|
| 221 |
+
|
| 222 |
+
# string1 = ["mri", "female"]
|
| 223 |
+
string1 = "modality: ct, gender: female, age: 51, roi: abdomen"
|
| 224 |
+
# string1 = "modality: Magnetic Resonance, gender: female"
|
| 225 |
+
embeder_output1 = str2emb(string1, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
|
| 226 |
+
|
| 227 |
+
# string2 = "Hello world!"
|
| 228 |
+
# string2 = ["ct", "male"]
|
| 229 |
+
# string2 = "modality: mri, gender: female, roi: head"
|
| 230 |
+
string2 = "modality: ct, gender: female, age: 50, roi: head"
|
| 231 |
+
# string2 = "modality: ct, gender: male, roi: head"
|
| 232 |
+
embeder_output2 = str2emb(string2, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
|
| 233 |
+
|
| 234 |
+
input_size = embeder.config.vocab_size
|
| 235 |
+
in_size = embeder.config.hidden_size
|
| 236 |
+
|
| 237 |
+
print(embeder, input_size, in_size)
|
| 238 |
+
print(tokenizer)
|
| 239 |
+
|
| 240 |
+
# embeder_output1 shape: [batch_size, max_words_num, hidden_size]
|
| 241 |
+
print(embeder_output1)
|
| 242 |
+
print(embeder_output1.shape) # torch.Size([1, 8, 768])
|
| 243 |
+
|
| 244 |
+
# embeder_output2 shape: [batch_size, max_words_num, hidden_size]
|
| 245 |
+
print(embeder_output2)
|
| 246 |
+
print(embeder_output2.shape) # torch.Size([1, 8, 768])
|
| 247 |
+
|
| 248 |
+
# check the difference between the two sentences in embedding space
|
| 249 |
+
# embeder_output1[0, :, :] shape: [max_words_num, hidden_size]
|
| 250 |
+
# embeder_output2[0, :, :] shape: [max_words_num, hidden_size]
|
| 251 |
+
# error = torch.max(torch.abs(embeder_output1[0, :, :] - embeder_output2[0, :, :]), dim=-1)
|
| 252 |
+
error = torch.abs(embeder_output1 - embeder_output2)
|
| 253 |
+
print(error)
|
| 254 |
+
print("Embedding distance between the two sentences: ")
|
| 255 |
+
print(f"String1: {string1}")
|
| 256 |
+
print(f"String2: {string2}")
|
| 257 |
+
print(torch.mean(error))
|
| 258 |
+
exit()
|
Dataloader/dataLoader.py
ADDED
|
@@ -0,0 +1,1473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import json
|
| 4 |
+
import SimpleITK as sitk
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.transform import rescale, resize, downscale_local_mean
|
| 7 |
+
# from torchvision.transforms import v2
|
| 8 |
+
import sys
|
| 9 |
+
sys.path.append('./')
|
| 10 |
+
from Dataloader.dataloader_utils import *
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
# add your mapping files here
|
| 14 |
+
# mapping_files = {
|
| 15 |
+
# 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
|
| 16 |
+
# 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
|
| 17 |
+
# # 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_1/nifti_mappings.json',
|
| 18 |
+
# }
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
mapping_files = {
|
| 22 |
+
'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
|
| 23 |
+
'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
|
| 24 |
+
'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
|
| 25 |
+
'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
|
| 26 |
+
'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
|
| 27 |
+
# 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
|
| 28 |
+
'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
|
| 29 |
+
'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
|
| 30 |
+
'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
|
| 31 |
+
'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
|
| 32 |
+
'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
|
| 33 |
+
'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 34 |
+
'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
|
| 35 |
+
'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
CLAMP_RANGE = [-400, 400] # default clamp range for the images
|
| 39 |
+
|
| 40 |
+
indivi_ROI_list = ['abdomen','arm','brain','hand','head','leg','neck','pelvis','skeleton','thorax']
|
| 41 |
+
|
| 42 |
+
def reverse_axis_order(arr):
|
| 43 |
+
"""SimpleITK to NumPy axis order conversion."""
|
| 44 |
+
# For 3D or 4D arrays, this is just a fast view, not a copy.
|
| 45 |
+
return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
|
| 46 |
+
|
| 47 |
+
def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high'):
|
| 48 |
+
"""Sample a random value from a uniform distribution with multiple orders.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
high (float): Upper bound of the uniform distribution.
|
| 52 |
+
low (float): Lower bound of the uniform distribution.
|
| 53 |
+
order_num (int): Number of times to sample.
|
| 54 |
+
type (str): 'high' or 'low', determines the sampling direction.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
sample_value (float): The sampled value after multiple orders.
|
| 58 |
+
|
| 59 |
+
Notes:
|
| 60 |
+
- If type is 'high', samples are drawn iteratively from [low, high], each time using the previous sample as the new lower bound.
|
| 61 |
+
- If type is 'low', samples are drawn iteratively from [low, high], each time using the previous sample as the new upper bound.
|
| 62 |
+
- If order_num is 0, returns the low value.
|
| 63 |
+
- If order_num is 1, returns a single random value from the uniform distribution.
|
| 64 |
+
- If order_num is 2, returns a value from a linear distribution.
|
| 65 |
+
- If order_num is 3, returns a value from a quadratic distribution.
|
| 66 |
+
"""
|
| 67 |
+
if type == 'high':
|
| 68 |
+
sample_value = low
|
| 69 |
+
for _ in range(order_num):
|
| 70 |
+
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 71 |
+
elif type == 'low':
|
| 72 |
+
sample_value = high
|
| 73 |
+
for _ in range(order_num):
|
| 74 |
+
sample_value = np.random.uniform(low, high=sample_value)
|
| 75 |
+
return sample_value
|
| 76 |
+
|
| 77 |
+
class OminiDataset(object):
|
| 78 |
+
"""Base class for OmniMorph datasets."""
|
| 79 |
+
def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
|
| 80 |
+
|
| 81 |
+
# self.mappings = mapping_files
|
| 82 |
+
self.ALLdata = self.combine_data(mappings = mapping_files)
|
| 83 |
+
self.out_sz = out_sz
|
| 84 |
+
self.reverse_axis_order = reverse_axis_order
|
| 85 |
+
self.min_dim = min_dim
|
| 86 |
+
self.clamp_range = clamp_range
|
| 87 |
+
self.min_crop_ratio = min_crop_ratio
|
| 88 |
+
self.transform = transform
|
| 89 |
+
self.ndims = 3
|
| 90 |
+
|
| 91 |
+
def get_ALLdata(self):
|
| 92 |
+
return self.ALLdata
|
| 93 |
+
|
| 94 |
+
def get_all_ROI(self):
|
| 95 |
+
# Get all the ROI options. and remove the reduntant ones
|
| 96 |
+
ROIs = []
|
| 97 |
+
# ALLdata_filtered = data
|
| 98 |
+
for k in self.ALLdata_filtered.keys():
|
| 99 |
+
ROIs.append(self.ALLdata[k]['ROI'])
|
| 100 |
+
ROIs = set(ROIs)
|
| 101 |
+
return ROIs
|
| 102 |
+
|
| 103 |
+
def get_filter_ROIs(self,keep_single_roi=False):
|
| 104 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 105 |
+
# if keep_single_roi == True:
|
| 106 |
+
# for k in self.ALLdata_filtered.keys():
|
| 107 |
+
# if '-' in self.ALLdata_filtered[k]['ROI']:
|
| 108 |
+
# del ALLdata_filtered[k]
|
| 109 |
+
# d = {k: v for k, v in ALLdata_filtered.items() if v['ROI'] in self.ROIs}
|
| 110 |
+
for k in ALLdata_filtered.keys():
|
| 111 |
+
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
|
| 112 |
+
del ALLdata_filtered[k]
|
| 113 |
+
return ALLdata_filtered
|
| 114 |
+
|
| 115 |
+
def combine_data(self, mappings = mapping_files):
|
| 116 |
+
ALLdata = {}
|
| 117 |
+
for j in mappings.keys():
|
| 118 |
+
with open(mappings[j], 'r') as f:
|
| 119 |
+
mappings_tmp = json.load(f)
|
| 120 |
+
ALLdata.update(mappings_tmp)
|
| 121 |
+
return ALLdata
|
| 122 |
+
|
| 123 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 124 |
+
# Get a 3D volume from the 4D volume, sometime the input image may have 4 dimensions
|
| 125 |
+
if self.reverse_axis_order:
|
| 126 |
+
volume = reverse_axis_order(volume)
|
| 127 |
+
if volume.ndim == 4:
|
| 128 |
+
if select_channel is None:
|
| 129 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 130 |
+
volume = volume[:, :, :, select_channel]
|
| 131 |
+
return volume
|
| 132 |
+
|
| 133 |
+
def get_filter_mindim(self):
|
| 134 |
+
# Filter out images with dimensions less than min_dim
|
| 135 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 136 |
+
ALLdata = self.ALLdata.copy()
|
| 137 |
+
for k in self.ALLdata.keys():
|
| 138 |
+
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
|
| 139 |
+
del ALLdata[k]
|
| 140 |
+
return ALLdata
|
| 141 |
+
|
| 142 |
+
def normalize(self, volume, eps=1e-7):
|
| 143 |
+
# Normalize the image (0-1)
|
| 144 |
+
volume = volume.astype(np.float64)
|
| 145 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 146 |
+
return volume
|
| 147 |
+
|
| 148 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 149 |
+
# Fast random crop with optional padding using NumPy
|
| 150 |
+
d, h, w = volume.shape
|
| 151 |
+
if crop_size is None:
|
| 152 |
+
crop_size = self.out_sz
|
| 153 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 154 |
+
|
| 155 |
+
# Only pad if needed (avoid np.pad if not necessary)
|
| 156 |
+
pad_d = max(0, crop_d - d)
|
| 157 |
+
pad_h = max(0, crop_h - h)
|
| 158 |
+
pad_w = max(0, crop_w - w)
|
| 159 |
+
if pad_d or pad_h or pad_w:
|
| 160 |
+
pad_width = (
|
| 161 |
+
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
|
| 162 |
+
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
|
| 163 |
+
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
|
| 164 |
+
)
|
| 165 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 166 |
+
d, h, w = volume.shape
|
| 167 |
+
|
| 168 |
+
# Crop indices
|
| 169 |
+
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
|
| 170 |
+
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
|
| 171 |
+
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
|
| 172 |
+
|
| 173 |
+
# Use NumPy slicing (very fast)
|
| 174 |
+
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 175 |
+
|
| 176 |
+
class OminiDataset_v1(Dataset):
|
| 177 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.2, reverse_axis_order = False):
|
| 178 |
+
self.mappings = mapping_files
|
| 179 |
+
self.ALLdata = self.combine_data()
|
| 180 |
+
self.out_sz = out_sz
|
| 181 |
+
self.reverse_axis_order = reverse_axis_order
|
| 182 |
+
self.min_crop_ratio = min_crop_ratio
|
| 183 |
+
self.crop_ratio_sample_order = 2
|
| 184 |
+
self.transform = transform
|
| 185 |
+
self.clamp_range = clamp_range
|
| 186 |
+
self.ndims = 3
|
| 187 |
+
# Start you filtering here
|
| 188 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# self.min_dim = self.find_min_dim()
|
| 192 |
+
|
| 193 |
+
def find_min_dim(self):
|
| 194 |
+
# Find the minimum dimension of the images
|
| 195 |
+
min_dim = 100000
|
| 196 |
+
for k in self.ALLdata.keys():
|
| 197 |
+
value = self.ALLdata[k]
|
| 198 |
+
if min(value['Size']) < min_dim:
|
| 199 |
+
min_dim = min(value['Size'])
|
| 200 |
+
return min_dim
|
| 201 |
+
|
| 202 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 203 |
+
# Fast random crop with optional padding using NumPy
|
| 204 |
+
d, h, w = volume.shape
|
| 205 |
+
if crop_size is None:
|
| 206 |
+
crop_size = self.out_sz
|
| 207 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 208 |
+
|
| 209 |
+
# Only pad if needed (avoid np.pad if not necessary)
|
| 210 |
+
pad_d = max(0, crop_d - d)
|
| 211 |
+
pad_h = max(0, crop_h - h)
|
| 212 |
+
pad_w = max(0, crop_w - w)
|
| 213 |
+
if pad_d or pad_h or pad_w:
|
| 214 |
+
pad_width = (
|
| 215 |
+
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
|
| 216 |
+
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
|
| 217 |
+
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
|
| 218 |
+
)
|
| 219 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 220 |
+
d, h, w = volume.shape
|
| 221 |
+
|
| 222 |
+
# Crop indices
|
| 223 |
+
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
|
| 224 |
+
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
|
| 225 |
+
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
|
| 226 |
+
|
| 227 |
+
# Use NumPy slicing (very fast)
|
| 228 |
+
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 229 |
+
|
| 230 |
+
def get_ALLdata(self):
|
| 231 |
+
# Return all data
|
| 232 |
+
return self.ALLdata
|
| 233 |
+
|
| 234 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 235 |
+
if self.reverse_axis_order:
|
| 236 |
+
volume = reverse_axis_order(volume)
|
| 237 |
+
if volume.ndim == 4:
|
| 238 |
+
if select_channel is None:
|
| 239 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 240 |
+
volume = volume[:, :, :, select_channel]
|
| 241 |
+
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
|
| 242 |
+
return volume
|
| 243 |
+
|
| 244 |
+
def get_filter_ROI(self, key_word):
|
| 245 |
+
# Filter out images with a key word
|
| 246 |
+
ALLdata = self.ALLdata.copy()
|
| 247 |
+
for k in self.ALLdata.keys():
|
| 248 |
+
if key_word not in k["ROI"]:
|
| 249 |
+
del ALLdata[k]
|
| 250 |
+
return ALLdata
|
| 251 |
+
|
| 252 |
+
def get_filter_mindim(self):
|
| 253 |
+
# Filter out images with dimensions less than min_dim
|
| 254 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 255 |
+
ALLdata = self.ALLdata.copy()
|
| 256 |
+
for k in self.ALLdata.keys():
|
| 257 |
+
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
|
| 258 |
+
del ALLdata[k]
|
| 259 |
+
return ALLdata
|
| 260 |
+
|
| 261 |
+
def combine_data(self):
|
| 262 |
+
ALLdata = {}
|
| 263 |
+
for j in self.mappings.keys():
|
| 264 |
+
with open(self.mappings[j], 'r') as f:
|
| 265 |
+
mappings = json.load(f)
|
| 266 |
+
ALLdata.update(mappings)
|
| 267 |
+
return ALLdata
|
| 268 |
+
|
| 269 |
+
def __len__(self):
|
| 270 |
+
return len(self.ALLdata_filtered.keys())
|
| 271 |
+
|
| 272 |
+
def normalize(self, volume, eps=1e-7):
|
| 273 |
+
# Normalize the image (0-1)
|
| 274 |
+
volume = volume.astype(np.float64)
|
| 275 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 276 |
+
return volume
|
| 277 |
+
|
| 278 |
+
def __getitem__(self, idx):
|
| 279 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 280 |
+
if 0:
|
| 281 |
+
print(key)
|
| 282 |
+
volume = sitk.ReadImage(key)
|
| 283 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 284 |
+
# if volume.ndim == 4:
|
| 285 |
+
volume = self.get_3D_volume(volume)
|
| 286 |
+
|
| 287 |
+
if self.clamp_range is not None:
|
| 288 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 289 |
+
if modality == "CT":
|
| 290 |
+
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
|
| 291 |
+
volume = self.normalize(volume)
|
| 292 |
+
|
| 293 |
+
if self.min_crop_ratio is not None:
|
| 294 |
+
# print(f'before volume_shape: {volume.shape}')
|
| 295 |
+
# crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 296 |
+
crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
|
| 297 |
+
# crop_size = int(min(volume.shape) * crop_ratio)
|
| 298 |
+
crop_size = int(max(volume.shape) * crop_ratio)
|
| 299 |
+
volume = self.random_crop_3d(volume, crop_size)
|
| 300 |
+
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
else:
|
| 304 |
+
volume = self.random_crop_3d(volume, self.out_sz)
|
| 305 |
+
volume = volume[None, :, :, :]
|
| 306 |
+
|
| 307 |
+
if self.transform is not None:
|
| 308 |
+
return self.transform(volume)
|
| 309 |
+
|
| 310 |
+
return volume
|
| 311 |
+
|
| 312 |
+
class OMDataset_indiv(Dataset):
|
| 313 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, reverse_axis_order = False):
|
| 314 |
+
# self.mappings = mapping_files
|
| 315 |
+
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 316 |
+
self.out_sz = out_sz
|
| 317 |
+
self.max_sz = out_sz*8
|
| 318 |
+
self.reverse_axis_order = reverse_axis_order
|
| 319 |
+
self.min_crop_ratio = min_crop_ratio
|
| 320 |
+
self.crop_ratio_sample_order = 2
|
| 321 |
+
self.transform = transform
|
| 322 |
+
self.clamp_range = clamp_range
|
| 323 |
+
self.ndims = 3
|
| 324 |
+
|
| 325 |
+
# Start you filtering here
|
| 326 |
+
# print(f"Filtering data with out_sz: {self.out_sz}, min_crop_ratio: {min_crop_ratio}")
|
| 327 |
+
print(f"Diffusion mode: Total data size before filtering: {len(self.ALLdata)}")
|
| 328 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 329 |
+
print(f"Diffusion mode: Filtered data size: {len(self.ALLdata_filtered)}")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# self.min_dim = self.find_min_dim()
|
| 333 |
+
|
| 334 |
+
def find_min_dim(self):
|
| 335 |
+
# Find the minimum dimension of the images
|
| 336 |
+
min_dim = 100000
|
| 337 |
+
for k in self.ALLdata.keys():
|
| 338 |
+
value = self.ALLdata[k]
|
| 339 |
+
if min(value['Size']) < min_dim:
|
| 340 |
+
min_dim = min(value['Size'])
|
| 341 |
+
return min_dim
|
| 342 |
+
|
| 343 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 344 |
+
# Fast random crop with optional padding using NumPy
|
| 345 |
+
d, h, w = volume.shape
|
| 346 |
+
if crop_size is None:
|
| 347 |
+
crop_size = self.out_sz
|
| 348 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 349 |
+
|
| 350 |
+
# Only pad if needed (avoid np.pad if not necessary)
|
| 351 |
+
pad_d = max(0, crop_d - d)
|
| 352 |
+
pad_h = max(0, crop_h - h)
|
| 353 |
+
pad_w = max(0, crop_w - w)
|
| 354 |
+
if pad_d or pad_h or pad_w:
|
| 355 |
+
pad_width = (
|
| 356 |
+
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
|
| 357 |
+
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
|
| 358 |
+
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
|
| 359 |
+
)
|
| 360 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 361 |
+
d, h, w = volume.shape
|
| 362 |
+
|
| 363 |
+
# Crop indices
|
| 364 |
+
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
|
| 365 |
+
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
|
| 366 |
+
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
|
| 367 |
+
|
| 368 |
+
# Use NumPy slicing (very fast)
|
| 369 |
+
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 370 |
+
|
| 371 |
+
def get_ALLdata(self):
|
| 372 |
+
# Return all data
|
| 373 |
+
return self.ALLdata
|
| 374 |
+
|
| 375 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 376 |
+
if self.reverse_axis_order:
|
| 377 |
+
volume = reverse_axis_order(volume)
|
| 378 |
+
if volume.ndim == 4:
|
| 379 |
+
if select_channel is None:
|
| 380 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 381 |
+
volume = volume[:, :, :, select_channel]
|
| 382 |
+
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
|
| 383 |
+
return volume
|
| 384 |
+
|
| 385 |
+
def get_filter_ROI(self, key_word):
|
| 386 |
+
# Filter out images with a key word
|
| 387 |
+
ALLdata = self.ALLdata.copy()
|
| 388 |
+
for k in self.ALLdata.keys():
|
| 389 |
+
if key_word not in k["ROI"]:
|
| 390 |
+
del ALLdata[k]
|
| 391 |
+
return ALLdata
|
| 392 |
+
|
| 393 |
+
def get_filter_mindim(self):
|
| 394 |
+
# Filter out images with dimensions less than min_dim
|
| 395 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 396 |
+
ALLdata = self.ALLdata.copy()
|
| 397 |
+
for k in self.ALLdata.keys():
|
| 398 |
+
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
|
| 399 |
+
del ALLdata[k]
|
| 400 |
+
return ALLdata
|
| 401 |
+
|
| 402 |
+
def combine_data(self, mappings = mapping_files):
|
| 403 |
+
ALLdata = {}
|
| 404 |
+
for j in mappings.keys():
|
| 405 |
+
with open(mappings[j], 'r') as f:
|
| 406 |
+
mappings_tmp = json.load(f)
|
| 407 |
+
ALLdata.update(mappings_tmp)
|
| 408 |
+
return ALLdata
|
| 409 |
+
|
| 410 |
+
def __len__(self):
|
| 411 |
+
return len(self.ALLdata_filtered.keys())
|
| 412 |
+
|
| 413 |
+
def normalize(self, volume, eps=1e-7):
|
| 414 |
+
# Normalize the image (0-1)
|
| 415 |
+
volume = volume.astype(np.float64)
|
| 416 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 417 |
+
return volume
|
| 418 |
+
|
| 419 |
+
def __getitem__(self, idx):
|
| 420 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 421 |
+
embd = self.ALLdata_filtered[key]['embd']
|
| 422 |
+
embd = np.array(embd, dtype=np.float32)
|
| 423 |
+
|
| 424 |
+
if 0:
|
| 425 |
+
print(key)
|
| 426 |
+
volume = sitk.ReadImage(key)
|
| 427 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 428 |
+
# if volume.ndim == 4:
|
| 429 |
+
volume = self.get_3D_volume(volume)
|
| 430 |
+
|
| 431 |
+
if self.clamp_range is not None:
|
| 432 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 433 |
+
if modality == "CT":
|
| 434 |
+
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
|
| 435 |
+
volume = self.normalize(volume)
|
| 436 |
+
|
| 437 |
+
if self.min_crop_ratio is not None:
|
| 438 |
+
# print(f'before volume_shape: {volume.shape}')
|
| 439 |
+
# crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 440 |
+
crop_ratio = sample_random_uniform_multi_order(high=1., low=self.min_crop_ratio, order_num=self.crop_ratio_sample_order, type='high')
|
| 441 |
+
# crop_size = int(min(volume.shape) * crop_ratio)
|
| 442 |
+
crop_size = int(max(volume.shape) * crop_ratio)
|
| 443 |
+
crop_size = min(crop_size, self.max_sz)
|
| 444 |
+
volume = self.random_crop_3d(volume, crop_size)
|
| 445 |
+
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
else:
|
| 449 |
+
volume = self.random_crop_3d(volume, self.out_sz)
|
| 450 |
+
volume = volume[None, :, :, :]
|
| 451 |
+
|
| 452 |
+
if self.transform is not None:
|
| 453 |
+
return self.transform(volume)
|
| 454 |
+
|
| 455 |
+
return [volume, embd]
|
| 456 |
+
|
| 457 |
+
class OminiDataset_paired(Dataset):
|
| 458 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.9, ROIs = None, modality = None, reverse_axis_order = False):
|
| 459 |
+
# self.mappings = mapping_files
|
| 460 |
+
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 461 |
+
self.out_sz = out_sz
|
| 462 |
+
self.sz_range = get_sizeRange_dict()
|
| 463 |
+
self.min_dim_ratio = 0.5
|
| 464 |
+
self.reverse_axis_order = reverse_axis_order
|
| 465 |
+
self.min_crop_ratio = min_crop_ratio
|
| 466 |
+
self.transform = transform
|
| 467 |
+
self.clamp_range = clamp_range
|
| 468 |
+
self.ndims = 3
|
| 469 |
+
# Start you filtering here
|
| 470 |
+
# print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
|
| 471 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 472 |
+
# print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
|
| 473 |
+
self.ALLdata_filtered = self.get_filter_modality(modality)
|
| 474 |
+
# print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
|
| 475 |
+
if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
|
| 476 |
+
self.ROIs = self.get_all_ROI()
|
| 477 |
+
else:
|
| 478 |
+
self.ROIs = ROIs
|
| 479 |
+
self.ALLdata_filtered = self.get_filter_ROIs()
|
| 480 |
+
# print(f"Number of images after ROI filtering: {len(self.ALLdata_filtered.keys())}")
|
| 481 |
+
# filtering ends here
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def combine_data(self, mappings = mapping_files):
|
| 486 |
+
ALLdata = {}
|
| 487 |
+
for j in mappings.keys():
|
| 488 |
+
with open(mappings[j], 'r') as f:
|
| 489 |
+
mappings_tmp = json.load(f)
|
| 490 |
+
ALLdata.update(mappings_tmp)
|
| 491 |
+
return ALLdata
|
| 492 |
+
|
| 493 |
+
def normalize(self, volume, eps=1e-7):
|
| 494 |
+
# Normalize the image (0-1)
|
| 495 |
+
volume = volume.astype(np.float64)
|
| 496 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 497 |
+
return volume
|
| 498 |
+
|
| 499 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 500 |
+
# Fast random crop with optional padding using NumPy
|
| 501 |
+
d, h, w = volume.shape
|
| 502 |
+
if crop_size is None:
|
| 503 |
+
crop_size = self.out_sz
|
| 504 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 505 |
+
|
| 506 |
+
# Only pad if needed (avoid np.pad if not necessary)
|
| 507 |
+
pad_d = max(0, crop_d - d)
|
| 508 |
+
pad_h = max(0, crop_h - h)
|
| 509 |
+
pad_w = max(0, crop_w - w)
|
| 510 |
+
if pad_d or pad_h or pad_w:
|
| 511 |
+
pad_width = (
|
| 512 |
+
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
|
| 513 |
+
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
|
| 514 |
+
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
|
| 515 |
+
)
|
| 516 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 517 |
+
d, h, w = volume.shape
|
| 518 |
+
|
| 519 |
+
# Crop indices
|
| 520 |
+
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
|
| 521 |
+
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
|
| 522 |
+
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
|
| 523 |
+
|
| 524 |
+
# Use NumPy slicing (very fast)
|
| 525 |
+
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 526 |
+
|
| 527 |
+
# def random_crop_3d(self, volume, crop_size=None):
|
| 528 |
+
# # Randomly crop the image
|
| 529 |
+
# d, h, w = volume.shape
|
| 530 |
+
# if crop_size is None:
|
| 531 |
+
# crop_size = self.out_sz
|
| 532 |
+
# crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 533 |
+
|
| 534 |
+
# if crop_d > d or crop_h > h or crop_w > w:
|
| 535 |
+
# raise ValueError("Crop size must be smaller than the original array size")
|
| 536 |
+
|
| 537 |
+
# start_d = np.random.randint(0, d - crop_d + 1)
|
| 538 |
+
# start_h = np.random.randint(0, h - crop_h + 1)
|
| 539 |
+
# start_w = np.random.randint(0, w - crop_w + 1)
|
| 540 |
+
|
| 541 |
+
# cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 542 |
+
|
| 543 |
+
# return cropped_array
|
| 544 |
+
|
| 545 |
+
def get_all_ROI(self):
|
| 546 |
+
# Get all the ROI options. and remove the reduntant ones
|
| 547 |
+
ROIs = []
|
| 548 |
+
for k in self.ALLdata_filtered.keys():
|
| 549 |
+
ROIs.append(self.ALLdata[k]['ROI'])
|
| 550 |
+
ROIs = set(ROIs)
|
| 551 |
+
return ROIs
|
| 552 |
+
|
| 553 |
+
def find_min_dim(self):
|
| 554 |
+
# Find the minimum dimension of the images
|
| 555 |
+
min_dim = 100000
|
| 556 |
+
for k in self.ALLdata.keys():
|
| 557 |
+
value = self.ALLdata[k]
|
| 558 |
+
if min(value['Size']) < min_dim:
|
| 559 |
+
min_dim = min(value['Size'])
|
| 560 |
+
return min_dim
|
| 561 |
+
|
| 562 |
+
def get_ALLdata(self):
|
| 563 |
+
# Return all data
|
| 564 |
+
return self.ALLdata
|
| 565 |
+
|
| 566 |
+
def get_filter_modality(self, key_words=None):
|
| 567 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 568 |
+
if key_words is not None:
|
| 569 |
+
for k in self.ALLdata_filtered.keys():
|
| 570 |
+
if ALLdata_filtered[k]["Modality"] not in key_words:
|
| 571 |
+
del ALLdata_filtered[k]
|
| 572 |
+
return ALLdata_filtered
|
| 573 |
+
|
| 574 |
+
def get_filter_ROI(self, key_word):
|
| 575 |
+
# Filter out images with a key word
|
| 576 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 577 |
+
for k in self.ALLdata_filtered.keys():
|
| 578 |
+
if key_word not in k["ROI"]:
|
| 579 |
+
del ALLdata_filtered[k]
|
| 580 |
+
return ALLdata_filtered
|
| 581 |
+
|
| 582 |
+
def get_key_by_ROI(self, key_word):
|
| 583 |
+
# Get all the keys with a key word
|
| 584 |
+
keys = []
|
| 585 |
+
for k in self.ALLdata_filtered.keys():
|
| 586 |
+
if key_word == self.ALLdata_filtered[k]["ROI"]:
|
| 587 |
+
keys.append(k)
|
| 588 |
+
return keys
|
| 589 |
+
|
| 590 |
+
def get_filter_ROIs(self):
|
| 591 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 592 |
+
for k in self.ALLdata_filtered.keys():
|
| 593 |
+
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
|
| 594 |
+
del ALLdata_filtered[k]
|
| 595 |
+
return ALLdata_filtered
|
| 596 |
+
|
| 597 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 598 |
+
if self.reverse_axis_order:
|
| 599 |
+
volume = reverse_axis_order(volume)
|
| 600 |
+
if volume.ndim == 4:
|
| 601 |
+
if select_channel is None:
|
| 602 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 603 |
+
volume = volume[:, :, :, select_channel]
|
| 604 |
+
return volume
|
| 605 |
+
|
| 606 |
+
def get_filter_mindim(self):
|
| 607 |
+
# Filter out images with dimensions less than min_dim
|
| 608 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 609 |
+
ALLdata = self.ALLdata.copy()
|
| 610 |
+
for k in self.ALLdata.keys():
|
| 611 |
+
img_sz = self.ALLdata[k]['Size'][:self.ndims]
|
| 612 |
+
del_flag = False
|
| 613 |
+
del_flag = del_flag or min(img_sz) < self.out_sz
|
| 614 |
+
# print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
|
| 615 |
+
# print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
|
| 616 |
+
del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
|
| 617 |
+
del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
|
| 618 |
+
# del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
|
| 619 |
+
if del_flag:
|
| 620 |
+
del ALLdata[k]
|
| 621 |
+
return ALLdata
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def __getitem__(self,idx):
|
| 626 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 627 |
+
volume_A = sitk.ReadImage(key)
|
| 628 |
+
volume_A = sitk.GetArrayFromImage(volume_A)
|
| 629 |
+
|
| 630 |
+
paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
|
| 631 |
+
paired_key = random.choice(paired_keys)
|
| 632 |
+
|
| 633 |
+
volume_B = sitk.ReadImage(paired_key)
|
| 634 |
+
volume_B = sitk.GetArrayFromImage(volume_B)
|
| 635 |
+
|
| 636 |
+
# if volume_A.ndim == 4 or volume_B.ndim == 4:
|
| 637 |
+
volume_A = self.get_3D_volume(volume_A)
|
| 638 |
+
volume_B = self.get_3D_volume(volume_B)
|
| 639 |
+
|
| 640 |
+
if self.clamp_range is not None:
|
| 641 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 642 |
+
if modality == "CT":
|
| 643 |
+
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
|
| 644 |
+
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
|
| 645 |
+
volume_A = self.normalize(volume_A)
|
| 646 |
+
volume_B = self.normalize(volume_B)
|
| 647 |
+
|
| 648 |
+
if self.min_crop_ratio is not None:
|
| 649 |
+
|
| 650 |
+
# print(f'before volume_shape: {volume.shape}')
|
| 651 |
+
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 652 |
+
crop_size_A = int(min(volume_A.shape) * crop_ratio)
|
| 653 |
+
crop_size_B = int(min(volume_B.shape) * crop_ratio)
|
| 654 |
+
# crop_size_A = int(max(volume_A.shape) * crop_ratio)
|
| 655 |
+
# crop_size_B = int(max(volume_B.shape) * crop_ratio)
|
| 656 |
+
volume_A = self.random_crop_3d(volume_A, crop_size_A)
|
| 657 |
+
volume_B = self.random_crop_3d(volume_B, crop_size_B)
|
| 658 |
+
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 659 |
+
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 660 |
+
|
| 661 |
+
else:
|
| 662 |
+
volume_A = self.random_crop_3d(volume_A, self.out_sz)
|
| 663 |
+
volume_B = self.random_crop_3d(volume_B, self.out_sz)
|
| 664 |
+
volume_A = volume_A[None, :, :, :]
|
| 665 |
+
volume_B = volume_B[None, :, :, :]
|
| 666 |
+
|
| 667 |
+
if self.transform is not None:
|
| 668 |
+
return self.transform(volume_A), self.transform(volume_B)
|
| 669 |
+
|
| 670 |
+
# print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
|
| 671 |
+
return volume_A, volume_B
|
| 672 |
+
|
| 673 |
+
def __len__(self):
|
| 674 |
+
return len(self.ALLdata_filtered.keys())
|
| 675 |
+
|
| 676 |
+
class OMDataset_pair(Dataset):
|
| 677 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = indivi_ROI_list, modality = None, reverse_axis_order = False):
|
| 678 |
+
# self.mappings = mapping_files
|
| 679 |
+
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 680 |
+
self.out_sz = out_sz
|
| 681 |
+
self.max_sz = out_sz*8
|
| 682 |
+
self.sz_range = get_sizeRange_dict()
|
| 683 |
+
self.min_dim_ratio = 0.7
|
| 684 |
+
self.reverse_axis_order = reverse_axis_order
|
| 685 |
+
self.min_crop_ratio = min_crop_ratio
|
| 686 |
+
self.transform = transform
|
| 687 |
+
self.clamp_range = clamp_range
|
| 688 |
+
self.ndims = 3
|
| 689 |
+
# Start you filtering here
|
| 690 |
+
# print(f"Number of images before filtering: {len(self.ALLdata.keys())}")
|
| 691 |
+
print(f"Registration mode: Total data size before filtering: {len(self.ALLdata)}")
|
| 692 |
+
|
| 693 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 694 |
+
# print(f"Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
|
| 695 |
+
self.ALLdata_filtered = self.get_filter_modality(modality)
|
| 696 |
+
# print(f"Number of images after modality filtering: {len(self.ALLdata_filtered.keys())}")
|
| 697 |
+
if ROIs is None:# if no ROIs are provided, get all the ROIs from filtered data
|
| 698 |
+
self.ROIs = self.get_all_ROI()
|
| 699 |
+
else:
|
| 700 |
+
self.ROIs = ROIs
|
| 701 |
+
self.ALLdata_filtered = self.get_filter_ROIs()
|
| 702 |
+
print(f"Registration mode: Number of images after filtering: {len(self.ALLdata_filtered.keys())}")
|
| 703 |
+
# filtering ends here
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def combine_data(self, mappings = mapping_files):
|
| 708 |
+
ALLdata = {}
|
| 709 |
+
for j in mappings.keys():
|
| 710 |
+
with open(mappings[j], 'r') as f:
|
| 711 |
+
mappings_tmp = json.load(f)
|
| 712 |
+
ALLdata.update(mappings_tmp)
|
| 713 |
+
return ALLdata
|
| 714 |
+
|
| 715 |
+
def normalize(self, volume, eps=1e-7):
|
| 716 |
+
# Normalize the image (0-1)
|
| 717 |
+
volume = volume.astype(np.float64)
|
| 718 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 719 |
+
return volume
|
| 720 |
+
|
| 721 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 722 |
+
# Fast random crop with optional padding using NumPy
|
| 723 |
+
d, h, w = volume.shape
|
| 724 |
+
if crop_size is None:
|
| 725 |
+
crop_size = self.out_sz
|
| 726 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 727 |
+
|
| 728 |
+
# Only pad if needed (avoid np.pad if not necessary)
|
| 729 |
+
pad_d = max(0, crop_d - d)
|
| 730 |
+
pad_h = max(0, crop_h - h)
|
| 731 |
+
pad_w = max(0, crop_w - w)
|
| 732 |
+
if pad_d or pad_h or pad_w:
|
| 733 |
+
pad_width = (
|
| 734 |
+
(np.random.randint(0, pad_d + 1), pad_d - np.random.randint(0, pad_d + 1)),
|
| 735 |
+
(np.random.randint(0, pad_h + 1), pad_h - np.random.randint(0, pad_h + 1)),
|
| 736 |
+
(np.random.randint(0, pad_w + 1), pad_w - np.random.randint(0, pad_w + 1)),
|
| 737 |
+
)
|
| 738 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 739 |
+
d, h, w = volume.shape
|
| 740 |
+
|
| 741 |
+
# Crop indices
|
| 742 |
+
start_d = np.random.randint(0, d - crop_d + 1) if d > crop_d else 0
|
| 743 |
+
start_h = np.random.randint(0, h - crop_h + 1) if h > crop_h else 0
|
| 744 |
+
start_w = np.random.randint(0, w - crop_w + 1) if w > crop_w else 0
|
| 745 |
+
|
| 746 |
+
# Use NumPy slicing (very fast)
|
| 747 |
+
return volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 748 |
+
|
| 749 |
+
# def random_crop_3d(self, volume, crop_size=None):
|
| 750 |
+
# # Randomly crop the image
|
| 751 |
+
# d, h, w = volume.shape
|
| 752 |
+
# if crop_size is None:
|
| 753 |
+
# crop_size = self.out_sz
|
| 754 |
+
# crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 755 |
+
|
| 756 |
+
# if crop_d > d or crop_h > h or crop_w > w:
|
| 757 |
+
# raise ValueError("Crop size must be smaller than the original array size")
|
| 758 |
+
|
| 759 |
+
# start_d = np.random.randint(0, d - crop_d + 1)
|
| 760 |
+
# start_h = np.random.randint(0, h - crop_h + 1)
|
| 761 |
+
# start_w = np.random.randint(0, w - crop_w + 1)
|
| 762 |
+
|
| 763 |
+
# cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 764 |
+
|
| 765 |
+
# return cropped_array
|
| 766 |
+
|
| 767 |
+
def get_all_ROI(self):
|
| 768 |
+
# Get all the ROI options. and remove the reduntant ones
|
| 769 |
+
ROIs = []
|
| 770 |
+
for k in self.ALLdata_filtered.keys():
|
| 771 |
+
ROIs.append(self.ALLdata[k]['ROI'])
|
| 772 |
+
ROIs = set(ROIs)
|
| 773 |
+
return ROIs
|
| 774 |
+
|
| 775 |
+
def find_min_dim(self):
|
| 776 |
+
# Find the minimum dimension of the images
|
| 777 |
+
min_dim = 100000
|
| 778 |
+
for k in self.ALLdata.keys():
|
| 779 |
+
value = self.ALLdata[k]
|
| 780 |
+
if min(value['Size']) < min_dim:
|
| 781 |
+
min_dim = min(value['Size'])
|
| 782 |
+
return min_dim
|
| 783 |
+
|
| 784 |
+
def get_ALLdata(self):
|
| 785 |
+
# Return all data
|
| 786 |
+
return self.ALLdata
|
| 787 |
+
|
| 788 |
+
def get_filter_modality(self, key_words=None):
|
| 789 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 790 |
+
if key_words is not None:
|
| 791 |
+
for k in self.ALLdata_filtered.keys():
|
| 792 |
+
if ALLdata_filtered[k]["Modality"] not in key_words:
|
| 793 |
+
del ALLdata_filtered[k]
|
| 794 |
+
return ALLdata_filtered
|
| 795 |
+
|
| 796 |
+
def get_filter_ROI(self, key_word):
|
| 797 |
+
# Filter out images with a key word
|
| 798 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 799 |
+
for k in self.ALLdata_filtered.keys():
|
| 800 |
+
if key_word not in k["ROI"]:
|
| 801 |
+
del ALLdata_filtered[k]
|
| 802 |
+
return ALLdata_filtered
|
| 803 |
+
|
| 804 |
+
def get_key_by_ROI(self, key_word):
|
| 805 |
+
# Get all the keys with a key word
|
| 806 |
+
keys = []
|
| 807 |
+
for k in self.ALLdata_filtered.keys():
|
| 808 |
+
if key_word == self.ALLdata_filtered[k]["ROI"]:
|
| 809 |
+
keys.append(k)
|
| 810 |
+
return keys
|
| 811 |
+
|
| 812 |
+
def filter_keys_by_xx(self, key_word, keys=None, term="ROI"):
|
| 813 |
+
# Filter out images with a key word
|
| 814 |
+
filtered_keys = []
|
| 815 |
+
if keys is None:
|
| 816 |
+
keys = self.ALLdata_filtered.keys()
|
| 817 |
+
for k in keys:
|
| 818 |
+
value = self.ALLdata_filtered[k].get(term, None)
|
| 819 |
+
if value is not None and key_word == value:
|
| 820 |
+
filtered_keys.append(k)
|
| 821 |
+
return filtered_keys
|
| 822 |
+
|
| 823 |
+
def get_filter_ROIs(self):
|
| 824 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 825 |
+
for k in self.ALLdata_filtered.keys():
|
| 826 |
+
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
|
| 827 |
+
del ALLdata_filtered[k]
|
| 828 |
+
return ALLdata_filtered
|
| 829 |
+
|
| 830 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 831 |
+
if self.reverse_axis_order:
|
| 832 |
+
volume = reverse_axis_order(volume)
|
| 833 |
+
if volume.ndim == 4:
|
| 834 |
+
if select_channel is None:
|
| 835 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 836 |
+
volume = volume[:, :, :, select_channel]
|
| 837 |
+
return volume
|
| 838 |
+
|
| 839 |
+
def get_filter_mindim(self):
|
| 840 |
+
# Filter out images with dimensions less than min_dim
|
| 841 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 842 |
+
ALLdata = self.ALLdata.copy()
|
| 843 |
+
for k in self.ALLdata.keys():
|
| 844 |
+
img_sz = self.ALLdata[k]['Size'][:self.ndims]
|
| 845 |
+
del_flag = False
|
| 846 |
+
del_flag = del_flag or min(img_sz) < self.out_sz
|
| 847 |
+
# print(f"Size: {self.ALLdata[k]['Size']}, Spacing_mm: {self.ALLdata[k]['Spacing_mm']}, ROI: {self.ALLdata[k]['ROI']}")
|
| 848 |
+
# print(f"sz_range: {self.sz_range[self.ALLdata[k]['ROI']]}, min_dim_ratio: {self.min_dim_ratio}")
|
| 849 |
+
del_flag = del_flag or (min(img_sz)*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']][0]
|
| 850 |
+
del_flag = del_flag or (min(img_sz)/max(img_sz) < self.min_dim_ratio)
|
| 851 |
+
# del_flag = min(self.ALLdata[k]['Size']) < self.out_sz or (min(self.ALLdata[k]['Size'])*self.ALLdata[k]['Spacing_mm']) < self.sz_range[self.ALLdata[k]['ROI']] or (min(self.ALLdata[k]['Size'])/max(self.ALLdata[k]['Size']) < self.min_dim_ratio)
|
| 852 |
+
if del_flag:
|
| 853 |
+
del ALLdata[k]
|
| 854 |
+
return ALLdata
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
|
| 858 |
+
def __getitem__(self,idx):
|
| 859 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 860 |
+
volume_A = sitk.ReadImage(key)
|
| 861 |
+
volume_A = sitk.GetArrayFromImage(volume_A)
|
| 862 |
+
|
| 863 |
+
embd_A = self.ALLdata_filtered[key]['embd']
|
| 864 |
+
embd_A = np.array(embd_A, dtype=np.float32)
|
| 865 |
+
|
| 866 |
+
all_keys = list(self.ALLdata_filtered.keys())
|
| 867 |
+
paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['ROI'], all_keys, term="ROI")
|
| 868 |
+
paired_keys = self.filter_keys_by_xx(self.ALLdata_filtered[key]['Modality'], paired_keys, term="Modality")
|
| 869 |
+
# paired_keys = self.get_key_by_ROI(self.ALLdata_filtered[key]['ROI'])
|
| 870 |
+
|
| 871 |
+
paired_key = random.choice(paired_keys)
|
| 872 |
+
|
| 873 |
+
print(f"Key: {key}, Paired Key: {paired_key}")
|
| 874 |
+
print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
volume_B = sitk.ReadImage(paired_key)
|
| 878 |
+
volume_B = sitk.GetArrayFromImage(volume_B)
|
| 879 |
+
|
| 880 |
+
embd_B = self.ALLdata_filtered[paired_key]['embd']
|
| 881 |
+
embd_B = np.array(embd_B, dtype=np.float32)
|
| 882 |
+
|
| 883 |
+
# if volume_A.ndim == 4 or volume_B.ndim == 4:
|
| 884 |
+
volume_A = self.get_3D_volume(volume_A)
|
| 885 |
+
volume_B = self.get_3D_volume(volume_B)
|
| 886 |
+
|
| 887 |
+
if self.clamp_range is not None:
|
| 888 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 889 |
+
if modality == "CT":
|
| 890 |
+
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
|
| 891 |
+
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
|
| 892 |
+
volume_A = self.normalize(volume_A)
|
| 893 |
+
volume_B = self.normalize(volume_B)
|
| 894 |
+
|
| 895 |
+
if self.min_crop_ratio is not None:
|
| 896 |
+
|
| 897 |
+
# print(f'before volume_shape: {volume.shape}')
|
| 898 |
+
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 899 |
+
# crop_size_A = int(min(volume_A.shape) * crop_ratio)
|
| 900 |
+
# crop_size_B = int(min(volume_B.shape) * crop_ratio)
|
| 901 |
+
crop_size_A = int(max(volume_A.shape) * crop_ratio)
|
| 902 |
+
crop_size_B = int(max(volume_B.shape) * crop_ratio)
|
| 903 |
+
crop_size_A = min(crop_size_A, self.max_sz)
|
| 904 |
+
crop_size_B = min(crop_size_B, self.max_sz)
|
| 905 |
+
volume_A = self.random_crop_3d(volume_A, crop_size_A)
|
| 906 |
+
volume_B = self.random_crop_3d(volume_B, crop_size_B)
|
| 907 |
+
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 908 |
+
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 909 |
+
|
| 910 |
+
else:
|
| 911 |
+
volume_A = self.random_crop_3d(volume_A, self.out_sz)
|
| 912 |
+
volume_B = self.random_crop_3d(volume_B, self.out_sz)
|
| 913 |
+
volume_A = volume_A[None, :, :, :]
|
| 914 |
+
volume_B = volume_B[None, :, :, :]
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
if self.transform is not None:
|
| 918 |
+
return self.transform(volume_A), self.transform(volume_B)
|
| 919 |
+
|
| 920 |
+
# print(self.ALLdata_filtered[key]['ROI'],self.ALLdata_filtered[key]['Modality'],self.ALLdata_filtered[key]['Dataset_name'],'---',self.ALLdata_filtered[paired_key]['ROI'], self.ALLdata_filtered[paired_key]['Modality'], self.ALLdata_filtered[paired_key]['Dataset_name'])
|
| 921 |
+
return [volume_A, volume_B, embd_A, embd_B]
|
| 922 |
+
|
| 923 |
+
def __len__(self):
|
| 924 |
+
return len(self.ALLdata_filtered.keys())
|
| 925 |
+
|
| 926 |
+
class OminiDataset_paired_inf(object):
|
| 927 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.3, ROIs = None):
|
| 928 |
+
# self.mappings = mapping_files
|
| 929 |
+
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 930 |
+
self.out_sz = out_sz
|
| 931 |
+
self.min_crop_ratio = min_crop_ratio
|
| 932 |
+
self.transform = transform
|
| 933 |
+
self.clamp_range = clamp_range
|
| 934 |
+
self.ndims = 3
|
| 935 |
+
# Start you filtering here:
|
| 936 |
+
# filter out images with dimensions less than min_dim
|
| 937 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 938 |
+
# filter out images with ROIs that are not in the provided ROIs
|
| 939 |
+
if ROIs is None:
|
| 940 |
+
self.ROIs = self.get_all_ROI()
|
| 941 |
+
else:
|
| 942 |
+
self.ROIs = ROIs
|
| 943 |
+
self.ALLdata_filtered = self.get_filter_ROIs()
|
| 944 |
+
# filtering ends here
|
| 945 |
+
|
| 946 |
+
self.roi_scan_mapping = self.build_ROI_scan_mapping()
|
| 947 |
+
self.keys_dist, self.total = self.get_keys_dist()
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
def get_all_ROI(self):
|
| 953 |
+
# Get all the ROI options. and remove the reduntant ones
|
| 954 |
+
ROIs = []
|
| 955 |
+
for k in self.ALLdata_filtered.keys():
|
| 956 |
+
ROIs.append(self.ALLdata[k]['ROI'])
|
| 957 |
+
ROIs = set(ROIs)
|
| 958 |
+
return ROIs
|
| 959 |
+
|
| 960 |
+
def get_ALLdata(self):
|
| 961 |
+
# Return all data
|
| 962 |
+
return self.ALLdata
|
| 963 |
+
|
| 964 |
+
def combine_data(self, mappings = mapping_files):
|
| 965 |
+
ALLdata = {}
|
| 966 |
+
for j in mappings.keys():
|
| 967 |
+
with open(mappings[j], 'r') as f:
|
| 968 |
+
mappings_tmp = json.load(f)
|
| 969 |
+
ALLdata.update(mappings_tmp)
|
| 970 |
+
return ALLdata
|
| 971 |
+
|
| 972 |
+
def __len__(self):
|
| 973 |
+
return len(self.ALLdata_filtered.keys())
|
| 974 |
+
|
| 975 |
+
def random_crop_3d(self, volume, crop_size=None):
|
| 976 |
+
# Randomly crop the image
|
| 977 |
+
d, h, w = volume.shape
|
| 978 |
+
if crop_size is None:
|
| 979 |
+
crop_size = self.out_sz
|
| 980 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 981 |
+
|
| 982 |
+
if crop_d > d or crop_h > h or crop_w > w:
|
| 983 |
+
raise ValueError("Crop size must be smaller than the original array size")
|
| 984 |
+
|
| 985 |
+
start_d = np.random.randint(0, d - crop_d + 1)
|
| 986 |
+
start_h = np.random.randint(0, h - crop_h + 1)
|
| 987 |
+
start_w = np.random.randint(0, w - crop_w + 1)
|
| 988 |
+
|
| 989 |
+
cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 990 |
+
|
| 991 |
+
return cropped_array
|
| 992 |
+
|
| 993 |
+
def normalize(self, volume, eps=1e-7):
|
| 994 |
+
# Normalize the image (0-1)
|
| 995 |
+
volume = volume.astype(np.float64)
|
| 996 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 997 |
+
return volume
|
| 998 |
+
|
| 999 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 1000 |
+
volume = reverse_axis_order(volume)
|
| 1001 |
+
if volume.ndim == 4:
|
| 1002 |
+
if select_channel is None:
|
| 1003 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 1004 |
+
volume = volume[:, :, :, select_channel]
|
| 1005 |
+
return volume
|
| 1006 |
+
|
| 1007 |
+
def get_filter_mindim(self):
|
| 1008 |
+
# Filter out images with dimensions less than min_dim
|
| 1009 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 1010 |
+
ALLdata = self.ALLdata.copy()
|
| 1011 |
+
for k in self.ALLdata.keys():
|
| 1012 |
+
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
|
| 1013 |
+
del ALLdata[k]
|
| 1014 |
+
return ALLdata
|
| 1015 |
+
|
| 1016 |
+
def get_filter_ROI(self, key_word):
|
| 1017 |
+
# Filter out images with a key word
|
| 1018 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 1019 |
+
for k in self.ALLdata_filtered.keys():
|
| 1020 |
+
if key_word not in k["ROI"]:
|
| 1021 |
+
del ALLdata_filtered[k]
|
| 1022 |
+
return ALLdata_filtered
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
def get_filter_ROIs(self):
|
| 1026 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 1027 |
+
for k in self.ALLdata_filtered.keys():
|
| 1028 |
+
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
|
| 1029 |
+
del ALLdata_filtered[k]
|
| 1030 |
+
return ALLdata_filtered
|
| 1031 |
+
|
| 1032 |
+
def get_keys_dist(self):
|
| 1033 |
+
ROIs = self.get_all_ROI()
|
| 1034 |
+
keys_dist = {}
|
| 1035 |
+
total = 0
|
| 1036 |
+
for item in self.ALLdata_filtered.keys():
|
| 1037 |
+
if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
|
| 1038 |
+
keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
|
| 1039 |
+
keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
|
| 1040 |
+
|
| 1041 |
+
return keys_dist, total
|
| 1042 |
+
|
| 1043 |
+
def build_ROI_scan_mapping(self):
|
| 1044 |
+
# Build a mapping of ROIs to scans
|
| 1045 |
+
ROI_scan_mapping = {}
|
| 1046 |
+
for item in self.ALLdata_filtered.keys():
|
| 1047 |
+
if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
|
| 1048 |
+
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
|
| 1049 |
+
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
|
| 1050 |
+
return ROI_scan_mapping
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
def get_random_2_items(self, mode = 'uniform'):
|
| 1054 |
+
# Get a random pair of items from the dataset with the same ROI
|
| 1055 |
+
if mode == 'uniform':
|
| 1056 |
+
idx = random.randint(0, len(self.keys_dist.keys()) - 1)
|
| 1057 |
+
key = list(self.keys_dist.keys())[idx]
|
| 1058 |
+
path_1 = random.choice(self.roi_scan_mapping[key])
|
| 1059 |
+
path_2 = random.choice(self.roi_scan_mapping[key])
|
| 1060 |
+
|
| 1061 |
+
volume_A = sitk.ReadImage(path_1)
|
| 1062 |
+
volume_A = sitk.GetArrayFromImage(volume_A)
|
| 1063 |
+
|
| 1064 |
+
volume_B = sitk.ReadImage(path_2)
|
| 1065 |
+
volume_B = sitk.GetArrayFromImage(volume_B)
|
| 1066 |
+
|
| 1067 |
+
if self.clamp_range is not None:
|
| 1068 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 1069 |
+
if modality == "CT":
|
| 1070 |
+
volume_A = np.clip(volume_A, self.clamp_range[0], self.clamp_range[1])
|
| 1071 |
+
volume_B = np.clip(volume_B, self.clamp_range[0], self.clamp_range[1])
|
| 1072 |
+
volume_A = self.normalize(volume_A)
|
| 1073 |
+
volume_B = self.normalize(volume_B)
|
| 1074 |
+
|
| 1075 |
+
if self.min_crop_ratio is not None:
|
| 1076 |
+
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 1077 |
+
crop_size_A = int(min(volume_A.shape) * crop_ratio)
|
| 1078 |
+
crop_size_B = int(min(volume_B.shape) * crop_ratio)
|
| 1079 |
+
volume_A = self.random_crop_3d(volume_A, crop_size_A)
|
| 1080 |
+
volume_B = self.random_crop_3d(volume_B, crop_size_B)
|
| 1081 |
+
volume_A = resize(volume_A, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 1082 |
+
volume_B = resize(volume_B, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 1083 |
+
else:
|
| 1084 |
+
volume_A = self.radndom_crop_3d(volume_A, self.out_sz)
|
| 1085 |
+
volume_B = self.radndom_crop_3d(volume_B, self.out_sz)
|
| 1086 |
+
volume_A = volume_A[None, :, :, :]
|
| 1087 |
+
volume_B = volume_B[None, :, :, :]
|
| 1088 |
+
if self.transform is not None:
|
| 1089 |
+
return self.transform(volume_A), self.transform(volume_B)
|
| 1090 |
+
return volume_A, volume_B
|
| 1091 |
+
|
| 1092 |
+
elif mode == 'original':
|
| 1093 |
+
pass
|
| 1094 |
+
|
| 1095 |
+
def build_batch(self, batch_size = 2):
|
| 1096 |
+
batch_1 = []
|
| 1097 |
+
batch_2 = []
|
| 1098 |
+
for i in range(batch_size):
|
| 1099 |
+
V_a, V_b = self.get_random_2_items()
|
| 1100 |
+
batch_1.append(V_a)
|
| 1101 |
+
batch_2.append(V_b)
|
| 1102 |
+
return np.array(batch_1), np.array(batch_2)
|
| 1103 |
+
|
| 1104 |
+
class OminiDataset_inference_w_all(object):
|
| 1105 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.75, ROIs = None, label_key = ['brain'], task_key = 'segmentation', database = None, select_channels_dict = {}):
|
| 1106 |
+
self.mappings = mapping_files
|
| 1107 |
+
# database=['MSD', 'TotalSegmentor']
|
| 1108 |
+
if database is not None:
|
| 1109 |
+
self.mappings = {db: self.mappings[db] for db in database if db in self.mappings}
|
| 1110 |
+
# select_channels_dict={
|
| 1111 |
+
# "ImgDict":["ed","es"]
|
| 1112 |
+
# }
|
| 1113 |
+
self.select_channels_dict = select_channels_dict
|
| 1114 |
+
self.ALLdata = self.combine_data(mappings=self.mappings)
|
| 1115 |
+
self.out_sz = out_sz
|
| 1116 |
+
self.label_key = label_key
|
| 1117 |
+
self.min_crop_ratio = min_crop_ratio
|
| 1118 |
+
self.transform = transform
|
| 1119 |
+
self.clamp_range = clamp_range
|
| 1120 |
+
self.ndims = 3
|
| 1121 |
+
self.is_reverse_axis_order = True # for inference, always reverse axis order (nifty is reverse order than numpy)
|
| 1122 |
+
|
| 1123 |
+
# Start you filtering here:
|
| 1124 |
+
# self.ALLdata_filtered = self.ALLdata.copy()
|
| 1125 |
+
# filter out images with dimensions less than min_dim
|
| 1126 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 1127 |
+
# filter out images with ROIs that are not in the provided ROIs
|
| 1128 |
+
if ROIs is None:
|
| 1129 |
+
self.ROIs = self.get_all_ROI()
|
| 1130 |
+
else:
|
| 1131 |
+
self.ROIs = ROIs
|
| 1132 |
+
self.ALLdata_filtered = self.get_filter_ROIs()
|
| 1133 |
+
self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
|
| 1134 |
+
# filtering ends here
|
| 1135 |
+
|
| 1136 |
+
self.roi_scan_mapping = self.build_ROI_scan_mapping()
|
| 1137 |
+
self.keys_dist, self.total = self.get_keys_dist()
|
| 1138 |
+
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
def get_all_ROI(self):
|
| 1142 |
+
# Get all the ROI options. and remove the reduntant ones
|
| 1143 |
+
ROIs = []
|
| 1144 |
+
for k in self.ALLdata_filtered.keys():
|
| 1145 |
+
ROIs.append(self.ALLdata[k]['ROI'])
|
| 1146 |
+
ROIs = set(ROIs)
|
| 1147 |
+
return ROIs
|
| 1148 |
+
|
| 1149 |
+
def get_keys_dist(self):
|
| 1150 |
+
ROIs = self.get_all_ROI()
|
| 1151 |
+
keys_dist = {}
|
| 1152 |
+
total = 0
|
| 1153 |
+
for item in self.ALLdata_filtered.keys():
|
| 1154 |
+
if self.ALLdata_filtered[item]['ROI'] not in keys_dist:
|
| 1155 |
+
keys_dist[self.ALLdata_filtered[item]['ROI']] = 0
|
| 1156 |
+
keys_dist[self.ALLdata_filtered[item]['ROI']] += 1
|
| 1157 |
+
|
| 1158 |
+
return keys_dist, total
|
| 1159 |
+
|
| 1160 |
+
def build_ROI_scan_mapping(self):
|
| 1161 |
+
# Build a mapping of ROIs to scans
|
| 1162 |
+
ROI_scan_mapping = {}
|
| 1163 |
+
for item in self.ALLdata_filtered.keys():
|
| 1164 |
+
if self.ALLdata_filtered[item]['ROI'] not in ROI_scan_mapping:
|
| 1165 |
+
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']] = []
|
| 1166 |
+
ROI_scan_mapping[self.ALLdata_filtered[item]['ROI']].append(item)
|
| 1167 |
+
return ROI_scan_mapping
|
| 1168 |
+
|
| 1169 |
+
def get_3D_volume(self, volume, select_channel = None):
|
| 1170 |
+
volume = reverse_axis_order(volume) if self.is_reverse_axis_order else volume
|
| 1171 |
+
if volume.ndim == 4:
|
| 1172 |
+
if select_channel is None:
|
| 1173 |
+
select_channel = np.random.randint(0, volume.shape[3] - 1)
|
| 1174 |
+
volume = volume[:, :, :, select_channel]
|
| 1175 |
+
# print(f"Volume shape: {volume.shape}, selected channel: {select_channel}")
|
| 1176 |
+
return volume
|
| 1177 |
+
|
| 1178 |
+
def get_filter_mindim(self):
|
| 1179 |
+
# Filter out images with dimensions less than min_dim
|
| 1180 |
+
# Top priority is to filter out images with dimensions less than min_dim
|
| 1181 |
+
ALLdata = self.ALLdata.copy()
|
| 1182 |
+
for k in self.ALLdata.keys():
|
| 1183 |
+
if min(self.ALLdata[k]['Size'][:self.ndims]) < self.out_sz/2:
|
| 1184 |
+
del ALLdata[k]
|
| 1185 |
+
return ALLdata
|
| 1186 |
+
|
| 1187 |
+
def find_min_dim(self):
|
| 1188 |
+
# Find the minimum dimension of the images
|
| 1189 |
+
min_dim = 100000
|
| 1190 |
+
for k in self.ALLdata.keys():
|
| 1191 |
+
value = self.ALLdata[k]
|
| 1192 |
+
if min(value['Size']) < min_dim:
|
| 1193 |
+
min_dim = min(value['Size'])
|
| 1194 |
+
return min_dim
|
| 1195 |
+
|
| 1196 |
+
# def combine_data(self):
|
| 1197 |
+
# ALLdata = {}
|
| 1198 |
+
# for j in self.mappings.keys():
|
| 1199 |
+
# with open(self.mappings[j], 'r') as f:
|
| 1200 |
+
# mappings = json.load(f)
|
| 1201 |
+
# ALLdata.update(mappings)
|
| 1202 |
+
# return ALLdata
|
| 1203 |
+
|
| 1204 |
+
def combine_data(self, mappings = mapping_files):
|
| 1205 |
+
ALLdata = {}
|
| 1206 |
+
for j in mappings.keys():
|
| 1207 |
+
with open(mappings[j], 'r') as f:
|
| 1208 |
+
mappings_tmp = json.load(f)
|
| 1209 |
+
ALLdata.update(mappings_tmp)
|
| 1210 |
+
return ALLdata
|
| 1211 |
+
|
| 1212 |
+
def normalize(self, volume, eps=1e-7):
|
| 1213 |
+
# Normalize the image (0-1)
|
| 1214 |
+
volume = volume.astype(np.float64)
|
| 1215 |
+
volume = (volume - np.min(volume)) / (np.ptp(volume) + eps)
|
| 1216 |
+
return volume
|
| 1217 |
+
|
| 1218 |
+
def get_key_by_ROI(self, key_word):
|
| 1219 |
+
# Get all the keys with a key word
|
| 1220 |
+
keys = []
|
| 1221 |
+
for k in self.ALLdata_filtered.keys():
|
| 1222 |
+
if key_word == self.ALLdata_filtered[k]["ROI"]:
|
| 1223 |
+
keys.append(k)
|
| 1224 |
+
return keys
|
| 1225 |
+
|
| 1226 |
+
def get_filter_task(self, task_key = 'segmentation'):
|
| 1227 |
+
# Filter out images with task type that are not in the provided labels_path
|
| 1228 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 1229 |
+
for k in self.ALLdata_filtered.keys():
|
| 1230 |
+
if 'Label_path' not in self.ALLdata_filtered[k] or task_key not in self.ALLdata_filtered[k]['Label_path']:
|
| 1231 |
+
del ALLdata_filtered[k]
|
| 1232 |
+
Warning(f"Label path not found for {k} with task key {task_key}. This image will be removed from the dataset.")
|
| 1233 |
+
return ALLdata_filtered
|
| 1234 |
+
|
| 1235 |
+
def get_filter_labels(self, task_key='segmentation', label_keys=['heart']):
|
| 1236 |
+
# Filter out images where 'Label_path' does not contain any of the label_keys for the given task_key
|
| 1237 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 1238 |
+
keys_to_remove = []
|
| 1239 |
+
for k in list(ALLdata_filtered.keys()):
|
| 1240 |
+
label_path = ALLdata_filtered[k].get('Label_path', {})
|
| 1241 |
+
task_labels = label_path.get(task_key, {})
|
| 1242 |
+
# Check if any label_keys are present in task_labels
|
| 1243 |
+
# print(f"Checking {k} for task key {task_labels.keys()} with label keys {label_keys}")
|
| 1244 |
+
has_any_label = any((tk in label_keys) for tk in task_labels.keys())
|
| 1245 |
+
# print(f"Has any label: {has_any_label}")
|
| 1246 |
+
if not has_any_label:
|
| 1247 |
+
keys_to_remove.append(k)
|
| 1248 |
+
# print(f"Label path not found for {k} with task key {task_key} and label keys {label_keys}. This image will be removed from the dataset.")
|
| 1249 |
+
for k in keys_to_remove:
|
| 1250 |
+
del ALLdata_filtered[k]
|
| 1251 |
+
return ALLdata_filtered
|
| 1252 |
+
|
| 1253 |
+
def get_random_pad_crop_params(self, volume_shape, crop_size=None, random=True):
|
| 1254 |
+
# Get random padding and cropping parameters for a given shape
|
| 1255 |
+
d, h, w = volume_shape[:3]
|
| 1256 |
+
if crop_size is None:
|
| 1257 |
+
crop_size = self.out_sz
|
| 1258 |
+
crop_d, crop_h, crop_w = crop_size, crop_size, crop_size
|
| 1259 |
+
|
| 1260 |
+
# Calculate padding
|
| 1261 |
+
pad_width = []
|
| 1262 |
+
for size, crop in zip((d, h, w), (crop_d, crop_h, crop_w)):
|
| 1263 |
+
if crop > size:
|
| 1264 |
+
total_pad = crop - size
|
| 1265 |
+
pad_before = np.random.randint(0, total_pad + 1)
|
| 1266 |
+
pad_after = total_pad - pad_before
|
| 1267 |
+
pad_width.append((pad_before, pad_after))
|
| 1268 |
+
else:
|
| 1269 |
+
pad_width.append((0, 0))
|
| 1270 |
+
|
| 1271 |
+
# Update shape after padding
|
| 1272 |
+
d_p, h_p, w_p = d + pad_width[0][0] + pad_width[0][1], h + pad_width[1][0] + pad_width[1][1], w + pad_width[2][0] + pad_width[2][1]
|
| 1273 |
+
|
| 1274 |
+
if random:
|
| 1275 |
+
# Calculate cropping start indices (random crop)
|
| 1276 |
+
start_d = np.random.randint(0, d_p - crop_d + 1) if d_p > crop_d else 0
|
| 1277 |
+
start_h = np.random.randint(0, h_p - crop_h + 1) if h_p > crop_h else 0
|
| 1278 |
+
start_w = np.random.randint(0, w_p - crop_w + 1) if w_p > crop_w else 0
|
| 1279 |
+
else:
|
| 1280 |
+
# Calculate cropping start indices (center crop)
|
| 1281 |
+
start_d = max((d_p - crop_d) // 2, 0)
|
| 1282 |
+
start_h = max((h_p - crop_h) // 2, 0)
|
| 1283 |
+
start_w = max((w_p - crop_w) // 2, 0)
|
| 1284 |
+
|
| 1285 |
+
crop_slices = (start_d, start_h, start_w, crop_d, crop_h, crop_w)
|
| 1286 |
+
return pad_width, crop_slices
|
| 1287 |
+
|
| 1288 |
+
def apply_pad_crop(self, volume, pad_width, crop_slices):
|
| 1289 |
+
# Apply padding and cropping to the volume
|
| 1290 |
+
if any(pad != (0, 0) for pad in pad_width):
|
| 1291 |
+
volume = np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 1292 |
+
start_d, start_h, start_w, crop_d, crop_h, crop_w = crop_slices
|
| 1293 |
+
cropped_array = volume[start_d:start_d + crop_d, start_h:start_h + crop_h, start_w:start_w + crop_w]
|
| 1294 |
+
return cropped_array
|
| 1295 |
+
|
| 1296 |
+
def get_filter_ROIs(self):
|
| 1297 |
+
ALLdata_filtered = self.ALLdata_filtered.copy()
|
| 1298 |
+
for k in self.ALLdata_filtered.keys():
|
| 1299 |
+
if self.ALLdata_filtered[k]['ROI'] not in self.ROIs:
|
| 1300 |
+
del ALLdata_filtered[k]
|
| 1301 |
+
return ALLdata_filtered
|
| 1302 |
+
|
| 1303 |
+
def get_channel_ids(self, key):
|
| 1304 |
+
"""
|
| 1305 |
+
Get the indices where ImgDict values match the selected channels (e.g., 'ed', 'es').
|
| 1306 |
+
|
| 1307 |
+
Returns:
|
| 1308 |
+
list: List of integer indices matching the selected channels
|
| 1309 |
+
"""
|
| 1310 |
+
img_dict = self.ALLdata_filtered[key].get("ImgDict", {})
|
| 1311 |
+
selected_values = self.select_channels_dict.get("ImgDict", [])
|
| 1312 |
+
# Build reverse mapping: value -> index
|
| 1313 |
+
value_to_idx = {value: int(idx) for idx, value in img_dict.items()}
|
| 1314 |
+
|
| 1315 |
+
# Get indices in the order of selected_values
|
| 1316 |
+
indices = [
|
| 1317 |
+
value_to_idx[val] for val in selected_values
|
| 1318 |
+
if val in value_to_idx
|
| 1319 |
+
]
|
| 1320 |
+
return indices
|
| 1321 |
+
# return sorted(indices)
|
| 1322 |
+
|
| 1323 |
+
def __len__(self):
|
| 1324 |
+
return len(self.ALLdata_filtered.keys())
|
| 1325 |
+
|
| 1326 |
+
def __getitem__(self, idx):
|
| 1327 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 1328 |
+
return_dict = dict()
|
| 1329 |
+
|
| 1330 |
+
print(f"Processing key: {key}")
|
| 1331 |
+
|
| 1332 |
+
volume = sitk.ReadImage(key)
|
| 1333 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 1334 |
+
|
| 1335 |
+
if volume.ndim == 4:
|
| 1336 |
+
channel_ids = self.get_channel_ids(key)
|
| 1337 |
+
if len(channel_ids) == 0:
|
| 1338 |
+
# warning message that this key has no matching channels
|
| 1339 |
+
Warning(f"No matching channels found for key: {key} with ImgDict: {self.ALLdata_filtered[key].get('ImgDict', {})} and selected channels: {self.select_channels_dict.get('ImgDict', [])}. Using random channel.")
|
| 1340 |
+
channel_id = None
|
| 1341 |
+
else:
|
| 1342 |
+
channel_id=channel_ids[0]
|
| 1343 |
+
|
| 1344 |
+
volume = self.get_3D_volume(volume, select_channel = channel_id)
|
| 1345 |
+
|
| 1346 |
+
if self.clamp_range is not None:
|
| 1347 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 1348 |
+
if modality == "CT":
|
| 1349 |
+
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
|
| 1350 |
+
volume = self.normalize(volume)
|
| 1351 |
+
|
| 1352 |
+
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 1353 |
+
|
| 1354 |
+
crop_size = int(max(volume.shape) * crop_ratio)
|
| 1355 |
+
pad_width, crop_slices = self.get_random_pad_crop_params(volume.shape, crop_size)
|
| 1356 |
+
# print(f"Pad width: {pad_width}, Crop slices: {crop_slices}, Original shape: {volume.shape}")
|
| 1357 |
+
volume = self.apply_pad_crop(volume, pad_width, crop_slices)
|
| 1358 |
+
|
| 1359 |
+
label_dict = dict()
|
| 1360 |
+
if 'Label_path' in self.ALLdata_filtered[key]:
|
| 1361 |
+
for lk in self.label_key:
|
| 1362 |
+
if lk in self.ALLdata_filtered[key]['Label_path']['segmentation'].keys():
|
| 1363 |
+
label = sitk.ReadImage(self.ALLdata_filtered[key]['Label_path']['segmentation'][lk])
|
| 1364 |
+
label = sitk.GetArrayFromImage(label)
|
| 1365 |
+
# print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
|
| 1366 |
+
label = reverse_axis_order(label) if self.is_reverse_axis_order else label
|
| 1367 |
+
|
| 1368 |
+
# print(f"Label shape: {label.shape}, key: {key}, label key: {lk}")
|
| 1369 |
+
if label.ndim > self.ndims:
|
| 1370 |
+
if len(channel_ids) != 0:
|
| 1371 |
+
label = label[...,channel_ids] # assuming channel last
|
| 1372 |
+
pad_width_lab = pad_width + [(0,0)]*(label.ndim - self.ndims)
|
| 1373 |
+
# print(f"Label with channels, pad_width_lab: {pad_width_lab}")
|
| 1374 |
+
else:
|
| 1375 |
+
pad_width_lab = pad_width
|
| 1376 |
+
label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
|
| 1377 |
+
# print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
|
| 1378 |
+
label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
|
| 1379 |
+
if label.ndim > self.ndims:
|
| 1380 |
+
if self.ndims==3:
|
| 1381 |
+
label_dict[lk] = np.transpose(label_dict[lk], (3,0,1,2)) # assuming channel last
|
| 1382 |
+
elif self.ndims==4:
|
| 1383 |
+
label_dict[lk] = np.transpose(label_dict[lk], (4,0,1,2,3)) # assuming channel last
|
| 1384 |
+
# print(f"After resize, label shape: {label_dict[lk].shape}, key: {key}, label key: {lk}")
|
| 1385 |
+
else:
|
| 1386 |
+
label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
|
| 1387 |
+
Warning(f"Label path not found for {key} with label key {lk}.")
|
| 1388 |
+
label_dict[lk] = label_dict[lk][None, :, :, :] if label_dict[lk].ndim == 3 else label_dict[lk]
|
| 1389 |
+
else:
|
| 1390 |
+
for lk in self.label_key:
|
| 1391 |
+
label_dict[lk] = np.full([self.out_sz]*self.ndims, -1)
|
| 1392 |
+
Warning(f"Label path not found for {key} with label key {lk}.")
|
| 1393 |
+
label_dict[lk] = label_dict[lk][None, :, :, :]
|
| 1394 |
+
|
| 1395 |
+
volume =resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 1396 |
+
# return_dict['labels'] = label_dict
|
| 1397 |
+
return_dict['labels'] = np.concatenate([v for v in label_dict.values()], axis=1)
|
| 1398 |
+
|
| 1399 |
+
return_dict['img'] = volume[None, :, :, :]
|
| 1400 |
+
return_dict['label_channels'] = list(self.select_channels_dict.get("ImgDict", []))
|
| 1401 |
+
return return_dict
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
class OminiDataset_bertembd(OminiDataset):
|
| 1405 |
+
def __init__(self,
|
| 1406 |
+
out_sz = 128,
|
| 1407 |
+
transform=None,
|
| 1408 |
+
clamp_range = CLAMP_RANGE,
|
| 1409 |
+
min_crop_ratio = 0.85,
|
| 1410 |
+
ROIs = None,
|
| 1411 |
+
modality = None,
|
| 1412 |
+
reverse_axis_order = False,
|
| 1413 |
+
min_dim = 3,
|
| 1414 |
+
mapping_files = mapping_files):
|
| 1415 |
+
super().init(out_sz = out_sz,
|
| 1416 |
+
transform = transform,
|
| 1417 |
+
clamp_range = clamp_range,
|
| 1418 |
+
min_crop_ratio = min_crop_ratio,
|
| 1419 |
+
ROIs = ROIs,
|
| 1420 |
+
modality = modality,
|
| 1421 |
+
reverse_axis_order = reverse_axis_order,
|
| 1422 |
+
min_dim = min_dim,
|
| 1423 |
+
mapping_files=mapping_files)
|
| 1424 |
+
# start you filtering here
|
| 1425 |
+
self.ALLdata_filtered = self.get_filter_mindim()
|
| 1426 |
+
if ROIs is None:
|
| 1427 |
+
# if no ROIs are provided, get all the ROIs from filtered data
|
| 1428 |
+
self.ROIs = self.get_all_ROI()
|
| 1429 |
+
else:
|
| 1430 |
+
self.ROIs = ROIs
|
| 1431 |
+
self.ALLdata_filtered = self.get_filter_ROIs()
|
| 1432 |
+
# self.ALLdata_filtered = self.filter_embd()
|
| 1433 |
+
# self.ALLdata_filtered = self.get_filter_labels(task_key=task_key,label_keys=label_key)
|
| 1434 |
+
# end your filtering here
|
| 1435 |
+
def __getitem__(self, idx):
|
| 1436 |
+
key = list(self.ALLdata_filtered.keys())[idx]
|
| 1437 |
+
embd = self.ALLdata_filtered[key]['embd']
|
| 1438 |
+
if 0:
|
| 1439 |
+
print(key)
|
| 1440 |
+
|
| 1441 |
+
volume = sitk.ReadImage(key)
|
| 1442 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 1443 |
+
volume = self.get_3D_volume(volume)
|
| 1444 |
+
|
| 1445 |
+
if self.clamp_range is not None:
|
| 1446 |
+
modality = self.ALLdata_filtered[key].get("Modality", None)
|
| 1447 |
+
if modality == "CT":
|
| 1448 |
+
volume = np.clip(volume, self.clamp_range[0], self.clamp_range[1])
|
| 1449 |
+
volume = self.normalize(volume)
|
| 1450 |
+
|
| 1451 |
+
if self.min_crop_ratio is not None:
|
| 1452 |
+
crop_ratio = np.random.uniform(self.min_crop_ratio, 1)
|
| 1453 |
+
crop_size = int(max(volume.shape) * crop_ratio)
|
| 1454 |
+
volume = self.random_crop_3d(volume, crop_size)
|
| 1455 |
+
volume = resize(volume, [self.out_sz]*self.ndims, anti_aliasing = True, preserve_range = True)
|
| 1456 |
+
else:
|
| 1457 |
+
volume = self.random_crop_3d(volume, self.out_sz)
|
| 1458 |
+
volume = volume[None, :, :, :]
|
| 1459 |
+
|
| 1460 |
+
if self.transform is not None:
|
| 1461 |
+
return self.transform(volume)
|
| 1462 |
+
|
| 1463 |
+
return volume,np.array(embd)
|
| 1464 |
+
|
| 1465 |
+
def __len__(self):
|
| 1466 |
+
return len(self.ALLdata_filtered.keys())
|
| 1467 |
+
|
| 1468 |
+
def filter_embd(self):
|
| 1469 |
+
for k in self.ALLdata_filtered.keys():
|
| 1470 |
+
if 'BERT_embedding_keys' not in self.ALLdata_filtered[k]['Metadata']:
|
| 1471 |
+
del self.ALLdata_filtered[k]
|
| 1472 |
+
return self.ALLdata_filtered
|
| 1473 |
+
|
Dataloader/dataloader0.py
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
from torchvision import datasets, transforms
|
| 6 |
+
import nibabel as nib
|
| 7 |
+
from skimage.transform import rescale, resize, downscale_local_mean
|
| 8 |
+
from scipy.ndimage import zoom
|
| 9 |
+
import numpy as np
|
| 10 |
+
# import SimpleITK as sitk
|
| 11 |
+
|
| 12 |
+
# print(os.getcwd())
|
| 13 |
+
import sys
|
| 14 |
+
sys.path.append('./')
|
| 15 |
+
from Dataloader.dataloader_utils import *
|
| 16 |
+
|
| 17 |
+
EPS = 1e-7
|
| 18 |
+
|
| 19 |
+
def get_dataloader(data_name='cmr',mode='train'):
|
| 20 |
+
if data_name=='cmr':
|
| 21 |
+
if mode=='train':
|
| 22 |
+
dataloader=CMR_loader
|
| 23 |
+
elif mode =='aug':
|
| 24 |
+
dataloader=CMR_tgt_loader
|
| 25 |
+
else:
|
| 26 |
+
print('mode not exist')
|
| 27 |
+
elif data_name=='lct':
|
| 28 |
+
if mode=='train':
|
| 29 |
+
dataloader=LCT_loader
|
| 30 |
+
elif mode =='aug':
|
| 31 |
+
dataloader=LCT_tgt_loader
|
| 32 |
+
else:
|
| 33 |
+
print('mode not exist')
|
| 34 |
+
else:
|
| 35 |
+
print('dataloader not exist')
|
| 36 |
+
return dataloader
|
| 37 |
+
|
| 38 |
+
class LCT_loader(Dataset):
|
| 39 |
+
def __init__(self, data_root_path = f'Data/Src_data/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
|
| 40 |
+
# def __init__(self, data_root_path = '/home/data/jzheng/CTLung_processed/', target_res = (256, 256),transforms = None, noise_scale=0.0, patient_index = None):
|
| 41 |
+
self.files = [data_root_path + f for f in os.listdir(data_root_path) if f.endswith('.npy')]
|
| 42 |
+
self.transforms = transforms
|
| 43 |
+
self.noise_scale=noise_scale
|
| 44 |
+
self.d_p = data_root_path
|
| 45 |
+
|
| 46 |
+
def __getitem__(self, item):
|
| 47 |
+
array = np.load(self.files[item])
|
| 48 |
+
if 'process' not in self.d_p:
|
| 49 |
+
array = (array - array.min()) / (array.max() - array.min() + EPS) # Normalize to 0 to 1
|
| 50 |
+
array = array[None,:,:,:] # add a channel to array make it (‘C’,H,W,Z)
|
| 51 |
+
if self.transforms != None:
|
| 52 |
+
array = self.transforms(array)
|
| 53 |
+
# print(array.shape)
|
| 54 |
+
return array, array, item # -> (B, C, H, W, Z)
|
| 55 |
+
# return array, array # -> (B, C, H, W, Z)
|
| 56 |
+
|
| 57 |
+
def __len__(self):
|
| 58 |
+
return len(self.files)
|
| 59 |
+
|
| 60 |
+
class LCT_tgt_loader(Dataset):
|
| 61 |
+
def __init__(self, data_root_path = "Data/Tgt_data/lct/",noise_scale=0.0, patient_index = None):
|
| 62 |
+
self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
|
| 63 |
+
self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
|
| 64 |
+
|
| 65 |
+
self.files_tr.sort()
|
| 66 |
+
self.files_gt.sort()
|
| 67 |
+
|
| 68 |
+
self.transforms = transforms
|
| 69 |
+
self.noise_scale=noise_scale
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, item):
|
| 72 |
+
img_nib = nib.load(self.files_tr[item])
|
| 73 |
+
mask_nib = nib.load(self.files_gt[item])
|
| 74 |
+
|
| 75 |
+
image = img_nib.get_fdata()
|
| 76 |
+
mask = mask_nib.get_fdata()
|
| 77 |
+
|
| 78 |
+
image = image[None,:,:,:]
|
| 79 |
+
mask = mask[None,:,:,:]
|
| 80 |
+
|
| 81 |
+
print(self.files_tr[item],self.files_gt[item])
|
| 82 |
+
|
| 83 |
+
return image, mask, item
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
assert len(self.files_gt) == len(self.files_tr)
|
| 89 |
+
return len(self.files_gt)
|
| 90 |
+
|
| 91 |
+
class LCT_seg(Dataset):
|
| 92 |
+
def __init__(self, data_root_path = "/home/data/jzheng/CTLung_processed/testset/modality_0001/",noise_scale=0.0, patient_index = None):
|
| 93 |
+
self.files_gt = [data_root_path + "Gt/" + f for f in os.listdir(data_root_path + "Gt/")]
|
| 94 |
+
self.files_tr = [data_root_path + 'Tr/' + f for f in os.listdir(data_root_path + "Tr/")]
|
| 95 |
+
|
| 96 |
+
self.files_tr.sort()
|
| 97 |
+
self.files_gt.sort()
|
| 98 |
+
|
| 99 |
+
self.transforms = transforms
|
| 100 |
+
self.noise_scale=noise_scale
|
| 101 |
+
|
| 102 |
+
def __getitem__(self, item):
|
| 103 |
+
img_nib = nib.load(self.files_tr[item])
|
| 104 |
+
mask_nib = nib.load(self.files_gt[item])
|
| 105 |
+
|
| 106 |
+
image = img_nib.get_fdata()
|
| 107 |
+
mask = mask_nib.get_fdata()
|
| 108 |
+
|
| 109 |
+
image = image[None,:,:,:]
|
| 110 |
+
mask = mask[None,:,:,:]
|
| 111 |
+
|
| 112 |
+
print(self.files_tr[item],self.files_gt[item])
|
| 113 |
+
|
| 114 |
+
return image, mask, item
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
assert len(self.files_gt) == len(self.files_tr)
|
| 120 |
+
return len(self.files_gt)
|
| 121 |
+
|
| 122 |
+
class CMR_loader_preprocess(Dataset):
|
| 123 |
+
# This is for pre_processing for CMR. not use for training model
|
| 124 |
+
def __init__(self, data_path = 'Data/CTLung_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
| 125 |
+
# def __init__(self, data_path = '/home/data/jzheng/CMR_processed/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
| 126 |
+
self.d_p = data_path
|
| 127 |
+
self.target_res = target_res
|
| 128 |
+
self.files = [self.d_p + x for x in os.listdir(self.d_p)]
|
| 129 |
+
self.transforms = transforms
|
| 130 |
+
self.noise_scale=noise_scale
|
| 131 |
+
|
| 132 |
+
def __getitem__(self, item):
|
| 133 |
+
array = nib.load(self.files[item]).get_fdata()
|
| 134 |
+
array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
|
| 135 |
+
array = array[None, :, :]
|
| 136 |
+
array = remove_background(array) # jzheng 20240228
|
| 137 |
+
array = (array - array.min()) / (array.max() - array.min() + EPS)
|
| 138 |
+
|
| 139 |
+
if self.noise_scale > 0:
|
| 140 |
+
array = thresh_img(array,[0,self.noise_scale])
|
| 141 |
+
array = array * (np.random.normal(1, self.noise_scale*2))
|
| 142 |
+
|
| 143 |
+
if self.transforms != None:
|
| 144 |
+
array = self.transforms(array)
|
| 145 |
+
return array, self.files[item]
|
| 146 |
+
|
| 147 |
+
def __len__(self):
|
| 148 |
+
return len(self.files)
|
| 149 |
+
|
| 150 |
+
class CMR_loader(Dataset):
|
| 151 |
+
# niff format size is (H,W) for CMR
|
| 152 |
+
# CMR_processed_rmbg_resize means the niif image has been gone throught rmbg and resize offline to make trainig fast
|
| 153 |
+
def __init__(self, data_path = f'Data/Src_data/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
| 154 |
+
# def __init__(self, data_path = '/home/data/jzheng/CMR_processed_rmbg_resize/', target_res = (256, 256), transforms = None, noise_scale=0.0):
|
| 155 |
+
self.d_p = data_path
|
| 156 |
+
self.ndims = 2
|
| 157 |
+
self.target_res = target_res
|
| 158 |
+
self.files = [self.d_p + x for x in os.listdir(self.d_p)]
|
| 159 |
+
self.transforms = transforms
|
| 160 |
+
# self.get_transform()
|
| 161 |
+
self.noise_scale=noise_scale
|
| 162 |
+
self.preprocessed='resize' in data_path
|
| 163 |
+
|
| 164 |
+
def __getitem__(self, item):
|
| 165 |
+
array = nib.load(self.files[item]).get_fdata()
|
| 166 |
+
if not self.preprocessed:
|
| 167 |
+
array = resize(array, self.target_res, anti_aliasing = True, preserve_range = True)
|
| 168 |
+
array = array[None, :, :]
|
| 169 |
+
if not self.preprocessed:
|
| 170 |
+
array = remove_background(array) # jzheng 20240228
|
| 171 |
+
array = (array - array.min()) / (array.max() - array.min() + EPS)
|
| 172 |
+
|
| 173 |
+
# if self.noise_scale > 0:
|
| 174 |
+
# array = thresh_img(array,[0,self.noise_scale])
|
| 175 |
+
# array = array * (np.random.normal(1, self.noise_scale*2)) + np.random.normal(0, self.noise_scale*2)
|
| 176 |
+
|
| 177 |
+
if self.transforms != None:
|
| 178 |
+
array = self.transforms(array)
|
| 179 |
+
return array, array, item
|
| 180 |
+
|
| 181 |
+
def __len__(self):
|
| 182 |
+
return len(self.files)
|
| 183 |
+
|
| 184 |
+
def get_transform(self,degrees=np.pi,translate=0.125):
|
| 185 |
+
# self.transforms = torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
|
| 186 |
+
self.transforms = torchvision.transforms.Compose([
|
| 187 |
+
# torchvision.transforms.Resize((hyp_parameters['img_size'], hyp_parameters['img_size'])),
|
| 188 |
+
torchvision.transforms.ToTensor(),
|
| 189 |
+
torchvision.transforms.RandomAffine(degrees=degrees,translate=[translate]*self.ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
|
| 190 |
+
# torchvision.transforms.ToTensor(),
|
| 191 |
+
# torchvision.transforms.Normalize(0.5, 0.5)
|
| 192 |
+
# Lambda(lambda x: (x - 0.5) * 2)
|
| 193 |
+
])
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
class CMR_tgt_loader(Dataset):
|
| 197 |
+
def __init__(self,
|
| 198 |
+
data_path = 'Data/Tgt_data/cmr/',
|
| 199 |
+
# gt_path = '/home/data/jzheng/acdc/train_gt/',
|
| 200 |
+
target_res = (256,256),
|
| 201 |
+
is_3d = False,
|
| 202 |
+
patient_index = [],
|
| 203 |
+
):
|
| 204 |
+
|
| 205 |
+
# parameter initialize
|
| 206 |
+
self.d_p = os.path.join(data_path,'Tr','')
|
| 207 |
+
self.gt_p = os.path.join(data_path,'Gt','')
|
| 208 |
+
self.img_files = os.listdir(self.d_p)
|
| 209 |
+
self.gt_files = os.listdir(self.gt_p)
|
| 210 |
+
self.p_indice = patient_index
|
| 211 |
+
self.target_res_2d = target_res
|
| 212 |
+
self.img_files.sort()
|
| 213 |
+
self.gt_files.sort()
|
| 214 |
+
self.img_samples = []
|
| 215 |
+
self.gt_samples = []
|
| 216 |
+
self.p_id = []
|
| 217 |
+
|
| 218 |
+
if len(self.p_indice) == 0:
|
| 219 |
+
self.p_indice = [x for x in range(1,101)]
|
| 220 |
+
# build patient-to-file correspondence
|
| 221 |
+
p2f = {}
|
| 222 |
+
assert len(self.gt_files) == len(self.img_files)
|
| 223 |
+
print(self.p_indice)
|
| 224 |
+
for i in self.p_indice:
|
| 225 |
+
for gt_f, img_f in zip(self.gt_files, self.img_files):
|
| 226 |
+
pf_id = gt_f.split('_')[0]
|
| 227 |
+
pf_id = pf_id[-3:]
|
| 228 |
+
if i == int(pf_id):
|
| 229 |
+
img_volume = nib.load(self.d_p + img_f).get_fdata()
|
| 230 |
+
gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
|
| 231 |
+
assert img_volume.shape == gt_volume.shape
|
| 232 |
+
depth = img_volume.shape[2]
|
| 233 |
+
for si in range(depth):
|
| 234 |
+
img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 235 |
+
img = (img - img.min()) / (img.max() - img.min() + EPS)
|
| 236 |
+
|
| 237 |
+
gt = gt_volume[:, :, si]
|
| 238 |
+
|
| 239 |
+
gt_1_index = gt == 1
|
| 240 |
+
gt_2_index = gt == 2
|
| 241 |
+
gt_3_index = gt == 3
|
| 242 |
+
gt_4_index = gt == 4
|
| 243 |
+
|
| 244 |
+
gt_1 = gt * gt_1_index
|
| 245 |
+
gt_2 = gt * gt_2_index
|
| 246 |
+
gt_3 = gt * gt_3_index
|
| 247 |
+
gt_4 = gt * gt_4_index
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 251 |
+
gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 252 |
+
gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 253 |
+
gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
self.img_samples.append(img[np.newaxis, :, :])
|
| 257 |
+
self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
|
| 258 |
+
self.p_id.append(i)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def __getitem__(self, item):
|
| 262 |
+
|
| 263 |
+
return self.img_samples[item], self.gt_samples[item], self.p_id[item]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def __len__(self):
|
| 267 |
+
|
| 268 |
+
assert len(self.img_samples) == len(self.gt_samples)
|
| 269 |
+
return len(self.img_samples)
|
| 270 |
+
|
| 271 |
+
class acdc_seg(Dataset):
|
| 272 |
+
def __init__(self,
|
| 273 |
+
data_path = '/home/data/jzheng/acdc/train_images/',
|
| 274 |
+
gt_path = '/home/data/jzheng/acdc/train_gt/',
|
| 275 |
+
target_res = (256,256),
|
| 276 |
+
is_3d = False,
|
| 277 |
+
patient_index = [],
|
| 278 |
+
):
|
| 279 |
+
|
| 280 |
+
# parameter initialize
|
| 281 |
+
self.d_p = data_path
|
| 282 |
+
self.gt_p = gt_path
|
| 283 |
+
self.img_files = os.listdir(self.d_p)
|
| 284 |
+
self.gt_files = os.listdir(self.gt_p)
|
| 285 |
+
self.p_indice = patient_index
|
| 286 |
+
self.target_res_2d = target_res
|
| 287 |
+
self.img_files.sort()
|
| 288 |
+
self.gt_files.sort()
|
| 289 |
+
self.img_samples = []
|
| 290 |
+
self.gt_samples = []
|
| 291 |
+
self.p_id = []
|
| 292 |
+
|
| 293 |
+
if len(self.p_indice) == 0:
|
| 294 |
+
self.p_indice = [x for x in range(1,101)]
|
| 295 |
+
# build patient-to-file correspondence
|
| 296 |
+
p2f = {}
|
| 297 |
+
assert len(self.gt_files) == len(self.img_files)
|
| 298 |
+
print(self.p_indice)
|
| 299 |
+
for i in self.p_indice:
|
| 300 |
+
for gt_f, img_f in zip(self.gt_files, self.img_files):
|
| 301 |
+
pf_id = gt_f.split('_')[0]
|
| 302 |
+
pf_id = pf_id[-3:]
|
| 303 |
+
if i == int(pf_id):
|
| 304 |
+
img_volume = nib.load(self.d_p + img_f).get_fdata()
|
| 305 |
+
gt_volume = nib.load(self.gt_p + gt_f).get_fdata()
|
| 306 |
+
assert img_volume.shape == gt_volume.shape
|
| 307 |
+
depth = img_volume.shape[2]
|
| 308 |
+
for si in range(depth):
|
| 309 |
+
img = resize(img_volume[:, :, si], self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 310 |
+
img = (img - img.min()) / (img.max() - img.min() + EPS)
|
| 311 |
+
|
| 312 |
+
gt = gt_volume[:, :, si]
|
| 313 |
+
|
| 314 |
+
gt_1_index = gt == 1
|
| 315 |
+
gt_2_index = gt == 2
|
| 316 |
+
gt_3_index = gt == 3
|
| 317 |
+
gt_4_index = gt == 4
|
| 318 |
+
|
| 319 |
+
gt_1 = gt * gt_1_index
|
| 320 |
+
gt_2 = gt * gt_2_index
|
| 321 |
+
gt_3 = gt * gt_3_index
|
| 322 |
+
gt_4 = gt * gt_4_index
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
gt_1 = resize(gt_1, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 326 |
+
gt_2 = resize(gt_2, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 327 |
+
gt_3 = resize(gt_3, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 328 |
+
gt_4 = resize(gt_4, self.target_res_2d, anti_aliasing=True, preserve_range=True)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
self.img_samples.append(img[np.newaxis, :, :])
|
| 332 |
+
self.gt_samples.append(np.array([gt_1, gt_2, gt_3, gt_4]))
|
| 333 |
+
self.p_id.append(i)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def __getitem__(self, item):
|
| 337 |
+
|
| 338 |
+
return self.img_samples[item], self.gt_samples[item], self.p_id[item]
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def __len__(self):
|
| 342 |
+
|
| 343 |
+
assert len(self.img_samples) == len(self.gt_samples)
|
| 344 |
+
return len(self.img_samples)
|
| 345 |
+
|
| 346 |
+
class acdc_gan(Dataset):
|
| 347 |
+
def __init__(self,
|
| 348 |
+
train_path = '/home/data/jzheng/acdc/images/',
|
| 349 |
+
target_res = (32, 256, 256),
|
| 350 |
+
is_3d = False,
|
| 351 |
+
transforms = None
|
| 352 |
+
):
|
| 353 |
+
self.t_p = train_path
|
| 354 |
+
self.files = os.listdir(self.t_p)
|
| 355 |
+
self.sample_list_2d = []
|
| 356 |
+
self.is_3d = is_3d
|
| 357 |
+
self.target_res = target_res
|
| 358 |
+
self.res_2d = (target_res[1], target_res[2])
|
| 359 |
+
self.transforms = transforms
|
| 360 |
+
|
| 361 |
+
if self.is_3d == False:
|
| 362 |
+
for f in self.files:
|
| 363 |
+
img = nib.load(self.t_p + f).get_fdata()
|
| 364 |
+
depth = img.shape[2]
|
| 365 |
+
f_i = int(round(depth*0.1))
|
| 366 |
+
b_i = int(round(depth*0.9))
|
| 367 |
+
interval_slice = img[:, :, f_i:b_i]
|
| 368 |
+
for ii in range(interval_slice.shape[2]):
|
| 369 |
+
single_slice = interval_slice[:,:,ii]
|
| 370 |
+
single_slice = resize(single_slice, self.res_2d, anti_aliasing=True, preserve_range=True)
|
| 371 |
+
single_slice = (single_slice - single_slice.min()) / ( single_slice.max() - single_slice.min() + EPS)
|
| 372 |
+
self.sample_list_2d.append(single_slice[None,:,:])
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def __len__(self):
|
| 376 |
+
if self.is_3d == False:
|
| 377 |
+
return len(self.sample_list_2d)
|
| 378 |
+
else:
|
| 379 |
+
return len(self.files )
|
| 380 |
+
|
| 381 |
+
def __getitem__(self, index):
|
| 382 |
+
if self.is_3d == False:
|
| 383 |
+
return self.sample_list_2d[index], self.sample_list_2d[index]
|
| 384 |
+
for f in self.files:
|
| 385 |
+
img = nib.load(self.t_p + f).get_fdata()
|
| 386 |
+
target_d_ratio = self.target_res[0] / img.shape[2]
|
| 387 |
+
target_w_ratio = self.target_res[1] / img.shape[0]
|
| 388 |
+
target_h_ratio = self.target_res[2] / img.shape[1]
|
| 389 |
+
|
| 390 |
+
resize_img = zoom(img, (target_w_ratio, target_h_ratio, target_d_ratio))
|
| 391 |
+
|
| 392 |
+
resize_img = np.swapaxes(resize_img, 0, 2)
|
| 393 |
+
resize_img = np.swapaxes(resize_img, 1, 2)
|
| 394 |
+
resize_img = (resize_img - resize_img.min()) / (resize_img.max() - resize_img.min() + EPS)
|
| 395 |
+
if transforms != None:
|
| 396 |
+
resize_img = self.transforms(resize_img)
|
| 397 |
+
return resize_img, resize_img
|
| 398 |
+
|
| 399 |
+
class acdc_gan_single_slice(Dataset):
|
| 400 |
+
def __init__(self, train_path = '/well/papiez/shared/ACDC/clean_training/images/'):
|
| 401 |
+
self.t_p = train_path
|
| 402 |
+
self.files = os.listdir(self.t_p)
|
| 403 |
+
|
| 404 |
+
def __len__(self):
|
| 405 |
+
return len(self.files)
|
| 406 |
+
|
| 407 |
+
def __getitem__(self, index):
|
| 408 |
+
img = self.files[index]
|
| 409 |
+
img = nib.load(self.t_p + img).get_fdata()
|
| 410 |
+
depth = img.shape[2]
|
| 411 |
+
mid_d = int(depth/2)
|
| 412 |
+
mid_slice = img[:,:,mid_d]
|
| 413 |
+
mid_slice = resize(mid_slice, (128, 128), anti_aliasing=True, preserve_range=True)
|
| 414 |
+
mid_slice = (mid_slice-mid_slice.min())/(mid_slice.max()-mid_slice.min()+EPS)
|
| 415 |
+
# print(mid_slice.max(),mid_slice.min())
|
| 416 |
+
|
| 417 |
+
return mid_slice, mid_slice
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
|
Dataloader/dataloader_tester.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataLoader import *
|
| 2 |
+
import torchvision.transforms as tf
|
| 3 |
+
import SimpleITK as sitk
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
transform = tf.Compose([
|
| 8 |
+
tf.ToTensor(), # Convert image to tensor
|
| 9 |
+
])
|
| 10 |
+
|
| 11 |
+
mapping_files_bert = {
|
| 12 |
+
# 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
|
| 13 |
+
# 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
|
| 14 |
+
'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
if __name__ == "__main__":
|
| 18 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 19 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 20 |
+
# dataset = OminiDataset_paired_inf(transform=None)
|
| 21 |
+
# dataset = OminiDataset_inference_w_all(transform=None)
|
| 22 |
+
# dataset = OminiDataset_bertembd(transform=None,mapping_files=mapping_files_bert)
|
| 23 |
+
dataset = OminiDataset(transform=None)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# print(dataset.get_keys_dist())
|
| 29 |
+
# print(len(dataset))
|
| 30 |
+
# print(dataset.build_batch().shape)
|
| 31 |
+
# exit()
|
| 32 |
+
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
| 33 |
+
|
| 34 |
+
for i, data in enumerate(dataloader):
|
| 35 |
+
print(data[1])
|
| 36 |
+
exit()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# print(dataset.get_ALLdata())
|
Dataloader/dataloader_utils.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
# from torch import nn, optim
|
| 4 |
+
# from torch.autograd.variable import Variable
|
| 5 |
+
# from torchvision import transforms, datasets
|
| 6 |
+
# from torchvision.utils import save_image
|
| 7 |
+
# import torch.nn.functional as F
|
| 8 |
+
# import scipy.ndimage as spimg
|
| 9 |
+
# import pyquaternion as quater
|
| 10 |
+
# import random
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.ndimage import gaussian_filter, binary_dilation, binary_erosion, generate_binary_structure
|
| 13 |
+
import pydicom
|
| 14 |
+
from scipy.ndimage import zoom
|
| 15 |
+
from einops import rearrange, reduce, repeat
|
| 16 |
+
|
| 17 |
+
def get_sizeRange_dict(roi=''):
|
| 18 |
+
"""
|
| 19 |
+
Returns a dictionary with size ranges for different regions of interest (ROIs).
|
| 20 |
+
If a specific ROI is provided, returns the size range for that ROI.
|
| 21 |
+
If no ROI is provided, returns the entire dictionary.
|
| 22 |
+
Args:
|
| 23 |
+
roi (str): The region of interest for which to get the size range.
|
| 24 |
+
Returns:
|
| 25 |
+
dict or list: A dictionary with size ranges for all ROIs, or a list with the size range for the specified ROI.
|
| 26 |
+
"""
|
| 27 |
+
# Define the size ranges for different ROIs
|
| 28 |
+
# The values are in the format [min_size, max_size]
|
| 29 |
+
# The sizes are in mm for the minimum and maximum dimensions
|
| 30 |
+
sizeRange_dict = {
|
| 31 |
+
'whole-body': [420, 2048],
|
| 32 |
+
'neck-thorax-abdomen-pelvis-leg': [400, 2048],
|
| 33 |
+
'neck-thorax-abdomen-pelvis': [380, 2048],
|
| 34 |
+
'thorax-abdomen-pelvis-leg': [360, 2048],
|
| 35 |
+
'neck-thorax-abdomen': [320, 1024],
|
| 36 |
+
'head-neck-thorax-abdomen': [360, 2048],
|
| 37 |
+
'head-neck-thorax': [340, 1024],
|
| 38 |
+
'thorax-abdomen-pelvis': [340, 1024],
|
| 39 |
+
'abdomen-pelvis-leg': [320, 1024],
|
| 40 |
+
'neck-thorax': [220, 1024],
|
| 41 |
+
'thorax-abdomen': [260, 1024],
|
| 42 |
+
'abdomen-pelvis': [260, 1024],
|
| 43 |
+
'pelvis-leg': [240, 1024],
|
| 44 |
+
'head-neck': [240, 1024],
|
| 45 |
+
'head': [150, 1024],
|
| 46 |
+
'brain': [128, 1024],
|
| 47 |
+
'neck': [140, 1024],
|
| 48 |
+
'abdomen': [240, 1024],
|
| 49 |
+
'pelvis': [220, 1024],
|
| 50 |
+
'thorax': [220, 1024],
|
| 51 |
+
'arm': [140, 1024],
|
| 52 |
+
'hand': [140, 1024],
|
| 53 |
+
'leg': [160, 1024],
|
| 54 |
+
'skeleton': [130, 1024],
|
| 55 |
+
}
|
| 56 |
+
if roi in sizeRange_dict:
|
| 57 |
+
return sizeRange_dict[roi]
|
| 58 |
+
else:
|
| 59 |
+
return sizeRange_dict
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def remove_background(img,replace_value=None,num_bin=256,dim_ch=0,sigma=None):
|
| 63 |
+
# common_value1,common_value2=[], []
|
| 64 |
+
# if replace_value is None:
|
| 65 |
+
if dim_ch is None:
|
| 66 |
+
dim_ch=0
|
| 67 |
+
img=np.expand_dims(img,axis=dim_ch)
|
| 68 |
+
ims = np.split(img,img.shape[dim_ch],axis=dim_ch)
|
| 69 |
+
# ims =[img]
|
| 70 |
+
ims = [np.squeeze(im,axis=dim_ch) for im in ims]
|
| 71 |
+
msk1 = np.ones_like(ims[0])
|
| 72 |
+
for im in ims:
|
| 73 |
+
if num_bin>0:
|
| 74 |
+
flatten_im=im.flatten()
|
| 75 |
+
hist, bins = np.histogram(flatten_im,bins=range(num_bin))
|
| 76 |
+
# common_value1.append(np.argmax(hist))
|
| 77 |
+
common_value1 = np.argmax(hist)
|
| 78 |
+
# hist[common_value1] = -10**5
|
| 79 |
+
msk1[im!=common_value1] = 0
|
| 80 |
+
# common_value2 = np.argmax(hist)
|
| 81 |
+
if sigma is not None and sigma > 0:
|
| 82 |
+
# struct=generate_binary_structure()
|
| 83 |
+
msk1 = binary_dilation(msk1,iterations=int(sigma*4)).astype(float)
|
| 84 |
+
msk0 = binary_erosion(1-msk1,iterations=int(sigma*4)).astype(float)
|
| 85 |
+
msk_blur = gaussian_filter(msk0, sigma=sigma*4,truncate=sigma//4, mode='nearest')
|
| 86 |
+
# msk_blur = msk0
|
| 87 |
+
for id, im in enumerate(ims):
|
| 88 |
+
if replace_value is None:
|
| 89 |
+
# a=im[np.logical_not(msk1)]
|
| 90 |
+
# replace_value[id] = np.min(im[np.logical_not(msk1)])
|
| 91 |
+
replace_v=np.min(im[np.logical_not(msk1)])
|
| 92 |
+
else:
|
| 93 |
+
replace_v=replace_value[id]
|
| 94 |
+
# im[msk1==1] = replace_v
|
| 95 |
+
if sigma is not None and sigma>0:
|
| 96 |
+
im_blur=im
|
| 97 |
+
im_blur[msk1==1]=replace_v
|
| 98 |
+
im_blur = gaussian_filter(im_blur, sigma=sigma*4,truncate=sigma//4, mode='nearest')
|
| 99 |
+
# im[msk1==1] = im_blur[msk1==1]
|
| 100 |
+
im=im*(msk_blur) + im_blur*(1-msk_blur)
|
| 101 |
+
else:
|
| 102 |
+
im[msk1 == 1] = replace_v
|
| 103 |
+
# print(im.shape)
|
| 104 |
+
ims[id]=im
|
| 105 |
+
return np.stack(ims,axis=dim_ch)
|
| 106 |
+
|
| 107 |
+
def thresh_img(img,thresh = None,EPS = 10**-7):
|
| 108 |
+
|
| 109 |
+
if isinstance(thresh,list):
|
| 110 |
+
threshold=np.random.uniform(thresh[0],thresh[1])
|
| 111 |
+
upbound=1-np.random.uniform(thresh[0],thresh[1])-threshold
|
| 112 |
+
else:
|
| 113 |
+
threshold=thresh
|
| 114 |
+
if threshold is not None:
|
| 115 |
+
# img=img-threshold
|
| 116 |
+
# img=np.where(img>=0,img,0)
|
| 117 |
+
# img = np.maximum(img-threshold,0)
|
| 118 |
+
# img = torch.maximum(img - threshold,torch.tensor(0.))
|
| 119 |
+
if isinstance(img,list):
|
| 120 |
+
device=img[0].device
|
| 121 |
+
for i in range(len(img)):
|
| 122 |
+
img[i] = torch.clamp(img[i]-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
|
| 123 |
+
else:
|
| 124 |
+
device=img.device
|
| 125 |
+
img = torch.clamp(img-threshold,min=torch.tensor(0.).to(device),max=torch.tensor(upbound).to(device))
|
| 126 |
+
# return (img - img.min()) / (img.max() - img.min() + EPS)
|
| 127 |
+
return img
|
| 128 |
+
|
| 129 |
+
def clamp_img_tensor(img,clamp = [None,None]):
|
| 130 |
+
device=img.device
|
| 131 |
+
if clamp[0] is not None and clamp[1] is not None:
|
| 132 |
+
img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device),max=torch.tensor(clamp[1]).to(device))
|
| 133 |
+
else:
|
| 134 |
+
if clamp[0] is not None:
|
| 135 |
+
img = torch.clamp(img, min=torch.tensor(clamp[0]).to(device))
|
| 136 |
+
if clamp[1] is not None:
|
| 137 |
+
img = torch.clamp(img, max=torch.tensor(clamp[1]).to(device))
|
| 138 |
+
return img
|
| 139 |
+
|
| 140 |
+
def read_CT_volume(folder_path,target_res = 128):
|
| 141 |
+
# read CT into a (128x128x128) cube and pad the insufficient dimension
|
| 142 |
+
|
| 143 |
+
dicom_slices = []
|
| 144 |
+
# Iterate over each file in the folder
|
| 145 |
+
for filename in sorted(os.listdir(folder_path), reverse=True):
|
| 146 |
+
if filename.endswith(".dcm"): # Check if the file is a DICOM file
|
| 147 |
+
file_path = os.path.join(folder_path, filename)
|
| 148 |
+
|
| 149 |
+
# Read the DICOM file
|
| 150 |
+
dicom_data = pydicom.dcmread(file_path)
|
| 151 |
+
|
| 152 |
+
# Append DICOM pixel data to the list
|
| 153 |
+
dicom_slices.append(dicom_data.pixel_array)
|
| 154 |
+
|
| 155 |
+
# Convert the list of slices to a numpy array
|
| 156 |
+
|
| 157 |
+
dicom_slices = np.array(dicom_slices)
|
| 158 |
+
dicome_volume = rearrange(dicom_slices, 'z h w -> h w z')
|
| 159 |
+
|
| 160 |
+
# Get spatial information from the first DICOM file
|
| 161 |
+
first_dicom = pydicom.dcmread(os.path.join(folder_path, os.listdir(folder_path)[0]))
|
| 162 |
+
slice_thickness = first_dicom.SliceThickness
|
| 163 |
+
pixel_spacing = first_dicom.PixelSpacing
|
| 164 |
+
|
| 165 |
+
# Get the scaling ratio for each dim
|
| 166 |
+
h_axis_ratio = pixel_spacing[0]
|
| 167 |
+
w_axis_ratio = pixel_spacing[1]
|
| 168 |
+
z_axis_ratio = slice_thickness
|
| 169 |
+
|
| 170 |
+
# find the longest dim that need to rescale
|
| 171 |
+
longest_axis = max([h_axis_ratio*dicome_volume.shape[0], w_axis_ratio*dicome_volume.shape[1],z_axis_ratio*dicome_volume.shape[2]])
|
| 172 |
+
c_factor = longest_axis/target_res
|
| 173 |
+
# print((h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
|
| 174 |
+
resized_volume = zoom(dicome_volume, (h_axis_ratio/c_factor, w_axis_ratio/c_factor ,z_axis_ratio/c_factor))
|
| 175 |
+
# print('resize', resized_volume.shape)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
max_dim_size = max(resized_volume.shape)
|
| 179 |
+
|
| 180 |
+
# Calculate padding for each dimension
|
| 181 |
+
padding_h = max_dim_size - resized_volume.shape[0]
|
| 182 |
+
padding_w = max_dim_size - resized_volume.shape[1]
|
| 183 |
+
padding_z = max_dim_size - resized_volume.shape[2]
|
| 184 |
+
|
| 185 |
+
pad_depth = (padding_z // 2, padding_z - padding_z // 2)
|
| 186 |
+
pad_height = (padding_h // 2, padding_h - padding_h // 2)
|
| 187 |
+
pad_width = (padding_w // 2, padding_w - padding_w // 2)
|
| 188 |
+
|
| 189 |
+
# Pad the array symmetrically
|
| 190 |
+
padded_resized_volume = np.pad(resized_volume, (pad_height, pad_width, pad_depth), mode='constant')
|
| 191 |
+
|
| 192 |
+
return padded_resized_volume, slice_thickness, pixel_spacing
|
| 193 |
+
|
Dataloader/embding_gen.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.utils.data import Dataset, DataLoader
|
| 3 |
+
import json
|
| 4 |
+
import SimpleITK as sitk
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.transform import rescale, resize, downscale_local_mean
|
| 7 |
+
# from torchvision.transforms import v2
|
| 8 |
+
import sys
|
| 9 |
+
from bert_helper import *
|
| 10 |
+
sys.path.append('./')
|
| 11 |
+
from Dataloader.dataloader_utils import *
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
mapping_files = {
|
| 17 |
+
# 'MSD': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MSD_processed/nifti_mappings_updated.json',
|
| 18 |
+
# 'TotalSegmentor': '/home/data/Github/data/data_gen_def/DATASETS_processed/TotalSegmentorCT_MRI/nifti_mappings.json',
|
| 19 |
+
# 'Kaggle_osic': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/Kaggle_osic_new/nifti_mappings.json',
|
| 20 |
+
# 'CancerImageArchive': '/home/data/Github/data/data_gen_def/DATASETS_processed/CancerImageArchive_test/nifti_mappings.json',
|
| 21 |
+
# 'MnMs': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/MnMs/nifti_mappings.json',
|
| 22 |
+
# 'Brats2019': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2019/nifti_mappings.json',
|
| 23 |
+
# 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
|
| 24 |
+
# 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
|
| 25 |
+
# 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
|
| 26 |
+
'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
|
| 27 |
+
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
|
| 28 |
+
# 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
|
| 29 |
+
# 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
|
| 30 |
+
# 'AbdomenCT1k':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenCT1k/nifti_mappings.json',
|
| 31 |
+
|
| 32 |
+
}
|
| 33 |
+
save_paths = {
|
| 34 |
+
'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
|
| 35 |
+
'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
|
| 36 |
+
'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
|
| 37 |
+
'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
|
| 38 |
+
'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
|
| 39 |
+
'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
|
| 40 |
+
'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
|
| 41 |
+
'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
|
| 42 |
+
'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
|
| 43 |
+
'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
|
| 44 |
+
'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
|
| 45 |
+
'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 46 |
+
'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
|
| 47 |
+
'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
|
| 48 |
+
}
|
| 49 |
+
query = {
|
| 50 |
+
'MSD': ['description'],
|
| 51 |
+
'TotalSegmentor': ['age','gender'],
|
| 52 |
+
'Kaggle_osic': ['Age','Sex','Smoke_Status','Weeks','FVC','Percent'],
|
| 53 |
+
'CancerImageArchive':['Series_Description', 'Study_Description', 'Manufacturer'],
|
| 54 |
+
'MnMs': ['Age','Sex','Height','Weight'],
|
| 55 |
+
'Brats2019': ['Age', 'Grade', 'Survival','ResectionStatus'],
|
| 56 |
+
'Brats2020': ['Age', 'Grade', 'Survival','ResectionStatus'],
|
| 57 |
+
'Brats2021': ['Age', 'Grade', 'Survival','ResectionStatus'],
|
| 58 |
+
'OASIS_1': ['Age', 'M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
|
| 59 |
+
'OASIS_2': ['Age', 'Group','M/F','ASF','Educ','SES','MMSE','eTIV','CDR','nWBV'],
|
| 60 |
+
'PSMA-FDG-PET-CT-LESION':['Study Description', 'diagnosis','age','sex',"pet_radionuclide",'ct_contrast_agent'],
|
| 61 |
+
'PSMA-CT':[],
|
| 62 |
+
'AbdomenAtlas':[],
|
| 63 |
+
'AbdomenCT1k':[],
|
| 64 |
+
}
|
| 65 |
+
add_text = {
|
| 66 |
+
'MSD': {},
|
| 67 |
+
'TotalSegmentor': {},
|
| 68 |
+
'Kaggle_osic': {'description': 'pulmonary fibrosis progression'},
|
| 69 |
+
'CancerImageArchive': {},
|
| 70 |
+
'MnMs': {},
|
| 71 |
+
'Brats2019': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
|
| 72 |
+
'Brats2020': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
|
| 73 |
+
'Brats2021': {'description': 'could include brain tumor, glioma, glioblastoma, low grade glioma, high grade glioma'},
|
| 74 |
+
'OASIS_1': {},
|
| 75 |
+
'OASIS_2': {},
|
| 76 |
+
'PSMA-CT':{'description': 'melanoma patients'},
|
| 77 |
+
'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
|
| 78 |
+
'AbdomenAtlas':{},
|
| 79 |
+
'AbdomenCT1k':{},
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# bert intialization
|
| 84 |
+
model_name = '/home/jachin/data/Github/OmniMorph/External/Models/bert_large_uncased'
|
| 85 |
+
reduce_method = 'mean'
|
| 86 |
+
max_words_num = 32 # max number of words in the caption > 2
|
| 87 |
+
# max_words_num = 64 # max number of words in the caption > 2
|
| 88 |
+
|
| 89 |
+
embeder, tokenizer = get_frozen_embeder(model_name)
|
| 90 |
+
def embed_str_filter(str_input, filter_words=['segmentation', 'registration']):
|
| 91 |
+
'''
|
| 92 |
+
Filter out specific words from the input string.
|
| 93 |
+
'''
|
| 94 |
+
for word in filter_words:
|
| 95 |
+
str_input = str_input.replace(word, '')
|
| 96 |
+
return str_input
|
| 97 |
+
|
| 98 |
+
for dataset in mapping_files.keys():
|
| 99 |
+
jsn_path = mapping_files[dataset]
|
| 100 |
+
|
| 101 |
+
with open(jsn_path, 'r') as f:
|
| 102 |
+
embd_json = json.load(f)
|
| 103 |
+
for key in embd_json.keys():
|
| 104 |
+
embd_json_temp = {}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
embd_json_temp['Modality'] = embd_json[key]['Modality']
|
| 108 |
+
embd_json_temp['ROI'] = embd_json[key]['ROI']
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
query_key = query[dataset]
|
| 112 |
+
|
| 113 |
+
meta_data = embd_json[key]['Metadata']
|
| 114 |
+
for q in query_key:
|
| 115 |
+
if q in meta_data:
|
| 116 |
+
embd_json_temp[q] = meta_data[q]
|
| 117 |
+
else:
|
| 118 |
+
embd_json_temp[q] = 'N/A'
|
| 119 |
+
for q in add_text[dataset].keys():
|
| 120 |
+
if q in embd_json_temp:
|
| 121 |
+
embd_json_temp[q] += ', ' + add_text[dataset][q]
|
| 122 |
+
else:
|
| 123 |
+
embd_json_temp[q] = add_text[dataset][q]
|
| 124 |
+
emdb_str = str(embd_json_temp)[1:-1].lower()
|
| 125 |
+
embd_str = replace_text(emdb_str, get_synonyms_dict(None))
|
| 126 |
+
embd_str = embed_str_filter(embd_str)
|
| 127 |
+
|
| 128 |
+
print(f'embd_json_temp: {str(embd_json_temp)}')
|
| 129 |
+
print(f'embd_str: {embd_str}')
|
| 130 |
+
print(f'words_num: {len(embd_str.split())}')
|
| 131 |
+
assert(len(embd_str.split()) <= max_words_num), f'Too many words in the caption: {embd_str}'
|
| 132 |
+
|
| 133 |
+
embd = str2emb(embd_str, max_words_num, embeder, tokenizer, reduce_method=reduce_method)
|
| 134 |
+
print(embd)
|
| 135 |
+
embd_json[key]['embd'] = embd.tolist()[0]
|
| 136 |
+
embd_json[key]['embd_key'] = embd_str
|
| 137 |
+
|
| 138 |
+
# exit()
|
| 139 |
+
|
| 140 |
+
new_jsn_path = save_paths[dataset]
|
| 141 |
+
with open(new_jsn_path, 'w') as f:
|
| 142 |
+
json.dump(embd_json, f, indent=4)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
Dataloader/nifty_mappings/AbdomenAtlas_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:303c3fb7388e7b3b01cb6f494c3ac3f542da98487039e5b2415786ac4af58ba0
|
| 3 |
+
size 179457573
|
Dataloader/nifty_mappings/AbdomenCT1k_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0abaaa1013fdafe3fae6d5544746a66d8b20892ceb3cf9141a125113984e8350
|
| 3 |
+
size 37315918
|
Dataloader/nifty_mappings/Brats2019_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c5b80fc861484d36d8d6e0f97c404e2c321ee965cc1556a868205f5937d24fe
|
| 3 |
+
size 12126490
|
Dataloader/nifty_mappings/Brats2020_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de345c6a66a4f33552aacbb961cd034ac488500ff5d48810579055f0543162dc
|
| 3 |
+
size 17743015
|
Dataloader/nifty_mappings/Brats2021_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4990a7031d6ac91e1c33e6db046dddf234f67dd8edecd07691675945b9d00af5
|
| 3 |
+
size 44722001
|
Dataloader/nifty_mappings/CIA_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:98cbd21d3d5b7f5fb84091705fbbfcd0f8f26cb26ff4b34ffcf546cf1cedb48a
|
| 3 |
+
size 32744567
|
Dataloader/nifty_mappings/Kaggle_osic_mappings.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dataloader/nifty_mappings/MSD_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a1ab13c61cd6829f088ee92bff4ce12a0f0e19fc9367682291fbd9717b149e83
|
| 3 |
+
size 92620864
|
Dataloader/nifty_mappings/MnMs_mappings.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dataloader/nifty_mappings/OASIS_1_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8784bff1bb5c9ba08fccc8ca9776f3f26c9b2993c1c446ef17d5ba1dd2bda490
|
| 3 |
+
size 15609846
|
Dataloader/nifty_mappings/OASIS_2_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f88910a0846e056b0d4caacd6e6ebfebde52b537828756e217d9a6c6343177c
|
| 3 |
+
size 13396017
|
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c3c8729df59b6e9771fa791c5fe1cd7636e83a3c17109613984cdce0d92eefdc
|
| 3 |
+
size 11700732
|
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:922363b739e1f14243731ea283ee730bc55724a27360d2f28f32b01b23ede5d9
|
| 3 |
+
size 48425273
|
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c36ba45053fea97244c259af0151ddb02e8281fce8c8f439cc88733bd71d668f
|
| 3 |
+
size 67962146
|
Diffusion/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import Diffusion
|
| 2 |
+
from . import diffuser
|
| 3 |
+
from . import networks
|
| 4 |
+
from . import losses
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append('./Diffusion')
|
| 8 |
+
sys.path.append('./')
|
Diffusion/diffuser.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.nn.utils.stateless import functional_call
|
| 5 |
+
|
| 6 |
+
import Diffusion.utils_diff as utils
|
| 7 |
+
from Diffusion.networks import *
|
| 8 |
+
# from networks import *
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
EPS = 1e-8
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DeformDDPM(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
network,
|
| 20 |
+
n_steps=50,
|
| 21 |
+
beta_schedule_fn = None,
|
| 22 |
+
device='cpu',
|
| 23 |
+
image_chw=(1, 28, 28),
|
| 24 |
+
batch_size = 1,
|
| 25 |
+
img_pad_mode = "zeros",
|
| 26 |
+
ddf_pad_mode="border",
|
| 27 |
+
padding_mode="border",
|
| 28 |
+
v_scale = 0.008/256,
|
| 29 |
+
resample_mode=None,
|
| 30 |
+
):
|
| 31 |
+
super(DeformDDPM, self).__init__()
|
| 32 |
+
self.rec_num=2
|
| 33 |
+
self.ndims=len(image_chw)-1
|
| 34 |
+
self.n_steps = n_steps
|
| 35 |
+
self.v_scale = v_scale
|
| 36 |
+
self.device = device
|
| 37 |
+
self.msk_noise_scale = torch.tensor(0)
|
| 38 |
+
|
| 39 |
+
# print('================')
|
| 40 |
+
# print("device:",device)
|
| 41 |
+
# if device == 'cpu':
|
| 42 |
+
# print("num_device: 1")
|
| 43 |
+
# else:
|
| 44 |
+
# print("num_device:", torch.cuda.device_count())
|
| 45 |
+
# print('================')
|
| 46 |
+
|
| 47 |
+
self.num_device = torch.cuda.device_count()
|
| 48 |
+
|
| 49 |
+
self.batch_size = batch_size #//self.num_device
|
| 50 |
+
self.img_pad_mode = img_pad_mode
|
| 51 |
+
self.ddf_pad_mode = ddf_pad_mode
|
| 52 |
+
self.padding_mode = padding_mode
|
| 53 |
+
self.resample_mode = resample_mode
|
| 54 |
+
self.image_chw = image_chw
|
| 55 |
+
self.network = network#.to(self.device)
|
| 56 |
+
self.ddf_stn_full = STN(
|
| 57 |
+
img_sz = self.image_chw[1],
|
| 58 |
+
ndims = self.ndims,
|
| 59 |
+
padding_mode = self.padding_mode,
|
| 60 |
+
device = self.device,
|
| 61 |
+
)
|
| 62 |
+
self._DDF_Encoder_init()
|
| 63 |
+
self.copy_opt = nn.Identity()
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
def get_stn(self):
|
| 67 |
+
return self.img_stn, self.ddf_stn_full
|
| 68 |
+
|
| 69 |
+
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
|
| 70 |
+
if ctl_sz is None:
|
| 71 |
+
ctl_sz = self.image_chw[1] // ctl_ratio
|
| 72 |
+
self.ctl_sz=ctl_sz
|
| 73 |
+
self.img_sz=self.image_chw[1]
|
| 74 |
+
self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
|
| 75 |
+
self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
|
| 76 |
+
self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
|
| 77 |
+
|
| 78 |
+
def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
|
| 79 |
+
rec_num = 1
|
| 80 |
+
mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
|
| 81 |
+
mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
|
| 82 |
+
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 83 |
+
# mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
|
| 84 |
+
# mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
|
| 85 |
+
mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
|
| 86 |
+
mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
|
| 87 |
+
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 88 |
+
return rec_num,mul_num_ddf,mul_num_dvf
|
| 89 |
+
|
| 90 |
+
# def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3):
|
| 91 |
+
# # high: tensor of shape (...), low: int or tensor broadcastable to high
|
| 92 |
+
# sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone()
|
| 93 |
+
# for _ in range(order_num):
|
| 94 |
+
# # For each element, sample in [sample_num, high]
|
| 95 |
+
# # torch.randint requires scalar low/high, so we use elementwise sampling
|
| 96 |
+
# rand_shape = high.shape
|
| 97 |
+
# # Clamp sample_num to be <= high
|
| 98 |
+
# sample_num = torch.minimum(sample_num, high)
|
| 99 |
+
# # Generate random numbers for each element
|
| 100 |
+
# rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device)
|
| 101 |
+
# for idx in np.ndindex(rand_shape):
|
| 102 |
+
# l = sample_num[idx].item()
|
| 103 |
+
# h = high[idx].item()
|
| 104 |
+
# if l >= h:
|
| 105 |
+
# rand[idx] = l
|
| 106 |
+
# else:
|
| 107 |
+
# rand[idx] = torch.randint(l, h + 1, (1,), device=high.device)
|
| 108 |
+
# sample_num = rand.to(high.dtype)
|
| 109 |
+
# return sample_num
|
| 110 |
+
|
| 111 |
+
def _get_random_ddf(self,img,t):
|
| 112 |
+
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| 113 |
+
ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
|
| 114 |
+
warped_img = self.img_stn(img,ddf_forward)
|
| 115 |
+
return warped_img, dvf_forward,ddf_forward
|
| 116 |
+
|
| 117 |
+
def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
|
| 118 |
+
dvf=0
|
| 119 |
+
if self.img_sz is None:
|
| 120 |
+
self.img_sz=max(ctl_szs)
|
| 121 |
+
if 1 in ctl_szs:
|
| 122 |
+
dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
|
| 123 |
+
dvf = dvf + dvf_rot
|
| 124 |
+
for ctl_sz in ctl_szs:
|
| 125 |
+
_v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
|
| 126 |
+
# temp>>
|
| 127 |
+
if ctl_sz <= 2:
|
| 128 |
+
_v_scale = _v_scale/2
|
| 129 |
+
# temp<<
|
| 130 |
+
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
|
| 131 |
+
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 132 |
+
dvf=dvf+dvf_comp
|
| 133 |
+
return dvf
|
| 134 |
+
|
| 135 |
+
def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
|
| 136 |
+
sample_value = low
|
| 137 |
+
for _ in range(order_num):
|
| 138 |
+
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 139 |
+
return sample_value
|
| 140 |
+
|
| 141 |
+
def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
|
| 142 |
+
crop_rate=2
|
| 143 |
+
for _ in range(self.ndims+1):
|
| 144 |
+
mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
|
| 145 |
+
# v_scale = v_scale *crop_rate
|
| 146 |
+
ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| 147 |
+
if ddf0 is not None:
|
| 148 |
+
ddf=ddf0
|
| 149 |
+
else:
|
| 150 |
+
ddf = torch.zeros(ctl_ddf_sz) * 0
|
| 151 |
+
dddf = torch.zeros(ctl_ddf_sz) * 0
|
| 152 |
+
scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine
|
| 153 |
+
# scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine
|
| 154 |
+
# scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine
|
| 155 |
+
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
| 156 |
+
|
| 157 |
+
for i in range(rec_num):
|
| 158 |
+
# Randomly select 5 elements from ctl_szs (if there are at least 5)
|
| 159 |
+
if len(ctl_szs_all) > select_num:
|
| 160 |
+
ctl_szs = random.sample(ctl_szs_all, select_num)
|
| 161 |
+
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
|
| 162 |
+
# if True:
|
| 163 |
+
if noise_ratio==0:
|
| 164 |
+
dvf0=dvf
|
| 165 |
+
else:
|
| 166 |
+
dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
|
| 167 |
+
# print([num.shape for num in mul_num])
|
| 168 |
+
for j in range(torch.max(mul_num[0]).item()):
|
| 169 |
+
flag = [(n>j).int().to(self.device) for n in mul_num]
|
| 170 |
+
ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
|
| 171 |
+
dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
|
| 172 |
+
|
| 173 |
+
ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 174 |
+
# ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
|
| 175 |
+
if self.ndims==2:
|
| 176 |
+
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 177 |
+
else:
|
| 178 |
+
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 179 |
+
# if rec_num==1:
|
| 180 |
+
if True:
|
| 181 |
+
dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 182 |
+
# dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
|
| 183 |
+
if self.ndims == 2:
|
| 184 |
+
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 185 |
+
else:
|
| 186 |
+
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 187 |
+
return ddf,dddf
|
| 188 |
+
else:
|
| 189 |
+
return ddf
|
| 190 |
+
|
| 191 |
+
def create_noise_map(self, img, noise_type='gaussian', noise_ratio=0.2):
|
| 192 |
+
if noise_type == 'gaussian':
|
| 193 |
+
noise_map = torch.randn_like(img) * noise_ratio
|
| 194 |
+
elif noise_type == 'uniform':
|
| 195 |
+
noise_map = torch.rand_like(img) # 0-1
|
| 196 |
+
elif noise_type == 'binary':
|
| 197 |
+
noise_map = torch.bernoulli(torch.rand_like(img))
|
| 198 |
+
else:
|
| 199 |
+
noise_map = torch.zeros_like(img)
|
| 200 |
+
noise_map = noise_map.to(img.device)
|
| 201 |
+
return noise_map
|
| 202 |
+
|
| 203 |
+
def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
|
| 204 |
+
noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
|
| 205 |
+
return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
|
| 206 |
+
|
| 207 |
+
def apply_noise(self, img, noise_map=None, apply_mask=None):
|
| 208 |
+
return img * apply_mask + noise_map * (1-apply_mask)
|
| 209 |
+
|
| 210 |
+
def downsample(self, img, down_ratio_range=[1./32,1]):
|
| 211 |
+
down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
|
| 212 |
+
# print(down_ratio)
|
| 213 |
+
down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 214 |
+
# print(down_img)
|
| 215 |
+
# return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio)
|
| 216 |
+
return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy
|
| 217 |
+
|
| 218 |
+
def get_slice_mask(self, img, slice_num_range=[0,32]):
|
| 219 |
+
slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
|
| 220 |
+
mask = torch.zeros_like(img)
|
| 221 |
+
sample_ratio = 0
|
| 222 |
+
for i in range(self.ndims):
|
| 223 |
+
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| 224 |
+
slice_idx = random.sample(range(self.image_chw[1]), slice_num)
|
| 225 |
+
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| 226 |
+
for idx in slice_idx:
|
| 227 |
+
mask[..., idx] = 1
|
| 228 |
+
mask = mask.permute(*transpose_list)
|
| 229 |
+
# sample_ratio += slice_num / self.image_chw[1] / self.ndims
|
| 230 |
+
sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy
|
| 231 |
+
|
| 232 |
+
# print(mask)
|
| 233 |
+
# print("sample_ratio:", sample_ratio)
|
| 234 |
+
return mask, sample_ratio
|
| 235 |
+
|
| 236 |
+
def project(self, img):
|
| 237 |
+
proj_img = torch.zeros_like(img)
|
| 238 |
+
rand_bourn = np.random.randint(0, 2, size=[self.ndims])
|
| 239 |
+
proj_dim_num = np.sum(rand_bourn)
|
| 240 |
+
for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
|
| 241 |
+
if pflag:
|
| 242 |
+
proj_img += torch.mean(img, dim=i, keepdim=True)
|
| 243 |
+
# print("projecting dim:", i)
|
| 244 |
+
return proj_img/(proj_dim_num+EPS), proj_dim_num
|
| 245 |
+
|
| 246 |
+
def proc_cond_img(self, img, proc_type=None):
|
| 247 |
+
# Remove torch.no_grad() since most operations are not differentiable anyway
|
| 248 |
+
proc_img = img.clone().detach()
|
| 249 |
+
if proc_type is None:
|
| 250 |
+
# Heavily bias towards 'uncon' for efficiency
|
| 251 |
+
proc_type = random.choices(
|
| 252 |
+
# ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
|
| 253 |
+
# weights=[1, 1, 1, 1, 1, 1, 3], k=1
|
| 254 |
+
['adding', 'independ', 'downsample', 'slice', 'none', 'uncon'],
|
| 255 |
+
weights=[1, 1, 1, 1, 1, 3], k=1
|
| 256 |
+
)[0]
|
| 257 |
+
mask = torch.tensor(1, device=img.device)
|
| 258 |
+
cond_ratio = torch.tensor(1., device=img.device)
|
| 259 |
+
self.msk_noise_scale = torch.tensor(0, device=img.device)
|
| 260 |
+
noise_type = random.choice(['gaussian', 'uniform', 'none'])
|
| 261 |
+
# Precompute noise_map only if needed
|
| 262 |
+
noise_map = None
|
| 263 |
+
if proc_type not in ['none', None, '']:
|
| 264 |
+
if proc_type == 'uncon':
|
| 265 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type)
|
| 266 |
+
proc_img = noise_map
|
| 267 |
+
mask = torch.tensor(0, device=img.device)
|
| 268 |
+
cond_ratio = torch.tensor(0, device=img.device)
|
| 269 |
+
return proc_img, mask, cond_ratio
|
| 270 |
+
if proc_type in ['adding', 'independ', 'slice']:
|
| 271 |
+
# self.msk_noise_scale = 0
|
| 272 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type)
|
| 273 |
+
if proc_type == 'adding':
|
| 274 |
+
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| 275 |
+
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| 276 |
+
elif proc_type == 'independ':
|
| 277 |
+
mask = self.create_noise_map(img, noise_type='binary')
|
| 278 |
+
if self.msk_noise_scale == 0:
|
| 279 |
+
proc_img = img * mask
|
| 280 |
+
else:
|
| 281 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| 282 |
+
with torch.no_grad():
|
| 283 |
+
cond_ratio = mask.float().mean()
|
| 284 |
+
elif proc_type == 'downsample':
|
| 285 |
+
# proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
|
| 286 |
+
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
|
| 287 |
+
cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| 288 |
+
elif proc_type == 'slice':
|
| 289 |
+
slice_num_max = random.randint(1, 64)
|
| 290 |
+
slice_num_max = random.randint(1, slice_num_max)
|
| 291 |
+
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 292 |
+
if self.msk_noise_scale == 0:
|
| 293 |
+
proc_img = img * mask
|
| 294 |
+
else:
|
| 295 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| 296 |
+
cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| 297 |
+
elif proc_type == 'project':
|
| 298 |
+
proc_img, proj_num = self.project(proc_img)
|
| 299 |
+
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
|
| 300 |
+
# cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy
|
| 301 |
+
return proc_img, mask, cond_ratio
|
| 302 |
+
|
| 303 |
+
def diffuse(self, x_0, t):
|
| 304 |
+
t=torch.tensor(t)
|
| 305 |
+
# img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t)
|
| 306 |
+
# return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn
|
| 307 |
+
return self._get_random_ddf(img = x_0, t = t)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def recover(self, x, y, t,rec_num=2, text=None):
|
| 311 |
+
if isinstance(t, list):
|
| 312 |
+
t=[torch.tensor(t0) for t0 in t]
|
| 313 |
+
t=[t0.to(x.device) for t0 in t]
|
| 314 |
+
else:
|
| 315 |
+
t=torch.tensor(t)
|
| 316 |
+
t.to(x.device)
|
| 317 |
+
if rec_num is None:
|
| 318 |
+
rec_num = self.rec_num
|
| 319 |
+
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
|
| 320 |
+
|
| 321 |
+
def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
|
| 322 |
+
"""
|
| 323 |
+
use detach to recover:
|
| 324 |
+
- but not include no_grad
|
| 325 |
+
"""
|
| 326 |
+
if isinstance(t, list):
|
| 327 |
+
t = [torch.tensor(t0, device=x.device) for t0 in t]
|
| 328 |
+
else:
|
| 329 |
+
t = torch.tensor(t, device=x.device)
|
| 330 |
+
|
| 331 |
+
if rec_num is None:
|
| 332 |
+
rec_num = self.rec_num
|
| 333 |
+
|
| 334 |
+
# params = {k: v.detach() for k, v in self.network.named_parameters()}
|
| 335 |
+
# buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer
|
| 336 |
+
# # functional_call require position args,here kwargs doesnot work, so:
|
| 337 |
+
# def _forward(module, kw):
|
| 338 |
+
# return module(**kw)
|
| 339 |
+
# # functional_call(module, ...) can only pass args/kwargs to module.forward
|
| 340 |
+
# # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs)
|
| 341 |
+
# return functional_call(
|
| 342 |
+
# self.network,
|
| 343 |
+
# (params, buffers),
|
| 344 |
+
# args=(),
|
| 345 |
+
# kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
|
| 346 |
+
# )
|
| 347 |
+
|
| 348 |
+
# 1) param detached
|
| 349 |
+
params = {k: v.detach() for k, v in self.network.named_parameters()}
|
| 350 |
+
# 2) buffers keeps unchanged
|
| 351 |
+
buffers = dict(self.network.named_buffers())
|
| 352 |
+
|
| 353 |
+
# 3) old version of PyTorch doesnot support passing params and buffers together
|
| 354 |
+
params_and_buffers = {}
|
| 355 |
+
params_and_buffers.update(params)
|
| 356 |
+
params_and_buffers.update(buffers)
|
| 357 |
+
return functional_call(
|
| 358 |
+
self.network,
|
| 359 |
+
params_and_buffers,
|
| 360 |
+
(),
|
| 361 |
+
kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
|
| 366 |
+
if mask is None:
|
| 367 |
+
mask = 1
|
| 368 |
+
# org_imgs=self.copy_opt(x0)
|
| 369 |
+
if cond_imgs is None:
|
| 370 |
+
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
|
| 371 |
+
noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
|
| 372 |
+
if isinstance(self.network,DefRec_MutAttnNet):
|
| 373 |
+
t = [t] * 1
|
| 374 |
+
return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
|
| 375 |
+
|
| 376 |
+
def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
|
| 377 |
+
if T is not None:
|
| 378 |
+
return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 379 |
+
else:
|
| 380 |
+
return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 381 |
+
# if mask is None:
|
| 382 |
+
# mask = 1
|
| 383 |
+
# cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs)
|
| 384 |
+
# noisy_imgs, dvf_I, _ = self.diffuse(x0, t)
|
| 385 |
+
# if isinstance(self.network, DefRec_MutAttnNet):
|
| 386 |
+
# t = [t] * 1
|
| 387 |
+
# return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I
|
| 388 |
+
|
| 389 |
+
def diff_recover(self,
|
| 390 |
+
img_org,
|
| 391 |
+
msk_org=None,
|
| 392 |
+
T=[None,None],
|
| 393 |
+
ddf_rand=None,
|
| 394 |
+
v_scale = None,
|
| 395 |
+
t_save=None,
|
| 396 |
+
cond_imgs=None,
|
| 397 |
+
proc_type=None,
|
| 398 |
+
text=None,
|
| 399 |
+
):
|
| 400 |
+
if cond_imgs is None:
|
| 401 |
+
cond_imgs = img_org.clone().detach()
|
| 402 |
+
# if proc_type is not None:
|
| 403 |
+
cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
|
| 404 |
+
if ddf_rand is None:
|
| 405 |
+
if v_scale is not None:
|
| 406 |
+
self.v_scale=v_scale
|
| 407 |
+
self._DDF_Encoder_init()
|
| 408 |
+
if T[0] is None or T[0] == 0:
|
| 409 |
+
img_diff = img_org.clone().detach()
|
| 410 |
+
ddf_rand = torch.zeros_like(img_diff)
|
| 411 |
+
else:
|
| 412 |
+
img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
|
| 413 |
+
else:
|
| 414 |
+
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
|
| 415 |
+
ddf_comp = ddf_rand.clone().detach()
|
| 416 |
+
img_rec = img_diff.clone().detach()
|
| 417 |
+
if msk_org is not None:
|
| 418 |
+
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
|
| 419 |
+
else:
|
| 420 |
+
msk_diff = None
|
| 421 |
+
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
|
| 422 |
+
img_save=[]
|
| 423 |
+
msk_save=[]
|
| 424 |
+
|
| 425 |
+
if isinstance(self.network,DefRec_MutAttnNet):
|
| 426 |
+
# Denosing image via list of t
|
| 427 |
+
t_list = list(range(T[1]-1, -1, -1))
|
| 428 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
|
| 429 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 430 |
+
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 431 |
+
if msk_org is not None:
|
| 432 |
+
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| 433 |
+
else:
|
| 434 |
+
# Denosing image
|
| 435 |
+
if isinstance(T[-1], int):
|
| 436 |
+
time_steps = range(T[-1] - 1, -1, -1)
|
| 437 |
+
trainable_iterations =[]
|
| 438 |
+
else:
|
| 439 |
+
time_steps = T[-1]
|
| 440 |
+
|
| 441 |
+
# # Randomly select k iterations to make their parameters trainable
|
| 442 |
+
# win_len = 2 # Number of iterations to make trainable
|
| 443 |
+
# if len(time_steps) <= win_len:
|
| 444 |
+
# win_start = 0
|
| 445 |
+
# else:
|
| 446 |
+
# win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
|
| 447 |
+
# win_end = win_start + win_len - 1
|
| 448 |
+
|
| 449 |
+
k=2
|
| 450 |
+
# trainable_iterations = time_steps[win_start: win_start + win_len]
|
| 451 |
+
# trainable_iterations = random.sample(time_steps, k)
|
| 452 |
+
trainable_iterations = time_steps[-1:-k-1:-1]
|
| 453 |
+
# print(time_steps)
|
| 454 |
+
# print("trainable_iterations:", trainable_iterations)
|
| 455 |
+
for i in time_steps:
|
| 456 |
+
t = torch.tensor(np.array([i])).to(self.device)
|
| 457 |
+
|
| 458 |
+
if i in trainable_iterations:
|
| 459 |
+
# Make parameters trainable for this iteration
|
| 460 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 461 |
+
else:
|
| 462 |
+
# Freeze parameters for this iteration using torch.no_grad()
|
| 463 |
+
with torch.no_grad():
|
| 464 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 465 |
+
# for idx, i in enumerate(time_steps):
|
| 466 |
+
# t = torch.tensor(np.array([i])).to(self.device)
|
| 467 |
+
# if idx < win_start:
|
| 468 |
+
# # just no_grad
|
| 469 |
+
# with torch.no_grad():
|
| 470 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 471 |
+
# elif win_start <= idx <= win_end:
|
| 472 |
+
# # normal update
|
| 473 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 474 |
+
# else:
|
| 475 |
+
# # freeze params but keep grad for input
|
| 476 |
+
# pre_dvf_I = self.recover_frozen_params_but_grad_input(
|
| 477 |
+
# x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text
|
| 478 |
+
# )
|
| 479 |
+
|
| 480 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 481 |
+
# Apply to image
|
| 482 |
+
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 483 |
+
if msk_org is not None:
|
| 484 |
+
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| 485 |
+
if t_save is not None:
|
| 486 |
+
if i in t_save:
|
| 487 |
+
img_save.append(img_rec)
|
| 488 |
+
if msk_org is not None:
|
| 489 |
+
msk_save.append(msk_rec)
|
| 490 |
+
|
| 491 |
+
# for i in time_steps:
|
| 492 |
+
# t = torch.tensor(np.array([i])).to(self.device)
|
| 493 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None)
|
| 494 |
+
# ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 495 |
+
# # apply to image
|
| 496 |
+
# img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 497 |
+
# if msk_org is not None:
|
| 498 |
+
# msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp)
|
| 499 |
+
# if t_save is not None:
|
| 500 |
+
# if i in t_save:
|
| 501 |
+
# img_save.append(img_rec)
|
| 502 |
+
# if msk_org is not None:
|
| 503 |
+
# msk_save.append(msk_rec)
|
| 504 |
+
# print(torch.max(torch.abs(ddf_comp)))
|
| 505 |
+
# print(torch.max(torch.abs(ddf_rand)))
|
| 506 |
+
|
| 507 |
+
return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
|
| 508 |
+
|
| 509 |
+
if __name__ == "__main__":
|
| 510 |
+
H, W = 8, 8
|
| 511 |
+
deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
|
| 512 |
+
# img = torch.zeros([1, 1, H, W])
|
| 513 |
+
img = torch.randn([1, 1, H, W])
|
| 514 |
+
t = 1
|
| 515 |
+
rec_num = 2
|
| 516 |
+
# proc_type = 'adding'
|
| 517 |
+
# proc_type = 'independ'
|
| 518 |
+
# proc_type = 'downsample'
|
| 519 |
+
proc_type = 'slice'
|
| 520 |
+
# proc_type = 'project'
|
| 521 |
+
# proc_type = 'none'
|
| 522 |
+
print(img)
|
| 523 |
+
cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
|
| 524 |
+
print(cond_imgs)
|
| 525 |
+
# img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type)
|
| 526 |
+
# print(img_rec.shape, dvf_I.shape)
|
| 527 |
+
|
| 528 |
+
# proc_type = 'adding'
|
| 529 |
+
# ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)
|
| 530 |
+
|
| 531 |
+
|
Diffusion/losses.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
losses for DRDM
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
EPS=1e-7
|
| 12 |
+
|
| 13 |
+
# eps_scale = 10e-5
|
| 14 |
+
# eps_scale = 10e-4
|
| 15 |
+
# eps_scale = 1e-4
|
| 16 |
+
eps_scale = 1e-5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LMSE(torch.nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
Labeled Mean Square Error (LMSE)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
|
| 25 |
+
super(LMSE, self).__init__()
|
| 26 |
+
self.eps = eps
|
| 27 |
+
self.relate_eps = relate_eps
|
| 28 |
+
self.ndims = 3
|
| 29 |
+
self.smooth = smooth
|
| 30 |
+
self.win = win
|
| 31 |
+
# Set window size
|
| 32 |
+
if self.win is None:
|
| 33 |
+
self.win = [5] * self.ndims
|
| 34 |
+
if smooth:
|
| 35 |
+
self.kernels = self._build_kernel(std=0.0)
|
| 36 |
+
|
| 37 |
+
def _build_kernel(self, std=0.0):
|
| 38 |
+
if std == 0.0:
|
| 39 |
+
return torch.ones([1, 1, *self.win])
|
| 40 |
+
else:
|
| 41 |
+
tail = int(np.ceil(std)) * 3
|
| 42 |
+
k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
|
| 43 |
+
kernel = k / torch.sum(k)
|
| 44 |
+
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| 45 |
+
# print(kernel.item)
|
| 46 |
+
return kernel.unsqueeze(0).unsqueeze(0)
|
| 47 |
+
|
| 48 |
+
def forward(self, I, J, label=None):
|
| 49 |
+
"""
|
| 50 |
+
Computes the labeled mean squared error between I and J (ref).
|
| 51 |
+
If label is provided, computes the MSE only over the labeled regions.
|
| 52 |
+
"""
|
| 53 |
+
padding = [(w-1) // 2 for w in self.win]
|
| 54 |
+
if self.smooth:
|
| 55 |
+
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
|
| 56 |
+
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
|
| 57 |
+
mse = (I - J) ** 2
|
| 58 |
+
if self.relate_eps is not None:
|
| 59 |
+
mse = mse/((J**2) + self.relate_eps)
|
| 60 |
+
if label is not None:
|
| 61 |
+
label = label.float()
|
| 62 |
+
mse = mse * label
|
| 63 |
+
mse_sum = torch.sum(mse, dim=(2, 3, 4))
|
| 64 |
+
label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
|
| 65 |
+
loss = torch.mean(mse_sum / label_sum)
|
| 66 |
+
else:
|
| 67 |
+
loss = torch.mean(mse)
|
| 68 |
+
return loss
|
| 69 |
+
|
| 70 |
+
class LNCC(torch.nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Local (over window) normalized cross-correlation (LNCC)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, win=None, num_ch=1, eps=1e-6, central=True, smooth=True):
|
| 76 |
+
super(LNCC, self).__init__()
|
| 77 |
+
self.scale = 2e0
|
| 78 |
+
self.win = win
|
| 79 |
+
self.eps = eps
|
| 80 |
+
self.central = central
|
| 81 |
+
self.ndims = 3
|
| 82 |
+
self.strides = [1] * (self.ndims + 2)
|
| 83 |
+
self.smooth = smooth
|
| 84 |
+
|
| 85 |
+
# Set window size
|
| 86 |
+
if self.win is None:
|
| 87 |
+
self.win = [9] * self.ndims
|
| 88 |
+
self.padding = [(w-1) // 2 for w in self.win]
|
| 89 |
+
|
| 90 |
+
if smooth:
|
| 91 |
+
self.kernels = self._build_kernel(std=0.45)
|
| 92 |
+
self.sum_filt = self._build_kernel(std=0.0)
|
| 93 |
+
|
| 94 |
+
def _build_kernel(self, std=0.0):
|
| 95 |
+
if std == 0.0:
|
| 96 |
+
return torch.ones([1, 1, *self.win])/np.prod(self.win)
|
| 97 |
+
else:
|
| 98 |
+
self.tail = int(np.ceil(std)) * 2
|
| 99 |
+
k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
|
| 100 |
+
kernel = k / torch.sum(k)
|
| 101 |
+
# print(kernel)
|
| 102 |
+
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| 103 |
+
# kernel = kernel * np.prod(self.win)
|
| 104 |
+
# print('Gaussian kernel created with std:', std)
|
| 105 |
+
# print('Kernel sum:', torch.sum(kernel))
|
| 106 |
+
|
| 107 |
+
return kernel.unsqueeze(0).unsqueeze(0)
|
| 108 |
+
|
| 109 |
+
def lncc(self, I, J, label=None):
|
| 110 |
+
self.sum_filt = self.sum_filt.to(I.device)
|
| 111 |
+
|
| 112 |
+
if self.smooth:
|
| 113 |
+
self.kernels = self.kernels.to(I.device)
|
| 114 |
+
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
|
| 115 |
+
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
|
| 116 |
+
|
| 117 |
+
# if self.central:
|
| 118 |
+
# I = I - torch.mean(I, dim=(2, 3, 4), keepdim=True)
|
| 119 |
+
# J = J - torch.mean(J, dim=(2, 3, 4), keepdim=True)
|
| 120 |
+
# Compute CC squares
|
| 121 |
+
I2 = I * I
|
| 122 |
+
J2 = J * J
|
| 123 |
+
IJ = I * J
|
| 124 |
+
|
| 125 |
+
if self.central:
|
| 126 |
+
# Compute local sums via convolution
|
| 127 |
+
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
|
| 128 |
+
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
|
| 129 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| 130 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| 131 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
| 132 |
+
|
| 133 |
+
# Compute cross-correlation
|
| 134 |
+
win_size = np.prod(self.win)
|
| 135 |
+
# print('Window size:', win_size)
|
| 136 |
+
# u_I = I_sum / win_size
|
| 137 |
+
# u_J = J_sum / win_size
|
| 138 |
+
# cross = IJ_sum - ((I_sum * J_sum) / win_size)
|
| 139 |
+
# I_var = I2_sum - ((I_sum * I_sum) / win_size)
|
| 140 |
+
# J_var = J2_sum - ((J_sum * J_sum) / win_size)
|
| 141 |
+
cross = IJ_sum - (I_sum * J_sum)
|
| 142 |
+
I_var = I2_sum - (I_sum * I_sum)
|
| 143 |
+
J_var = J2_sum - (J_sum * J_sum)
|
| 144 |
+
else:
|
| 145 |
+
# if 1:
|
| 146 |
+
# Compute local sums via convolution
|
| 147 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| 148 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| 149 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
| 150 |
+
|
| 151 |
+
cross = IJ_sum
|
| 152 |
+
I_var = I2_sum
|
| 153 |
+
J_var = J2_sum
|
| 154 |
+
|
| 155 |
+
# cc = (cross * cross) / (I_var * J_var + self.eps)
|
| 156 |
+
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
|
| 157 |
+
if label is not None:
|
| 158 |
+
label = label.float()
|
| 159 |
+
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
| 160 |
+
|
| 161 |
+
return torch.mean(cc)
|
| 162 |
+
|
| 163 |
+
def forward(self, I, J, label=None):
|
| 164 |
+
return -self.lncc(I*self.scale, J*self.scale, label=label)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class NCC(torch.nn.Module):
|
| 169 |
+
# def __init__(self, eps_scale=10e-7,img_sz=256):
|
| 170 |
+
def __init__(self, eps_scale=10e-5,img_sz=256):
|
| 171 |
+
super(NCC, self).__init__()
|
| 172 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 173 |
+
# self.scale=10e4
|
| 174 |
+
self.scale=1e2
|
| 175 |
+
|
| 176 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| 177 |
+
if ddf_stn is None:
|
| 178 |
+
trm_pred=pred
|
| 179 |
+
else:
|
| 180 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 181 |
+
trm_pred = self.scale * trm_pred
|
| 182 |
+
inv_lab = self.scale * inv_lab
|
| 183 |
+
if mask is None:
|
| 184 |
+
loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
|
| 185 |
+
else:
|
| 186 |
+
batch_size = inv_lab.shape[0]
|
| 187 |
+
loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
|
| 188 |
+
return loss_gen
|
| 189 |
+
|
| 190 |
+
class MRSE(torch.nn.Module):
|
| 191 |
+
def __init__(self, eps_scale=eps_scale,img_sz=256):
|
| 192 |
+
super(MRSE, self).__init__()
|
| 193 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 194 |
+
self.scale = 10e1
|
| 195 |
+
|
| 196 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| 197 |
+
if ddf_stn is None:
|
| 198 |
+
trm_pred=pred
|
| 199 |
+
else:
|
| 200 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 201 |
+
trm_pred = self.scale * trm_pred
|
| 202 |
+
inv_lab = self.scale * inv_lab
|
| 203 |
+
if mask is None:
|
| 204 |
+
loss_gen = torch.mean(
|
| 205 |
+
torch.sum(torch.square(trm_pred + inv_lab), dim=1)
|
| 206 |
+
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| 207 |
+
)
|
| 208 |
+
else:
|
| 209 |
+
batch_size = inv_lab.shape[0]
|
| 210 |
+
loss_gen = torch.sum(
|
| 211 |
+
torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
|
| 212 |
+
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| 213 |
+
)/torch.sum(mask)/batch_size
|
| 214 |
+
return loss_gen/1
|
| 215 |
+
|
| 216 |
+
class RMSE(torch.nn.Module):
|
| 217 |
+
def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
|
| 218 |
+
super(RMSE, self).__init__()
|
| 219 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 220 |
+
self.ndims=ndims
|
| 221 |
+
|
| 222 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None):
|
| 223 |
+
if ddf_stn is None:
|
| 224 |
+
trm_pred=pred
|
| 225 |
+
else:
|
| 226 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 227 |
+
loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
|
| 228 |
+
dim=list(range(1, 1 + self.ndims))) / (
|
| 229 |
+
torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
|
| 230 |
+
return loss_gen
|
| 231 |
+
# loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS))
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class Grad(torch.nn.Module):
|
| 235 |
+
"""
|
| 236 |
+
N-D gradient loss
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(self, penalty=['l1'],ndims=2, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
|
| 240 |
+
super(Grad, self).__init__()
|
| 241 |
+
self.penalty = penalty
|
| 242 |
+
self.eps = eps
|
| 243 |
+
self.outrange_weight = outrange_weight
|
| 244 |
+
self.detj_weight=detj_weight
|
| 245 |
+
self.apear_scale = apear_scale
|
| 246 |
+
self.ndims=ndims
|
| 247 |
+
self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
|
| 248 |
+
self.act = torch.nn.ReLU(inplace=False)
|
| 249 |
+
self.dist=dist
|
| 250 |
+
self.sign=sign
|
| 251 |
+
self.waive_thresh=waive_thresh
|
| 252 |
+
|
| 253 |
+
def _diffs(self, y,dist=None):
|
| 254 |
+
if dist is None:
|
| 255 |
+
dist=self.dist
|
| 256 |
+
# vol_shape = y.size()[2:]
|
| 257 |
+
# vol_shape = y.get_shape().as_list()[1:-1]
|
| 258 |
+
# ndims = len(vol_shape)
|
| 259 |
+
|
| 260 |
+
df = [None] * self.ndims
|
| 261 |
+
for i in range(self.ndims):
|
| 262 |
+
d = i + 2
|
| 263 |
+
# permute dimensions to put the ith dimension first
|
| 264 |
+
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| 265 |
+
yp = y.permute(r)
|
| 266 |
+
dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
|
| 267 |
+
|
| 268 |
+
# permute back
|
| 269 |
+
# note: this might not be necessary for this loss specifically,
|
| 270 |
+
# since the results are just summed over anyway.
|
| 271 |
+
r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| 272 |
+
df[i] = dfi.permute(r)
|
| 273 |
+
return df
|
| 274 |
+
|
| 275 |
+
def _eq_diffs(self, y,dist=None):
|
| 276 |
+
if dist is None:
|
| 277 |
+
dist=self.dist
|
| 278 |
+
# vol_shape = y.get_shape().as_list()[1:-1]
|
| 279 |
+
vol_shape = y.size()[2:]
|
| 280 |
+
ndims = len(vol_shape)
|
| 281 |
+
pad = [0, 0] * (ndims + 1) +[dist, 0]
|
| 282 |
+
pad1 = [0, 0] * (ndims + 1) +[0, dist]
|
| 283 |
+
# df = [None, None] * ndims
|
| 284 |
+
df = [None] * ndims
|
| 285 |
+
for i in range(ndims):
|
| 286 |
+
d = i + 2
|
| 287 |
+
r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| 288 |
+
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| 289 |
+
yt = y.permute(r)
|
| 290 |
+
dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
|
| 291 |
+
df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
|
| 292 |
+
# df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
|
| 293 |
+
# df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri)
|
| 294 |
+
y.permute(ri)
|
| 295 |
+
return df
|
| 296 |
+
|
| 297 |
+
def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
|
| 298 |
+
if dist is None:
|
| 299 |
+
dist=self.dist
|
| 300 |
+
vol_shape = y.size()[2:]
|
| 301 |
+
ndims = len(vol_shape)
|
| 302 |
+
df = [None] * ndims
|
| 303 |
+
|
| 304 |
+
for i in range(ndims):
|
| 305 |
+
d = i + 2
|
| 306 |
+
r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| 307 |
+
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| 308 |
+
yt = y.permute(r)
|
| 309 |
+
wt = w.permute(r)
|
| 310 |
+
dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
|
| 311 |
+
df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
|
| 312 |
+
y.permute(ri)
|
| 313 |
+
w.permute(ri)
|
| 314 |
+
return df
|
| 315 |
+
|
| 316 |
+
def _outl_dist(self, y,range_thresh=0.2):
|
| 317 |
+
self.device = y.device
|
| 318 |
+
vol_shape = y.size()[2:]
|
| 319 |
+
self.max_sz=self.max_sz.to(self.device)
|
| 320 |
+
act=torch.nn.ReLU(inplace=True)
|
| 321 |
+
loss=0.
|
| 322 |
+
for i in range(self.ndims):
|
| 323 |
+
d = i + 2
|
| 324 |
+
# permute dimensions to put the ith dimension first
|
| 325 |
+
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| 326 |
+
ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| 327 |
+
yt = y.permute(r)
|
| 328 |
+
loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
|
| 329 |
+
# loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh)))
|
| 330 |
+
y.permute(ri)
|
| 331 |
+
return loss/self.ndims
|
| 332 |
+
|
| 333 |
+
def _center_dist(self, y):
|
| 334 |
+
self.device = y.device
|
| 335 |
+
vol_shape = y.size()[2:]
|
| 336 |
+
self.max_sz=self.max_sz.to(self.device)
|
| 337 |
+
select_loc = [s // 2 for s in vol_shape]
|
| 338 |
+
if self.ndims==3:
|
| 339 |
+
# return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))
|
| 340 |
+
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
|
| 341 |
+
elif self.ndims == 2:
|
| 342 |
+
# return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))
|
| 343 |
+
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# def _eval_detJ(self, disp=None, weight=None):
|
| 347 |
+
# weight = 1
|
| 348 |
+
# if self.ndims==3:
|
| 349 |
+
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + (
|
| 350 |
+
# disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + (
|
| 351 |
+
# disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - (
|
| 352 |
+
# disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - (
|
| 353 |
+
# disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - (
|
| 354 |
+
# disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...])
|
| 355 |
+
# elif self.ndims==2:
|
| 356 |
+
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...])
|
| 357 |
+
|
| 358 |
+
# return detj * weight
|
| 359 |
+
|
| 360 |
+
def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
|
| 361 |
+
"""
|
| 362 |
+
disp: list length ndims
|
| 363 |
+
disp[i] is derivative wrt spatial dim i (forward diff),
|
| 364 |
+
tensor shape [B, C=ndims, ...]
|
| 365 |
+
add_identity: True if y_pred is displacement u and phi=x+u
|
| 366 |
+
spacing: voxel spacing (or 1.0). If you care about physical units,
|
| 367 |
+
divide derivatives by spacing (and dist). Sign won't change.
|
| 368 |
+
"""
|
| 369 |
+
# Optional scaling (won't affect sign as long as spacing>0)
|
| 370 |
+
if spacing != 1.0:
|
| 371 |
+
disp = [d / spacing for d in disp]
|
| 372 |
+
|
| 373 |
+
if self.ndims == 2:
|
| 374 |
+
dux_dx = disp[0][:, 0, ...]
|
| 375 |
+
duy_dx = disp[0][:, 1, ...]
|
| 376 |
+
dux_dy = disp[1][:, 0, ...]
|
| 377 |
+
duy_dy = disp[1][:, 1, ...]
|
| 378 |
+
|
| 379 |
+
if add_identity:
|
| 380 |
+
j11 = 1.0 + dux_dx
|
| 381 |
+
j22 = 1.0 + duy_dy
|
| 382 |
+
else:
|
| 383 |
+
j11 = dux_dx
|
| 384 |
+
j22 = duy_dy
|
| 385 |
+
|
| 386 |
+
detj = j11 * j22 - dux_dy * duy_dx
|
| 387 |
+
return detj
|
| 388 |
+
|
| 389 |
+
elif self.ndims == 3:
|
| 390 |
+
dux_dx = disp[0][:, 0, ...]
|
| 391 |
+
duy_dx = disp[0][:, 1, ...]
|
| 392 |
+
duz_dx = disp[0][:, 2, ...]
|
| 393 |
+
|
| 394 |
+
dux_dy = disp[1][:, 0, ...]
|
| 395 |
+
duy_dy = disp[1][:, 1, ...]
|
| 396 |
+
duz_dy = disp[1][:, 2, ...]
|
| 397 |
+
|
| 398 |
+
dux_dz = disp[2][:, 0, ...]
|
| 399 |
+
duy_dz = disp[2][:, 1, ...]
|
| 400 |
+
duz_dz = disp[2][:, 2, ...]
|
| 401 |
+
|
| 402 |
+
if add_identity:
|
| 403 |
+
j11 = 1.0 + dux_dx
|
| 404 |
+
j22 = 1.0 + duy_dy
|
| 405 |
+
j33 = 1.0 + duz_dz
|
| 406 |
+
else:
|
| 407 |
+
j11 = dux_dx
|
| 408 |
+
j22 = duy_dy
|
| 409 |
+
j33 = duz_dz
|
| 410 |
+
|
| 411 |
+
j12 = dux_dy; j13 = dux_dz
|
| 412 |
+
j21 = duy_dx; j23 = duy_dz
|
| 413 |
+
j31 = duz_dx; j32 = duz_dy
|
| 414 |
+
|
| 415 |
+
detj = (
|
| 416 |
+
j11 * (j22 * j33 - j23 * j32)
|
| 417 |
+
- j12 * (j21 * j33 - j23 * j31)
|
| 418 |
+
+ j13 * (j21 * j32 - j22 * j31)
|
| 419 |
+
)
|
| 420 |
+
return detj
|
| 421 |
+
|
| 422 |
+
else:
|
| 423 |
+
raise ValueError(f"Unsupported ndims={self.ndims}")
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def forward(self, y_pred=None,x_in=None, img=None, msk=None):
|
| 427 |
+
reg_loss = 0
|
| 428 |
+
act=torch.nn.ReLU(inplace=True)
|
| 429 |
+
|
| 430 |
+
dg = 1
|
| 431 |
+
if img is not None:
|
| 432 |
+
dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
|
| 433 |
+
if msk is not None:
|
| 434 |
+
dg = dg * msk
|
| 435 |
+
|
| 436 |
+
if 'l1' in self.penalty:
|
| 437 |
+
df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| 438 |
+
reg_loss += sum(df) / len(df)
|
| 439 |
+
|
| 440 |
+
if 'l2' in self.penalty:
|
| 441 |
+
df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| 442 |
+
reg_loss += torch.sqrt(sum(df) / len(df))
|
| 443 |
+
|
| 444 |
+
if 'negdetj' in self.penalty:
|
| 445 |
+
df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0])
|
| 446 |
+
reg_loss += 0.5*df
|
| 447 |
+
if 'range' in self.penalty:
|
| 448 |
+
reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+
|
| 449 |
+
if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
|
| 450 |
+
mean_dim=list(range(1, self.ndims + 2))
|
| 451 |
+
dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
|
| 452 |
+
dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
|
| 453 |
+
|
| 454 |
+
y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
|
| 455 |
+
x_in = x_in if isinstance(x_in,list) else [x_in]
|
| 456 |
+
if 'std' in self.penalty:
|
| 457 |
+
reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
|
| 458 |
+
if 'param' in self.penalty:
|
| 459 |
+
for id, d in enumerate(self.dist):
|
| 460 |
+
df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
|
| 461 |
+
reg_loss += 1 * (df) / len(self.dist)
|
| 462 |
+
|
| 463 |
+
if 'detj' in self.penalty:
|
| 464 |
+
df = torch.mean(torch.abs(
|
| 465 |
+
torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
|
| 466 |
+
reg_loss += 0.5*df
|
| 467 |
+
|
| 468 |
+
return reg_loss
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def avg_std_skew_kurt(array,ndims=2):
|
| 472 |
+
dim = list(range(2, ndims + 2))
|
| 473 |
+
mean = torch.mean(array,dim=dim)
|
| 474 |
+
diffs = array - mean
|
| 475 |
+
var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
|
| 476 |
+
std = torch.pow(var, 0.5)
|
| 477 |
+
zscores = diffs / std
|
| 478 |
+
skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
|
| 479 |
+
kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
|
| 480 |
+
return [mean,std,skews,kurtoses]
|
| 481 |
+
|
| 482 |
+
def grad_std(array,ndims=2):
|
| 483 |
+
dim = list(range(2, ndims + 2))
|
| 484 |
+
array=torch.clamp(array,min=-0.8,max=0.8)
|
| 485 |
+
dim0=list(range(1,ndims+2))
|
| 486 |
+
std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
|
| 487 |
+
return std
|
| 488 |
+
|
| 489 |
+
def avg_std(array,ndims=2):
|
| 490 |
+
dim = list(range(2, ndims + 2))
|
| 491 |
+
return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
if __name__ == "__main__":
|
| 495 |
+
# ndims=2
|
| 496 |
+
# dist=[16,32]
|
| 497 |
+
# ddf = torch.rand(1,2,128,128)
|
| 498 |
+
# # ddf[:,:,0,:]=ddf[:,:,0,:]-1
|
| 499 |
+
# # ddf[:,:,1,:]=ddf[:,:,1,:]+1
|
| 500 |
+
# # ddf[:,:,0,0]=ddf[:,:,0,0] -1
|
| 501 |
+
# # ddf[:,:,1,1]=ddf[:,:,1,1] +1
|
| 502 |
+
# # ddf[:,0,0,1]=ddf[:,0,0,1] +1
|
| 503 |
+
# # ddf[:,1,0,1]=ddf[:,1,0,1] -1
|
| 504 |
+
# # ddf[:,0,0,1]=ddf[:,0,0,1] -1
|
| 505 |
+
# # ddf[:,1,0,1]=ddf[:,1,0,1] +1
|
| 506 |
+
# # ddf[:,1,1,0]=ddf[:,1,1,0] -1
|
| 507 |
+
# # ddf[:,0,1,0]=ddf[:,0,1,0] +1
|
| 508 |
+
# ddf=ddf
|
| 509 |
+
# img = torch.rand(1,1,128,128)
|
| 510 |
+
# x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
|
| 511 |
+
# x_in=[torch.tensor(x_in).type(torch.float32),0.]
|
| 512 |
+
|
| 513 |
+
# Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
|
| 514 |
+
# loss_detj = Loss_detj(ddf,x_in,img)
|
| 515 |
+
# print(loss_detj)
|
| 516 |
+
|
| 517 |
+
size = 128
|
| 518 |
+
smooth = True
|
| 519 |
+
# smooth = False
|
| 520 |
+
img3d = torch.empty(1,1,size,size,size).uniform_(0,1)
|
| 521 |
+
img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
|
| 522 |
+
# img3d_t = img3d.clone().detach()
|
| 523 |
+
# img3d_t = torch.zeros_like(img3d)
|
| 524 |
+
translation = 2
|
| 525 |
+
start = 0
|
| 526 |
+
end = 32
|
| 527 |
+
# img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
|
| 528 |
+
# img3d_t[:,:,:,translation:,translation:] = img3d[:,:,:,:size-translation,:size-translation]
|
| 529 |
+
img3d_t[:,:,:,:,translation:] = img3d[:,:,:,:,:size-translation]
|
| 530 |
+
# img3d_t[:,:,start:end,start:end,start:end] = img3d[:,:,start+translation:end+translation,start+translation:end+translation,start+translation:end+translation]
|
| 531 |
+
img3d_t = img3d_t
|
| 532 |
+
loss_ncc = LNCC(smooth=smooth,central=True)
|
| 533 |
+
loss_sim = loss_ncc(img3d, img3d_t)
|
| 534 |
+
print(loss_sim)
|
Diffusion/losses_ncc0.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
losses for DRDM
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sys
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
EPS=1e-7
|
| 12 |
+
|
| 13 |
+
# eps_scale = 10e-5
|
| 14 |
+
# eps_scale = 10e-4
|
| 15 |
+
# eps_scale = 1e-4
|
| 16 |
+
eps_scale = 1e-5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LMSE(torch.nn.Module):
|
| 21 |
+
"""
|
| 22 |
+
Labeled Mean Square Error (LMSE)
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, eps=1e-7, relate_eps=5e-1, win=None, smooth=False):
|
| 26 |
+
super(LMSE, self).__init__()
|
| 27 |
+
self.eps = eps
|
| 28 |
+
self.relate_eps = relate_eps
|
| 29 |
+
self.ndims = 3
|
| 30 |
+
self.smooth = smooth
|
| 31 |
+
self.win = win
|
| 32 |
+
# Set window size
|
| 33 |
+
if self.win is None:
|
| 34 |
+
self.win = [5] * self.ndims
|
| 35 |
+
if smooth:
|
| 36 |
+
self.kernels = self._build_kernel(std=0.0)
|
| 37 |
+
|
| 38 |
+
def _build_kernel(self, std=0.0):
|
| 39 |
+
if std == 0.0:
|
| 40 |
+
return torch.ones([1, 1, *self.win])
|
| 41 |
+
else:
|
| 42 |
+
tail = int(np.ceil(std)) * 3
|
| 43 |
+
k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
|
| 44 |
+
kernel = k / torch.sum(k)
|
| 45 |
+
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| 46 |
+
return kernel.unsqueeze(0).unsqueeze(0)
|
| 47 |
+
|
| 48 |
+
def forward(self, I, J, label=None):
|
| 49 |
+
"""
|
| 50 |
+
Computes the labeled mean squared error between I and J (ref).
|
| 51 |
+
If label is provided, computes the MSE only over the labeled regions.
|
| 52 |
+
"""
|
| 53 |
+
padding = [(w-1) // 2 for w in self.win]
|
| 54 |
+
if self.smooth:
|
| 55 |
+
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
|
| 56 |
+
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
|
| 57 |
+
mse = (I - J) ** 2
|
| 58 |
+
if self.relate_eps is not None:
|
| 59 |
+
mse = mse/((J**2) + self.relate_eps)
|
| 60 |
+
if label is not None:
|
| 61 |
+
label = label.float()
|
| 62 |
+
mse = mse * label
|
| 63 |
+
mse_sum = torch.sum(mse, dim=(2, 3, 4))
|
| 64 |
+
label_sum = torch.sum(label, dim=(2, 3, 4)) + self.eps
|
| 65 |
+
loss = torch.mean(mse_sum / label_sum)
|
| 66 |
+
else:
|
| 67 |
+
loss = torch.mean(mse)
|
| 68 |
+
return loss
|
| 69 |
+
|
| 70 |
+
class LNCC(torch.nn.Module):
|
| 71 |
+
"""
|
| 72 |
+
Local (over window) normalized cross-correlation (LNCC)
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, win=None, num_ch=1, eps=1e-7, central=True, smooth=False):
|
| 76 |
+
super(LNCC, self).__init__()
|
| 77 |
+
self.win = win
|
| 78 |
+
self.eps = eps
|
| 79 |
+
self.central = central
|
| 80 |
+
self.ndims = 3
|
| 81 |
+
self.strides = [1] * (self.ndims + 2)
|
| 82 |
+
self.smooth = smooth
|
| 83 |
+
|
| 84 |
+
# Set window size
|
| 85 |
+
if self.win is None:
|
| 86 |
+
self.win = [11] * self.ndims
|
| 87 |
+
|
| 88 |
+
if smooth:
|
| 89 |
+
self.kernels = self._build_kernel(std=0.5)
|
| 90 |
+
self.sum_filt = self._build_kernel(std=0.0)
|
| 91 |
+
|
| 92 |
+
def _build_kernel(self, std=0.0):
|
| 93 |
+
if std == 0.0:
|
| 94 |
+
return torch.ones([1, 1, *self.win])
|
| 95 |
+
else:
|
| 96 |
+
tail = int(np.ceil(std)) * 3
|
| 97 |
+
k = torch.exp(-0.5 * torch.arange(-tail, tail + 1, dtype=torch.float32) ** 2 / std ** 2)
|
| 98 |
+
kernel = k / torch.sum(k)
|
| 99 |
+
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| 100 |
+
return kernel.unsqueeze(0).unsqueeze(0)
|
| 101 |
+
|
| 102 |
+
def lncc(self, I, J, label=None):
|
| 103 |
+
self.sum_filt = self.sum_filt.to(I.device)
|
| 104 |
+
padding = [(w-1) // 2 for w in self.win]
|
| 105 |
+
|
| 106 |
+
if self.smooth:
|
| 107 |
+
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=padding)
|
| 108 |
+
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=padding)
|
| 109 |
+
|
| 110 |
+
# Compute CC squares
|
| 111 |
+
I2 = I * I
|
| 112 |
+
J2 = J * J
|
| 113 |
+
IJ = I * J
|
| 114 |
+
|
| 115 |
+
if self.central:
|
| 116 |
+
# Compute local sums via convolution
|
| 117 |
+
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=padding)
|
| 118 |
+
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=padding)
|
| 119 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
|
| 120 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
|
| 121 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
|
| 122 |
+
|
| 123 |
+
# Compute cross-correlation
|
| 124 |
+
win_size = np.prod(self.win)
|
| 125 |
+
cross = IJ_sum - (I_sum * J_sum) / win_size
|
| 126 |
+
I_var = I2_sum - (I_sum * I_sum) / win_size
|
| 127 |
+
J_var = J2_sum - (J_sum * J_sum) / win_size
|
| 128 |
+
else:
|
| 129 |
+
# Compute local sums via convolution
|
| 130 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=padding)
|
| 131 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=padding)
|
| 132 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=padding)
|
| 133 |
+
|
| 134 |
+
cross = IJ_sum
|
| 135 |
+
I_var = I2_sum
|
| 136 |
+
J_var = J2_sum
|
| 137 |
+
|
| 138 |
+
cc = (cross * cross) / (I_var * J_var + self.eps)
|
| 139 |
+
if label is not None:
|
| 140 |
+
label = label.float()
|
| 141 |
+
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
| 142 |
+
|
| 143 |
+
return torch.mean(cc)
|
| 144 |
+
|
| 145 |
+
def forward(self, I, J, label=None):
|
| 146 |
+
return -self.lncc(I, J, label=label)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class NCC(torch.nn.Module):
|
| 151 |
+
# def __init__(self, eps_scale=10e-7,img_sz=256):
|
| 152 |
+
def __init__(self, eps_scale=10e-5,img_sz=256):
|
| 153 |
+
super(NCC, self).__init__()
|
| 154 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 155 |
+
# self.scale=10e4
|
| 156 |
+
self.scale=1e2
|
| 157 |
+
|
| 158 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| 159 |
+
if ddf_stn is None:
|
| 160 |
+
trm_pred=pred
|
| 161 |
+
else:
|
| 162 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 163 |
+
trm_pred = self.scale * trm_pred
|
| 164 |
+
inv_lab = self.scale * inv_lab
|
| 165 |
+
if mask is None:
|
| 166 |
+
loss_gen = torch.mean(torch.sum(trm_pred*inv_lab,dim=1)/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))
|
| 167 |
+
else:
|
| 168 |
+
batch_size = inv_lab.shape[0]
|
| 169 |
+
loss_gen = torch.sum(torch.sum(trm_pred*inv_lab,dim=1)*mask/(torch.sqrt(torch.sum(torch.square(trm_pred),dim=1)*torch.sum(torch.square(inv_lab),dim=1)+self.eps_scale)))/torch.sum(mask)/batch_size
|
| 170 |
+
return loss_gen
|
| 171 |
+
|
| 172 |
+
class MRSE(torch.nn.Module):
|
| 173 |
+
def __init__(self, eps_scale=eps_scale,img_sz=256):
|
| 174 |
+
super(MRSE, self).__init__()
|
| 175 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 176 |
+
self.scale = 10e1
|
| 177 |
+
|
| 178 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None,mask=None):
|
| 179 |
+
if ddf_stn is None:
|
| 180 |
+
trm_pred=pred
|
| 181 |
+
else:
|
| 182 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 183 |
+
trm_pred = self.scale * trm_pred
|
| 184 |
+
inv_lab = self.scale * inv_lab
|
| 185 |
+
if mask is None:
|
| 186 |
+
loss_gen = torch.mean(
|
| 187 |
+
torch.sum(torch.square(trm_pred + inv_lab), dim=1)
|
| 188 |
+
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
batch_size = inv_lab.shape[0]
|
| 192 |
+
loss_gen = torch.sum(
|
| 193 |
+
torch.sum(torch.square(trm_pred + inv_lab), dim=1) * mask
|
| 194 |
+
/ (torch.sum(torch.square(inv_lab), dim=1) + self.eps_scale)
|
| 195 |
+
)/torch.sum(mask)/batch_size
|
| 196 |
+
return loss_gen/1
|
| 197 |
+
|
| 198 |
+
class RMSE(torch.nn.Module):
|
| 199 |
+
def __init__(self, eps_scale=eps_scale,img_sz=256,ndims=2):
|
| 200 |
+
super(RMSE, self).__init__()
|
| 201 |
+
self.eps_scale=eps_scale#*img_sz/256
|
| 202 |
+
self.ndims=ndims
|
| 203 |
+
|
| 204 |
+
def forward(self,pred,inv_lab=None,ddf_stn=None):
|
| 205 |
+
if ddf_stn is None:
|
| 206 |
+
trm_pred=pred
|
| 207 |
+
else:
|
| 208 |
+
trm_pred=-ddf_stn(pred, inv_lab)
|
| 209 |
+
loss_gen = torch.mean(torch.mean(torch.sum(torch.square(trm_pred - inv_lab), dim=1),
|
| 210 |
+
dim=list(range(1, 1 + self.ndims))) / (
|
| 211 |
+
torch.mean(torch.sum(torch.square(inv_lab), dim=1), dim=list(range(1, 1 + self.ndims))) + self.eps_scale))
|
| 212 |
+
return loss_gen
|
| 213 |
+
# loss_gen = torch.mean(torch.mean(torch.sum(torch.square(ddf_stn(pre_dvf_I, dvf_I) + dvf_I), dim=1),dim=list(range(1,1+ndims))) / (torch.mean(torch.sum(torch.square(dvf_I), dim=1),dim=list(range(1,1+ndims))) + EPS))
|
| 214 |
+
|
| 215 |
+
class Grad(torch.nn.Module):
|
| 216 |
+
"""
|
| 217 |
+
N-D gradient loss
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=2, apear_scale=4, dist=1, sign=1,waive_thresh=10**-5):
|
| 221 |
+
super(Grad, self).__init__()
|
| 222 |
+
self.penalty = penalty
|
| 223 |
+
self.eps = eps
|
| 224 |
+
self.outrange_weight = outrange_weight
|
| 225 |
+
self.detj_weight=detj_weight
|
| 226 |
+
self.apear_scale = apear_scale
|
| 227 |
+
self.ndims=ndims
|
| 228 |
+
self.max_sz = torch.reshape(torch.tensor([outrange_thresh]*ndims, dtype=torch.float32) , [1]+[ndims]+[1]*(ndims))
|
| 229 |
+
self.act = torch.nn.ReLU(inplace=False)
|
| 230 |
+
self.dist=dist
|
| 231 |
+
self.sign=sign
|
| 232 |
+
self.waive_thresh=waive_thresh
|
| 233 |
+
|
| 234 |
+
def _diffs(self, y,dist=None):
|
| 235 |
+
if dist is None:
|
| 236 |
+
dist=self.dist
|
| 237 |
+
# vol_shape = y.size()[2:]
|
| 238 |
+
# vol_shape = y.get_shape().as_list()[1:-1]
|
| 239 |
+
# ndims = len(vol_shape)
|
| 240 |
+
|
| 241 |
+
df = [None] * self.ndims
|
| 242 |
+
for i in range(self.ndims):
|
| 243 |
+
d = i + 2
|
| 244 |
+
# permute dimensions to put the ith dimension first
|
| 245 |
+
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| 246 |
+
yp = y.permute(r)
|
| 247 |
+
dfi = (yp[dist:, ...] - yp[:-dist, ...])/float(dist)
|
| 248 |
+
|
| 249 |
+
# permute back
|
| 250 |
+
# note: this might not be necessary for this loss specifically,
|
| 251 |
+
# since the results are just summed over anyway.
|
| 252 |
+
r = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| 253 |
+
df[i] = dfi.permute(r)
|
| 254 |
+
return df
|
| 255 |
+
|
| 256 |
+
def _eq_diffs(self, y,dist=None):
|
| 257 |
+
if dist is None:
|
| 258 |
+
dist=self.dist
|
| 259 |
+
# vol_shape = y.get_shape().as_list()[1:-1]
|
| 260 |
+
vol_shape = y.size()[2:]
|
| 261 |
+
ndims = len(vol_shape)
|
| 262 |
+
pad = [0, 0] * (ndims + 1) +[dist, 0]
|
| 263 |
+
pad1 = [0, 0] * (ndims + 1) +[0, dist]
|
| 264 |
+
# df = [None, None] * ndims
|
| 265 |
+
df = [None] * ndims
|
| 266 |
+
for i in range(ndims):
|
| 267 |
+
d = i + 2
|
| 268 |
+
r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| 269 |
+
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| 270 |
+
yt = y.permute(r)
|
| 271 |
+
dy=(yt[dist:, ...] - yt[:-dist, ...])/float(dist)
|
| 272 |
+
df[i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
|
| 273 |
+
# df[2*i] = (F.pad(dy, pad,mode='constant',value=0)).permute(ri)
|
| 274 |
+
# df[2*i+1] = (F.pad(dy, pad1, mode='constant', value=0)).permute(ri)
|
| 275 |
+
y.permute(ri)
|
| 276 |
+
return df
|
| 277 |
+
|
| 278 |
+
def _weighted_diffs_error(self, y,dist=None,w=None,expect=None,mean_dim=None):
|
| 279 |
+
if dist is None:
|
| 280 |
+
dist=self.dist
|
| 281 |
+
vol_shape = y.size()[2:]
|
| 282 |
+
ndims = len(vol_shape)
|
| 283 |
+
df = [None] * ndims
|
| 284 |
+
|
| 285 |
+
for i in range(ndims):
|
| 286 |
+
d = i + 2
|
| 287 |
+
r=[d, *range(d), *range(d + 1, ndims + 2)]
|
| 288 |
+
ri=[*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
|
| 289 |
+
yt = y.permute(r)
|
| 290 |
+
wt = w.permute(r)
|
| 291 |
+
dy=(torch.abs(yt[dist:, ...] - yt[:-dist, ...])-expect.permute(r))*(wt[dist:, ...]*wt[:-dist, ...])
|
| 292 |
+
df[i] = torch.mean((dy).permute(ri),dim=mean_dim,keepdim=True)
|
| 293 |
+
y.permute(ri)
|
| 294 |
+
w.permute(ri)
|
| 295 |
+
return df
|
| 296 |
+
|
| 297 |
+
def _outl_dist(self, y,range_thresh=0.2):
|
| 298 |
+
self.device = y.device
|
| 299 |
+
vol_shape = y.size()[2:]
|
| 300 |
+
self.max_sz=self.max_sz.to(self.device)
|
| 301 |
+
act=torch.nn.ReLU(inplace=True)
|
| 302 |
+
loss=0.
|
| 303 |
+
for i in range(self.ndims):
|
| 304 |
+
d = i + 2
|
| 305 |
+
# permute dimensions to put the ith dimension first
|
| 306 |
+
r = [d, *range(d), *range(d + 1, self.ndims + 2)]
|
| 307 |
+
ri = [*range(1, d + 1), 0, *range(d + 1, self.ndims + 2)]
|
| 308 |
+
yt = y.permute(r)
|
| 309 |
+
loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])))+torch.mean(torch.square(act(yt[-1,:,i, ...]-range_thresh)))
|
| 310 |
+
# loss += torch.mean(torch.square(act(-range_thresh-yt[0,:,i, ...])+act(yt[-1,:,i, ...]-range_thresh)))
|
| 311 |
+
y.permute(ri)
|
| 312 |
+
return loss/self.ndims
|
| 313 |
+
|
| 314 |
+
def _center_dist(self, y):
|
| 315 |
+
self.device = y.device
|
| 316 |
+
vol_shape = y.size()[2:]
|
| 317 |
+
self.max_sz=self.max_sz.to(self.device)
|
| 318 |
+
select_loc = [s // 2 for s in vol_shape]
|
| 319 |
+
if self.ndims==3:
|
| 320 |
+
# return torch.mean(self.act(torch.abs(y[:,:, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz))
|
| 321 |
+
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1], select_loc[2]]) - self.max_sz)))
|
| 322 |
+
elif self.ndims == 2:
|
| 323 |
+
# return torch.mean(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz))
|
| 324 |
+
return torch.mean(torch.square(self.act(torch.abs(y[:, :, select_loc[0], select_loc[1]]) - self.max_sz)))
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# def _eval_detJ(self, disp=None, weight=None):
|
| 328 |
+
# weight = 1
|
| 329 |
+
# if self.ndims==3:
|
| 330 |
+
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...] * disp[2][:, 2, ...]) + (
|
| 331 |
+
# disp[0][:, 1, ...] * disp[1][:, 2, ...] * disp[2][:, 0, ...]) + (
|
| 332 |
+
# disp[0][:, 2, ...] * disp[1][:, 0, ...] * disp[2][:, 1, ...]) - (
|
| 333 |
+
# disp[0][:, 2, ...] * disp[1][:, 1, ...] * disp[2][:, 0, ...]) - (
|
| 334 |
+
# disp[0][:, 0, ...] * disp[1][:, 2, ...] * disp[2][:, 1, ...]) - (
|
| 335 |
+
# disp[0][:, 1, ...] * disp[1][:, 0, ...] * disp[2][:, 2, ...])
|
| 336 |
+
# elif self.ndims==2:
|
| 337 |
+
# detj = (disp[0][:, 0, ...] * disp[1][:, 1, ...]) - (disp[0][:, 1, ...] * disp[1][:, 0, ...])
|
| 338 |
+
|
| 339 |
+
# return detj * weight
|
| 340 |
+
|
| 341 |
+
def _eval_detJ(self, disp, add_identity=True, spacing=1.0):
|
| 342 |
+
"""
|
| 343 |
+
disp: list length ndims
|
| 344 |
+
disp[i] is derivative wrt spatial dim i (forward diff),
|
| 345 |
+
tensor shape [B, C=ndims, ...]
|
| 346 |
+
add_identity: True if y_pred is displacement u and phi=x+u
|
| 347 |
+
spacing: voxel spacing (or 1.0). If you care about physical units,
|
| 348 |
+
divide derivatives by spacing (and dist). Sign won't change.
|
| 349 |
+
"""
|
| 350 |
+
# Optional scaling (won't affect sign as long as spacing>0)
|
| 351 |
+
if spacing != 1.0:
|
| 352 |
+
disp = [d / spacing for d in disp]
|
| 353 |
+
|
| 354 |
+
if self.ndims == 2:
|
| 355 |
+
dux_dx = disp[0][:, 0, ...]
|
| 356 |
+
duy_dx = disp[0][:, 1, ...]
|
| 357 |
+
dux_dy = disp[1][:, 0, ...]
|
| 358 |
+
duy_dy = disp[1][:, 1, ...]
|
| 359 |
+
|
| 360 |
+
if add_identity:
|
| 361 |
+
j11 = 1.0 + dux_dx
|
| 362 |
+
j22 = 1.0 + duy_dy
|
| 363 |
+
else:
|
| 364 |
+
j11 = dux_dx
|
| 365 |
+
j22 = duy_dy
|
| 366 |
+
|
| 367 |
+
detj = j11 * j22 - dux_dy * duy_dx
|
| 368 |
+
return detj
|
| 369 |
+
|
| 370 |
+
elif self.ndims == 3:
|
| 371 |
+
dux_dx = disp[0][:, 0, ...]
|
| 372 |
+
duy_dx = disp[0][:, 1, ...]
|
| 373 |
+
duz_dx = disp[0][:, 2, ...]
|
| 374 |
+
|
| 375 |
+
dux_dy = disp[1][:, 0, ...]
|
| 376 |
+
duy_dy = disp[1][:, 1, ...]
|
| 377 |
+
duz_dy = disp[1][:, 2, ...]
|
| 378 |
+
|
| 379 |
+
dux_dz = disp[2][:, 0, ...]
|
| 380 |
+
duy_dz = disp[2][:, 1, ...]
|
| 381 |
+
duz_dz = disp[2][:, 2, ...]
|
| 382 |
+
|
| 383 |
+
if add_identity:
|
| 384 |
+
j11 = 1.0 + dux_dx
|
| 385 |
+
j22 = 1.0 + duy_dy
|
| 386 |
+
j33 = 1.0 + duz_dz
|
| 387 |
+
else:
|
| 388 |
+
j11 = dux_dx
|
| 389 |
+
j22 = duy_dy
|
| 390 |
+
j33 = duz_dz
|
| 391 |
+
|
| 392 |
+
j12 = dux_dy; j13 = dux_dz
|
| 393 |
+
j21 = duy_dx; j23 = duy_dz
|
| 394 |
+
j31 = duz_dx; j32 = duz_dy
|
| 395 |
+
|
| 396 |
+
detj = (
|
| 397 |
+
j11 * (j22 * j33 - j23 * j32)
|
| 398 |
+
- j12 * (j21 * j33 - j23 * j31)
|
| 399 |
+
+ j13 * (j21 * j32 - j22 * j31)
|
| 400 |
+
)
|
| 401 |
+
return detj
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
raise ValueError(f"Unsupported ndims={self.ndims}")
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def forward(self, y_pred=None,x_in=None, img=None, msk=None):
|
| 408 |
+
reg_loss = 0
|
| 409 |
+
act=torch.nn.ReLU(inplace=True)
|
| 410 |
+
|
| 411 |
+
dg = 1
|
| 412 |
+
if img is not None:
|
| 413 |
+
dg = torch.exp(-self.apear_scale * sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img)]) / torch.sum(torch.square(0.2 + img), dim=1, keepdim=True))
|
| 414 |
+
if msk is not None:
|
| 415 |
+
dg = dg * msk
|
| 416 |
+
|
| 417 |
+
if 'l1' in self.penalty:
|
| 418 |
+
df = [torch.mean(dg*F.relu(torch.abs(f) - self.waive_thresh,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| 419 |
+
reg_loss += sum(df) / len(df)
|
| 420 |
+
|
| 421 |
+
if 'l2' in self.penalty:
|
| 422 |
+
df = [torch.mean(dg*F.relu(f * f - self.waive_thresh**2,inplace=True)) for f in self._eq_diffs(y_pred)]
|
| 423 |
+
reg_loss += torch.sqrt(sum(df) / len(df))
|
| 424 |
+
|
| 425 |
+
if 'negdetj' in self.penalty:
|
| 426 |
+
df = self.detj_weight*torch.mean(act(-self._eval_detJ(self._eq_diffs(y_pred,dist=1)))) # , dg[...,0])
|
| 427 |
+
reg_loss += 0.5*df
|
| 428 |
+
if 'range' in self.penalty:
|
| 429 |
+
reg_loss += self.outrange_weight * (self._center_dist(y_pred)) #self._outl_dist(y_pred))#+
|
| 430 |
+
if 'param' in self.penalty or 'detj' in self.penalty or 'std' in self.penalty:
|
| 431 |
+
mean_dim=list(range(1, self.ndims + 2))
|
| 432 |
+
dg = torch.sum(torch.abs(img),dim=1,keepdim=True)* torch.exp(-self.apear_scale * torch.nn.ReLU(inplace=True)(.1-sum([torch.sum(g * g, dim=1, keepdim=True) for g in self._eq_diffs(img,dist=3)]) / torch.sum(torch.square(.1 + img), dim=1, keepdim=True)))
|
| 433 |
+
dg = dg/(EPS+torch.mean(dg,dim=mean_dim,keepdim=True))
|
| 434 |
+
|
| 435 |
+
y_pred = torch.clamp(y_pred, min=-0.8, max=0.8)
|
| 436 |
+
x_in = x_in if isinstance(x_in,list) else [x_in]
|
| 437 |
+
if 'std' in self.penalty:
|
| 438 |
+
reg_loss += self.sign*torch.mean(torch.clamp(grad_std((y_pred-torch.mean(y_pred,dim=list(range(2,ndims+2)),keepdim=True))*dg), max=.2, min=0))
|
| 439 |
+
if 'param' in self.penalty:
|
| 440 |
+
for id, d in enumerate(self.dist):
|
| 441 |
+
df = torch.mean(torch.abs(sum(self._weighted_diffs_error(y_pred, dist=d, w=dg, expect=torch.abs(x_in[-1][:, id:id + 1, ...]),mean_dim=mean_dim))))
|
| 442 |
+
reg_loss += 1 * (df) / len(self.dist)
|
| 443 |
+
|
| 444 |
+
if 'detj' in self.penalty:
|
| 445 |
+
df = torch.mean(torch.abs(
|
| 446 |
+
torch.mean((torch.abs(self._eval_detJ(self._eq_diffs(y_pred, dist=1))) - torch.abs(x_in[0])) * dg, dim=mean_dim)))
|
| 447 |
+
reg_loss += 0.5*df
|
| 448 |
+
|
| 449 |
+
return reg_loss
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def avg_std_skew_kurt(array,ndims=2):
|
| 453 |
+
dim = list(range(2, ndims + 2))
|
| 454 |
+
mean = torch.mean(array,dim=dim)
|
| 455 |
+
diffs = array - mean
|
| 456 |
+
var = torch.mean(torch.pow(diffs, 2.0),dim=dim)
|
| 457 |
+
std = torch.pow(var, 0.5)
|
| 458 |
+
zscores = diffs / std
|
| 459 |
+
skews = torch.mean(torch.pow(zscores, 3.0),dim=dim)
|
| 460 |
+
kurtoses = torch.mean(torch.pow(zscores, 4.0),dim=dim) - 3.0
|
| 461 |
+
return [mean,std,skews,kurtoses]
|
| 462 |
+
|
| 463 |
+
def grad_std(array,ndims=2):
|
| 464 |
+
dim = list(range(2, ndims + 2))
|
| 465 |
+
array=torch.clamp(array,min=-0.8,max=0.8)
|
| 466 |
+
dim0=list(range(1,ndims+2))
|
| 467 |
+
std = torch.sqrt(torch.mean(torch.square(array - torch.mean(array, dim=dim, keepdim=True)), dim=dim0))
|
| 468 |
+
return std
|
| 469 |
+
|
| 470 |
+
def avg_std(array,ndims=2):
|
| 471 |
+
dim = list(range(2, ndims + 2))
|
| 472 |
+
return [torch.mean(array,dim=dim),grad_std(array,dim=dim)]
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
if __name__ == "__main__":
|
| 476 |
+
ndims=2
|
| 477 |
+
dist=[16,32]
|
| 478 |
+
ddf = torch.rand(1,2,128,128)
|
| 479 |
+
# ddf[:,:,0,:]=ddf[:,:,0,:]-1
|
| 480 |
+
# ddf[:,:,1,:]=ddf[:,:,1,:]+1
|
| 481 |
+
# ddf[:,:,0,0]=ddf[:,:,0,0] -1
|
| 482 |
+
# ddf[:,:,1,1]=ddf[:,:,1,1] +1
|
| 483 |
+
# ddf[:,0,0,1]=ddf[:,0,0,1] +1
|
| 484 |
+
# ddf[:,1,0,1]=ddf[:,1,0,1] -1
|
| 485 |
+
# ddf[:,0,0,1]=ddf[:,0,0,1] -1
|
| 486 |
+
# ddf[:,1,0,1]=ddf[:,1,0,1] +1
|
| 487 |
+
# ddf[:,1,1,0]=ddf[:,1,1,0] -1
|
| 488 |
+
# ddf[:,0,1,0]=ddf[:,0,1,0] +1
|
| 489 |
+
ddf=ddf
|
| 490 |
+
img = torch.rand(1,1,128,128)
|
| 491 |
+
x_in=np.reshape([0.2,0.3],newshape=[1,ndims]+[1]*ndims)
|
| 492 |
+
x_in=[torch.tensor(x_in).type(torch.float32),0.]
|
| 493 |
+
|
| 494 |
+
Loss_detj = Grad(penalty=['detj'],ndims=ndims,dist=dist)
|
| 495 |
+
loss_detj = Loss_detj(ddf,x_in,img)
|
| 496 |
+
print(loss_detj)
|
Diffusion/networks.py
ADDED
|
@@ -0,0 +1,1167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
def get_net(name="recresnet"):
|
| 8 |
+
name = name.lower()
|
| 9 |
+
if name == "recresacnet":
|
| 10 |
+
net = RecResACNet
|
| 11 |
+
elif name == "recmutattnnet":
|
| 12 |
+
net = RecMutAttnNet
|
| 13 |
+
elif name == "recmutattnnet0":
|
| 14 |
+
net = RecMutAttnNet0
|
| 15 |
+
elif name == "recmutattnnet1":
|
| 16 |
+
net = RecMutAttnNet1
|
| 17 |
+
elif name == "defrecmutattnnet":
|
| 18 |
+
net = DefRec_MutAttnNet
|
| 19 |
+
elif name == "recmutattnnet_contrastive":
|
| 20 |
+
net = RecMutAttnNet_contrastive
|
| 21 |
+
else:
|
| 22 |
+
net = None
|
| 23 |
+
return net
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def sinusoidal_embedding(n, d):
|
| 28 |
+
# Returns the standard positional embedding
|
| 29 |
+
embedding = torch.zeros(n, d)
|
| 30 |
+
wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
|
| 31 |
+
wk = wk.reshape((1, d))
|
| 32 |
+
t = torch.arange(n).reshape((n, 1))
|
| 33 |
+
embedding[:,::2] = torch.sin(t * wk[:,::2])
|
| 34 |
+
embedding[:,1::2] = torch.cos(t * wk[:,::2])
|
| 35 |
+
return embedding
|
| 36 |
+
|
| 37 |
+
class AtrousBlock(nn.Module):
|
| 38 |
+
def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
|
| 39 |
+
super(AtrousBlock, self).__init__()
|
| 40 |
+
# if 0 not in shape:
|
| 41 |
+
if normalize:
|
| 42 |
+
# print(shape)
|
| 43 |
+
# self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024
|
| 44 |
+
norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024
|
| 45 |
+
self.ln = norm(out_c,affine=True)
|
| 46 |
+
else:
|
| 47 |
+
self.ln = nn.Identity()
|
| 48 |
+
Conv=getattr(nn,'Conv%dd' % ndims)
|
| 49 |
+
if in_c!=out_c:
|
| 50 |
+
self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None
|
| 51 |
+
else:
|
| 52 |
+
self.conv0 = None
|
| 53 |
+
self.convs = nn.ModuleList([
|
| 54 |
+
Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
|
| 55 |
+
if ar>0 else Conv(out_c, out_c, 1, 1, 0)
|
| 56 |
+
for ar in atrous_rates
|
| 57 |
+
])
|
| 58 |
+
# self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding)
|
| 59 |
+
# self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding)
|
| 60 |
+
self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
|
| 61 |
+
# self.activation = nn.ReLU() if activation is None else activation
|
| 62 |
+
# self.activation = nn.ReLU()
|
| 63 |
+
self.normalize = normalize
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
if self.conv0 is not None:
|
| 67 |
+
x = self.conv0(x) #if self.conv0 is not None else x
|
| 68 |
+
x = self.ln(x) if self.normalize else x # jzheng 15/03/2024
|
| 69 |
+
out=nn.Identity()(x)
|
| 70 |
+
for conv in self.convs:
|
| 71 |
+
out = self.activation(out)
|
| 72 |
+
out = conv(out)
|
| 73 |
+
return self.activation(out+x)
|
| 74 |
+
|
| 75 |
+
# ==============================================
|
| 76 |
+
# Unconditional Network
|
| 77 |
+
# ==============================================
|
| 78 |
+
|
| 79 |
+
class RecResACNet(nn.Module):
|
| 80 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
|
| 81 |
+
super(RecResACNet, self).__init__()
|
| 82 |
+
|
| 83 |
+
self.dimension = ndims
|
| 84 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 85 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 86 |
+
|
| 87 |
+
# Sinusoidal embedding
|
| 88 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 89 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 90 |
+
self.time_embed.requires_grad_(False)
|
| 91 |
+
|
| 92 |
+
# First half
|
| 93 |
+
self.te1 = self._make_te(time_emb_dim, 1)
|
| 94 |
+
self.b1 = nn.Sequential(
|
| 95 |
+
AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
|
| 96 |
+
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
| 97 |
+
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
| 98 |
+
|
| 99 |
+
)
|
| 100 |
+
self.down1 = self.Conv(10, 10, 4, 2, 1)
|
| 101 |
+
|
| 102 |
+
self.te2 = self._make_te(time_emb_dim, 10)
|
| 103 |
+
self.b2 = nn.Sequential(
|
| 104 |
+
AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
|
| 105 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
|
| 106 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
|
| 107 |
+
)
|
| 108 |
+
self.down2 = self.Conv(20, 20, 4, 2, 1)
|
| 109 |
+
|
| 110 |
+
self.te3 = self._make_te(time_emb_dim, 20)
|
| 111 |
+
self.b3 = nn.Sequential(
|
| 112 |
+
AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
|
| 113 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
|
| 114 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
|
| 115 |
+
)
|
| 116 |
+
self.down3 = self.Conv(40, 40, 4, 2, 1)
|
| 117 |
+
|
| 118 |
+
# Bottleneck
|
| 119 |
+
self.te_mid = self._make_te(time_emb_dim, 40)
|
| 120 |
+
self.b_mid = nn.Sequential(
|
| 121 |
+
AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
|
| 122 |
+
AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
|
| 123 |
+
AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Second half
|
| 127 |
+
self.up1 = self.ConvT(40, 40, 4, 2, 1)
|
| 128 |
+
|
| 129 |
+
self.te4 = self._make_te(time_emb_dim, 80)
|
| 130 |
+
self.b4 = nn.Sequential(
|
| 131 |
+
AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
|
| 132 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| 133 |
+
AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.up2 = self.ConvT(20, 20, 4, 2, 1)
|
| 137 |
+
self.te5 = self._make_te(time_emb_dim, 40)
|
| 138 |
+
self.b5 = nn.Sequential(
|
| 139 |
+
AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| 140 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| 141 |
+
AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.up3 = self.ConvT(10, 10, 4, 2, 1)
|
| 145 |
+
self.te_out = self._make_te(time_emb_dim, 20)
|
| 146 |
+
self.b_out = nn.Sequential(
|
| 147 |
+
AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| 148 |
+
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
|
| 149 |
+
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.conv_out = self.Conv(10, ndims, 3, 1, 1)
|
| 153 |
+
|
| 154 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 155 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 156 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 157 |
+
zip(sample_coords, max_sz)], 1)
|
| 158 |
+
|
| 159 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 160 |
+
ref = self.ref_grid if ref is None else ref
|
| 161 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 162 |
+
# resample_mode = 'bicubic'
|
| 163 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 164 |
+
# padding_mode = "border"
|
| 165 |
+
|
| 166 |
+
if True:
|
| 167 |
+
# return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
|
| 168 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 169 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 170 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 171 |
+
align_corners=True)
|
| 172 |
+
|
| 173 |
+
def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
|
| 174 |
+
#
|
| 175 |
+
self.device = x.device
|
| 176 |
+
# [h, w] = x.size()[2:]
|
| 177 |
+
img_sz = x.size()[2:]
|
| 178 |
+
n = x.size()[0]
|
| 179 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 180 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 181 |
+
# [h,w]=img_sz
|
| 182 |
+
# self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2])
|
| 183 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 184 |
+
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0),
|
| 185 |
+
# [1, 2, h, w]).to(self.device)
|
| 186 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 187 |
+
[1, self.dimension]+list(img_sz)).to(self.device)
|
| 188 |
+
img = x
|
| 189 |
+
|
| 190 |
+
# x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
|
| 191 |
+
t = self.time_embed(t)
|
| 192 |
+
|
| 193 |
+
for rec_id in range(rec_num):
|
| 194 |
+
out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28)
|
| 195 |
+
out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14)
|
| 196 |
+
out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7)
|
| 197 |
+
|
| 198 |
+
out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3)
|
| 199 |
+
|
| 200 |
+
out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
|
| 201 |
+
out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7)
|
| 202 |
+
|
| 203 |
+
out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
|
| 204 |
+
out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14)
|
| 205 |
+
|
| 206 |
+
out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
|
| 207 |
+
out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28)
|
| 208 |
+
|
| 209 |
+
out = self.conv_out(out)
|
| 210 |
+
|
| 211 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 212 |
+
if rec_id == 0:
|
| 213 |
+
ddf = ddf_one
|
| 214 |
+
else:
|
| 215 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 216 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 217 |
+
|
| 218 |
+
return ddf
|
| 219 |
+
|
| 220 |
+
def _make_te(self, dim_in, dim_out):
|
| 221 |
+
# make time embedding
|
| 222 |
+
|
| 223 |
+
return nn.Sequential(
|
| 224 |
+
nn.Linear(dim_in, dim_out),
|
| 225 |
+
# nn.SiLU(),
|
| 226 |
+
nn.ReLU(),
|
| 227 |
+
nn.Linear(dim_out, dim_out)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# ==============================================
|
| 231 |
+
# Conditional Network
|
| 232 |
+
# ==============================================
|
| 233 |
+
|
| 234 |
+
class cross_attn(nn.Module):
|
| 235 |
+
def __init__(self, q, k, v, ndims=2):
|
| 236 |
+
self.q = q
|
| 237 |
+
self.k = k
|
| 238 |
+
self.v = v
|
| 239 |
+
self.ndims = ndims
|
| 240 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
|
| 241 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
|
| 242 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 243 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 244 |
+
|
| 245 |
+
def forward(self, x, y):
|
| 246 |
+
q = self.q(x)
|
| 247 |
+
k = self.k(y)
|
| 248 |
+
v = self.v(y)
|
| 249 |
+
attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
|
| 250 |
+
out = torch.matmul(attn, v)
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
class DefRec_MutAttnNet(nn.Module):
|
| 254 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 255 |
+
super(DefRec_MutAttnNet, self).__init__()
|
| 256 |
+
|
| 257 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 258 |
+
# self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 259 |
+
self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
|
| 260 |
+
self.conditional_input = conditional_input
|
| 261 |
+
self.num_heads = num_heads
|
| 262 |
+
self.text_feat_chn = text_feat_chn
|
| 263 |
+
|
| 264 |
+
self.dimension = ndims
|
| 265 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 266 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 267 |
+
self.copy = nn.Identity()
|
| 268 |
+
# Sinusoidal embedding
|
| 269 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 270 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 271 |
+
self.time_embed.requires_grad_(False)
|
| 272 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 273 |
+
self.down_layers = nn.ModuleList()
|
| 274 |
+
self.up_layers = nn.ModuleList()
|
| 275 |
+
self.ted_layers = nn.ModuleList()
|
| 276 |
+
self.teu_layers = nn.ModuleList()
|
| 277 |
+
self.block_down = nn.ModuleList()
|
| 278 |
+
self.block_up = nn.ModuleList()
|
| 279 |
+
if self.conditional_input:
|
| 280 |
+
self.block_down_cond = nn.ModuleList()
|
| 281 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 282 |
+
# self.fuse_conv1 = nn.ModuleList()
|
| 283 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 284 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 285 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 286 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 287 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 288 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 289 |
+
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 290 |
+
self.img_res = [res]*self.dimension
|
| 291 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 292 |
+
[1, self.dimension]+list(self.img_res))
|
| 293 |
+
|
| 294 |
+
for i in range(1, self.hier_num + 1):
|
| 295 |
+
j=-i
|
| 296 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 297 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 298 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 299 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 300 |
+
self.block_down.append(nn.Sequential(
|
| 301 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 302 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 303 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 304 |
+
))
|
| 305 |
+
if self.conditional_input:
|
| 306 |
+
self.block_down_cond.append(nn.Sequential(
|
| 307 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 308 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 309 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 310 |
+
))
|
| 311 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 312 |
+
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 313 |
+
if i==self.hier_num:
|
| 314 |
+
k=j
|
| 315 |
+
else:
|
| 316 |
+
k=j-1
|
| 317 |
+
self.block_up.append(nn.Sequential(
|
| 318 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 319 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 320 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 321 |
+
))
|
| 322 |
+
|
| 323 |
+
# Bottleneck
|
| 324 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 325 |
+
self.b_mid = nn.Sequential(
|
| 326 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 327 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 328 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 332 |
+
|
| 333 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 334 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 335 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 336 |
+
zip(sample_coords, max_sz)], 1)
|
| 337 |
+
|
| 338 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 339 |
+
ref = self.ref_grid if ref is None else ref
|
| 340 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 341 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 342 |
+
|
| 343 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 344 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 345 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 346 |
+
align_corners=True)
|
| 347 |
+
|
| 348 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 349 |
+
self.device = x.device
|
| 350 |
+
img_sz = x.size()[2:]
|
| 351 |
+
n = x.size()[0]
|
| 352 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 353 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 354 |
+
|
| 355 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 356 |
+
if list(img_sz) != self.img_res:
|
| 357 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 358 |
+
# print(img_sz, self.img_res)
|
| 359 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 360 |
+
[1, self.dimension]+list(img_sz))
|
| 361 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 362 |
+
|
| 363 |
+
img = x
|
| 364 |
+
if self.conditional_input:
|
| 365 |
+
tgt = y
|
| 366 |
+
# encode the conditional input
|
| 367 |
+
tgt_down_list = []
|
| 368 |
+
for i in range(self.hier_num):
|
| 369 |
+
# out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
|
| 370 |
+
if self.conditional_input:
|
| 371 |
+
tgt = self.block_down_cond[i](tgt)
|
| 372 |
+
tgt_down_list.append(self.copy(tgt))
|
| 373 |
+
tgt = self.down_layers[i](tgt)
|
| 374 |
+
tgt_mid = self.copy(tgt)
|
| 375 |
+
tgt_shape = tgt_mid.shape
|
| 376 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 377 |
+
tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 378 |
+
|
| 379 |
+
t = [t0.to(self.device) for t0 in t]
|
| 380 |
+
t = [t0 for _ in range(rec_num) for t0 in t]
|
| 381 |
+
for rec_id,time in enumerate(t):
|
| 382 |
+
t_emb = self.time_embed(time)
|
| 383 |
+
|
| 384 |
+
# for rec_id in range(rec_num):
|
| 385 |
+
# if self.conditional_input:
|
| 386 |
+
# tgt = y
|
| 387 |
+
enc_list = []
|
| 388 |
+
out = img
|
| 389 |
+
for i in range(self.hier_num):
|
| 390 |
+
out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
|
| 391 |
+
if self.conditional_input:
|
| 392 |
+
# tgt = self.block_down_cond[i](tgt)
|
| 393 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
|
| 394 |
+
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 395 |
+
enc_list.append(out)
|
| 396 |
+
out = self.down_layers[i](out)
|
| 397 |
+
# if self.conditional_input:
|
| 398 |
+
# tgt = self.down_layers[i](tgt)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
|
| 402 |
+
if self.conditional_input:
|
| 403 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 404 |
+
out_shape = out.shape
|
| 405 |
+
# tgt_shape = tgt.shape
|
| 406 |
+
# # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 407 |
+
# tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 408 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
|
| 409 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 410 |
+
out = out + out_attn
|
| 411 |
+
|
| 412 |
+
if self.conditional_input:
|
| 413 |
+
if text is None:
|
| 414 |
+
text = self.text
|
| 415 |
+
text = text.to(self.device)
|
| 416 |
+
out_txt = self.img2txt(out) + text
|
| 417 |
+
out_txt = self.txt_proc(out_txt)
|
| 418 |
+
out_txt = self.txt2img(out_txt)
|
| 419 |
+
out = out + out_txt
|
| 420 |
+
|
| 421 |
+
for i in range(self.hier_num):
|
| 422 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 423 |
+
out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
|
| 424 |
+
|
| 425 |
+
out = self.conv_out(out)/128
|
| 426 |
+
|
| 427 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 428 |
+
if rec_id == 0:
|
| 429 |
+
ddf = ddf_one
|
| 430 |
+
else:
|
| 431 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 432 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 433 |
+
|
| 434 |
+
return ddf
|
| 435 |
+
|
| 436 |
+
def _make_te(self, dim_in, dim_out):
|
| 437 |
+
return nn.Sequential(
|
| 438 |
+
nn.Linear(dim_in, dim_out),
|
| 439 |
+
nn.ReLU(),
|
| 440 |
+
nn.Linear(dim_out, dim_out)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
class RecMutAttnNet1(nn.Module):
|
| 444 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 445 |
+
super(RecMutAttnNet1, self).__init__()
|
| 446 |
+
|
| 447 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 448 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 449 |
+
self.conditional_input = conditional_input
|
| 450 |
+
self.num_heads = num_heads
|
| 451 |
+
self.text_feat_chn = text_feat_chn
|
| 452 |
+
|
| 453 |
+
self.dimension = ndims
|
| 454 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 455 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 456 |
+
|
| 457 |
+
# Sinusoidal embedding
|
| 458 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 459 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 460 |
+
self.time_embed.requires_grad_(False)
|
| 461 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 462 |
+
self.down_layers = nn.ModuleList()
|
| 463 |
+
self.up_layers = nn.ModuleList()
|
| 464 |
+
self.ted_layers = nn.ModuleList()
|
| 465 |
+
self.teu_layers = nn.ModuleList()
|
| 466 |
+
self.block_down = nn.ModuleList()
|
| 467 |
+
if self.conditional_input:
|
| 468 |
+
self.block_down_cond = nn.ModuleList()
|
| 469 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 470 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 471 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 472 |
+
|
| 473 |
+
self.block_up = nn.ModuleList()
|
| 474 |
+
|
| 475 |
+
for i in range(1, self.hier_num + 1):
|
| 476 |
+
j=-i
|
| 477 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 478 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 479 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 480 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 481 |
+
self.block_down.append(nn.Sequential(
|
| 482 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 483 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 484 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 485 |
+
))
|
| 486 |
+
if self.conditional_input:
|
| 487 |
+
self.block_down_cond.append(nn.Sequential(
|
| 488 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 489 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 490 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 491 |
+
))
|
| 492 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 493 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 494 |
+
if i==self.hier_num:
|
| 495 |
+
k=j
|
| 496 |
+
else:
|
| 497 |
+
k=j-1
|
| 498 |
+
self.block_up.append(nn.Sequential(
|
| 499 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 500 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 501 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 502 |
+
))
|
| 503 |
+
|
| 504 |
+
# Bottleneck
|
| 505 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 506 |
+
self.b_mid = nn.Sequential(
|
| 507 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 508 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 509 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 513 |
+
|
| 514 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 515 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 516 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 517 |
+
zip(sample_coords, max_sz)], 1)
|
| 518 |
+
|
| 519 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 520 |
+
ref = self.ref_grid if ref is None else ref
|
| 521 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 522 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 523 |
+
|
| 524 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 525 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 526 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 527 |
+
align_corners=True)
|
| 528 |
+
|
| 529 |
+
def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
|
| 530 |
+
self.device = x.device
|
| 531 |
+
img_sz = x.size()[2:]
|
| 532 |
+
n = x.size()[0]
|
| 533 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 534 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 535 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 536 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 537 |
+
[1, self.dimension]+list(img_sz)).to(self.device)
|
| 538 |
+
img = x
|
| 539 |
+
t = self.time_embed(t)
|
| 540 |
+
|
| 541 |
+
for rec_id in range(rec_num):
|
| 542 |
+
if self.conditional_input:
|
| 543 |
+
tgt = y
|
| 544 |
+
enc_list = []
|
| 545 |
+
out = img
|
| 546 |
+
for i in range(self.hier_num):
|
| 547 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 548 |
+
if self.conditional_input:
|
| 549 |
+
tgt = self.block_down_cond[i](tgt)
|
| 550 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 551 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 552 |
+
enc_list.append(out)
|
| 553 |
+
out = self.down_layers[i](out)
|
| 554 |
+
if self.conditional_input:
|
| 555 |
+
tgt = self.down_layers[i](tgt)
|
| 556 |
+
|
| 557 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 558 |
+
if self.conditional_input:
|
| 559 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 560 |
+
out_shape = out.shape
|
| 561 |
+
tgt_shape = tgt.shape
|
| 562 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 563 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 564 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 565 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 566 |
+
out = out + out_attn
|
| 567 |
+
|
| 568 |
+
for i in range(self.hier_num):
|
| 569 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 570 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 571 |
+
|
| 572 |
+
out = self.conv_out(out)/128
|
| 573 |
+
|
| 574 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 575 |
+
if rec_id == 0:
|
| 576 |
+
ddf = ddf_one
|
| 577 |
+
else:
|
| 578 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 579 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 580 |
+
|
| 581 |
+
return ddf
|
| 582 |
+
|
| 583 |
+
def _make_te(self, dim_in, dim_out):
|
| 584 |
+
return nn.Sequential(
|
| 585 |
+
nn.Linear(dim_in, dim_out),
|
| 586 |
+
nn.ReLU(),
|
| 587 |
+
nn.Linear(dim_out, dim_out)
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
class RecMutAttnNet(nn.Module):
|
| 591 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 592 |
+
super(RecMutAttnNet, self).__init__()
|
| 593 |
+
|
| 594 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 595 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 596 |
+
self.conditional_input = conditional_input
|
| 597 |
+
self.num_heads = num_heads
|
| 598 |
+
self.text_feat_chn = text_feat_chn
|
| 599 |
+
|
| 600 |
+
self.dimension = ndims
|
| 601 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 602 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 603 |
+
|
| 604 |
+
# Sinusoidal embedding
|
| 605 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 606 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 607 |
+
self.time_embed.requires_grad_(False)
|
| 608 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 609 |
+
self.down_layers = nn.ModuleList()
|
| 610 |
+
self.up_layers = nn.ModuleList()
|
| 611 |
+
self.ted_layers = nn.ModuleList()
|
| 612 |
+
self.teu_layers = nn.ModuleList()
|
| 613 |
+
self.block_down = nn.ModuleList()
|
| 614 |
+
self.block_up = nn.ModuleList()
|
| 615 |
+
if self.conditional_input:
|
| 616 |
+
self.block_down_cond = nn.ModuleList()
|
| 617 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 618 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 619 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 620 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 621 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 622 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 623 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 624 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 625 |
+
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 626 |
+
self.img_res = [res]*self.dimension
|
| 627 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 628 |
+
[1, self.dimension]+list(self.img_res))
|
| 629 |
+
|
| 630 |
+
for i in range(1, self.hier_num + 1):
|
| 631 |
+
j=-i
|
| 632 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 633 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 634 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 635 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 636 |
+
self.block_down.append(nn.Sequential(
|
| 637 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 638 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 639 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 640 |
+
))
|
| 641 |
+
if self.conditional_input:
|
| 642 |
+
self.block_down_cond.append(nn.Sequential(
|
| 643 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 644 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 645 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 646 |
+
))
|
| 647 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 648 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 649 |
+
if i==self.hier_num:
|
| 650 |
+
k=j
|
| 651 |
+
else:
|
| 652 |
+
k=j-1
|
| 653 |
+
self.block_up.append(nn.Sequential(
|
| 654 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 655 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 656 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 657 |
+
))
|
| 658 |
+
|
| 659 |
+
# Bottleneck
|
| 660 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 661 |
+
self.b_mid = nn.Sequential(
|
| 662 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 663 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 664 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 668 |
+
|
| 669 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 670 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 671 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 672 |
+
zip(sample_coords, max_sz)], 1)
|
| 673 |
+
|
| 674 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 675 |
+
ref = self.ref_grid if ref is None else ref
|
| 676 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 677 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 678 |
+
|
| 679 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 680 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 681 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 682 |
+
align_corners=True)
|
| 683 |
+
|
| 684 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 685 |
+
self.device = x.device
|
| 686 |
+
img_sz = x.size()[2:]
|
| 687 |
+
n = x.size()[0]
|
| 688 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 689 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 690 |
+
|
| 691 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 692 |
+
if list(img_sz) != self.img_res:
|
| 693 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 694 |
+
# print(img_sz, self.img_res)
|
| 695 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 696 |
+
[1, self.dimension]+list(img_sz))
|
| 697 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 698 |
+
|
| 699 |
+
img = x
|
| 700 |
+
t = self.time_embed(t)
|
| 701 |
+
|
| 702 |
+
for rec_id in range(rec_num):
|
| 703 |
+
if self.conditional_input:
|
| 704 |
+
tgt = y
|
| 705 |
+
enc_list = []
|
| 706 |
+
out = img
|
| 707 |
+
for i in range(self.hier_num):
|
| 708 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 709 |
+
if self.conditional_input:
|
| 710 |
+
tgt = self.block_down_cond[i](tgt)
|
| 711 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 712 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 713 |
+
enc_list.append(out)
|
| 714 |
+
out = self.down_layers[i](out)
|
| 715 |
+
if self.conditional_input:
|
| 716 |
+
tgt = self.down_layers[i](tgt)
|
| 717 |
+
|
| 718 |
+
|
| 719 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 720 |
+
if self.conditional_input:
|
| 721 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 722 |
+
out_shape = out.shape
|
| 723 |
+
tgt_shape = tgt.shape
|
| 724 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 725 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 726 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 727 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 728 |
+
out = out + out_attn
|
| 729 |
+
|
| 730 |
+
if self.conditional_input:
|
| 731 |
+
if text is None:
|
| 732 |
+
text = self.text
|
| 733 |
+
text = text.to(self.device)
|
| 734 |
+
text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| 735 |
+
out_txt = self.img2txt(out) + text
|
| 736 |
+
out_txt = self.txt_proc(out_txt)
|
| 737 |
+
out_txt = self.txt2img(out_txt)
|
| 738 |
+
out = out + out_txt
|
| 739 |
+
|
| 740 |
+
for i in range(self.hier_num):
|
| 741 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 742 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 743 |
+
|
| 744 |
+
out = self.conv_out(out)/128
|
| 745 |
+
|
| 746 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 747 |
+
if rec_id == 0:
|
| 748 |
+
ddf = ddf_one
|
| 749 |
+
else:
|
| 750 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 751 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 752 |
+
|
| 753 |
+
return ddf
|
| 754 |
+
|
| 755 |
+
def _make_te(self, dim_in, dim_out):
|
| 756 |
+
return nn.Sequential(
|
| 757 |
+
nn.Linear(dim_in, dim_out),
|
| 758 |
+
nn.ReLU(),
|
| 759 |
+
nn.Linear(dim_out, dim_out)
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
class RecMutAttnNet_contrastive(nn.Module):
|
| 763 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 764 |
+
super(RecMutAttnNet_contrastive, self).__init__()
|
| 765 |
+
|
| 766 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 767 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 768 |
+
self.conditional_input = conditional_input
|
| 769 |
+
self.num_heads = num_heads
|
| 770 |
+
self.text_feat_chn = text_feat_chn
|
| 771 |
+
|
| 772 |
+
self.dimension = ndims
|
| 773 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 774 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 775 |
+
|
| 776 |
+
# Sinusoidal embedding
|
| 777 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 778 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 779 |
+
self.time_embed.requires_grad_(False)
|
| 780 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 781 |
+
self.down_layers = nn.ModuleList()
|
| 782 |
+
self.up_layers = nn.ModuleList()
|
| 783 |
+
self.ted_layers = nn.ModuleList()
|
| 784 |
+
self.teu_layers = nn.ModuleList()
|
| 785 |
+
self.block_down = nn.ModuleList()
|
| 786 |
+
self.block_up = nn.ModuleList()
|
| 787 |
+
if self.conditional_input:
|
| 788 |
+
self.block_down_cond = nn.ModuleList()
|
| 789 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 790 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 791 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 792 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 793 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 794 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 795 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 796 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 797 |
+
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 798 |
+
self.img_res = [res]*self.dimension
|
| 799 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 800 |
+
[1, self.dimension]+list(self.img_res))
|
| 801 |
+
|
| 802 |
+
for i in range(1, self.hier_num + 1):
|
| 803 |
+
j=-i
|
| 804 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 805 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 806 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 807 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 808 |
+
self.block_down.append(nn.Sequential(
|
| 809 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 810 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 811 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 812 |
+
))
|
| 813 |
+
if self.conditional_input:
|
| 814 |
+
self.block_down_cond.append(nn.Sequential(
|
| 815 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 816 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 817 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 818 |
+
))
|
| 819 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 820 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 821 |
+
if i==self.hier_num:
|
| 822 |
+
k=j
|
| 823 |
+
else:
|
| 824 |
+
k=j-1
|
| 825 |
+
self.block_up.append(nn.Sequential(
|
| 826 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 827 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 828 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 829 |
+
))
|
| 830 |
+
|
| 831 |
+
# Bottleneck
|
| 832 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 833 |
+
self.b_mid = nn.Sequential(
|
| 834 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 835 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 836 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 840 |
+
|
| 841 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 842 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 843 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 844 |
+
zip(sample_coords, max_sz)], 1)
|
| 845 |
+
|
| 846 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 847 |
+
ref = self.ref_grid if ref is None else ref
|
| 848 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 849 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 850 |
+
|
| 851 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 852 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 853 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 854 |
+
align_corners=True)
|
| 855 |
+
|
| 856 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 857 |
+
self.device = x.device
|
| 858 |
+
img_sz = x.size()[2:]
|
| 859 |
+
n = x.size()[0]
|
| 860 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 861 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 862 |
+
|
| 863 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 864 |
+
if list(img_sz) != self.img_res:
|
| 865 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 866 |
+
# print(img_sz, self.img_res)
|
| 867 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 868 |
+
[1, self.dimension]+list(img_sz))
|
| 869 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 870 |
+
|
| 871 |
+
img = x
|
| 872 |
+
t = self.time_embed(t)
|
| 873 |
+
|
| 874 |
+
for rec_id in range(rec_num):
|
| 875 |
+
if self.conditional_input:
|
| 876 |
+
tgt = y
|
| 877 |
+
enc_list = []
|
| 878 |
+
out = img
|
| 879 |
+
for i in range(self.hier_num):
|
| 880 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 881 |
+
if self.conditional_input:
|
| 882 |
+
tgt = self.block_down_cond[i](tgt)
|
| 883 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 884 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 885 |
+
enc_list.append(out)
|
| 886 |
+
out = self.down_layers[i](out)
|
| 887 |
+
if self.conditional_input:
|
| 888 |
+
tgt = self.down_layers[i](tgt)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 892 |
+
if self.conditional_input:
|
| 893 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 894 |
+
out_shape = out.shape
|
| 895 |
+
tgt_shape = tgt.shape
|
| 896 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 897 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 898 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 899 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 900 |
+
out = out + out_attn
|
| 901 |
+
|
| 902 |
+
if self.conditional_input:
|
| 903 |
+
if text is None:
|
| 904 |
+
text = self.text
|
| 905 |
+
text = text.to(self.device)
|
| 906 |
+
text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| 907 |
+
img_embd = self.global_maxpool(self.img2txt(out)).view(n, -1) # [B, 1024]
|
| 908 |
+
out_txt = self.img2txt(out) + text
|
| 909 |
+
out_txt = self.txt_proc(out_txt)
|
| 910 |
+
out_txt = self.txt2img(out_txt)
|
| 911 |
+
out = out + out_txt
|
| 912 |
+
|
| 913 |
+
for i in range(self.hier_num):
|
| 914 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 915 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 916 |
+
|
| 917 |
+
out = self.conv_out(out)/128
|
| 918 |
+
|
| 919 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 920 |
+
if rec_id == 0:
|
| 921 |
+
ddf = ddf_one
|
| 922 |
+
else:
|
| 923 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 924 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 925 |
+
|
| 926 |
+
return ddf, img_embd
|
| 927 |
+
|
| 928 |
+
def _make_te(self, dim_in, dim_out):
|
| 929 |
+
return nn.Sequential(
|
| 930 |
+
nn.Linear(dim_in, dim_out),
|
| 931 |
+
nn.ReLU(),
|
| 932 |
+
nn.Linear(dim_out, dim_out)
|
| 933 |
+
)
|
| 934 |
+
# class RecMutAttnNet(nn.Module):
|
| 935 |
+
# def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
|
| 936 |
+
# super(RecMutAttnNet, self).__init__()
|
| 937 |
+
|
| 938 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 939 |
+
# self.conditional_input = conditional_input
|
| 940 |
+
|
| 941 |
+
# self.dimension = ndims
|
| 942 |
+
# self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 943 |
+
# self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 944 |
+
|
| 945 |
+
# # Sinusoidal embedding
|
| 946 |
+
# self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 947 |
+
# self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 948 |
+
# self.time_embed.requires_grad_(False)
|
| 949 |
+
# self.hier_num = len(self.feat_channels) - 1
|
| 950 |
+
# self.down_layers = nn.ModuleList()
|
| 951 |
+
# self.up_layers = nn.ModuleList()
|
| 952 |
+
# self.ted_layers = nn.ModuleList()
|
| 953 |
+
# self.teu_layers = nn.ModuleList()
|
| 954 |
+
# self.block_down = nn.ModuleList()
|
| 955 |
+
# if self.conditional_input:
|
| 956 |
+
# self.block_down_cond = nn.ModuleList()
|
| 957 |
+
# self.fuse_conv0 = nn.ModuleList()
|
| 958 |
+
# self.fuse_conv1 = nn.ModuleList()
|
| 959 |
+
# self.block_up = nn.ModuleList()
|
| 960 |
+
|
| 961 |
+
# for i in range(1, self.hier_num + 1):
|
| 962 |
+
# j=-i
|
| 963 |
+
# self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 964 |
+
# self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 965 |
+
# self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 966 |
+
# self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 967 |
+
# self.block_down.append(nn.Sequential(
|
| 968 |
+
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 969 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 970 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 971 |
+
# ))
|
| 972 |
+
# if self.conditional_input:
|
| 973 |
+
# self.block_down_cond.append(nn.Sequential(
|
| 974 |
+
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 975 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 976 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 977 |
+
# ))
|
| 978 |
+
# self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 979 |
+
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 980 |
+
# if i==self.hier_num:
|
| 981 |
+
# k=j
|
| 982 |
+
# else:
|
| 983 |
+
# k=j-1
|
| 984 |
+
# self.block_up.append(nn.Sequential(
|
| 985 |
+
# AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 986 |
+
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 987 |
+
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 988 |
+
# ))
|
| 989 |
+
|
| 990 |
+
# # Bottleneck
|
| 991 |
+
# self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 992 |
+
# self.b_mid = nn.Sequential(
|
| 993 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 994 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 995 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 996 |
+
# )
|
| 997 |
+
|
| 998 |
+
# self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 999 |
+
|
| 1000 |
+
# def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 1001 |
+
# sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1002 |
+
# return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 1003 |
+
# zip(sample_coords, max_sz)], 1)
|
| 1004 |
+
|
| 1005 |
+
# def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 1006 |
+
# ref = self.ref_grid if ref is None else ref
|
| 1007 |
+
# img_sz = self.max_sz if img_sz is None else img_sz
|
| 1008 |
+
# resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 1009 |
+
|
| 1010 |
+
# return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 1011 |
+
# np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 1012 |
+
# [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 1013 |
+
# align_corners=True)
|
| 1014 |
+
|
| 1015 |
+
# def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
|
| 1016 |
+
# self.device = x.device
|
| 1017 |
+
# img_sz = x.size()[2:]
|
| 1018 |
+
# n = x.size()[0]
|
| 1019 |
+
# self.max_sz = [img_sz[0]] * self.dimension
|
| 1020 |
+
# ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 1021 |
+
# self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 1022 |
+
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 1023 |
+
# [1, self.dimension]+list(img_sz)).to(self.device)
|
| 1024 |
+
# img = x
|
| 1025 |
+
# t = self.time_embed(t)
|
| 1026 |
+
|
| 1027 |
+
# for rec_id in range(rec_num):
|
| 1028 |
+
# if self.conditional_input:
|
| 1029 |
+
# tgt = y
|
| 1030 |
+
# enc_list = []
|
| 1031 |
+
# out = img
|
| 1032 |
+
# for i in range(self.hier_num):
|
| 1033 |
+
# out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 1034 |
+
# if self.conditional_input:
|
| 1035 |
+
# tgt = self.block_down_cond[i](tgt)
|
| 1036 |
+
# out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 1037 |
+
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 1038 |
+
# enc_list.append(out)
|
| 1039 |
+
# out = self.down_layers[i](out)
|
| 1040 |
+
# if self.conditional_input:
|
| 1041 |
+
# tgt = self.down_layers[i](tgt)
|
| 1042 |
+
|
| 1043 |
+
# out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 1044 |
+
# if self.conditional_input:
|
| 1045 |
+
# out = out + tgt
|
| 1046 |
+
|
| 1047 |
+
# for i in range(self.hier_num):
|
| 1048 |
+
# out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 1049 |
+
# out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 1050 |
+
|
| 1051 |
+
# out = self.conv_out(out)/128
|
| 1052 |
+
|
| 1053 |
+
# ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 1054 |
+
# if rec_id == 0:
|
| 1055 |
+
# ddf = ddf_one
|
| 1056 |
+
# else:
|
| 1057 |
+
# ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 1058 |
+
# img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 1059 |
+
|
| 1060 |
+
# return ddf
|
| 1061 |
+
|
| 1062 |
+
# def _make_te(self, dim_in, dim_out):
|
| 1063 |
+
# return nn.Sequential(
|
| 1064 |
+
# nn.Linear(dim_in, dim_out),
|
| 1065 |
+
# nn.ReLU(),
|
| 1066 |
+
# nn.Linear(dim_out, dim_out)
|
| 1067 |
+
# )
|
| 1068 |
+
# ==============================================
|
| 1069 |
+
# Layers
|
| 1070 |
+
# ==============================================
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
def ddf_multiplier(dvf,mul_num=10,stn=None):
|
| 1074 |
+
ddf=dvf
|
| 1075 |
+
for i in range(mul_num):
|
| 1076 |
+
ddf = dvf + stn(ddf, dvf)
|
| 1077 |
+
return ddf
|
| 1078 |
+
|
| 1079 |
+
|
| 1080 |
+
def composite(ddfs,stn=None):
|
| 1081 |
+
if stn is None:
|
| 1082 |
+
stn = STN(device=ddfs[0].device,padding_mode="border")
|
| 1083 |
+
comp_ddf=ddfs[0]
|
| 1084 |
+
for i in range(1,len(ddfs)):
|
| 1085 |
+
comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
|
| 1086 |
+
return comp_ddf
|
| 1087 |
+
|
| 1088 |
+
class STN(nn.Module):
|
| 1089 |
+
def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
|
| 1090 |
+
super(STN, self).__init__()
|
| 1091 |
+
self.ndims=ndims
|
| 1092 |
+
self.img_sz=[img_sz]*ndims
|
| 1093 |
+
# self.img_sz=img_sz
|
| 1094 |
+
self.device = device
|
| 1095 |
+
self.padding_mode = padding_mode
|
| 1096 |
+
# max_sz=[128]*self.ndims
|
| 1097 |
+
max_sz=[img_sz]*self.ndims
|
| 1098 |
+
# max_sz=img_sz
|
| 1099 |
+
# max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz)
|
| 1100 |
+
# self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device)
|
| 1101 |
+
self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
|
| 1102 |
+
self.resample_mode=resample_mode
|
| 1103 |
+
if self.img_sz is not None:
|
| 1104 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
|
| 1105 |
+
[1, self.ndims] + self.img_sz).to(self.device)
|
| 1106 |
+
return
|
| 1107 |
+
def max_limit(self, sample_coords0, plus=0., minus=1.):
|
| 1108 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1109 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 1110 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 1111 |
+
zip(sample_coords, self.max_sz)], 1)
|
| 1112 |
+
|
| 1113 |
+
def boundary_limit(self, sample_coords0, plus=0., minus=1.):
|
| 1114 |
+
|
| 1115 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1116 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 1117 |
+
return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
|
| 1118 |
+
zip(sample_coords, self.max_sz, self.ref_grid)], 1)
|
| 1119 |
+
|
| 1120 |
+
def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
|
| 1121 |
+
# print(vol.device, ddf.device)
|
| 1122 |
+
# print(self.device)
|
| 1123 |
+
# print('===================')
|
| 1124 |
+
device = ddf.device
|
| 1125 |
+
|
| 1126 |
+
ref = self.ref_grid if ref is None else ref
|
| 1127 |
+
if img_sz is None:
|
| 1128 |
+
img_sz = self.max_sz
|
| 1129 |
+
else:
|
| 1130 |
+
img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
|
| 1131 |
+
# resample_mode = 'bicubic'
|
| 1132 |
+
if self.resample_mode is None:
|
| 1133 |
+
resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear'
|
| 1134 |
+
else:
|
| 1135 |
+
resample_mode=self.resample_mode
|
| 1136 |
+
# padding_mode = "border"
|
| 1137 |
+
# print(ddf.shape, ref.shape)
|
| 1138 |
+
return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
|
| 1139 |
+
[0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
|
| 1140 |
+
padding_mode=padding_mode,
|
| 1141 |
+
align_corners=True)
|
| 1142 |
+
|
| 1143 |
+
def forward(self,x,ddf):
|
| 1144 |
+
self.device = x.device if self.device is None else self.device
|
| 1145 |
+
if self.img_sz is None:
|
| 1146 |
+
self.img_sz = list(x.size()[2:]).to(self.device)
|
| 1147 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
|
| 1148 |
+
resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
|
| 1149 |
+
return resampled_x
|
| 1150 |
+
|
| 1151 |
+
if __name__ == '__main__':
|
| 1152 |
+
ndims = 3
|
| 1153 |
+
res = 128
|
| 1154 |
+
x = torch.rand([1, 1] + [res]*ndims)
|
| 1155 |
+
t = torch.randint(0, 1000, (1,))
|
| 1156 |
+
text = torch.rand([1, 1024] + [1]*ndims)
|
| 1157 |
+
model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
|
| 1158 |
+
y = model(x, x, t, text=text)
|
| 1159 |
+
print("Ouput shape", y.shape)
|
| 1160 |
+
|
| 1161 |
+
# Total parameters
|
| 1162 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1163 |
+
# Trainable parameters only
|
| 1164 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1165 |
+
|
| 1166 |
+
print(f"Total parameters: {total_params}")
|
| 1167 |
+
print(f"Trainable parameters: {trainable_params}")
|
Diffusion/utils_diff.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch import nn, optim
|
| 5 |
+
from torch.autograd.variable import Variable
|
| 6 |
+
from torchvision import transforms, datasets
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import scipy.ndimage as spimg
|
| 10 |
+
import pyquaternion as quater
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
from typing import Optional, Tuple, List
|
| 15 |
+
# from data_loader.acdc_dataloader import acdc_gan
|
| 16 |
+
|
| 17 |
+
# from Adaptive_Motion_Generator.Dataloader.Archive.acdc_dataloader import *
|
| 18 |
+
|
| 19 |
+
def get_barcode(index=[],header=['Patient','Slice','AugImg','NoiseStep'],digit=[4,6,4,4],split='_'):
|
| 20 |
+
# Patient0001_Slice0001_NosieImg0001_NoiseStep0070
|
| 21 |
+
barcode_str=''
|
| 22 |
+
header=header.copy()
|
| 23 |
+
digit=digit.copy()
|
| 24 |
+
if len(index)<3:
|
| 25 |
+
header[2] = 'ORG'
|
| 26 |
+
header[3] = 'NA'
|
| 27 |
+
digit[2] = 0
|
| 28 |
+
digit[3] = 0
|
| 29 |
+
index +=['','']
|
| 30 |
+
|
| 31 |
+
for id, h in enumerate(header):
|
| 32 |
+
barcode_str+=h+str(index[id]).zfill(digit[id])+split
|
| 33 |
+
return barcode_str[:-1]
|
| 34 |
+
|
| 35 |
+
class RandomResizedCrop3D(nn.Module):
|
| 36 |
+
"""Crop a random portion of a 3D volume and resize it to a given size.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
size (tuple of int): Expected output size of the crop, for each dimension (D, H, W).
|
| 40 |
+
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
|
| 41 |
+
before resizing. The scale is defined with respect to the volume of the original image.
|
| 42 |
+
ratio (tuple of float): Lower and upper bounds for the random aspect ratio of the crop, before resizing.
|
| 43 |
+
interpolation (str): Desired interpolation mode ('trilinear' or 'nearest').
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
size: Tuple[int, int, int],
|
| 49 |
+
scale=(0.6, 1.0),
|
| 50 |
+
ratio=(0.5, 1.5),
|
| 51 |
+
interpolation='trilinear'
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.size = size
|
| 55 |
+
self.scale = scale
|
| 56 |
+
self.ratio = ratio
|
| 57 |
+
self.interpolation = interpolation
|
| 58 |
+
|
| 59 |
+
@staticmethod
|
| 60 |
+
def get_params(img: torch.Tensor, rand_scale: float, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int, int, int]:
|
| 61 |
+
"""Get parameters for `crop` for a random sized crop.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
img (Tensor): Input image.
|
| 65 |
+
scale (list): Range of scale of the origin size cropped.
|
| 66 |
+
ratio (list): Range of aspect ratio of the origin aspect ratio cropped.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
tuple: params (i, j, k, d, h, w) to be passed to `crop` for a random sized crop.
|
| 70 |
+
"""
|
| 71 |
+
img_sz = np.array(list(img.size())[2:])
|
| 72 |
+
crop_sz = (img_sz * rand_scale).astype(np.int32) #[int(s*rand_scale) for s in img_sz]
|
| 73 |
+
start_id = np.random.randint(0, img_sz - crop_sz + 1, size=(img_sz.size,))
|
| 74 |
+
return start_id.tolist()+crop_sz.tolist()
|
| 75 |
+
|
| 76 |
+
# volume = depth * height * width
|
| 77 |
+
#
|
| 78 |
+
# log_ratio = torch.log(torch.tensor(ratio))
|
| 79 |
+
# for _ in range(10):
|
| 80 |
+
# target_volume = volume * torch.empty(1).uniform_(*scale).item()
|
| 81 |
+
# aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
| 82 |
+
#
|
| 83 |
+
# w = int(round(math.sqrt(target_volume * aspect_ratio)))
|
| 84 |
+
# h = int(round(math.sqrt(target_volume / aspect_ratio)))
|
| 85 |
+
# d = int(round(math.sqrt(target_volume / (w * h))))
|
| 86 |
+
#
|
| 87 |
+
# if 0 < w <= width and 0 < h <= height and 0 < d <= depth:
|
| 88 |
+
# i = torch.randint(0, depth - d + 1, size=(1,)).item()
|
| 89 |
+
# j = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 90 |
+
# k = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 91 |
+
# return i, j, k, d, h, w
|
| 92 |
+
#
|
| 93 |
+
# # Fallback to central crop
|
| 94 |
+
# return (depth - d) // 2, (height - h) // 2, (width - w) // 2, d, h, w
|
| 95 |
+
|
| 96 |
+
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
"""Apply the RandomResizedCrop transformation.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
img (Tensor): Input 3D image.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Tensor: Cropped and resized image.
|
| 104 |
+
"""
|
| 105 |
+
rand_scale = np.random.uniform(self.scale[0], self.scale[1])
|
| 106 |
+
[i, j, k, d, h, w] = self.get_params(img,rand_scale, self.scale, self.ratio)
|
| 107 |
+
# print(i, j, k, d, h, w)
|
| 108 |
+
img_cropped = img[:, :, i:i + d, j:j + h, k:k + w]
|
| 109 |
+
# print(img_cropped.shape)
|
| 110 |
+
img_resized = F.interpolate(img_cropped, size=self.size, mode=self.interpolation,
|
| 111 |
+
align_corners=False if self.interpolation == 'trilinear' else None)
|
| 112 |
+
return img_resized#.squeeze(0)
|
| 113 |
+
|
| 114 |
+
def __repr__(self) -> str:
|
| 115 |
+
return f"{self.__class__.__name__}(size={self.size}, scale={self.scale}, ratio={self.ratio}, interpolation={self.interpolation})"
|
| 116 |
+
|
| 117 |
+
def random_permute(X, select_dims=[-1,-2],include_flip=True):
|
| 118 |
+
axes=list(range(X[0].ndim))
|
| 119 |
+
selected_axes = [axes[i] for i in select_dims]
|
| 120 |
+
random.shuffle(selected_axes)
|
| 121 |
+
for i, dim in enumerate(select_dims):
|
| 122 |
+
axes[dim] = selected_axes[i]
|
| 123 |
+
if include_flip and random.choice([True,False]):
|
| 124 |
+
# X = [np.flip(x, axis=dim) for x in X]
|
| 125 |
+
X = [torch.flip(x, [dim]) for x in X]
|
| 126 |
+
# return [np.transpose(x,axes) for x in X]
|
| 127 |
+
return [x.permute(axes) for x in X]
|
| 128 |
+
|
| 129 |
+
# def thresh_img(img,thresh = None,EPS = 10**-7):
|
| 130 |
+
# threshold0 = np.random.uniform(thresh[0], thresh[1])
|
| 131 |
+
# threshold1 = np.random.uniform(thresh[0], thresh[1])
|
| 132 |
+
# scale =
|
| 133 |
+
# if threshold is not None:
|
| 134 |
+
# # img=img-threshold
|
| 135 |
+
# # img=np.where(img>=0,img,0)
|
| 136 |
+
# # img = np.maximum(img-threshold,0)
|
| 137 |
+
# img = torch.maximum(img - threshold,torch.tensor(0.))
|
| 138 |
+
# # return (img - img.min()) / (img.max() - img.min() + EPS)
|
| 139 |
+
# return img
|
| 140 |
+
|
| 141 |
+
def get_transformer(degrees=180,translate=0.125,ndims=2,prob=0.8,fill=0.,img_sz=None):
|
| 142 |
+
prob_crop=0. if img_sz==None else 0.8
|
| 143 |
+
# prob_crop=0. if len(img_sz)==2 else 0.8
|
| 144 |
+
|
| 145 |
+
if img_sz==None or len(img_sz)==2:
|
| 146 |
+
return torchvision.transforms.Compose([
|
| 147 |
+
torchvision.transforms.RandomApply([
|
| 148 |
+
torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims, fill=fill,
|
| 149 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
|
| 150 |
+
],prob),
|
| 151 |
+
# torchvision.transforms.RandomApply([
|
| 152 |
+
# torchvision.transforms.RandomResizedCrop(size=img_sz),
|
| 153 |
+
# ], prob_crop),
|
| 154 |
+
torchvision.transforms.RandomVerticalFlip(p=0.5),
|
| 155 |
+
torchvision.transforms.RandomAutocontrast(p=0.5),
|
| 156 |
+
])
|
| 157 |
+
else:
|
| 158 |
+
return torchvision.transforms.Compose([
|
| 159 |
+
torchvision.transforms.RandomApply([
|
| 160 |
+
torchvision.transforms.RandomResizedCrop(size=img_sz) if len(img_sz) == 2 else RandomResizedCrop3D(
|
| 161 |
+
size=img_sz),
|
| 162 |
+
], prob_crop),
|
| 163 |
+
])
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def get_random_affine_transformer(degrees=180,translate=0.125,ndims=2):
|
| 167 |
+
return torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
|
| 168 |
+
|
| 169 |
+
def channel_merge_acdc(img):
|
| 170 |
+
# input: a torch tensor (C,H,W)
|
| 171 |
+
ch = img.shape[0]
|
| 172 |
+
output = np.zeros((img.shape[1], img.shape[2]))
|
| 173 |
+
# output[img[2,:,:] == 1] = 1
|
| 174 |
+
for i in range(ch):
|
| 175 |
+
output= output + img[i]
|
| 176 |
+
return output
|
| 177 |
+
|
| 178 |
+
def img_crop(img, crop_rate=2, img_sz=[256,256]):
|
| 179 |
+
ndims=len(img_sz)
|
| 180 |
+
crop = [np.random.randint(0.*imgs, 1. * imgs)//crop_rate for imgs in img_sz]
|
| 181 |
+
crop = [crop, [1 * imgs//crop_rate - c for imgs, c in zip(img_sz, crop)]]
|
| 182 |
+
if ndims==2:
|
| 183 |
+
return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]: img_sz[1] - crop[1][1]]
|
| 184 |
+
else:
|
| 185 |
+
return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]:img_sz[1] - crop[1][1], crop[0][2]: img_sz[2] - crop[1][2]]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def boundary_limit(sample_coords0, max_sz, plus=0., minus=1.):
|
| 189 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 190 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 191 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) for x, sz in
|
| 192 |
+
zip(sample_coords, max_sz)], 1)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def resample(vol, ddf, ref=None, img_sz=None,max_sz=[128,128],ndims=2):
|
| 196 |
+
device = vol.device
|
| 197 |
+
img_sz = vol.size()[2:]
|
| 198 |
+
ndims=len(img_sz)
|
| 199 |
+
if ndims==2:
|
| 200 |
+
[h,w]=img_sz
|
| 201 |
+
img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=device), [1, 1, 1, ndims])
|
| 202 |
+
ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0), [1, ndims,h, w ])
|
| 203 |
+
elif ndims==3:
|
| 204 |
+
[h, w, d] = img_sz
|
| 205 |
+
img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2., (d-1)/2], device=device), [1, 1, 1, 1, ndims])
|
| 206 |
+
ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w), torch.arange(end=d)]), 0), [1, ndims,h, w, d])
|
| 207 |
+
# ref_grid.to(device)
|
| 208 |
+
# img_shape.to(device)
|
| 209 |
+
# ddf.to(device)
|
| 210 |
+
# ref = self.ref_grid if ref is None else ref
|
| 211 |
+
# img_sz = self.img_sz if img_sz is None else img_sz
|
| 212 |
+
resample_mode = 'bilinear'
|
| 213 |
+
# padding_mode = "border"
|
| 214 |
+
padding_mode = "zeros"
|
| 215 |
+
|
| 216 |
+
# img_sz = np.reshape(img_sz, [1] *(ndims+1)+[ndims])
|
| 217 |
+
# if ndims==2:
|
| 218 |
+
if True:
|
| 219 |
+
re=[0]+list(range(2,ndims+2))+[1]
|
| 220 |
+
# re=list(range(ndims+2))
|
| 221 |
+
# print((torch.flip((ddf.to(device) + ref_grid.permute(re))/ img_shape - 1, dims=[-1])).tolist())
|
| 222 |
+
return F.grid_sample(vol, torch.flip((ddf + ref_grid.permute(re).to(device))/ img_shape - 1, dims=[-1]).type(torch.float32).to(device), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
|
| 223 |
+
#
|
| 224 |
+
# return F.grid_sample(vol, torch.flip(
|
| 225 |
+
# torch.permute(ddf * torch.Tensor(np.reshape(np.array(max_sz), [1, 1, 1, ndims])) + ref_grid,
|
| 226 |
+
# [0, 2, 3, 1]) / img_shape - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 227 |
+
# align_corners=True)
|
| 228 |
+
|
| 229 |
+
def random_resample(vol,deform_scale=32.):
|
| 230 |
+
vol_size=vol.size()
|
| 231 |
+
device=vol.device
|
| 232 |
+
ndims = len(vol_size)-2
|
| 233 |
+
img_size=[s for s in vol_size[2:]]
|
| 234 |
+
if ndims==2:
|
| 235 |
+
img_size=img_size+[16]
|
| 236 |
+
# ddf,_,_=random_ddf(vol_size[0],img_size)
|
| 237 |
+
_,_,ddf=random_ddf(vol_size[0],img_size,ndims=ndims,range_gauss=deform_scale)
|
| 238 |
+
ddf=Variable(torch.tensor(ddf,dtype=torch.float32)).to(device)
|
| 239 |
+
if ndims==2:
|
| 240 |
+
return resample(vol,ddf[...,8,:ndims])
|
| 241 |
+
else:
|
| 242 |
+
return resample(vol, ddf[..., :ndims])
|
| 243 |
+
|
| 244 |
+
def get_random_deformed_mask(msk_shape, deform_scale=32.,apply_possibility=0.75):
|
| 245 |
+
msk = torch.ones([1, 1]+list(msk_shape),dtype=torch.float32)
|
| 246 |
+
if random.uniform(0,1) < apply_possibility:
|
| 247 |
+
return random_resample(msk, deform_scale=deform_scale)
|
| 248 |
+
else:
|
| 249 |
+
return msk
|
| 250 |
+
|
| 251 |
+
# grid option
|
| 252 |
+
def get_tranf_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],transl=[[0,0,0]]):
|
| 253 |
+
return np.concatenate([get_rot_mat(grid_size, vec=vec, ang=ang),transl],-1)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def get_rot_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],ndims=3):
|
| 257 |
+
vec = np.array(vec)
|
| 258 |
+
ang = np.array(ang)
|
| 259 |
+
batch_num = ang.shape[0]
|
| 260 |
+
return np.reshape(vecang2rotmats(vec, ang), [batch_num] + [ndims*(ndims)])
|
| 261 |
+
|
| 262 |
+
def random_mat(batch_sz, img_sz, num_class=2,pn_spline=20, pn_gauss=10, range_spline=2., range_gauss=48, spread_range=[5., 24.],
|
| 263 |
+
transl_range=32., rot_range=np.pi / 2):
|
| 264 |
+
scale=4
|
| 265 |
+
ndims=3
|
| 266 |
+
vec=np.reshape(np.random.uniform(-1., 1., [batch_sz,1, ndims])+np.random.uniform(-.1, .1, [batch_sz,num_class, ndims]),[batch_sz*num_class, ndims])
|
| 267 |
+
ang=np.reshape(np.random.uniform(-rot_range, rot_range, [batch_sz,1])+np.random.uniform(-rot_range/scale, rot_range/scale, [batch_sz,num_class]),[batch_sz*num_class])
|
| 268 |
+
transl=np.reshape(np.random.uniform(-transl_range, transl_range, [batch_sz,1,ndims])+np.random.uniform(-transl_range/scale, transl_range/scale, [batch_sz,num_class,ndims]),[batch_sz*num_class,ndims])
|
| 269 |
+
return np.reshape(np.concatenate([get_rot_mat(img_sz, vec=vec, ang=ang),transl],-1),[batch_sz,num_class,4,3])
|
| 270 |
+
|
| 271 |
+
# return np.reshape(get_tranf_mat(img_sz, vec=np.random.uniform(-1., 1., [batch_sz*num_class, 3]), ang=np.random.uniform(-rot_range, rot_range, [batch_sz*num_class]),transl=np.random.uniform(-transl_range, transl_range, [batch_sz*num_class,3])),[batch_sz,num_class,4,3])
|
| 272 |
+
|
| 273 |
+
def random_ddf(batch_sz, img_sz, pn_spline=20, pn_gauss=10, range_spline=1., range_gauss=16., spread_range=[16., 64.],
|
| 274 |
+
transl_range=0., rot_range=np.pi / 1,ndims=3):
|
| 275 |
+
rand_ang=np.random.uniform(-rot_range, rot_range, [batch_sz])
|
| 276 |
+
# rand_ang = np.random.randint(-4, 4, [batch_sz])*rot_range
|
| 277 |
+
|
| 278 |
+
if ndims==3:
|
| 279 |
+
rot_df = get_rot_ddf(img_sz, vec=np.random.uniform(-1., 1., [batch_sz, 3]),
|
| 280 |
+
ang=rand_ang)
|
| 281 |
+
else:
|
| 282 |
+
rot_df = get_rot_ddf(img_sz, vec=np.concatenate([np.zeros([batch_sz, 2]),np.ones([batch_sz, 1])],-1),
|
| 283 |
+
ang=rand_ang)
|
| 284 |
+
ndims = 3
|
| 285 |
+
# rot_df = +np.random.uniform(-1., 1., [batch_sz, ndims,ndims])
|
| 286 |
+
# ddf0=np.stack([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range)\
|
| 287 |
+
# +generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline)\
|
| 288 |
+
# +np.random.uniform(-transl_range,transl_range,[3]) for i in range(batch_sz)],axis=0)\
|
| 289 |
+
# +rot_df
|
| 290 |
+
if range_gauss>0:
|
| 291 |
+
ddf0 = np.tile([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range) \
|
| 292 |
+
# + generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline) \
|
| 293 |
+
+ np.random.uniform(-transl_range, transl_range, [ndims])], [batch_sz, 1, 1, 1, 1]) \
|
| 294 |
+
+ rot_df
|
| 295 |
+
else:
|
| 296 |
+
ddf0 = rot_df
|
| 297 |
+
|
| 298 |
+
def boundary_replicate(sample_coords, input_size, padding=5):
|
| 299 |
+
return np.stack(
|
| 300 |
+
[np.maximum(np.minimum(sample_coords[..., i], input_size[i] - 1 + padding), 0 - padding) for i in
|
| 301 |
+
range(len(input_size))], axis=-1), \
|
| 302 |
+
np.prod([((sample_coords[..., i] < input_size[i]) * (sample_coords[..., i] >= 0)) for i in
|
| 303 |
+
range(len(input_size))], axis=0)
|
| 304 |
+
|
| 305 |
+
ref = get_reference_grid(img_sz)
|
| 306 |
+
cf1, ind = boundary_replicate(ddf0 + ref, img_sz)
|
| 307 |
+
return cf1 - ref, np.expand_dims(ind, -1), rot_df
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def generate_random_gaussian_ddf(img_sz, pn=30, range_sz=5, spread_std=[0.1, 1.]):
|
| 311 |
+
x = np.floor(np.random.uniform(range_sz / 2., img_sz[0] - range_sz / 2., [1, pn])).astype('int')
|
| 312 |
+
y = np.floor(np.random.uniform(range_sz / 2., img_sz[1] - range_sz / 2., [1, pn])).astype('int')
|
| 313 |
+
z = np.floor(np.random.uniform(range_sz / 2., img_sz[2] - range_sz / 2., [1, pn])).astype('int')
|
| 314 |
+
|
| 315 |
+
odf = np.random.uniform(-range_sz, range_sz, [pn, 3])
|
| 316 |
+
vol = np.zeros([img_sz[0], img_sz[1], img_sz[2], 3])
|
| 317 |
+
vol[x, y, z] = odf
|
| 318 |
+
|
| 319 |
+
return spimg.gaussian_filter(vol, np.random.uniform(spread_std[0], spread_std[1]))
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def get_rot_ddf(grid_size, vec=[[0., 0., 1.]], ang=[[0.]]):
|
| 323 |
+
vec = np.array(vec)
|
| 324 |
+
ang = np.array(ang)
|
| 325 |
+
batch_num = ang.shape[0]
|
| 326 |
+
ref_grids = get_reference_grid(grid_size,
|
| 327 |
+
bias_scale=1.)
|
| 328 |
+
# a=vecang2rotmats(vec, ang)
|
| 329 |
+
return np.reshape(np.matmul(np.reshape(np.tile(ref_grids, [batch_num, 1, 1, 1, 1]), [batch_num, -1, 3]),
|
| 330 |
+
vecang2rotmats(vec, ang)), [batch_num] + grid_size + [3]) - ref_grids
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def get_reference_grid(grid_size, bias_scale=0.):
|
| 334 |
+
return np.stack(np.meshgrid(
|
| 335 |
+
[i for i in range(grid_size[0])],
|
| 336 |
+
[j for j in range(grid_size[1])],
|
| 337 |
+
[k for k in range(grid_size[2])],
|
| 338 |
+
indexing='ij'), axis=-1).astype('float') - bias_scale * (np.array(grid_size) - 1) / 2.
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def resample_linear(inputs, ddf=None, sample_coords=None,random_boundary=True):
|
| 342 |
+
if random_boundary:
|
| 343 |
+
random_factor = np.random.uniform(0., 1.)
|
| 344 |
+
min_val = np.min(inputs)
|
| 345 |
+
inputs[:, 0, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, 0, :, :]
|
| 346 |
+
inputs[:, -1, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, -1, :, :]
|
| 347 |
+
inputs[:, :, 0, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, 0, :]
|
| 348 |
+
inputs[:, :, -1, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, -1, :]
|
| 349 |
+
inputs[:, :, :, 0] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, 0]
|
| 350 |
+
inputs[:, :, :, -1] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, -1]
|
| 351 |
+
|
| 352 |
+
input_size = inputs.shape[1:4]
|
| 353 |
+
sample_coords = get_reference_grid(input_size) + ddf if sample_coords is None else sample_coords
|
| 354 |
+
spatial_rank = 3 # inputs.ndim - 2
|
| 355 |
+
xy = [sample_coords[..., i] for i in
|
| 356 |
+
range(sample_coords.shape[-1])] # tf.unstack(sample_coords, axis=len(sample_coords.shape)-1)
|
| 357 |
+
index_voxel_coords = [np.floor(x) for x in xy]
|
| 358 |
+
|
| 359 |
+
def boundary_replicate(sample_coords0, input_size0, plus=0):
|
| 360 |
+
return np.maximum(np.minimum(sample_coords0, input_size0 - 2 + plus), 0 + plus)
|
| 361 |
+
|
| 362 |
+
def boundary_replicate_float(sample_coords0, input_size0, plus=0.):
|
| 363 |
+
return np.maximum(np.minimum(sample_coords0, input_size0 - 1 + plus), 0 + plus)
|
| 364 |
+
|
| 365 |
+
xy = [boundary_replicate_float(x.astype('float32'), input_size[idx]) for idx, x in enumerate(xy)]
|
| 366 |
+
spatial_coords = [boundary_replicate(x.astype('int32'), input_size[idx])
|
| 367 |
+
for idx, x in enumerate(index_voxel_coords)]
|
| 368 |
+
spatial_coords_plus1 = [boundary_replicate((x + 1).astype('int32'), input_size[idx], 1)
|
| 369 |
+
for idx, x in enumerate(index_voxel_coords)]
|
| 370 |
+
|
| 371 |
+
weight = [np.expand_dims(x - i.astype('float32'), -1) for x, i in zip(xy, spatial_coords)]
|
| 372 |
+
weight_c = [np.expand_dims(i.astype('float32') - x, -1) for x, i in zip(xy, spatial_coords_plus1)]
|
| 373 |
+
|
| 374 |
+
sz = list(spatial_coords[0].shape)
|
| 375 |
+
batch_coords = np.tile(np.reshape(range(sz[0]), [sz[0]] + [1] * (len(sz) - 1)), [1] + sz[1:])
|
| 376 |
+
sc = (spatial_coords, spatial_coords_plus1)
|
| 377 |
+
binary_codes = [[int(c) for c in format(i, '0%ib' % spatial_rank)] for i in range(2 ** spatial_rank)]
|
| 378 |
+
|
| 379 |
+
make_sample = lambda bc: inputs[batch_coords, sc[bc[0]][0], sc[bc[1]][1], sc[bc[2]][
|
| 380 |
+
2], ...] # tf.gather_nd(inputs, np.stack([batch_coords] + [sc[c][i] for i, c in enumerate(bc)], -1))
|
| 381 |
+
samples = [make_sample(bc) for bc in binary_codes]
|
| 382 |
+
|
| 383 |
+
def pyramid_combination(samples0, weight0, weight_c0):
|
| 384 |
+
if len(weight0) == 1:
|
| 385 |
+
return samples0[0] * weight_c0[0] + samples0[1] * weight0[0]
|
| 386 |
+
else:
|
| 387 |
+
return pyramid_combination(samples0[::2], weight0[:-1], weight_c0[:-1]) * weight_c0[-1] + \
|
| 388 |
+
pyramid_combination(samples0[1::2], weight0[:-1], weight_c0[:-1]) * weight0[-1]
|
| 389 |
+
|
| 390 |
+
return pyramid_combination(samples, weight, weight_c)
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def vecang2rotmats(vec, ang):
|
| 394 |
+
return np.stack([np.reshape(vecang2rotmat(vec[i, ...], ang[i, ...]), [3, 3]) for i in range(len(vec))], 0)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def vecang2rotmat(vec, ang):
|
| 398 |
+
q = quater.Quaternion(axis=vec, angle=ang)
|
| 399 |
+
return q.rotation_matrix
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def images_to_vectors(images):
|
| 403 |
+
return images.view(images.size(0), 16384).to(device)
|
| 404 |
+
|
| 405 |
+
def vectors_to_images(vectors):
|
| 406 |
+
return vectors.view(vectors.size(0), 1, 128, 128).to(device)
|
| 407 |
+
|
| 408 |
+
def noise(size):
|
| 409 |
+
n = Variable(torch.randn(size, 100)).to(device)
|
| 410 |
+
return n
|
| 411 |
+
|
| 412 |
+
def ones_target(size):
|
| 413 |
+
data = Variable(torch.ones(size, 1)).to(device)
|
| 414 |
+
return data
|
| 415 |
+
|
| 416 |
+
def zeros_target(size):
|
| 417 |
+
data = Variable(torch.zeros(size, 1)).to(device)
|
| 418 |
+
return data
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def eval_detJ_lab(disp=None,vol1=None,vol2=None,thresh=0.5):
|
| 422 |
+
ndims=disp.ndim-2
|
| 423 |
+
if vol1 ==None or thresh==None:
|
| 424 |
+
label=1
|
| 425 |
+
else:
|
| 426 |
+
label=vol1>thresh
|
| 427 |
+
label=label*(spimg.laplace(label) < 0.1)
|
| 428 |
+
rescale_factor=2
|
| 429 |
+
label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
|
| 430 |
+
|
| 431 |
+
# disp = disp.permute([0, *range(2,ndims+2), 1])
|
| 432 |
+
# print(disp.shape)
|
| 433 |
+
disp = np.transpose(disp, [0, *range(2,ndims+2), 1])
|
| 434 |
+
# Jacob=np.stack(np.gradient(disp,axis=[-4,-3,-2]),-1)
|
| 435 |
+
Jacob=np.stack(np.gradient(disp,axis=[*range(1,ndims+1)]),-1)
|
| 436 |
+
for ii in range(ndims):
|
| 437 |
+
Jacob[..., ii, ii] = Jacob[..., ii, ii] + 1
|
| 438 |
+
# Jacob[..., 0, 0] = Jacob[..., 0, 0] + 1
|
| 439 |
+
# Jacob[..., 1, 1] = Jacob[..., 1, 1] + 1
|
| 440 |
+
# Jacob[..., 2, 2] = Jacob[..., 2, 2] + 1
|
| 441 |
+
return np.sum((np.linalg.det(Jacob)<0)*label)
|
| 442 |
+
|
| 443 |
+
def eval_def_mag(disp=None,vol1=None,vol2=None,thresh=0.5):
|
| 444 |
+
ndims=3
|
| 445 |
+
# if vol1 ==None or thresh==None:
|
| 446 |
+
# label=1
|
| 447 |
+
# else:
|
| 448 |
+
# label=vol1>thresh
|
| 449 |
+
# label=label*(spimg.laplace(label) < 0.1)
|
| 450 |
+
# rescale_factor=2
|
| 451 |
+
# label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
|
| 452 |
+
mag=np.sqrt(np.sum(np.square(disp),axis=1))
|
| 453 |
+
sz=mag.shape
|
| 454 |
+
max_mag=np.mean(np.max(np.reshape(mag,[sz[0],-1]),axis=-1))
|
| 455 |
+
avg_mag=np.mean(mag)
|
| 456 |
+
return [avg_mag,max_mag]
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def print_memory_usage(tag=""):
|
| 461 |
+
print(f"[{tag}] Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB | Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
vol_shape=[4,1,64,64]
|
| 466 |
+
|
| 467 |
+
vol=np.random.uniform(-1,1,vol_shape)
|
| 468 |
+
vol=Variable(torch.tensor(vol,dtype=torch.float32))
|
| 469 |
+
vol_res=random_resample(vol)
|
| 470 |
+
vol_crop=img_crop(vol_res)
|
| 471 |
+
|
| 472 |
+
mask = get_random_deformed_mask(vol.shape[2:])
|
| 473 |
+
|
| 474 |
+
print(mask)
|
| 475 |
+
|
| 476 |
+
# print(vol.tolist())
|
| 477 |
+
# print(vol_res.tolist())
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
OM_aug.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.utils import save_image
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim import Adam
|
| 7 |
+
from torchvision.utils import make_grid
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import random
|
| 12 |
+
import os
|
| 13 |
+
import utils
|
| 14 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 15 |
+
from Dataloader.dataLoader import *
|
| 16 |
+
|
| 17 |
+
from torchvision.utils import save_image
|
| 18 |
+
from einops import rearrange, reduce, repeat
|
| 19 |
+
# import matplotlib.image
|
| 20 |
+
import numpy as np
|
| 21 |
+
import nibabel as nib
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
import yaml
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
EPS = 10e-8
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--config",
|
| 32 |
+
"-C",
|
| 33 |
+
help="Path for the config file",
|
| 34 |
+
type=str,
|
| 35 |
+
default="Config/config_cmr.yaml",
|
| 36 |
+
# default="Config/config_lct.yaml",
|
| 37 |
+
required=False,
|
| 38 |
+
)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
#=======================================================================================================================
|
| 41 |
+
|
| 42 |
+
# config_path = 'Config/config_cmr.yaml'
|
| 43 |
+
# config_path = 'Config/config_lct.yaml'
|
| 44 |
+
|
| 45 |
+
# Load the YAML file into a dictionary
|
| 46 |
+
with open(args.config, 'r') as file:
|
| 47 |
+
hyp_parameters = yaml.safe_load(file)
|
| 48 |
+
print(hyp_parameters)
|
| 49 |
+
# hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
|
| 50 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 51 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 52 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 53 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 54 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 55 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 56 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 57 |
+
|
| 58 |
+
hyp_parameters['batchsize'] = 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# =======================================================================================================================
|
| 62 |
+
select_channels_dict={}
|
| 63 |
+
# min_crop_ratio = 0.5
|
| 64 |
+
min_crop_ratio = 0.9
|
| 65 |
+
|
| 66 |
+
# label_keys = ['heart']
|
| 67 |
+
# label_keys = ['brain']
|
| 68 |
+
# label_keys = ['pancreas']
|
| 69 |
+
# label_keys = ['spleen']
|
| 70 |
+
# label_keys = ['liver']
|
| 71 |
+
# database = ['MSD']
|
| 72 |
+
label_keys = ['heart']
|
| 73 |
+
database = ['MnMs']
|
| 74 |
+
# subtype = "ed" # 'ed' or 'es' for MnMs
|
| 75 |
+
subtype = "es" # 'ed' or 'es' for MnMs
|
| 76 |
+
hyp_parameters["aug_img_savepath"]=f"Data/Aug_data/mnms_{subtype}/img/"
|
| 77 |
+
hyp_parameters["aug_msk_savepath"]=f"Data/Aug_data/mnms_{subtype}/msk/"
|
| 78 |
+
hyp_parameters["aug_ddf_savepath"]=f"Data/Aug_data/mnms_{subtype}/ddf/"
|
| 79 |
+
select_channels_dict={
|
| 80 |
+
"ImgDict":[subtype]
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
|
| 84 |
+
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database, select_channels_dict=select_channels_dict)
|
| 85 |
+
Infer_Loader = DataLoader(
|
| 86 |
+
dataset,
|
| 87 |
+
batch_size=hyp_parameters['batchsize'],
|
| 88 |
+
shuffle=False
|
| 89 |
+
)
|
| 90 |
+
# =======================================================================================================================
|
| 91 |
+
|
| 92 |
+
# Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
|
| 93 |
+
# transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 94 |
+
# dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
|
| 95 |
+
# train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 100 |
+
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| 101 |
+
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
Net = get_net(hyp_parameters["net_name"])
|
| 106 |
+
|
| 107 |
+
Deformddpm = DeformDDPM(
|
| 108 |
+
network=Net(n_steps = hyp_parameters["timesteps"],
|
| 109 |
+
ndims = hyp_parameters["ndims"],
|
| 110 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 111 |
+
res = hyp_parameters['img_size']
|
| 112 |
+
),
|
| 113 |
+
n_steps = hyp_parameters["timesteps"],
|
| 114 |
+
image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 115 |
+
device = hyp_parameters["device"],
|
| 116 |
+
batch_size = hyp_parameters["batchsize"],
|
| 117 |
+
img_pad_mode = hyp_parameters["img_pad_mode"],
|
| 118 |
+
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| 119 |
+
padding_mode = hyp_parameters["padding_mode"],
|
| 120 |
+
v_scale = hyp_parameters["v_scale"],
|
| 121 |
+
resample_mode = hyp_parameters["resample_mode"],
|
| 122 |
+
)
|
| 123 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 124 |
+
|
| 125 |
+
ddf_stn = STN(
|
| 126 |
+
img_sz = hyp_parameters["img_size"],
|
| 127 |
+
ndims = hyp_parameters["ndims"],
|
| 128 |
+
padding_mode = hyp_parameters['padding_mode'],
|
| 129 |
+
device = hyp_parameters["device"],
|
| 130 |
+
)
|
| 131 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 132 |
+
|
| 133 |
+
print("Loading model from:", model_save_path)
|
| 134 |
+
# Deformddpm.load_state_dict(torch.load(model_save_path))
|
| 135 |
+
checkpoint = torch.load(model_save_path)
|
| 136 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 137 |
+
Deformddpm.eval()
|
| 138 |
+
|
| 139 |
+
os.makedirs(hyp_parameters['aug_img_savepath'], exist_ok=True)
|
| 140 |
+
os.makedirs(hyp_parameters['aug_msk_savepath'], exist_ok=True)
|
| 141 |
+
os.makedirs(hyp_parameters['aug_ddf_savepath'], exist_ok=True)
|
| 142 |
+
|
| 143 |
+
print("total num of image:", len(Infer_Loader))
|
| 144 |
+
for e, d in tqdm(enumerate(Infer_Loader)):
|
| 145 |
+
# if e<1:
|
| 146 |
+
# continue
|
| 147 |
+
# img, mask, pid = d
|
| 148 |
+
# img = d
|
| 149 |
+
# mask = d
|
| 150 |
+
img = d['img']
|
| 151 |
+
mask = d['labels']
|
| 152 |
+
label_str = str(d['label_channels'])
|
| 153 |
+
# mask = np.concatenate([v for v in d['labels'].values()], axis=1)
|
| 154 |
+
# print('img shape:', img.shape, 'mask shape:', mask.shape)
|
| 155 |
+
|
| 156 |
+
# pid = pid.cpu().detach().numpy()
|
| 157 |
+
# pid = pid[0]
|
| 158 |
+
pid = e
|
| 159 |
+
|
| 160 |
+
print('Processing to patient:', pid, ' image:',e)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
img = img.type(torch.float32)
|
| 164 |
+
img = img.to(hyp_parameters["device"])
|
| 165 |
+
image_original = img.cpu().detach().numpy()
|
| 166 |
+
|
| 167 |
+
mask = mask.type(torch.float32)
|
| 168 |
+
mask = mask.to(hyp_parameters["device"])
|
| 169 |
+
mask_original = mask.cpu().detach().numpy()
|
| 170 |
+
# print(pid, image_original.shape, mask_original.max())
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# if hyp_parameters["ndims"] == 2:
|
| 174 |
+
# nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
|
| 175 |
+
# nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
|
| 176 |
+
# elif hyp_parameters["ndims"] == 3:
|
| 177 |
+
# nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
|
| 178 |
+
# nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
|
| 179 |
+
nifti_img = utils.converet_to_nibabel(image_original,ndims=hyp_parameters["ndims"])
|
| 180 |
+
nifti_mask = utils.converet_to_nibabel(mask_original,ndims=hyp_parameters["ndims"])
|
| 181 |
+
|
| 182 |
+
# Saving original (undeformed image)
|
| 183 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 184 |
+
# Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 185 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
|
| 186 |
+
|
| 187 |
+
# Saving original (undeformed image)
|
| 188 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
|
| 189 |
+
# Lung CT: ...
|
| 190 |
+
nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
for im in range(hyp_parameters["aug_coe"]):
|
| 196 |
+
# # Permute
|
| 197 |
+
# if hyp_parameters["ndims"] == 2:
|
| 198 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
|
| 199 |
+
# elif hyp_parameters["ndims"] == 3:
|
| 200 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
|
| 201 |
+
|
| 202 |
+
print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
|
| 203 |
+
|
| 204 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,msk_org=mask,T=[noise_step,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
|
| 205 |
+
|
| 206 |
+
denoise_imgs = img_rec.cpu().detach().numpy()
|
| 207 |
+
denoise_msks = msk_rec.cpu().detach().numpy()
|
| 208 |
+
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 209 |
+
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 210 |
+
|
| 211 |
+
# if hyp_parameters["ndims"] == 2:
|
| 212 |
+
# nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
|
| 213 |
+
# nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
|
| 214 |
+
# nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
|
| 215 |
+
# nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
|
| 216 |
+
# elif hyp_parameters["ndims"] == 3:
|
| 217 |
+
# nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
|
| 218 |
+
# nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
|
| 219 |
+
# nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
|
| 220 |
+
# nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4)) ###
|
| 221 |
+
nifti_img_aug = utils.converet_to_nibabel(denoise_imgs,ndims=hyp_parameters["ndims"])
|
| 222 |
+
nifti_mask_aug = utils.converet_to_nibabel(denoise_msks,ndims=hyp_parameters["ndims"])
|
| 223 |
+
nifti_img = utils.converet_to_nibabel(noisy_imgs_np,ndims=hyp_parameters["ndims"])
|
| 224 |
+
nifti_mask = utils.converet_to_nibabel(noisy_msks_np,ndims=hyp_parameters["ndims"])
|
| 225 |
+
|
| 226 |
+
nib.save(nifti_img_aug, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| 227 |
+
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
| 228 |
+
|
| 229 |
+
# Saving noisy image to nifti
|
| 230 |
+
# CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
|
| 231 |
+
# Lung CT: ...
|
| 232 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
|
| 233 |
+
nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if (im - hyp_parameters["start_noise_step"])%2 == 0:
|
| 237 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 238 |
+
# break # for testing
|
| 239 |
+
if e >= 0:
|
| 240 |
+
exit()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
OM_aug_highres.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.utils import save_image
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim import Adam
|
| 7 |
+
from torchvision.utils import make_grid
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import random
|
| 12 |
+
import os
|
| 13 |
+
import utils
|
| 14 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 15 |
+
from Dataloader.dataLoader import *
|
| 16 |
+
|
| 17 |
+
from torchvision.utils import save_image
|
| 18 |
+
from einops import rearrange, reduce, repeat
|
| 19 |
+
# import matplotlib.image
|
| 20 |
+
import numpy as np
|
| 21 |
+
import nibabel as nib
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
import yaml
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
EPS = 10e-8
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--config",
|
| 32 |
+
"-C",
|
| 33 |
+
help="Path for the config file",
|
| 34 |
+
type=str,
|
| 35 |
+
default="Config/config_cmr.yaml",
|
| 36 |
+
# default="Config/config_lct.yaml",
|
| 37 |
+
required=False,
|
| 38 |
+
)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
#=======================================================================================================================
|
| 41 |
+
|
| 42 |
+
# config_path = 'Config/config_cmr.yaml'
|
| 43 |
+
# config_path = 'Config/config_lct.yaml'
|
| 44 |
+
|
| 45 |
+
# Load the YAML file into a dictionary
|
| 46 |
+
with open(args.config, 'r') as file:
|
| 47 |
+
hyp_parameters = yaml.safe_load(file)
|
| 48 |
+
print(hyp_parameters)
|
| 49 |
+
# hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
|
| 50 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 51 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 52 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 53 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 54 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 55 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 56 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 57 |
+
|
| 58 |
+
hyp_parameters['batchsize'] = 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# =======================================================================================================================
|
| 62 |
+
# min_crop_ratio = 0.5
|
| 63 |
+
min_crop_ratio = 0.9
|
| 64 |
+
|
| 65 |
+
# label_keys = ['heart']
|
| 66 |
+
# label_keys = ['brain']
|
| 67 |
+
label_keys = ['pancreas']
|
| 68 |
+
database = ['MSD']
|
| 69 |
+
|
| 70 |
+
# dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
|
| 71 |
+
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
|
| 72 |
+
Infer_Loader = DataLoader(
|
| 73 |
+
dataset,
|
| 74 |
+
batch_size=hyp_parameters['batchsize'],
|
| 75 |
+
shuffle=False
|
| 76 |
+
)
|
| 77 |
+
# =======================================================================================================================
|
| 78 |
+
|
| 79 |
+
# Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
|
| 80 |
+
# transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 81 |
+
# dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
|
| 82 |
+
# train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 87 |
+
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| 88 |
+
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
Net = get_net(hyp_parameters["net_name"])
|
| 93 |
+
|
| 94 |
+
Deformddpm = DeformDDPM(
|
| 95 |
+
network=Net(n_steps = hyp_parameters["timesteps"],
|
| 96 |
+
ndims = hyp_parameters["ndims"],
|
| 97 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 98 |
+
res = hyp_parameters['img_size']
|
| 99 |
+
),
|
| 100 |
+
n_steps = hyp_parameters["timesteps"],
|
| 101 |
+
image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 102 |
+
device = hyp_parameters["device"],
|
| 103 |
+
batch_size = hyp_parameters["batchsize"],
|
| 104 |
+
img_pad_mode = hyp_parameters["img_pad_mode"],
|
| 105 |
+
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| 106 |
+
padding_mode = hyp_parameters["padding_mode"],
|
| 107 |
+
v_scale = hyp_parameters["v_scale"],
|
| 108 |
+
resample_mode = hyp_parameters["resample_mode"],
|
| 109 |
+
)
|
| 110 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 111 |
+
|
| 112 |
+
ddf_stn = STN(
|
| 113 |
+
img_sz = hyp_parameters["img_size"],
|
| 114 |
+
ndims = hyp_parameters["ndims"],
|
| 115 |
+
padding_mode = hyp_parameters['padding_mode'],
|
| 116 |
+
device = hyp_parameters["device"],
|
| 117 |
+
)
|
| 118 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 119 |
+
|
| 120 |
+
print("Loading model from:", model_save_path)
|
| 121 |
+
# Deformddpm.load_state_dict(torch.load(model_save_path))
|
| 122 |
+
checkpoint = torch.load(model_save_path)
|
| 123 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 124 |
+
Deformddpm.eval()
|
| 125 |
+
|
| 126 |
+
os.makedirs(hyp_parameters['aug_img_savepath'], exist_ok=True)
|
| 127 |
+
os.makedirs(hyp_parameters['aug_msk_savepath'], exist_ok=True)
|
| 128 |
+
os.makedirs(hyp_parameters['aug_ddf_savepath'], exist_ok=True)
|
| 129 |
+
|
| 130 |
+
print("total num of image:", len(Infer_Loader))
|
| 131 |
+
for e, d in tqdm(enumerate(Infer_Loader)):
|
| 132 |
+
|
| 133 |
+
# img, mask, pid = d
|
| 134 |
+
# img = d
|
| 135 |
+
# mask = d
|
| 136 |
+
img = d['img']
|
| 137 |
+
mask = d['labels']
|
| 138 |
+
# mask = np.concatenate([v for v in d['labels'].values()], axis=1)
|
| 139 |
+
# print('img shape:', img.shape, 'mask shape:', mask.shape)
|
| 140 |
+
|
| 141 |
+
# pid = pid.cpu().detach().numpy()
|
| 142 |
+
# pid = pid[0]
|
| 143 |
+
pid = e
|
| 144 |
+
|
| 145 |
+
print('Processing to patient:', pid, ' image:',e)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
img = img.type(torch.float32)
|
| 149 |
+
img = img.to(hyp_parameters["device"])
|
| 150 |
+
image_original = img.cpu().detach().numpy()
|
| 151 |
+
|
| 152 |
+
mask = mask.type(torch.float32)
|
| 153 |
+
mask = mask.to(hyp_parameters["device"])
|
| 154 |
+
mask_original = mask.cpu().detach().numpy()
|
| 155 |
+
# print(pid, image_original.shape, mask_original.max())
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if hyp_parameters["ndims"] == 2:
|
| 159 |
+
nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
|
| 160 |
+
nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
|
| 161 |
+
elif hyp_parameters["ndims"] == 3:
|
| 162 |
+
nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
|
| 163 |
+
nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
|
| 164 |
+
|
| 165 |
+
# Saving original (undeformed image)
|
| 166 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 167 |
+
# Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 168 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
|
| 169 |
+
|
| 170 |
+
# Saving original (undeformed image)
|
| 171 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
|
| 172 |
+
# Lung CT: ...
|
| 173 |
+
nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 177 |
+
with torch.no_grad():
|
| 178 |
+
for im in range(hyp_parameters["aug_coe"]):
|
| 179 |
+
# # Permute
|
| 180 |
+
# if hyp_parameters["ndims"] == 2:
|
| 181 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
|
| 182 |
+
# elif hyp_parameters["ndims"] == 3:
|
| 183 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
|
| 184 |
+
|
| 185 |
+
print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
|
| 186 |
+
|
| 187 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,msk_org=mask,T=[noise_step,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
|
| 188 |
+
|
| 189 |
+
denoise_imgs = img_rec.cpu().detach().numpy()
|
| 190 |
+
denoise_msks = msk_rec.cpu().detach().numpy()
|
| 191 |
+
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 192 |
+
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 193 |
+
|
| 194 |
+
if hyp_parameters["ndims"] == 2:
|
| 195 |
+
nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
|
| 196 |
+
nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
|
| 197 |
+
nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
|
| 198 |
+
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
|
| 199 |
+
elif hyp_parameters["ndims"] == 3:
|
| 200 |
+
nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
|
| 201 |
+
nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
|
| 202 |
+
nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
|
| 203 |
+
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
|
| 204 |
+
|
| 205 |
+
nib.save(nifti_img_aug, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| 206 |
+
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
| 207 |
+
|
| 208 |
+
# Saving noisy image to nifti
|
| 209 |
+
# CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
|
| 210 |
+
# Lung CT: ...
|
| 211 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['aug_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
|
| 212 |
+
nib.save(nifti_mask, os.path.join(hyp_parameters['aug_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
if (im - hyp_parameters["start_noise_step"])%2 == 0:
|
| 216 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 217 |
+
# break # for testing
|
| 218 |
+
# if e > 5:
|
| 219 |
+
# break
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
OM_contrastive.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.optim import Adam
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from Diffusion.networks import get_net
|
| 6 |
+
from Dataloader.dataLoader import *
|
| 7 |
+
import argparse
|
| 8 |
+
import yaml
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import swanlab
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml")
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
with open(args.config, 'r') as file:
|
| 18 |
+
hyp = yaml.safe_load(file)
|
| 19 |
+
|
| 20 |
+
# Setup
|
| 21 |
+
device = torch.device(hyp['device'] if torch.cuda.is_available() else 'cpu')
|
| 22 |
+
data_name = hyp['data_name']
|
| 23 |
+
net_name = hyp['net_name']
|
| 24 |
+
ndims = hyp['ndims']
|
| 25 |
+
img_size = hyp['img_size']
|
| 26 |
+
model_save_path = os.path.join('Models', f'{data_name}_{net_name}/')
|
| 27 |
+
os.makedirs(model_save_path, exist_ok=True)
|
| 28 |
+
|
| 29 |
+
# SwanLab
|
| 30 |
+
swanlab.init(project="OM", config=hyp)
|
| 31 |
+
|
| 32 |
+
# Model
|
| 33 |
+
Net = get_net(net_name)
|
| 34 |
+
model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device)
|
| 35 |
+
optimizer = Adam(model.parameters(), lr=hyp['lr'])
|
| 36 |
+
|
| 37 |
+
# Data
|
| 38 |
+
dataset = OMDataset_indiv(out_sz=img_size, transform=None)
|
| 39 |
+
train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True)
|
| 40 |
+
|
| 41 |
+
# Training
|
| 42 |
+
print('start training...')
|
| 43 |
+
for epoch in range(hyp['epoch']):
|
| 44 |
+
epoch_loss = 0.0
|
| 45 |
+
|
| 46 |
+
for i, (volume, embd) in enumerate(train_loader):
|
| 47 |
+
t0 = time.time()
|
| 48 |
+
volume = volume.float().to(device)
|
| 49 |
+
embd = embd.to(device) # [B, 1024] GT text embedding
|
| 50 |
+
t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device)
|
| 51 |
+
|
| 52 |
+
_, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024]
|
| 53 |
+
|
| 54 |
+
# Cosine similarity loss: align img_embd with GT text embedding
|
| 55 |
+
loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean()
|
| 56 |
+
swanlab.log({"loss": loss.item()})
|
| 57 |
+
|
| 58 |
+
optimizer.zero_grad()
|
| 59 |
+
loss.backward()
|
| 60 |
+
optimizer.step()
|
| 61 |
+
epoch_loss += loss.item()
|
| 62 |
+
t1 = time.time()
|
| 63 |
+
dt = t1 - t0
|
| 64 |
+
swanlab.log({"Time(mins)/batch": dt/60})
|
| 65 |
+
avg_loss = epoch_loss / max(len(train_loader), 1)
|
| 66 |
+
print(f"Epoch {epoch:04d} | Loss: {avg_loss:.6f}")
|
| 67 |
+
swanlab.log({"Avg Loss/epoch": avg_loss})
|
| 68 |
+
|
| 69 |
+
# if epoch % hyp['epoch_per_save'] == 0:
|
| 70 |
+
# save_path = model_save_path + str(epoch).rjust(6, '0') + f'_{data_name}_{net_name}.pth'
|
| 71 |
+
# torch.save({'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, save_path)
|
| 72 |
+
# print(f"Saved: {save_path}")
|
OM_reg.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.utils import save_image
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim import Adam
|
| 7 |
+
from torchvision.utils import make_grid
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import random
|
| 12 |
+
import os
|
| 13 |
+
import utils
|
| 14 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 15 |
+
from Dataloader.dataLoader import *
|
| 16 |
+
|
| 17 |
+
from torchvision.utils import save_image
|
| 18 |
+
from einops import rearrange, reduce, repeat
|
| 19 |
+
# import matplotlib.image
|
| 20 |
+
import numpy as np
|
| 21 |
+
import nibabel as nib
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
import yaml
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
EPS = 10e-8
|
| 27 |
+
|
| 28 |
+
parser = argparse.ArgumentParser()
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--config",
|
| 32 |
+
"-C",
|
| 33 |
+
help="Path for the config file",
|
| 34 |
+
type=str,
|
| 35 |
+
default="Config/config_cmr.yaml",
|
| 36 |
+
# default="Config/config_lct.yaml",
|
| 37 |
+
required=False,
|
| 38 |
+
)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
#=======================================================================================================================
|
| 41 |
+
|
| 42 |
+
# config_path = 'Config/config_cmr.yaml'
|
| 43 |
+
# config_path = 'Config/config_lct.yaml'
|
| 44 |
+
|
| 45 |
+
# Load the YAML file into a dictionary
|
| 46 |
+
with open(args.config, 'r') as file:
|
| 47 |
+
hyp_parameters = yaml.safe_load(file)
|
| 48 |
+
print(hyp_parameters)
|
| 49 |
+
# hyp_parameters["aug_img_savepath"] = os.path.join(hyp_parameters["aug_img_savepath"],hyp_parameters["data_name"],'')
|
| 50 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 51 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 52 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 53 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 54 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 55 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 56 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 57 |
+
|
| 58 |
+
hyp_parameters['batchsize'] = 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# =======================================================================================================================
|
| 62 |
+
# min_crop_ratio = 0.5
|
| 63 |
+
min_crop_ratio = 0.9
|
| 64 |
+
|
| 65 |
+
# dataset = OminiDataset_v1(transform=None,min_crop_ratio=min_crop_ratio)
|
| 66 |
+
# Infer_Loader = DataLoader(
|
| 67 |
+
# dataset,
|
| 68 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 69 |
+
# shuffle=False
|
| 70 |
+
# )
|
| 71 |
+
|
| 72 |
+
# label_keys = ['heart']
|
| 73 |
+
label_keys = ['brain']
|
| 74 |
+
# label_keys = ['pancreas']
|
| 75 |
+
database = ['MSD']
|
| 76 |
+
|
| 77 |
+
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
|
| 78 |
+
Infer_Loader = DataLoader(
|
| 79 |
+
dataset,
|
| 80 |
+
batch_size=hyp_parameters['batchsize'],
|
| 81 |
+
shuffle=False
|
| 82 |
+
)
|
| 83 |
+
# =======================================================================================================================
|
| 84 |
+
|
| 85 |
+
# Data_Loader=get_dataloader(hyp_parameters['data_name'],mode='aug')
|
| 86 |
+
# transformer = utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 87 |
+
# dataset = Data_Loader(patient_index = hyp_parameters["patients_list"])
|
| 88 |
+
# train_loader = DataLoader(dataset, batch_size = hyp_parameters['batchsize'], shuffle = False)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 93 |
+
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| 94 |
+
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
Net = get_net(hyp_parameters["net_name"])
|
| 99 |
+
|
| 100 |
+
Deformddpm = DeformDDPM(
|
| 101 |
+
network=Net(n_steps = hyp_parameters["timesteps"],
|
| 102 |
+
ndims = hyp_parameters["ndims"],
|
| 103 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 104 |
+
res = hyp_parameters['img_size']
|
| 105 |
+
),
|
| 106 |
+
n_steps = hyp_parameters["timesteps"],
|
| 107 |
+
image_chw = [hyp_parameters["num_input_chn"]] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 108 |
+
device = hyp_parameters["device"],
|
| 109 |
+
batch_size = hyp_parameters["batchsize"],
|
| 110 |
+
img_pad_mode = hyp_parameters["img_pad_mode"],
|
| 111 |
+
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| 112 |
+
padding_mode = hyp_parameters["padding_mode"],
|
| 113 |
+
v_scale = hyp_parameters["v_scale"],
|
| 114 |
+
resample_mode = hyp_parameters["resample_mode"],
|
| 115 |
+
)
|
| 116 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 117 |
+
|
| 118 |
+
ddf_stn = STN(
|
| 119 |
+
img_sz = hyp_parameters["img_size"],
|
| 120 |
+
ndims = hyp_parameters["ndims"],
|
| 121 |
+
padding_mode = hyp_parameters['padding_mode'],
|
| 122 |
+
device = hyp_parameters["device"],
|
| 123 |
+
)
|
| 124 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 125 |
+
|
| 126 |
+
print("Loading model from:", model_save_path)
|
| 127 |
+
# Deformddpm.load_state_dict(torch.load(model_save_path))
|
| 128 |
+
checkpoint = torch.load(model_save_path)
|
| 129 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 130 |
+
Deformddpm.eval()
|
| 131 |
+
|
| 132 |
+
os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
|
| 133 |
+
os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
|
| 134 |
+
os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
|
| 135 |
+
|
| 136 |
+
print("total num of image:", len(Infer_Loader))
|
| 137 |
+
for e, d in tqdm(enumerate(Infer_Loader)):
|
| 138 |
+
# for e, d in enumerate(Infer_Loader):
|
| 139 |
+
# img, mask, pid = d
|
| 140 |
+
# img = d
|
| 141 |
+
# mask = d
|
| 142 |
+
img = d['img']
|
| 143 |
+
mask = d['labels']
|
| 144 |
+
|
| 145 |
+
# pid = pid.cpu().detach().numpy()
|
| 146 |
+
# pid = pid[0]
|
| 147 |
+
pid = e
|
| 148 |
+
|
| 149 |
+
print('Processing to patient:', pid, ' image:',e)
|
| 150 |
+
|
| 151 |
+
img = img.to(hyp_parameters["device"])
|
| 152 |
+
img = img.type(torch.float32)
|
| 153 |
+
image_original = img.cpu().detach().numpy()
|
| 154 |
+
#
|
| 155 |
+
#
|
| 156 |
+
if e <= 0:
|
| 157 |
+
target_img = img.clone().detach() # save the first image as target image for conditioning
|
| 158 |
+
|
| 159 |
+
mask = mask.to(hyp_parameters["device"])
|
| 160 |
+
mask = mask.type(torch.float32)
|
| 161 |
+
mask_original = mask.cpu().detach().numpy()
|
| 162 |
+
# print(pid, image_original.shape, mask_original.max())
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if hyp_parameters["ndims"] == 2:
|
| 166 |
+
nifti_img = nib.Nifti1Image(image_original[0,0,:,:], np.eye(4))
|
| 167 |
+
nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
|
| 168 |
+
elif hyp_parameters["ndims"] == 3:
|
| 169 |
+
nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
|
| 170 |
+
nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
|
| 171 |
+
|
| 172 |
+
# Saving original (undeformed image)
|
| 173 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 174 |
+
# Lung CT: Patient0001_Slice0001_ORG_NA.nii.gz
|
| 175 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e])+'.nii.gz'))
|
| 176 |
+
|
| 177 |
+
# Saving original (undeformed image)
|
| 178 |
+
# CMR: format: Patient0001_Slice0001_ORG_NA_GT.nii.gz
|
| 179 |
+
# Lung CT: ...
|
| 180 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e])+'_GT.nii.gz'))
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 184 |
+
with torch.no_grad():
|
| 185 |
+
for im in range(1):
|
| 186 |
+
# # Permute
|
| 187 |
+
# if hyp_parameters["ndims"] == 2:
|
| 188 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2]) # add random rotation to image
|
| 189 |
+
# elif hyp_parameters["ndims"] == 3:
|
| 190 |
+
# [img, mask] = utils.random_permute([img, mask], select_dims=[-1, -2, -3]) # add random rotation to image
|
| 191 |
+
|
| 192 |
+
print('Generating - >', 'Subject-',pid,', Scan-',e,' (',im,'/',hyp_parameters["aug_coe"],')', end='\r')
|
| 193 |
+
|
| 194 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save] = Deformddpm.diff_recover(img_org=img,cond_imgs=target_img.clone().detach(),msk_org=mask,T=[None,hyp_parameters["timesteps"]],v_scale=hyp_parameters["v_scale"],t_save=None,proc_type=hyp_parameters["condition_type"])
|
| 195 |
+
|
| 196 |
+
denoise_imgs = img_rec.cpu().detach().numpy()
|
| 197 |
+
denoise_msks = msk_rec.cpu().detach().numpy()
|
| 198 |
+
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 199 |
+
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 200 |
+
|
| 201 |
+
if hyp_parameters["ndims"] == 2:
|
| 202 |
+
nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:], np.eye(4))
|
| 203 |
+
nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,:,:,:], np.eye(4))
|
| 204 |
+
nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:], np.eye(4))
|
| 205 |
+
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
|
| 206 |
+
elif hyp_parameters["ndims"] == 3:
|
| 207 |
+
nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
|
| 208 |
+
nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
|
| 209 |
+
nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
|
| 210 |
+
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
|
| 211 |
+
|
| 212 |
+
nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| 213 |
+
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
| 214 |
+
|
| 215 |
+
# Saving noisy image to nifti
|
| 216 |
+
# CMR: format: Patient0001_Slice0001_NosieImg0001_NoiseStep0070.nii.gz
|
| 217 |
+
# Lung CT: ...
|
| 218 |
+
nib.save(nifti_img, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'.nii.gz'))
|
| 219 |
+
nib.save(nifti_mask, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step],header=['Patient','Slice','NoiseImg','NoiseStep'])+'_GT.nii.gz'))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if (im - hyp_parameters["start_noise_step"])%2 == 0:
|
| 223 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 224 |
+
# break # for testing
|
| 225 |
+
if e > 5:
|
| 226 |
+
break
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
|
OM_train.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.utils import save_image
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from torch.optim import Adam, SGD
|
| 10 |
+
from Diffusion.diffuser import DeformDDPM
|
| 11 |
+
from Diffusion.networks import get_net, STN
|
| 12 |
+
from torchvision.transforms import Lambda
|
| 13 |
+
import Diffusion.losses as losses
|
| 14 |
+
import random
|
| 15 |
+
import glob
|
| 16 |
+
import numpy as np
|
| 17 |
+
import utils
|
| 18 |
+
|
| 19 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 20 |
+
from Dataloader.dataLoader import *
|
| 21 |
+
|
| 22 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 23 |
+
import yaml
|
| 24 |
+
import argparse
|
| 25 |
+
|
| 26 |
+
####################
|
| 27 |
+
import torch.multiprocessing as mp
|
| 28 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 29 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 30 |
+
import torch.distributed as dist
|
| 31 |
+
# from torch.distributed import init_process_group
|
| 32 |
+
###############
|
| 33 |
+
def ddp_setup(rank, world_size):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
rank: Unique identifier of each process
|
| 37 |
+
world_size: Total number of processes
|
| 38 |
+
"""
|
| 39 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 40 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 41 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 42 |
+
torch.cuda.set_device(rank)
|
| 43 |
+
|
| 44 |
+
use_distributed = True
|
| 45 |
+
# use_distributed = False
|
| 46 |
+
|
| 47 |
+
EPS = 1e-5
|
| 48 |
+
|
| 49 |
+
parser = argparse.ArgumentParser()
|
| 50 |
+
|
| 51 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--config",
|
| 54 |
+
"-C",
|
| 55 |
+
help="Path for the config file",
|
| 56 |
+
type=str,
|
| 57 |
+
# default="Config/config_cmr.yaml",
|
| 58 |
+
# default="Config/config_lct.yaml",
|
| 59 |
+
default="Config/config_all.yaml",
|
| 60 |
+
required=False,
|
| 61 |
+
)
|
| 62 |
+
args = parser.parse_args()
|
| 63 |
+
#=======================================================================================================================
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main_train(rank=0,world_size=1):
|
| 68 |
+
if use_distributed:
|
| 69 |
+
ddp_setup(rank,world_size)
|
| 70 |
+
gpu_id = rank
|
| 71 |
+
|
| 72 |
+
# Load the YAML file into a dictionary
|
| 73 |
+
with open(args.config, 'r') as file:
|
| 74 |
+
hyp_parameters = yaml.safe_load(file)
|
| 75 |
+
print(hyp_parameters)
|
| 76 |
+
|
| 77 |
+
# epoch_per_save=10
|
| 78 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 79 |
+
|
| 80 |
+
data_name=hyp_parameters['data_name']
|
| 81 |
+
net_name = hyp_parameters['net_name']
|
| 82 |
+
|
| 83 |
+
Net=get_net(net_name)
|
| 84 |
+
|
| 85 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 86 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 87 |
+
model_dir=model_save_path
|
| 88 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 89 |
+
|
| 90 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 91 |
+
|
| 92 |
+
# tsfm = torchvision.transforms.Compose([
|
| 93 |
+
# torchvision.transforms.ToTensor(),
|
| 94 |
+
# ])
|
| 95 |
+
|
| 96 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 97 |
+
# train_loader = DataLoader(
|
| 98 |
+
# dataset,
|
| 99 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 100 |
+
# # shuffle=False,
|
| 101 |
+
# shuffle=True,
|
| 102 |
+
# drop_last=True,
|
| 103 |
+
# )
|
| 104 |
+
|
| 105 |
+
dataset = OminiDataset_v1(transform=None)
|
| 106 |
+
train_loader = DataLoader(
|
| 107 |
+
dataset,
|
| 108 |
+
batch_size=hyp_parameters['batchsize'],
|
| 109 |
+
shuffle=True,
|
| 110 |
+
drop_last=True,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
Deformddpm = DeformDDPM(
|
| 116 |
+
network=Net(
|
| 117 |
+
n_steps=hyp_parameters["timesteps"],
|
| 118 |
+
ndims=hyp_parameters["ndims"],
|
| 119 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 120 |
+
res = hyp_parameters['img_size']
|
| 121 |
+
),
|
| 122 |
+
n_steps=hyp_parameters["timesteps"],
|
| 123 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 124 |
+
device=hyp_parameters["device"],
|
| 125 |
+
batch_size=hyp_parameters["batchsize"],
|
| 126 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 127 |
+
v_scale=hyp_parameters["v_scale"],
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
ddf_stn = STN(
|
| 132 |
+
img_sz=hyp_parameters["img_size"],
|
| 133 |
+
ndims=hyp_parameters["ndims"],
|
| 134 |
+
# padding_mode="zeros",
|
| 135 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 136 |
+
device=hyp_parameters["device"],
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
if use_distributed:
|
| 141 |
+
Deformddpm.to(rank)
|
| 142 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 143 |
+
ddf_stn.to(rank)
|
| 144 |
+
else:
|
| 145 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 146 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 147 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# mse = nn.MSELoss()
|
| 151 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 152 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 153 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 154 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 155 |
+
|
| 156 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 157 |
+
# hyp_parameters["lr"]=0.00000001
|
| 158 |
+
# # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.95)
|
| 159 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 160 |
+
|
| 161 |
+
# # LR scheduler ----- YHM
|
| 162 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 163 |
+
|
| 164 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 165 |
+
|
| 166 |
+
# check for existing models
|
| 167 |
+
if not os.path.exists(model_dir):
|
| 168 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 169 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 170 |
+
model_files.sort()
|
| 171 |
+
if model_files:
|
| 172 |
+
if gpu_id == 0:
|
| 173 |
+
print(model_files)
|
| 174 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
|
| 175 |
+
else:
|
| 176 |
+
initial_epoch = 0
|
| 177 |
+
|
| 178 |
+
if gpu_id == 0:
|
| 179 |
+
print('len_train_data: ',len(dataset))
|
| 180 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 181 |
+
|
| 182 |
+
epoch_loss_tot = 0.0
|
| 183 |
+
epoch_loss_gen_d = 0.0
|
| 184 |
+
epoch_loss_gen_a = 0.0
|
| 185 |
+
epoch_loss_reg = 0.0
|
| 186 |
+
# Set model inside to train model
|
| 187 |
+
Deformddpm.train()
|
| 188 |
+
|
| 189 |
+
for step, batch in enumerate(train_loader):
|
| 190 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 191 |
+
# x0, _ = batch
|
| 192 |
+
x0 = batch # for omni dataset
|
| 193 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 194 |
+
|
| 195 |
+
n = x0.size()[0] # batch_size -> n
|
| 196 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 197 |
+
|
| 198 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 199 |
+
|
| 200 |
+
# random deformation + rotation
|
| 201 |
+
if hyp_parameters["ndims"]>2:
|
| 202 |
+
if np.random.uniform(0,1)<0.6:
|
| 203 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 204 |
+
x0 = transformer(x0)
|
| 205 |
+
if hyp_parameters['noise_scale']>0:
|
| 206 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 207 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 208 |
+
|
| 209 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 210 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 211 |
+
hyp_parameters["device"]
|
| 212 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, mask=blind_mask) # forward diffusion process
|
| 216 |
+
|
| 217 |
+
loss_tot=0
|
| 218 |
+
|
| 219 |
+
loss_ddf = loss_reg(pre_dvf_I)
|
| 220 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 221 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 222 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 223 |
+
|
| 224 |
+
loss_tot += 1. * loss_gen_d + 1. * loss_gen_a
|
| 225 |
+
loss_tot += 1.0 * loss_ddf
|
| 226 |
+
optimizer.zero_grad()
|
| 227 |
+
loss_tot.backward()
|
| 228 |
+
optimizer.step()
|
| 229 |
+
|
| 230 |
+
epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
|
| 231 |
+
epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
|
| 232 |
+
epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
|
| 233 |
+
epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
|
| 234 |
+
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 235 |
+
|
| 236 |
+
# break # FOR TESTING
|
| 237 |
+
|
| 238 |
+
if gpu_id == 0:
|
| 239 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 240 |
+
|
| 241 |
+
# # LR schedular step ----- YHM
|
| 242 |
+
# scheduler.step()
|
| 243 |
+
|
| 244 |
+
if 0 == epoch % epoch_per_save:
|
| 245 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 246 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 247 |
+
# break # FOR TESTING
|
| 248 |
+
if not use_distributed:
|
| 249 |
+
print(f"saved in {save_dir}")
|
| 250 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 251 |
+
torch.save({
|
| 252 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 253 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 254 |
+
'epoch': epoch
|
| 255 |
+
}, save_dir)
|
| 256 |
+
elif gpu_id == 0:
|
| 257 |
+
print(f"saved in {save_dir}")
|
| 258 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 259 |
+
torch.save({
|
| 260 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 261 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 262 |
+
'epoch': epoch
|
| 263 |
+
}, save_dir)
|
| 264 |
+
|
| 265 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
|
| 266 |
+
|
| 267 |
+
if gpu_id == 0:
|
| 268 |
+
# if 0:
|
| 269 |
+
utils.print_memory_usage("Before Loading Model")
|
| 270 |
+
if 1:
|
| 271 |
+
gc.collect()
|
| 272 |
+
torch.cuda.empty_cache()
|
| 273 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 274 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 275 |
+
checkpoint = torch.load(model_file)
|
| 276 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 277 |
+
if use_distributed:
|
| 278 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
|
| 279 |
+
else:
|
| 280 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 281 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 282 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 283 |
+
|
| 284 |
+
if use_distributed:
|
| 285 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 286 |
+
dist.barrier()
|
| 287 |
+
for param in Deformddpm.parameters():
|
| 288 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 289 |
+
dist.barrier()
|
| 290 |
+
for param_group in optimizer.param_groups:
|
| 291 |
+
for param in param_group['params']:
|
| 292 |
+
if param.grad is not None:
|
| 293 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 294 |
+
|
| 295 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 296 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 297 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 298 |
+
|
| 299 |
+
return initial_epoch, Deformddpm, optimizer
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if __name__ == "__main__":
|
| 304 |
+
if use_distributed:
|
| 305 |
+
world_size = torch.cuda.device_count()
|
| 306 |
+
print(f"Distributed GPU number = {world_size}")
|
| 307 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 308 |
+
else:
|
| 309 |
+
main_train(0,1)
|
OM_train_2modes.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.utils import save_image
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from torch.optim import Adam, SGD
|
| 10 |
+
from Diffusion.diffuser import DeformDDPM
|
| 11 |
+
from Diffusion.networks import get_net, STN
|
| 12 |
+
from torchvision.transforms import Lambda
|
| 13 |
+
import Diffusion.losses as losses
|
| 14 |
+
import random
|
| 15 |
+
import glob
|
| 16 |
+
import numpy as np
|
| 17 |
+
import utils
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 21 |
+
from Dataloader.dataLoader import *
|
| 22 |
+
|
| 23 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 24 |
+
import yaml
|
| 25 |
+
import argparse
|
| 26 |
+
|
| 27 |
+
####################
|
| 28 |
+
import torch.multiprocessing as mp
|
| 29 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 30 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 31 |
+
import torch.distributed as dist
|
| 32 |
+
# from torch.distributed import init_process_group
|
| 33 |
+
###############
|
| 34 |
+
def ddp_setup(rank, world_size):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
rank: Unique identifier of each process
|
| 38 |
+
world_size: Total number of processes
|
| 39 |
+
"""
|
| 40 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 41 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 42 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 43 |
+
torch.cuda.set_device(rank)
|
| 44 |
+
|
| 45 |
+
use_distributed = True
|
| 46 |
+
# use_distributed = False
|
| 47 |
+
|
| 48 |
+
EPS = 1e-5
|
| 49 |
+
MSK_EPS = 0.01
|
| 50 |
+
TEXT_EMBED_PROB = 0.7
|
| 51 |
+
AUG_RESAMPLE_PROB = 0.6
|
| 52 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 30] # [ang, dist, reg]
|
| 53 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 54 |
+
# LOSS_WEIGHTS_REGIST = [10.0, 1.0, 1.0] # [imgsim, imgmse, ddf]
|
| 55 |
+
# LOSS_WEIGHTS_REGIST = [2.0, 0.1, 1e3] # [imgsim, imgmse, ddf]
|
| 56 |
+
LOSS_WEIGHTS_REGIST = [2.0, 0.1, 256] # [imgsim, imgmse, ddf]
|
| 57 |
+
|
| 58 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 59 |
+
|
| 60 |
+
parser = argparse.ArgumentParser()
|
| 61 |
+
|
| 62 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--config",
|
| 65 |
+
"-C",
|
| 66 |
+
help="Path for the config file",
|
| 67 |
+
type=str,
|
| 68 |
+
# default="Config/config_cmr.yaml",
|
| 69 |
+
# default="Config/config_lct.yaml",
|
| 70 |
+
default="Config/config_all.yaml",
|
| 71 |
+
required=False,
|
| 72 |
+
)
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
#=======================================================================================================================
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 79 |
+
if use_distributed:
|
| 80 |
+
ddp_setup(rank,world_size)
|
| 81 |
+
|
| 82 |
+
if torch.distributed.is_initialized():
|
| 83 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 84 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 85 |
+
gpu_id = rank
|
| 86 |
+
|
| 87 |
+
# Load the YAML file into a dictionary
|
| 88 |
+
with open(args.config, 'r') as file:
|
| 89 |
+
hyp_parameters = yaml.safe_load(file)
|
| 90 |
+
print(hyp_parameters)
|
| 91 |
+
|
| 92 |
+
# epoch_per_save=10
|
| 93 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 94 |
+
|
| 95 |
+
data_name=hyp_parameters['data_name']
|
| 96 |
+
net_name = hyp_parameters['net_name']
|
| 97 |
+
|
| 98 |
+
Net=get_net(net_name)
|
| 99 |
+
|
| 100 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 101 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 102 |
+
model_dir=model_save_path
|
| 103 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 104 |
+
|
| 105 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 106 |
+
|
| 107 |
+
# tsfm = torchvision.transforms.Compose([
|
| 108 |
+
# torchvision.transforms.ToTensor(),
|
| 109 |
+
# ])
|
| 110 |
+
|
| 111 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 112 |
+
# train_loader = DataLoader(
|
| 113 |
+
# dataset,
|
| 114 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 115 |
+
# # shuffle=False,
|
| 116 |
+
# shuffle=True,
|
| 117 |
+
# drop_last=True,
|
| 118 |
+
# )
|
| 119 |
+
|
| 120 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 121 |
+
dataset = OMDataset_indiv(transform=None)
|
| 122 |
+
train_loader = DataLoader(
|
| 123 |
+
dataset,
|
| 124 |
+
batch_size=hyp_parameters['batchsize'],
|
| 125 |
+
shuffle=True,
|
| 126 |
+
drop_last=True,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 130 |
+
datasetp = OMDataset_pair(transform=None)
|
| 131 |
+
train_loader_p = DataLoader(
|
| 132 |
+
datasetp,
|
| 133 |
+
batch_size=hyp_parameters['batchsize']//2,
|
| 134 |
+
shuffle=True,
|
| 135 |
+
drop_last=True,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Deformddpm = DeformDDPM(
|
| 141 |
+
network=Net(
|
| 142 |
+
n_steps=hyp_parameters["timesteps"],
|
| 143 |
+
ndims=hyp_parameters["ndims"],
|
| 144 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 145 |
+
res = hyp_parameters['img_size']
|
| 146 |
+
),
|
| 147 |
+
n_steps=hyp_parameters["timesteps"],
|
| 148 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 149 |
+
device=hyp_parameters["device"],
|
| 150 |
+
batch_size=hyp_parameters["batchsize"],
|
| 151 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 152 |
+
v_scale=hyp_parameters["v_scale"],
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
ddf_stn = STN(
|
| 157 |
+
img_sz=hyp_parameters["img_size"],
|
| 158 |
+
ndims=hyp_parameters["ndims"],
|
| 159 |
+
# padding_mode="zeros",
|
| 160 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 161 |
+
device=hyp_parameters["device"],
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if use_distributed:
|
| 166 |
+
Deformddpm.to(rank)
|
| 167 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 168 |
+
ddf_stn.to(rank)
|
| 169 |
+
else:
|
| 170 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 171 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 172 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# mse = nn.MSELoss()
|
| 176 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 177 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 178 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 179 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 180 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 181 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 182 |
+
loss_imgsim = losses.LNCC()
|
| 183 |
+
loss_imgmse = losses.LMSE()
|
| 184 |
+
|
| 185 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 186 |
+
# hyp_parameters["lr"]=0.00000001
|
| 187 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 188 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 189 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 190 |
+
|
| 191 |
+
# # LR scheduler ----- YHM
|
| 192 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 193 |
+
|
| 194 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 195 |
+
|
| 196 |
+
# check for existing models
|
| 197 |
+
if not os.path.exists(model_dir):
|
| 198 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 199 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 200 |
+
model_files.sort()
|
| 201 |
+
if model_files:
|
| 202 |
+
if gpu_id == 0:
|
| 203 |
+
print(model_files)
|
| 204 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
|
| 205 |
+
else:
|
| 206 |
+
initial_epoch = 0
|
| 207 |
+
|
| 208 |
+
if gpu_id == 0:
|
| 209 |
+
print('len_train_data: ',len(dataset))
|
| 210 |
+
# Training loop
|
| 211 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 212 |
+
|
| 213 |
+
epoch_loss_tot = 0.0
|
| 214 |
+
epoch_loss_gen_d = 0.0
|
| 215 |
+
epoch_loss_gen_a = 0.0
|
| 216 |
+
epoch_loss_reg = 0.0
|
| 217 |
+
epoch_loss_regist = 0.0
|
| 218 |
+
epoch_loss_imgsim = 0.0
|
| 219 |
+
epoch_loss_imgmse = 0.0
|
| 220 |
+
epoch_loss_ddfreg = 0.0
|
| 221 |
+
# Set model inside to train model
|
| 222 |
+
Deformddpm.train()
|
| 223 |
+
|
| 224 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 225 |
+
|
| 226 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 227 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 228 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 229 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 230 |
+
|
| 231 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 232 |
+
# x0, _ = batch
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ==========================================================================
|
| 236 |
+
# diffusion train on single image
|
| 237 |
+
|
| 238 |
+
# x0 = batch # for omni dataset
|
| 239 |
+
[x0,embd] = batch # for om dataset
|
| 240 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 241 |
+
# print('embd:', embd.shape)
|
| 242 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 243 |
+
embd = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 244 |
+
else:
|
| 245 |
+
embd = None
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
n = x0.size()[0] # batch_size -> n
|
| 250 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 251 |
+
|
| 252 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 253 |
+
|
| 254 |
+
# random deformation + rotation
|
| 255 |
+
if hyp_parameters["ndims"]>2:
|
| 256 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 257 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 258 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 259 |
+
else:
|
| 260 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 261 |
+
x0 = transformer(x0)
|
| 262 |
+
if hyp_parameters['noise_scale']>0:
|
| 263 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 264 |
+
x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
|
| 265 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 266 |
+
|
| 267 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 268 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 269 |
+
hyp_parameters["device"]
|
| 270 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 271 |
+
|
| 272 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 273 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
|
| 274 |
+
# print('proc_type:', proc_type)
|
| 275 |
+
cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
|
| 276 |
+
|
| 277 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 278 |
+
|
| 279 |
+
loss_tot=0
|
| 280 |
+
|
| 281 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 282 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 283 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 284 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 285 |
+
|
| 286 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 287 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 288 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 289 |
+
|
| 290 |
+
# >> JZ: print nan in x0
|
| 291 |
+
if torch.isnan(x0).any():
|
| 292 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 293 |
+
# >> JZ: print loss of ddf
|
| 294 |
+
if loss_ddf>0.001:
|
| 295 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 296 |
+
# yu: check if loss_tot==nan or inf
|
| 297 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 298 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 299 |
+
loss_nan_step += 1
|
| 300 |
+
continue
|
| 301 |
+
if loss_nan_step > 5:
|
| 302 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 303 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
optimizer.zero_grad()
|
| 307 |
+
loss_tot.backward()
|
| 308 |
+
optimizer.step()
|
| 309 |
+
|
| 310 |
+
epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
|
| 311 |
+
epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
|
| 312 |
+
epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
|
| 313 |
+
epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
|
| 314 |
+
|
| 315 |
+
# print(loss_gen_a.item())
|
| 316 |
+
# if 0:
|
| 317 |
+
# if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
|
| 318 |
+
if step%train_mode_ratio == 0:
|
| 319 |
+
# ==========================================================================
|
| 320 |
+
# registration train on paired images
|
| 321 |
+
# x1, y1 = next(iter(train_loader_p))
|
| 322 |
+
# [x1, y1, _, embd_y] = next(iter(train_loader_p))
|
| 323 |
+
[x1, y1, _, embd_y] = batch_p
|
| 324 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 325 |
+
# embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
|
| 326 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 327 |
+
else:
|
| 328 |
+
# embd_x = None
|
| 329 |
+
embd_y = None
|
| 330 |
+
|
| 331 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 332 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 333 |
+
n = x1.size()[0] # batch_size -> n
|
| 334 |
+
# random deformation + rotation
|
| 335 |
+
# if hyp_parameters["ndims"]>2:
|
| 336 |
+
# if np.random.uniform(0,1)<0.6:
|
| 337 |
+
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 338 |
+
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 339 |
+
x1 = transformer(x1)
|
| 340 |
+
y1 = transformer(y1)
|
| 341 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 342 |
+
if hyp_parameters['noise_scale']>0:
|
| 343 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 344 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 345 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 346 |
+
x1 = x1 * random_scale + random_shift
|
| 347 |
+
y1 = y1 * random_scale + random_shift
|
| 348 |
+
# x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
|
| 349 |
+
# x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 350 |
+
# y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
|
| 351 |
+
# y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 352 |
+
# # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 353 |
+
# t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 354 |
+
# hyp_parameters["device"]
|
| 355 |
+
# ) # pick up a seq of rand number from 0 to 'timestep'
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# scale_regist = np.random.uniform(0.2,0.25)
|
| 359 |
+
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
|
| 360 |
+
scale_regist = np.random.uniform(0.05,0.7)
|
| 361 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), 16), reverse=True)
|
| 362 |
+
# scale_regist = np.random.uniform(0.4,1.)
|
| 363 |
+
# T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
|
| 364 |
+
# scale_regist = np.random.uniform(0.6,1.)
|
| 365 |
+
# init_T = int(hyp_parameters["timesteps"] * scale_regist)
|
| 366 |
+
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
|
| 367 |
+
|
| 368 |
+
T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
|
| 369 |
+
|
| 370 |
+
# print('T_regist:', T_regist)
|
| 371 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 372 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'none'])
|
| 373 |
+
# proc_type = random.choice(['project'])
|
| 374 |
+
y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 375 |
+
# msk_tgt = msk_tgt + MSK_EPS
|
| 376 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 377 |
+
# loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
|
| 378 |
+
# loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 379 |
+
# loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
|
| 380 |
+
loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 381 |
+
loss_mse = loss_imgmse(img_rec, y1, label=(y1>0.0)) # calculate loss for the registration process
|
| 382 |
+
loss_ddf1 = loss_reg1(ddf_comp,img=y1) # calculate loss for the registration process
|
| 383 |
+
|
| 384 |
+
loss_regist = 0
|
| 385 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 386 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 387 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 388 |
+
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 389 |
+
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 390 |
+
|
| 391 |
+
# >> JZ: print nan in x0
|
| 392 |
+
if torch.isnan(x0).any():
|
| 393 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# >> JZ: print loss of ddf
|
| 398 |
+
if loss_ddf1>0.001:
|
| 399 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 400 |
+
# # Print gradients for each parameter
|
| 401 |
+
# for name, param in Deformddpm.named_parameters():
|
| 402 |
+
# if param.grad is not None:
|
| 403 |
+
# print(f"Gradient for {name}: {param.grad.norm()}")
|
| 404 |
+
# else:
|
| 405 |
+
# print(f"Gradient for {name}: None")
|
| 406 |
+
|
| 407 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 408 |
+
optimizer.zero_grad()
|
| 409 |
+
loss_regist.backward()
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 414 |
+
optimizer.step()
|
| 415 |
+
|
| 416 |
+
epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
|
| 417 |
+
epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
|
| 418 |
+
epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
|
| 419 |
+
epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 423 |
+
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 424 |
+
# >> JZ: if loss_imgsim is zero
|
| 425 |
+
if loss_sim.item()>-0.001:
|
| 426 |
+
print(f"*** Zero image similarity loss at epoch {epoch}, step {step}.")
|
| 427 |
+
def save_niftiimage(tensor, filename):
|
| 428 |
+
import nibabel as nib
|
| 429 |
+
import numpy as np
|
| 430 |
+
array = tensor.squeeze().cpu().detach().numpy()
|
| 431 |
+
nifti_img = nib.Nifti1Image(array, affine=np.eye(4))
|
| 432 |
+
nib.save(nifti_img, filename)
|
| 433 |
+
# save the x1 and y1 images for debugging
|
| 434 |
+
save_path = os.path.join('/home/data/Github/OmniMorph/Log/error_files',f"debug_images_epoch{epoch}_step{step}/")
|
| 435 |
+
os.makedirs(save_path, exist_ok=True)
|
| 436 |
+
save_niftiimage(img_rec, os.path.join(save_path, 'img_rec.nii.gz'))
|
| 437 |
+
save_niftiimage(x1, os.path.join(save_path, 'x1.nii.gz'))
|
| 438 |
+
save_niftiimage(y1, os.path.join(save_path, 'y1.nii.gz'))
|
| 439 |
+
save_niftiimage(y1_proc, os.path.join(save_path, 'y1_proc.nii.gz'))
|
| 440 |
+
exit()
|
| 441 |
+
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 442 |
+
|
| 443 |
+
# break # FOR TESTING
|
| 444 |
+
# else:
|
| 445 |
+
# print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
|
| 446 |
+
# pass
|
| 447 |
+
|
| 448 |
+
if 1:
|
| 449 |
+
# if gpu_id == 0:
|
| 450 |
+
print('==================')
|
| 451 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 452 |
+
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
|
| 453 |
+
print('==================')
|
| 454 |
+
# # LR schedular step ----- YHM
|
| 455 |
+
# scheduler.step()
|
| 456 |
+
|
| 457 |
+
if 0 == epoch % epoch_per_save:
|
| 458 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 459 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 460 |
+
# break # FOR TESTING
|
| 461 |
+
if not use_distributed:
|
| 462 |
+
print(f"saved in {save_dir}")
|
| 463 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 464 |
+
torch.save({
|
| 465 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 466 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 467 |
+
'epoch': epoch
|
| 468 |
+
}, save_dir)
|
| 469 |
+
elif gpu_id == 0:
|
| 470 |
+
print(f"saved in {save_dir}")
|
| 471 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 472 |
+
torch.save({
|
| 473 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 474 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 475 |
+
'epoch': epoch
|
| 476 |
+
}, save_dir)
|
| 477 |
+
|
| 478 |
+
# Resource cleanup at the end of training
|
| 479 |
+
torch.cuda.empty_cache()
|
| 480 |
+
gc.collect()
|
| 481 |
+
if use_distributed and dist.is_initialized():
|
| 482 |
+
dist.destroy_process_group()
|
| 483 |
+
|
| 484 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
|
| 485 |
+
|
| 486 |
+
if gpu_id == 0:
|
| 487 |
+
# if 0:
|
| 488 |
+
utils.print_memory_usage("Before Loading Model")
|
| 489 |
+
if 1:
|
| 490 |
+
gc.collect()
|
| 491 |
+
torch.cuda.empty_cache()
|
| 492 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 493 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 494 |
+
checkpoint = torch.load(model_file)
|
| 495 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 496 |
+
if use_distributed:
|
| 497 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
|
| 498 |
+
else:
|
| 499 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 500 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 501 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 502 |
+
|
| 503 |
+
if use_distributed:
|
| 504 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 505 |
+
dist.barrier()
|
| 506 |
+
for param in Deformddpm.parameters():
|
| 507 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 508 |
+
dist.barrier()
|
| 509 |
+
for param_group in optimizer.param_groups:
|
| 510 |
+
for param in param_group['params']:
|
| 511 |
+
if param.grad is not None:
|
| 512 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 513 |
+
|
| 514 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 515 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 516 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 517 |
+
|
| 518 |
+
return initial_epoch, Deformddpm, optimizer
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
if __name__ == "__main__":
|
| 523 |
+
if use_distributed:
|
| 524 |
+
world_size = torch.cuda.device_count()
|
| 525 |
+
print(f"Distributed GPU number = {world_size}")
|
| 526 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 527 |
+
else:
|
| 528 |
+
main_train(0,1)
|
OM_train_3modes.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torchvision.utils import save_image
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from torch.optim import Adam, SGD
|
| 10 |
+
from Diffusion.diffuser import DeformDDPM
|
| 11 |
+
from Diffusion.networks import get_net, STN
|
| 12 |
+
from torchvision.transforms import Lambda
|
| 13 |
+
import Diffusion.losses as losses
|
| 14 |
+
import random
|
| 15 |
+
import glob
|
| 16 |
+
import numpy as np
|
| 17 |
+
import utils
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 21 |
+
from Dataloader.dataLoader import *
|
| 22 |
+
|
| 23 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 24 |
+
import yaml
|
| 25 |
+
import argparse
|
| 26 |
+
|
| 27 |
+
####################
|
| 28 |
+
import torch.multiprocessing as mp
|
| 29 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 30 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 31 |
+
import torch.distributed as dist
|
| 32 |
+
# from torch.distributed import init_process_group
|
| 33 |
+
###############
|
| 34 |
+
def ddp_setup(rank, world_size):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
rank: Unique identifier of each process
|
| 38 |
+
world_size: Total number of processes
|
| 39 |
+
"""
|
| 40 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 41 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 42 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 43 |
+
torch.cuda.set_device(rank)
|
| 44 |
+
|
| 45 |
+
use_distributed = True
|
| 46 |
+
# use_distributed = False
|
| 47 |
+
|
| 48 |
+
EPS = 1e-5
|
| 49 |
+
MSK_EPS = 0.01
|
| 50 |
+
TEXT_EMBED_PROB = 0.7
|
| 51 |
+
AUG_RESAMPLE_PROB = 0.6
|
| 52 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 3.0] # [ang, dist, reg]
|
| 53 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 54 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.2, 1e3] # [imgsim, imgmse, ddf]
|
| 55 |
+
|
| 56 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 57 |
+
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
|
| 60 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--config",
|
| 63 |
+
"-C",
|
| 64 |
+
help="Path for the config file",
|
| 65 |
+
type=str,
|
| 66 |
+
# default="Config/config_cmr.yaml",
|
| 67 |
+
# default="Config/config_lct.yaml",
|
| 68 |
+
default="Config/config_all.yaml",
|
| 69 |
+
required=False,
|
| 70 |
+
)
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
+
#=======================================================================================================================
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 77 |
+
if use_distributed:
|
| 78 |
+
ddp_setup(rank,world_size)
|
| 79 |
+
|
| 80 |
+
if torch.distributed.is_initialized():
|
| 81 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 82 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 83 |
+
gpu_id = rank
|
| 84 |
+
|
| 85 |
+
# Load the YAML file into a dictionary
|
| 86 |
+
with open(args.config, 'r') as file:
|
| 87 |
+
hyp_parameters = yaml.safe_load(file)
|
| 88 |
+
print(hyp_parameters)
|
| 89 |
+
|
| 90 |
+
# epoch_per_save=10
|
| 91 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 92 |
+
|
| 93 |
+
data_name=hyp_parameters['data_name']
|
| 94 |
+
net_name = hyp_parameters['net_name']
|
| 95 |
+
|
| 96 |
+
Net=get_net(net_name)
|
| 97 |
+
|
| 98 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 99 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 100 |
+
model_dir=model_save_path
|
| 101 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 102 |
+
|
| 103 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 104 |
+
|
| 105 |
+
# tsfm = torchvision.transforms.Compose([
|
| 106 |
+
# torchvision.transforms.ToTensor(),
|
| 107 |
+
# ])
|
| 108 |
+
|
| 109 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 110 |
+
# train_loader = DataLoader(
|
| 111 |
+
# dataset,
|
| 112 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 113 |
+
# # shuffle=False,
|
| 114 |
+
# shuffle=True,
|
| 115 |
+
# drop_last=True,
|
| 116 |
+
# )
|
| 117 |
+
|
| 118 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 119 |
+
dataset = OMDataset_indiv(transform=None)
|
| 120 |
+
train_loader = DataLoader(
|
| 121 |
+
dataset,
|
| 122 |
+
batch_size=hyp_parameters['batchsize'],
|
| 123 |
+
shuffle=True,
|
| 124 |
+
drop_last=True,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 128 |
+
datasetp = OMDataset_pair(transform=None)
|
| 129 |
+
train_loader_p = DataLoader(
|
| 130 |
+
datasetp,
|
| 131 |
+
batch_size=hyp_parameters['batchsize']//2,
|
| 132 |
+
shuffle=True,
|
| 133 |
+
drop_last=True,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
Deformddpm = DeformDDPM(
|
| 139 |
+
network=Net(
|
| 140 |
+
n_steps=hyp_parameters["timesteps"],
|
| 141 |
+
ndims=hyp_parameters["ndims"],
|
| 142 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 143 |
+
res = hyp_parameters['img_size']
|
| 144 |
+
),
|
| 145 |
+
n_steps=hyp_parameters["timesteps"],
|
| 146 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 147 |
+
device=hyp_parameters["device"],
|
| 148 |
+
batch_size=hyp_parameters["batchsize"],
|
| 149 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 150 |
+
v_scale=hyp_parameters["v_scale"],
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
ddf_stn = STN(
|
| 155 |
+
img_sz=hyp_parameters["img_size"],
|
| 156 |
+
ndims=hyp_parameters["ndims"],
|
| 157 |
+
# padding_mode="zeros",
|
| 158 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 159 |
+
device=hyp_parameters["device"],
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if use_distributed:
|
| 164 |
+
Deformddpm.to(rank)
|
| 165 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 166 |
+
ddf_stn.to(rank)
|
| 167 |
+
else:
|
| 168 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 169 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 170 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# mse = nn.MSELoss()
|
| 174 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 175 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e2)
|
| 176 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e2)
|
| 177 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 178 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 179 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 180 |
+
loss_imgsim = losses.LNCC()
|
| 181 |
+
loss_imgmse = losses.LMSE()
|
| 182 |
+
|
| 183 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 184 |
+
# hyp_parameters["lr"]=0.00000001
|
| 185 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 186 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 187 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 188 |
+
|
| 189 |
+
# # LR scheduler ----- YHM
|
| 190 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 191 |
+
|
| 192 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 193 |
+
|
| 194 |
+
# check for existing models
|
| 195 |
+
if not os.path.exists(model_dir):
|
| 196 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 197 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 198 |
+
model_files.sort()
|
| 199 |
+
if model_files:
|
| 200 |
+
if gpu_id == 0:
|
| 201 |
+
print(model_files)
|
| 202 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
|
| 203 |
+
else:
|
| 204 |
+
initial_epoch = 0
|
| 205 |
+
|
| 206 |
+
if gpu_id == 0:
|
| 207 |
+
print('len_train_data: ',len(dataset))
|
| 208 |
+
# Training loop
|
| 209 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 210 |
+
|
| 211 |
+
epoch_loss_tot = 0.0
|
| 212 |
+
epoch_loss_gen_d = 0.0
|
| 213 |
+
epoch_loss_gen_a = 0.0
|
| 214 |
+
epoch_loss_reg = 0.0
|
| 215 |
+
epoch_loss_regist = 0.0
|
| 216 |
+
epoch_loss_imgsim = 0.0
|
| 217 |
+
epoch_loss_imgmse = 0.0
|
| 218 |
+
epoch_loss_ddfreg = 0.0
|
| 219 |
+
# Set model inside to train model
|
| 220 |
+
Deformddpm.train()
|
| 221 |
+
|
| 222 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 223 |
+
|
| 224 |
+
for step, batch in tqdm(enumerate(train_loader)):
|
| 225 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 226 |
+
|
| 227 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 228 |
+
# x0, _ = batch
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ==========================================================================
|
| 232 |
+
# diffusion train on single image
|
| 233 |
+
|
| 234 |
+
# x0 = batch # for omni dataset
|
| 235 |
+
[x0,embd] = batch # for om dataset
|
| 236 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 237 |
+
# print('embd:', embd.shape)
|
| 238 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 239 |
+
embd = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 240 |
+
else:
|
| 241 |
+
embd = None
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
n = x0.size()[0] # batch_size -> n
|
| 246 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 247 |
+
|
| 248 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 249 |
+
|
| 250 |
+
# random deformation + rotation
|
| 251 |
+
if hyp_parameters["ndims"]>2:
|
| 252 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 253 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 254 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 255 |
+
else:
|
| 256 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 257 |
+
x0 = transformer(x0)
|
| 258 |
+
if hyp_parameters['noise_scale']>0:
|
| 259 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 260 |
+
x0 = thresh_img(x0, [0, 1*hyp_parameters['noise_scale']])
|
| 261 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 262 |
+
|
| 263 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 264 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 265 |
+
hyp_parameters["device"]
|
| 266 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 267 |
+
|
| 268 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 269 |
+
proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
|
| 270 |
+
# print('proc_type:', proc_type)
|
| 271 |
+
cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
|
| 272 |
+
|
| 273 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 274 |
+
|
| 275 |
+
loss_tot=0
|
| 276 |
+
|
| 277 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 278 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 279 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 280 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 281 |
+
|
| 282 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 283 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 284 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 285 |
+
|
| 286 |
+
# >> JZ: print nan in x0
|
| 287 |
+
if torch.isnan(x0).any():
|
| 288 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 289 |
+
# >> JZ: print loss of ddf
|
| 290 |
+
if loss_ddf>0.001:
|
| 291 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 292 |
+
# yu: check if loss_tot==nan or inf
|
| 293 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 294 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 295 |
+
loss_nan_step += 1
|
| 296 |
+
continue
|
| 297 |
+
if loss_nan_step > 5:
|
| 298 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 299 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
optimizer.zero_grad()
|
| 303 |
+
loss_tot.backward()
|
| 304 |
+
optimizer.step()
|
| 305 |
+
|
| 306 |
+
epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
|
| 307 |
+
epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
|
| 308 |
+
epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
|
| 309 |
+
epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
|
| 310 |
+
|
| 311 |
+
# print(loss_gen_a.item())
|
| 312 |
+
# if 0:
|
| 313 |
+
# if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
|
| 314 |
+
if step%train_mode_ratio == 0:
|
| 315 |
+
# ==========================================================================
|
| 316 |
+
# registration train on paired images
|
| 317 |
+
# x1, y1 = next(iter(train_loader_p))
|
| 318 |
+
[x1, y1, _, embd_y] = next(iter(train_loader_p))
|
| 319 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 320 |
+
# embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
|
| 321 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 322 |
+
else:
|
| 323 |
+
# embd_x = None
|
| 324 |
+
embd_y = None
|
| 325 |
+
|
| 326 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 327 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 328 |
+
n = x1.size()[0] # batch_size -> n
|
| 329 |
+
# random deformation + rotation
|
| 330 |
+
# if hyp_parameters["ndims"]>2:
|
| 331 |
+
# if np.random.uniform(0,1)<0.6:
|
| 332 |
+
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 333 |
+
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 334 |
+
x1 = transformer(x1)
|
| 335 |
+
y1 = transformer(y1)
|
| 336 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 337 |
+
if hyp_parameters['noise_scale']>0:
|
| 338 |
+
x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
|
| 339 |
+
x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 340 |
+
y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
|
| 341 |
+
y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 342 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 343 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 344 |
+
hyp_parameters["device"]
|
| 345 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
scale_regist = np.random.uniform(0.6,1.)
|
| 349 |
+
T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
|
| 350 |
+
# scale_regist = np.random.uniform(0.4,1.)
|
| 351 |
+
# T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
|
| 352 |
+
# scale_regist = np.random.uniform(0.6,1.)
|
| 353 |
+
# init_T = int(hyp_parameters["timesteps"] * scale_regist)
|
| 354 |
+
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
|
| 355 |
+
|
| 356 |
+
T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
|
| 357 |
+
|
| 358 |
+
# print('T_regist:', T_regist)
|
| 359 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 360 |
+
proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'none'])
|
| 361 |
+
# proc_type = random.choice(['project'])
|
| 362 |
+
y1, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 363 |
+
msk_tgt = msk_tgt + MSK_EPS
|
| 364 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 365 |
+
loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
|
| 366 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 367 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
|
| 368 |
+
|
| 369 |
+
loss_regist = 0
|
| 370 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 371 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 372 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 373 |
+
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 374 |
+
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 375 |
+
|
| 376 |
+
# >> JZ: print nan in x0
|
| 377 |
+
if torch.isnan(x0).any():
|
| 378 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 379 |
+
# >> JZ: print loss of ddf
|
| 380 |
+
if loss_ddf1>0.001:
|
| 381 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 382 |
+
|
| 383 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 384 |
+
optimizer.zero_grad()
|
| 385 |
+
loss_regist.backward()
|
| 386 |
+
|
| 387 |
+
# # Print gradients for each parameter
|
| 388 |
+
# for name, param in Deformddpm.named_parameters():
|
| 389 |
+
# if param.grad is not None:
|
| 390 |
+
# print(f"Gradient for {name}: {param.grad.norm()}")
|
| 391 |
+
# else:
|
| 392 |
+
# print(f"Gradient for {name}: None")
|
| 393 |
+
|
| 394 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 395 |
+
optimizer.step()
|
| 396 |
+
|
| 397 |
+
epoch_loss_regist += loss_regist.item() * len(x0) / len(train_loader.dataset)
|
| 398 |
+
epoch_loss_imgsim += loss_sim.item() * len(x0) / len(train_loader.dataset)
|
| 399 |
+
epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
|
| 400 |
+
epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 405 |
+
|
| 406 |
+
# break # FOR TESTING
|
| 407 |
+
# else:
|
| 408 |
+
# print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
|
| 409 |
+
# pass
|
| 410 |
+
|
| 411 |
+
if 1:
|
| 412 |
+
# if gpu_id == 0:
|
| 413 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 414 |
+
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
|
| 415 |
+
|
| 416 |
+
# # LR schedular step ----- YHM
|
| 417 |
+
# scheduler.step()
|
| 418 |
+
|
| 419 |
+
if 0 == epoch % epoch_per_save:
|
| 420 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 421 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 422 |
+
# break # FOR TESTING
|
| 423 |
+
if not use_distributed:
|
| 424 |
+
print(f"saved in {save_dir}")
|
| 425 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 426 |
+
torch.save({
|
| 427 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 428 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 429 |
+
'epoch': epoch
|
| 430 |
+
}, save_dir)
|
| 431 |
+
elif gpu_id == 0:
|
| 432 |
+
print(f"saved in {save_dir}")
|
| 433 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 434 |
+
torch.save({
|
| 435 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 436 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 437 |
+
'epoch': epoch
|
| 438 |
+
}, save_dir)
|
| 439 |
+
|
| 440 |
+
# Resource cleanup at the end of training
|
| 441 |
+
torch.cuda.empty_cache()
|
| 442 |
+
gc.collect()
|
| 443 |
+
if use_distributed and dist.is_initialized():
|
| 444 |
+
dist.destroy_process_group()
|
| 445 |
+
|
| 446 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
|
| 447 |
+
|
| 448 |
+
if gpu_id == 0:
|
| 449 |
+
# if 0:
|
| 450 |
+
utils.print_memory_usage("Before Loading Model")
|
| 451 |
+
if 1:
|
| 452 |
+
gc.collect()
|
| 453 |
+
torch.cuda.empty_cache()
|
| 454 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 455 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 456 |
+
checkpoint = torch.load(model_file)
|
| 457 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 458 |
+
if use_distributed:
|
| 459 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
|
| 460 |
+
else:
|
| 461 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 462 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 463 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 464 |
+
|
| 465 |
+
if use_distributed:
|
| 466 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 467 |
+
dist.barrier()
|
| 468 |
+
for param in Deformddpm.parameters():
|
| 469 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 470 |
+
dist.barrier()
|
| 471 |
+
for param_group in optimizer.param_groups:
|
| 472 |
+
for param in param_group['params']:
|
| 473 |
+
if param.grad is not None:
|
| 474 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 475 |
+
|
| 476 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 477 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 478 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 479 |
+
|
| 480 |
+
return initial_epoch, Deformddpm, optimizer
|
| 481 |
+
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
if __name__ == "__main__":
|
| 485 |
+
if use_distributed:
|
| 486 |
+
world_size = torch.cuda.device_count()
|
| 487 |
+
print(f"Distributed GPU number = {world_size}")
|
| 488 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 489 |
+
else:
|
| 490 |
+
main_train(0,1)
|
OM_train_uncon.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision.utils import save_image
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torch.optim import Adam, SGD
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import Diffusion.losses as losses
|
| 12 |
+
import random
|
| 13 |
+
import glob
|
| 14 |
+
import numpy as np
|
| 15 |
+
import utils
|
| 16 |
+
|
| 17 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 18 |
+
|
| 19 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 20 |
+
import yaml
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
####################
|
| 24 |
+
import torch.multiprocessing as mp
|
| 25 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 26 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 27 |
+
from torch.distributed import init_process_group, destroy_process_group
|
| 28 |
+
###############
|
| 29 |
+
def ddp_setup(rank, world_size):
|
| 30 |
+
"""
|
| 31 |
+
Args:
|
| 32 |
+
rank: Unique identifier of each process
|
| 33 |
+
world_size: Total number of processes
|
| 34 |
+
"""
|
| 35 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 36 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 37 |
+
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 38 |
+
torch.cuda.set_device(rank)
|
| 39 |
+
|
| 40 |
+
use_parallel=False
|
| 41 |
+
use_distributed = False
|
| 42 |
+
|
| 43 |
+
EPS = 1e-5
|
| 44 |
+
|
| 45 |
+
parser = argparse.ArgumentParser()
|
| 46 |
+
|
| 47 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--config",
|
| 50 |
+
"-C",
|
| 51 |
+
help="Path for the config file",
|
| 52 |
+
type=str,
|
| 53 |
+
default="Config/config_cmr.yaml",
|
| 54 |
+
# default="Config/config_lct.yaml",
|
| 55 |
+
required=False,
|
| 56 |
+
)
|
| 57 |
+
args = parser.parse_args()
|
| 58 |
+
#=======================================================================================================================
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main_train(rank,world_size):
|
| 63 |
+
|
| 64 |
+
ddp_setup(rank,world_size)
|
| 65 |
+
gpu_id = rank
|
| 66 |
+
|
| 67 |
+
# Load the YAML file into a dictionary
|
| 68 |
+
with open(args.config, 'r') as file:
|
| 69 |
+
hyp_parameters = yaml.safe_load(file)
|
| 70 |
+
print(hyp_parameters)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# epoch_per_save=10
|
| 75 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 76 |
+
|
| 77 |
+
data_name=hyp_parameters['data_name']
|
| 78 |
+
net_name = hyp_parameters['net_name']
|
| 79 |
+
|
| 80 |
+
Net=get_net(net_name)
|
| 81 |
+
|
| 82 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 83 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 84 |
+
model_dir=model_save_path
|
| 85 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 86 |
+
Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 87 |
+
|
| 88 |
+
tsfm = torchvision.transforms.Compose([
|
| 89 |
+
torchvision.transforms.ToTensor(),
|
| 90 |
+
])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 94 |
+
train_loader = DataLoader(
|
| 95 |
+
dataset,
|
| 96 |
+
batch_size=hyp_parameters['batchsize'],
|
| 97 |
+
# shuffle=False,
|
| 98 |
+
shuffle=True,
|
| 99 |
+
drop_last=True,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
Deformddpm = DeformDDPM(
|
| 105 |
+
network=Net(n_steps=hyp_parameters["timesteps"], ndims=hyp_parameters["ndims"], num_input_chn=1),
|
| 106 |
+
n_steps=hyp_parameters["timesteps"],
|
| 107 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 108 |
+
device=hyp_parameters["device"],
|
| 109 |
+
batch_size=hyp_parameters["batchsize"],
|
| 110 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 111 |
+
v_scale=hyp_parameters["v_scale"],
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
ddf_stn = STN(
|
| 116 |
+
img_sz=hyp_parameters["img_size"],
|
| 117 |
+
ndims=hyp_parameters["ndims"],
|
| 118 |
+
# padding_mode="zeros",
|
| 119 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 120 |
+
device=hyp_parameters["device"],
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
# Deformddpm.to(hyp_parameters["device"])
|
| 124 |
+
# ddf_stn.to(hyp_parameters["device"])
|
| 125 |
+
|
| 126 |
+
# if use_distributed:
|
| 127 |
+
# torch.distributed.init_process_group(backend='nccl')
|
| 128 |
+
# Deformddpm = nn.parallel.DistributedDataParallel(Deformddpm, device_ids=[torch.cuda.current_device()])
|
| 129 |
+
# ddf_stn = nn.parallel.DistributedDataParallel(ddf_stn, device_ids=[torch.cuda.current_device()])
|
| 130 |
+
# elif use_parallel:
|
| 131 |
+
# Deformddpm = nn.DataParallel(Deformddpm)
|
| 132 |
+
# ddf_stn = nn.DataParallel(ddf_stn)
|
| 133 |
+
|
| 134 |
+
Deformddpm.to(rank)
|
| 135 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 136 |
+
ddf_stn.to(rank)
|
| 137 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# mse = nn.MSELoss()
|
| 141 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 142 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 143 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 144 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 145 |
+
|
| 146 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 147 |
+
# hyp_parameters["lr"]=0.00000001
|
| 148 |
+
# # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.95)
|
| 149 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 150 |
+
|
| 151 |
+
# # LR scheduler ----- YHM
|
| 152 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 153 |
+
|
| 154 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 155 |
+
|
| 156 |
+
# check for existing models
|
| 157 |
+
if not os.path.exists(model_dir):
|
| 158 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 159 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 160 |
+
model_files.sort()
|
| 161 |
+
print(model_files)
|
| 162 |
+
if model_files:
|
| 163 |
+
# if there are any model files, load the most recent one
|
| 164 |
+
latest_model_file = model_files[-1]
|
| 165 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 166 |
+
if use_parallel:
|
| 167 |
+
Deformddpm.module.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 168 |
+
else:
|
| 169 |
+
Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 170 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 171 |
+
initial_epoch = int(os.path.basename(latest_model_file).split('.')[0][:6]) + 1
|
| 172 |
+
else:
|
| 173 |
+
initial_epoch = 0
|
| 174 |
+
print('len_train_data: ',len(dataset))
|
| 175 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 176 |
+
|
| 177 |
+
epoch_loss_tot = 0.0
|
| 178 |
+
epoch_loss_gen_d = 0.0
|
| 179 |
+
epoch_loss_gen_a = 0.0
|
| 180 |
+
epoch_loss_reg = 0.0
|
| 181 |
+
# Set model inside to train model
|
| 182 |
+
Deformddpm.train()
|
| 183 |
+
|
| 184 |
+
for step, batch in enumerate(train_loader):
|
| 185 |
+
# x0, _ = batch
|
| 186 |
+
x0, _, _ = batch
|
| 187 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 188 |
+
|
| 189 |
+
n = x0.size()[0] # batch_size -> n
|
| 190 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 191 |
+
# random deformation + rotation
|
| 192 |
+
if hyp_parameters["ndims"]>2:
|
| 193 |
+
if np.random.uniform(0,1)<0.6:
|
| 194 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 195 |
+
x0 = transformer(x0)
|
| 196 |
+
if hyp_parameters['noise_scale']>0:
|
| 197 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 198 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 199 |
+
|
| 200 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 201 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 202 |
+
hyp_parameters["device"]
|
| 203 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if use_parallel:
|
| 207 |
+
# # noisy_imgs, dvf_I = ddf_enc(img= x0, t)
|
| 208 |
+
# noisy_imgs, dvf_I,_ = Deformddpm.module.diffuse(x0, t)
|
| 209 |
+
# pre_dvf_I = Deformddpm.backward(noisy_imgs, t.reshape(16, -1))
|
| 210 |
+
pre_dvf_I, _ = Deformddpm.module(x0, t)
|
| 211 |
+
else:
|
| 212 |
+
# # noisy_imgs, dvf_I = ddf_enc(img= x0, t)
|
| 213 |
+
# noisy_imgs, dvf_I,_ = Deformddpm.diffuse(x0, t)
|
| 214 |
+
# pre_dvf_I = Deformddpm.backward(noisy_imgs, t.reshape(16, -1))
|
| 215 |
+
pre_dvf_I,dvf_I = Deformddpm(x0, t)
|
| 216 |
+
|
| 217 |
+
loss_tot=0
|
| 218 |
+
|
| 219 |
+
loss_ddf = loss_reg(pre_dvf_I)
|
| 220 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 221 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
|
| 222 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None)
|
| 223 |
+
|
| 224 |
+
loss_tot += 1.0 * loss_gen_d + 1.0 * loss_gen_a
|
| 225 |
+
loss_tot +=10 * loss_ddf
|
| 226 |
+
optimizer.zero_grad()
|
| 227 |
+
loss_tot.backward()
|
| 228 |
+
optimizer.step()
|
| 229 |
+
|
| 230 |
+
epoch_loss_tot += loss_tot.item() * len(x0) / len(train_loader.dataset)
|
| 231 |
+
epoch_loss_gen_d += loss_gen_d.item() * len(x0) / len(train_loader.dataset)
|
| 232 |
+
epoch_loss_gen_a += loss_gen_a.item() * len(x0) / len(train_loader.dataset)
|
| 233 |
+
epoch_loss_reg += loss_ddf.item() * len(x0) / len(train_loader.dataset)
|
| 234 |
+
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 235 |
+
|
| 236 |
+
if gpu_id == 0:
|
| 237 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 238 |
+
|
| 239 |
+
# # LR schedular step ----- YHM
|
| 240 |
+
# scheduler.step()
|
| 241 |
+
|
| 242 |
+
if 0 == epoch % epoch_per_save:
|
| 243 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 244 |
+
if os.path.exists(model_save_path):
|
| 245 |
+
print(f"saved in {save_dir}")
|
| 246 |
+
else:
|
| 247 |
+
os.makedirs(os.path.dirname(model_save_path))
|
| 248 |
+
# break # FOR TESTING
|
| 249 |
+
if use_parallel:
|
| 250 |
+
torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 251 |
+
elif gpu_id == 0:
|
| 252 |
+
torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
world_size = torch.cuda.device_count()
|
| 257 |
+
print(f"world size = {world_size}")
|
| 258 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OmniMorph: Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on conditional Deformation-Recovery Diffusion Model
|
| 2 |
+
|
| 3 |
+
## Environment
|
| 4 |
+
```
|
| 5 |
+
conda activate torch
|
| 6 |
+
conda deactivate
|
| 7 |
+
```
|
| 8 |
+
source /home/data/Github/OmniMorph/ominenv/bin/activate
|
| 9 |
+
|
| 10 |
+
## Masking CUDA
|
| 11 |
+
CUDA_VISIBLE_DEVICES=0,1,3 python ...
|
bash_infer.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
source /home/data/jzheng/Adaptive_Motion_Generator-master/pipenv/bin/activate
|
| 3 |
+
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=2
|
| 5 |
+
# export CUDA_VISIBLE_DEVICES=0
|
| 6 |
+
|
| 7 |
+
# python -u OM_aug.py -C Config/config_om.yaml
|
| 8 |
+
# python -u OM_reg.py -C Config/config_om.yaml
|
| 9 |
+
nohup python -u OM_aug.py -C Config/config_om.yaml > aug_log.txt 2>&1 &
|
bash_train.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
source /home/data/jzheng/Adaptive_Motion_Generator-master/pipenv/bin/activate
|
| 3 |
+
|
| 4 |
+
export CUDA_VISIBLE_DEVICES=3
|
| 5 |
+
# export CUDA_VISIBLE_DEVICES=1,3
|
| 6 |
+
# export CUDA_VISIBLE_DEVICES=1,2,3
|
| 7 |
+
# # python -u OM_train.py -C Config/config_lct.yaml
|
| 8 |
+
# nohup python -u OM_train.py -C Config/config_lct.yaml > train_log.txt 2>&1 &
|
| 9 |
+
|
| 10 |
+
# python -u OM_train_2modes.py -C Config/config_om.yaml
|
| 11 |
+
nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &
|
| 12 |
+
# nohup python -u OM_train.py -C Config/config_om.yaml > train_log.txt 2>&1 &
|
dataloader_tester.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchvision.utils import save_image
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from torch.optim import Adam, SGD
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import Diffusion.losses as losses
|
| 12 |
+
import random
|
| 13 |
+
import glob
|
| 14 |
+
import numpy as np
|
| 15 |
+
import utils
|
| 16 |
+
|
| 17 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 18 |
+
from Dataloader.dataLoader import *
|
| 19 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 20 |
+
import yaml
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
tsfm = torchvision.transforms.Compose(
|
| 24 |
+
[
|
| 25 |
+
torchvision.transforms.ToTensor(),
|
| 26 |
+
]
|
| 27 |
+
)
|
| 28 |
+
Data_Loader=get_dataloader(data_name = 'lct', mode='train')
|
| 29 |
+
|
| 30 |
+
dataset = Data_Loader(
|
| 31 |
+
target_res=[128] * 3,
|
| 32 |
+
transforms=None,
|
| 33 |
+
noise_scale=4.0e-05,
|
| 34 |
+
)
|
| 35 |
+
train_loader = DataLoader(
|
| 36 |
+
dataset,
|
| 37 |
+
batch_size=32,
|
| 38 |
+
# shuffle=False,
|
| 39 |
+
shuffle=True,
|
| 40 |
+
drop_last=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
dataset2 = OminiDataset_v1(transform=None)
|
| 45 |
+
train_loader2 = DataLoader(dataset2, batch_size=32, shuffle=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
dataset = OminiDataset_paired(transform=None, ROIs = ['leg'])
|
| 49 |
+
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 50 |
+
# print(dataset.get_all_ROI())
|
| 51 |
+
# print(dataset.getitem())
|
| 52 |
+
# print(dataset.get_ALLdata())
|
| 53 |
+
# print(dataset.getitem(idx=11))
|
| 54 |
+
# exit()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
for i, batch in enumerate(train_loader):
|
| 59 |
+
x0, x1 = batch
|
| 60 |
+
print(x0.shape,x1.shape)
|
| 61 |
+
print(x0.dtype,x1.dtype)
|
| 62 |
+
print(x0.min(),x0.max())
|
| 63 |
+
break
|
| 64 |
+
exit()
|
| 65 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
certifi==2022.12.7
|
| 2 |
+
charset-normalizer==2.1.1
|
| 3 |
+
contourpy==1.1.1
|
| 4 |
+
cycler==0.12.1
|
| 5 |
+
einops==0.3.2
|
| 6 |
+
elasticdeform==0.5.0
|
| 7 |
+
filelock==3.16.1
|
| 8 |
+
fonttools==4.49.0
|
| 9 |
+
fsspec==2025.3.0
|
| 10 |
+
hausdorff==0.2.6
|
| 11 |
+
huggingface-hub==0.29.3
|
| 12 |
+
idna==3.4
|
| 13 |
+
imageio==2.34.0
|
| 14 |
+
importlib_metadata==7.1.0
|
| 15 |
+
importlib_resources==6.1.2
|
| 16 |
+
joblib==1.4.0
|
| 17 |
+
kiwisolver==1.4.5
|
| 18 |
+
lazy_loader==0.3
|
| 19 |
+
llvmlite==0.41.1
|
| 20 |
+
matplotlib==3.7.5
|
| 21 |
+
networkx==3.1
|
| 22 |
+
nibabel==5.1.0
|
| 23 |
+
nptyping==2.5.0
|
| 24 |
+
numba==0.58.1
|
| 25 |
+
numpy==1.24.1
|
| 26 |
+
opencv-python==4.9.0.80
|
| 27 |
+
packaging==23.2
|
| 28 |
+
pandas==2.0.3
|
| 29 |
+
pillow==10.2.0
|
| 30 |
+
pydicom==2.4.4
|
| 31 |
+
pynrrd==1.0.0
|
| 32 |
+
pyparsing==3.1.1
|
| 33 |
+
pyquaternion==0.9.9
|
| 34 |
+
python-dateutil==2.8.2
|
| 35 |
+
pytz==2025.2
|
| 36 |
+
PyWavelets==1.4.1
|
| 37 |
+
PyYAML==6.0.2
|
| 38 |
+
regex==2024.11.6
|
| 39 |
+
requests==2.28.1
|
| 40 |
+
safetensors==0.5.3
|
| 41 |
+
scikit-image==0.21.0
|
| 42 |
+
scikit-learn==1.3.2
|
| 43 |
+
scipy==1.9.3
|
| 44 |
+
SimpleITK==2.3.1
|
| 45 |
+
six==1.16.0
|
| 46 |
+
threadpoolctl==3.5.0
|
| 47 |
+
tifffile==2023.7.10
|
| 48 |
+
tokenizers==0.20.3
|
| 49 |
+
torch==1.12.1+cu113
|
| 50 |
+
torchaudio==0.12.1+cu113
|
| 51 |
+
torchvision==0.13.1+cu113
|
| 52 |
+
tqdm==4.66.2
|
| 53 |
+
transformers==4.46.3
|
| 54 |
+
typing_extensions==4.8.0
|
| 55 |
+
tzdata==2025.2
|
| 56 |
+
urllib3==1.26.13
|
| 57 |
+
zipp==3.17.0
|
utils.py
ADDED
|
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from torch import nn, optim
|
| 5 |
+
from torch.autograd.variable import Variable
|
| 6 |
+
from torchvision import transforms, datasets
|
| 7 |
+
from torchvision.utils import save_image
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import scipy.ndimage as spimg
|
| 10 |
+
import pyquaternion as quater
|
| 11 |
+
import random
|
| 12 |
+
import numpy as np
|
| 13 |
+
import math
|
| 14 |
+
from typing import Optional, Tuple, List
|
| 15 |
+
import nibabel as nib
|
| 16 |
+
# from data_loader.acdc_dataloader import acdc_gan
|
| 17 |
+
|
| 18 |
+
# from Adaptive_Motion_Generator.Dataloader.Archive.acdc_dataloader import *
|
| 19 |
+
|
| 20 |
+
def get_barcode(index=[],header=['Patient','Slice','AugImg','NoiseStep'],digit=[4,6,4,4],split='_'):
|
| 21 |
+
# Patient0001_Slice0001_NosieImg0001_NoiseStep0070
|
| 22 |
+
barcode_str=''
|
| 23 |
+
header=header.copy()
|
| 24 |
+
digit=digit.copy()
|
| 25 |
+
if len(index)<3:
|
| 26 |
+
header[2] = 'ORG'
|
| 27 |
+
header[3] = 'NA'
|
| 28 |
+
digit[2] = 0
|
| 29 |
+
digit[3] = 0
|
| 30 |
+
index +=['','']
|
| 31 |
+
|
| 32 |
+
for id, h in enumerate(header):
|
| 33 |
+
barcode_str+=h+str(index[id]).zfill(digit[id])+split
|
| 34 |
+
return barcode_str[:-1]
|
| 35 |
+
|
| 36 |
+
class RandomResizedCrop3D(nn.Module):
|
| 37 |
+
"""Crop a random portion of a 3D volume and resize it to a given size.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
size (tuple of int): Expected output size of the crop, for each dimension (D, H, W).
|
| 41 |
+
scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop,
|
| 42 |
+
before resizing. The scale is defined with respect to the volume of the original image.
|
| 43 |
+
ratio (tuple of float): Lower and upper bounds for the random aspect ratio of the crop, before resizing.
|
| 44 |
+
interpolation (str): Desired interpolation mode ('trilinear' or 'nearest').
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
size: Tuple[int, int, int],
|
| 50 |
+
scale=(0.6, 1.0),
|
| 51 |
+
ratio=(0.5, 1.5),
|
| 52 |
+
interpolation='trilinear'
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.size = size
|
| 56 |
+
self.scale = scale
|
| 57 |
+
self.ratio = ratio
|
| 58 |
+
self.interpolation = interpolation
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def get_params(img: torch.Tensor, rand_scale: float, scale: List[float], ratio: List[float]) -> Tuple[int, int, int, int, int, int]:
|
| 62 |
+
"""Get parameters for `crop` for a random sized crop.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
img (Tensor): Input image.
|
| 66 |
+
scale (list): Range of scale of the origin size cropped.
|
| 67 |
+
ratio (list): Range of aspect ratio of the origin aspect ratio cropped.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
tuple: params (i, j, k, d, h, w) to be passed to `crop` for a random sized crop.
|
| 71 |
+
"""
|
| 72 |
+
img_sz = np.array(list(img.size())[2:])
|
| 73 |
+
crop_sz = (img_sz * rand_scale).astype(np.int32) #[int(s*rand_scale) for s in img_sz]
|
| 74 |
+
start_id = np.random.randint(0, img_sz - crop_sz + 1, size=(img_sz.size,))
|
| 75 |
+
return start_id.tolist()+crop_sz.tolist()
|
| 76 |
+
|
| 77 |
+
# volume = depth * height * width
|
| 78 |
+
#
|
| 79 |
+
# log_ratio = torch.log(torch.tensor(ratio))
|
| 80 |
+
# for _ in range(10):
|
| 81 |
+
# target_volume = volume * torch.empty(1).uniform_(*scale).item()
|
| 82 |
+
# aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
|
| 83 |
+
#
|
| 84 |
+
# w = int(round(math.sqrt(target_volume * aspect_ratio)))
|
| 85 |
+
# h = int(round(math.sqrt(target_volume / aspect_ratio)))
|
| 86 |
+
# d = int(round(math.sqrt(target_volume / (w * h))))
|
| 87 |
+
#
|
| 88 |
+
# if 0 < w <= width and 0 < h <= height and 0 < d <= depth:
|
| 89 |
+
# i = torch.randint(0, depth - d + 1, size=(1,)).item()
|
| 90 |
+
# j = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 91 |
+
# k = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 92 |
+
# return i, j, k, d, h, w
|
| 93 |
+
#
|
| 94 |
+
# # Fallback to central crop
|
| 95 |
+
# return (depth - d) // 2, (height - h) // 2, (width - w) // 2, d, h, w
|
| 96 |
+
|
| 97 |
+
def forward(self, img: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
"""Apply the RandomResizedCrop transformation.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
img (Tensor): Input 3D image.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Tensor: Cropped and resized image.
|
| 105 |
+
"""
|
| 106 |
+
rand_scale = np.random.uniform(self.scale[0], self.scale[1])
|
| 107 |
+
[i, j, k, d, h, w] = self.get_params(img,rand_scale, self.scale, self.ratio)
|
| 108 |
+
# print(i, j, k, d, h, w)
|
| 109 |
+
img_cropped = img[:, :, i:i + d, j:j + h, k:k + w]
|
| 110 |
+
# print(img_cropped.shape)
|
| 111 |
+
img_resized = F.interpolate(img_cropped, size=self.size, mode=self.interpolation,
|
| 112 |
+
align_corners=False if self.interpolation == 'trilinear' else None)
|
| 113 |
+
return img_resized#.squeeze(0)
|
| 114 |
+
|
| 115 |
+
def __repr__(self) -> str:
|
| 116 |
+
return f"{self.__class__.__name__}(size={self.size}, scale={self.scale}, ratio={self.ratio}, interpolation={self.interpolation})"
|
| 117 |
+
|
| 118 |
+
def random_permute(X, select_dims=[-1,-2],include_flip=True):
|
| 119 |
+
axes=list(range(X[0].ndim))
|
| 120 |
+
selected_axes = [axes[i] for i in select_dims]
|
| 121 |
+
random.shuffle(selected_axes)
|
| 122 |
+
for i, dim in enumerate(select_dims):
|
| 123 |
+
axes[dim] = selected_axes[i]
|
| 124 |
+
if include_flip and random.choice([True,False]):
|
| 125 |
+
# X = [np.flip(x, axis=dim) for x in X]
|
| 126 |
+
X = [torch.flip(x, [dim]) for x in X]
|
| 127 |
+
# return [np.transpose(x,axes) for x in X]
|
| 128 |
+
return [x.permute(axes) for x in X]
|
| 129 |
+
|
| 130 |
+
# def thresh_img(img,thresh = None,EPS = 10**-7):
|
| 131 |
+
# threshold0 = np.random.uniform(thresh[0], thresh[1])
|
| 132 |
+
# threshold1 = np.random.uniform(thresh[0], thresh[1])
|
| 133 |
+
# scale =
|
| 134 |
+
# if threshold is not None:
|
| 135 |
+
# # img=img-threshold
|
| 136 |
+
# # img=np.where(img>=0,img,0)
|
| 137 |
+
# # img = np.maximum(img-threshold,0)
|
| 138 |
+
# img = torch.maximum(img - threshold,torch.tensor(0.))
|
| 139 |
+
# # return (img - img.min()) / (img.max() - img.min() + EPS)
|
| 140 |
+
# return img
|
| 141 |
+
|
| 142 |
+
def get_transformer(degrees=180,translate=0.125,ndims=2,prob=0.8,fill=0.,img_sz=None):
|
| 143 |
+
prob_crop=0. if img_sz==None else 0.8
|
| 144 |
+
# prob_crop=0. if len(img_sz)==2 else 0.8
|
| 145 |
+
|
| 146 |
+
if img_sz==None or len(img_sz)==2:
|
| 147 |
+
return torchvision.transforms.Compose([
|
| 148 |
+
torchvision.transforms.RandomApply([
|
| 149 |
+
torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims, fill=fill,
|
| 150 |
+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR),
|
| 151 |
+
],prob),
|
| 152 |
+
# torchvision.transforms.RandomApply([
|
| 153 |
+
# torchvision.transforms.RandomResizedCrop(size=img_sz),
|
| 154 |
+
# ], prob_crop),
|
| 155 |
+
torchvision.transforms.RandomVerticalFlip(p=0.5),
|
| 156 |
+
torchvision.transforms.RandomAutocontrast(p=0.5),
|
| 157 |
+
])
|
| 158 |
+
else:
|
| 159 |
+
return torchvision.transforms.Compose([
|
| 160 |
+
torchvision.transforms.RandomApply([
|
| 161 |
+
torchvision.transforms.RandomResizedCrop(size=img_sz) if len(img_sz) == 2 else RandomResizedCrop3D(
|
| 162 |
+
size=img_sz),
|
| 163 |
+
], prob_crop),
|
| 164 |
+
])
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def get_random_affine_transformer(degrees=180,translate=0.125,ndims=2):
|
| 168 |
+
return torchvision.transforms.RandomAffine(degrees=degrees, translate=[translate] * ndims,interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
|
| 169 |
+
|
| 170 |
+
def channel_merge_acdc(img):
|
| 171 |
+
# input: a torch tensor (C,H,W)
|
| 172 |
+
ch = img.shape[0]
|
| 173 |
+
output = np.zeros((img.shape[1], img.shape[2]))
|
| 174 |
+
# output[img[2,:,:] == 1] = 1
|
| 175 |
+
for i in range(ch):
|
| 176 |
+
output= output + img[i]
|
| 177 |
+
return output
|
| 178 |
+
|
| 179 |
+
def img_crop(img, crop_rate=2, img_sz=[256,256]):
|
| 180 |
+
ndims=len(img_sz)
|
| 181 |
+
crop = [np.random.randint(0.*imgs, 1. * imgs)//crop_rate for imgs in img_sz]
|
| 182 |
+
crop = [crop, [1 * imgs//crop_rate - c for imgs, c in zip(img_sz, crop)]]
|
| 183 |
+
if ndims==2:
|
| 184 |
+
return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]: img_sz[1] - crop[1][1]]
|
| 185 |
+
else:
|
| 186 |
+
return img[..., crop[0][0]: img_sz[0] - crop[1][0], crop[0][1]:img_sz[1] - crop[1][1], crop[0][2]: img_sz[2] - crop[1][2]]
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def boundary_limit(sample_coords0, max_sz, plus=0., minus=1.):
|
| 190 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 191 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 192 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) for x, sz in
|
| 193 |
+
zip(sample_coords, max_sz)], 1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def resample(vol, ddf, ref=None, img_sz=None,max_sz=[128,128],ndims=2):
|
| 197 |
+
device = vol.device
|
| 198 |
+
img_sz = vol.size()[2:]
|
| 199 |
+
ndims=len(img_sz)
|
| 200 |
+
if ndims==2:
|
| 201 |
+
[h,w]=img_sz
|
| 202 |
+
img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=device), [1, 1, 1, ndims])
|
| 203 |
+
ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0), [1, ndims,h, w ])
|
| 204 |
+
elif ndims==3:
|
| 205 |
+
[h, w, d] = img_sz
|
| 206 |
+
img_shape = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2., (d-1)/2], device=device), [1, 1, 1, 1, ndims])
|
| 207 |
+
ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w), torch.arange(end=d)]), 0), [1, ndims,h, w, d])
|
| 208 |
+
# ref_grid.to(device)
|
| 209 |
+
# img_shape.to(device)
|
| 210 |
+
# ddf.to(device)
|
| 211 |
+
# ref = self.ref_grid if ref is None else ref
|
| 212 |
+
# img_sz = self.img_sz if img_sz is None else img_sz
|
| 213 |
+
resample_mode = 'bilinear'
|
| 214 |
+
# padding_mode = "border"
|
| 215 |
+
padding_mode = "zeros"
|
| 216 |
+
|
| 217 |
+
# img_sz = np.reshape(img_sz, [1] *(ndims+1)+[ndims])
|
| 218 |
+
# if ndims==2:
|
| 219 |
+
if True:
|
| 220 |
+
re=[0]+list(range(2,ndims+2))+[1]
|
| 221 |
+
# re=list(range(ndims+2))
|
| 222 |
+
# print((torch.flip((ddf.to(device) + ref_grid.permute(re))/ img_shape - 1, dims=[-1])).tolist())
|
| 223 |
+
return F.grid_sample(vol, torch.flip((ddf + ref_grid.permute(re).to(device))/ img_shape - 1, dims=[-1]).type(torch.float32).to(device), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
|
| 224 |
+
#
|
| 225 |
+
# return F.grid_sample(vol, torch.flip(
|
| 226 |
+
# torch.permute(ddf * torch.Tensor(np.reshape(np.array(max_sz), [1, 1, 1, ndims])) + ref_grid,
|
| 227 |
+
# [0, 2, 3, 1]) / img_shape - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 228 |
+
# align_corners=True)
|
| 229 |
+
|
| 230 |
+
def random_resample(vol,deform_scale=32.):
|
| 231 |
+
vol_size=vol.size()
|
| 232 |
+
device=vol.device
|
| 233 |
+
ndims = len(vol_size)-2
|
| 234 |
+
img_size=[s for s in vol_size[2:]]
|
| 235 |
+
if ndims==2:
|
| 236 |
+
img_size=img_size+[16]
|
| 237 |
+
# ddf,_,_=random_ddf(vol_size[0],img_size)
|
| 238 |
+
_,_,ddf=random_ddf(vol_size[0],img_size,ndims=ndims,range_gauss=deform_scale)
|
| 239 |
+
ddf=Variable(torch.tensor(ddf,dtype=torch.float32)).to(device)
|
| 240 |
+
if ndims==2:
|
| 241 |
+
return resample(vol,ddf[...,8,:ndims])
|
| 242 |
+
else:
|
| 243 |
+
return resample(vol, ddf[..., :ndims])
|
| 244 |
+
|
| 245 |
+
def get_random_deformed_mask(msk_shape, deform_scale=32.,apply_possibility=0.75):
|
| 246 |
+
msk = torch.ones([1, 1]+list(msk_shape),dtype=torch.float32)
|
| 247 |
+
if random.uniform(0,1) < apply_possibility:
|
| 248 |
+
return random_resample(msk, deform_scale=deform_scale)
|
| 249 |
+
else:
|
| 250 |
+
return msk
|
| 251 |
+
|
| 252 |
+
# grid option
|
| 253 |
+
def get_tranf_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],transl=[[0,0,0]]):
|
| 254 |
+
return np.concatenate([get_rot_mat(grid_size, vec=vec, ang=ang),transl],-1)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_rot_mat(grid_size, vec=[[0., 0., 1.]], ang=[[0.]],ndims=3):
|
| 258 |
+
vec = np.array(vec)
|
| 259 |
+
ang = np.array(ang)
|
| 260 |
+
batch_num = ang.shape[0]
|
| 261 |
+
return np.reshape(vecang2rotmats(vec, ang), [batch_num] + [ndims*(ndims)])
|
| 262 |
+
|
| 263 |
+
def random_mat(batch_sz, img_sz, num_class=2,pn_spline=20, pn_gauss=10, range_spline=2., range_gauss=48, spread_range=[5., 24.],
|
| 264 |
+
transl_range=32., rot_range=np.pi / 2):
|
| 265 |
+
scale=4
|
| 266 |
+
ndims=3
|
| 267 |
+
vec=np.reshape(np.random.uniform(-1., 1., [batch_sz,1, ndims])+np.random.uniform(-.1, .1, [batch_sz,num_class, ndims]),[batch_sz*num_class, ndims])
|
| 268 |
+
ang=np.reshape(np.random.uniform(-rot_range, rot_range, [batch_sz,1])+np.random.uniform(-rot_range/scale, rot_range/scale, [batch_sz,num_class]),[batch_sz*num_class])
|
| 269 |
+
transl=np.reshape(np.random.uniform(-transl_range, transl_range, [batch_sz,1,ndims])+np.random.uniform(-transl_range/scale, transl_range/scale, [batch_sz,num_class,ndims]),[batch_sz*num_class,ndims])
|
| 270 |
+
return np.reshape(np.concatenate([get_rot_mat(img_sz, vec=vec, ang=ang),transl],-1),[batch_sz,num_class,4,3])
|
| 271 |
+
|
| 272 |
+
# return np.reshape(get_tranf_mat(img_sz, vec=np.random.uniform(-1., 1., [batch_sz*num_class, 3]), ang=np.random.uniform(-rot_range, rot_range, [batch_sz*num_class]),transl=np.random.uniform(-transl_range, transl_range, [batch_sz*num_class,3])),[batch_sz,num_class,4,3])
|
| 273 |
+
|
| 274 |
+
def random_ddf(batch_sz, img_sz, pn_spline=20, pn_gauss=10, range_spline=1., range_gauss=16., spread_range=[16., 64.],
|
| 275 |
+
transl_range=0., rot_range=np.pi / 1,ndims=3):
|
| 276 |
+
rand_ang=np.random.uniform(-rot_range, rot_range, [batch_sz])
|
| 277 |
+
# rand_ang = np.random.randint(-4, 4, [batch_sz])*rot_range
|
| 278 |
+
|
| 279 |
+
if ndims==3:
|
| 280 |
+
rot_df = get_rot_ddf(img_sz, vec=np.random.uniform(-1., 1., [batch_sz, 3]),
|
| 281 |
+
ang=rand_ang)
|
| 282 |
+
else:
|
| 283 |
+
rot_df = get_rot_ddf(img_sz, vec=np.concatenate([np.zeros([batch_sz, 2]),np.ones([batch_sz, 1])],-1),
|
| 284 |
+
ang=rand_ang)
|
| 285 |
+
ndims = 3
|
| 286 |
+
# rot_df = +np.random.uniform(-1., 1., [batch_sz, ndims,ndims])
|
| 287 |
+
# ddf0=np.stack([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range)\
|
| 288 |
+
# +generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline)\
|
| 289 |
+
# +np.random.uniform(-transl_range,transl_range,[3]) for i in range(batch_sz)],axis=0)\
|
| 290 |
+
# +rot_df
|
| 291 |
+
if range_gauss>0:
|
| 292 |
+
ddf0 = np.tile([generate_random_gaussian_ddf(img_sz, pn_gauss, range_sz=range_gauss, spread_std=spread_range) \
|
| 293 |
+
# + generate_random_spline_ddf(img_sz, pn_spline, range_sz=range_spline) \
|
| 294 |
+
+ np.random.uniform(-transl_range, transl_range, [ndims])], [batch_sz, 1, 1, 1, 1]) \
|
| 295 |
+
+ rot_df
|
| 296 |
+
else:
|
| 297 |
+
ddf0 = rot_df
|
| 298 |
+
|
| 299 |
+
def boundary_replicate(sample_coords, input_size, padding=5):
|
| 300 |
+
return np.stack(
|
| 301 |
+
[np.maximum(np.minimum(sample_coords[..., i], input_size[i] - 1 + padding), 0 - padding) for i in
|
| 302 |
+
range(len(input_size))], axis=-1), \
|
| 303 |
+
np.prod([((sample_coords[..., i] < input_size[i]) * (sample_coords[..., i] >= 0)) for i in
|
| 304 |
+
range(len(input_size))], axis=0)
|
| 305 |
+
|
| 306 |
+
ref = get_reference_grid(img_sz)
|
| 307 |
+
cf1, ind = boundary_replicate(ddf0 + ref, img_sz)
|
| 308 |
+
return cf1 - ref, np.expand_dims(ind, -1), rot_df
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def generate_random_gaussian_ddf(img_sz, pn=30, range_sz=5, spread_std=[0.1, 1.]):
|
| 312 |
+
x = np.floor(np.random.uniform(range_sz / 2., img_sz[0] - range_sz / 2., [1, pn])).astype('int')
|
| 313 |
+
y = np.floor(np.random.uniform(range_sz / 2., img_sz[1] - range_sz / 2., [1, pn])).astype('int')
|
| 314 |
+
z = np.floor(np.random.uniform(range_sz / 2., img_sz[2] - range_sz / 2., [1, pn])).astype('int')
|
| 315 |
+
|
| 316 |
+
odf = np.random.uniform(-range_sz, range_sz, [pn, 3])
|
| 317 |
+
vol = np.zeros([img_sz[0], img_sz[1], img_sz[2], 3])
|
| 318 |
+
vol[x, y, z] = odf
|
| 319 |
+
|
| 320 |
+
return spimg.gaussian_filter(vol, np.random.uniform(spread_std[0], spread_std[1]))
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def get_rot_ddf(grid_size, vec=[[0., 0., 1.]], ang=[[0.]]):
|
| 324 |
+
vec = np.array(vec)
|
| 325 |
+
ang = np.array(ang)
|
| 326 |
+
batch_num = ang.shape[0]
|
| 327 |
+
ref_grids = get_reference_grid(grid_size,
|
| 328 |
+
bias_scale=1.)
|
| 329 |
+
# a=vecang2rotmats(vec, ang)
|
| 330 |
+
return np.reshape(np.matmul(np.reshape(np.tile(ref_grids, [batch_num, 1, 1, 1, 1]), [batch_num, -1, 3]),
|
| 331 |
+
vecang2rotmats(vec, ang)), [batch_num] + grid_size + [3]) - ref_grids
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def get_reference_grid(grid_size, bias_scale=0.):
|
| 335 |
+
return np.stack(np.meshgrid(
|
| 336 |
+
[i for i in range(grid_size[0])],
|
| 337 |
+
[j for j in range(grid_size[1])],
|
| 338 |
+
[k for k in range(grid_size[2])],
|
| 339 |
+
indexing='ij'), axis=-1).astype('float') - bias_scale * (np.array(grid_size) - 1) / 2.
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def resample_linear(inputs, ddf=None, sample_coords=None,random_boundary=True):
|
| 343 |
+
if random_boundary:
|
| 344 |
+
random_factor = np.random.uniform(0., 1.)
|
| 345 |
+
min_val = np.min(inputs)
|
| 346 |
+
inputs[:, 0, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, 0, :, :]
|
| 347 |
+
inputs[:, -1, :, :] = min_val * random_factor + (1 - random_factor) * inputs[:, -1, :, :]
|
| 348 |
+
inputs[:, :, 0, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, 0, :]
|
| 349 |
+
inputs[:, :, -1, :] = min_val * random_factor + (1 - random_factor) * inputs[:, :, -1, :]
|
| 350 |
+
inputs[:, :, :, 0] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, 0]
|
| 351 |
+
inputs[:, :, :, -1] = min_val * random_factor + (1 - random_factor) * inputs[:, :, :, -1]
|
| 352 |
+
|
| 353 |
+
input_size = inputs.shape[1:4]
|
| 354 |
+
sample_coords = get_reference_grid(input_size) + ddf if sample_coords is None else sample_coords
|
| 355 |
+
spatial_rank = 3 # inputs.ndim - 2
|
| 356 |
+
xy = [sample_coords[..., i] for i in
|
| 357 |
+
range(sample_coords.shape[-1])] # tf.unstack(sample_coords, axis=len(sample_coords.shape)-1)
|
| 358 |
+
index_voxel_coords = [np.floor(x) for x in xy]
|
| 359 |
+
|
| 360 |
+
def boundary_replicate(sample_coords0, input_size0, plus=0):
|
| 361 |
+
return np.maximum(np.minimum(sample_coords0, input_size0 - 2 + plus), 0 + plus)
|
| 362 |
+
|
| 363 |
+
def boundary_replicate_float(sample_coords0, input_size0, plus=0.):
|
| 364 |
+
return np.maximum(np.minimum(sample_coords0, input_size0 - 1 + plus), 0 + plus)
|
| 365 |
+
|
| 366 |
+
xy = [boundary_replicate_float(x.astype('float32'), input_size[idx]) for idx, x in enumerate(xy)]
|
| 367 |
+
spatial_coords = [boundary_replicate(x.astype('int32'), input_size[idx])
|
| 368 |
+
for idx, x in enumerate(index_voxel_coords)]
|
| 369 |
+
spatial_coords_plus1 = [boundary_replicate((x + 1).astype('int32'), input_size[idx], 1)
|
| 370 |
+
for idx, x in enumerate(index_voxel_coords)]
|
| 371 |
+
|
| 372 |
+
weight = [np.expand_dims(x - i.astype('float32'), -1) for x, i in zip(xy, spatial_coords)]
|
| 373 |
+
weight_c = [np.expand_dims(i.astype('float32') - x, -1) for x, i in zip(xy, spatial_coords_plus1)]
|
| 374 |
+
|
| 375 |
+
sz = list(spatial_coords[0].shape)
|
| 376 |
+
batch_coords = np.tile(np.reshape(range(sz[0]), [sz[0]] + [1] * (len(sz) - 1)), [1] + sz[1:])
|
| 377 |
+
sc = (spatial_coords, spatial_coords_plus1)
|
| 378 |
+
binary_codes = [[int(c) for c in format(i, '0%ib' % spatial_rank)] for i in range(2 ** spatial_rank)]
|
| 379 |
+
|
| 380 |
+
make_sample = lambda bc: inputs[batch_coords, sc[bc[0]][0], sc[bc[1]][1], sc[bc[2]][
|
| 381 |
+
2], ...] # tf.gather_nd(inputs, np.stack([batch_coords] + [sc[c][i] for i, c in enumerate(bc)], -1))
|
| 382 |
+
samples = [make_sample(bc) for bc in binary_codes]
|
| 383 |
+
|
| 384 |
+
def pyramid_combination(samples0, weight0, weight_c0):
|
| 385 |
+
if len(weight0) == 1:
|
| 386 |
+
return samples0[0] * weight_c0[0] + samples0[1] * weight0[0]
|
| 387 |
+
else:
|
| 388 |
+
return pyramid_combination(samples0[::2], weight0[:-1], weight_c0[:-1]) * weight_c0[-1] + \
|
| 389 |
+
pyramid_combination(samples0[1::2], weight0[:-1], weight_c0[:-1]) * weight0[-1]
|
| 390 |
+
|
| 391 |
+
return pyramid_combination(samples, weight, weight_c)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def vecang2rotmats(vec, ang):
|
| 395 |
+
return np.stack([np.reshape(vecang2rotmat(vec[i, ...], ang[i, ...]), [3, 3]) for i in range(len(vec))], 0)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def vecang2rotmat(vec, ang):
|
| 399 |
+
q = quater.Quaternion(axis=vec, angle=ang)
|
| 400 |
+
return q.rotation_matrix
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def images_to_vectors(images):
|
| 404 |
+
return images.view(images.size(0), 16384).to(device)
|
| 405 |
+
|
| 406 |
+
def vectors_to_images(vectors):
|
| 407 |
+
return vectors.view(vectors.size(0), 1, 128, 128).to(device)
|
| 408 |
+
|
| 409 |
+
def noise(size):
|
| 410 |
+
n = Variable(torch.randn(size, 100)).to(device)
|
| 411 |
+
return n
|
| 412 |
+
|
| 413 |
+
def ones_target(size):
|
| 414 |
+
data = Variable(torch.ones(size, 1)).to(device)
|
| 415 |
+
return data
|
| 416 |
+
|
| 417 |
+
def zeros_target(size):
|
| 418 |
+
data = Variable(torch.zeros(size, 1)).to(device)
|
| 419 |
+
return data
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def eval_detJ_lab(disp=None,vol1=None,vol2=None,thresh=0.5):
|
| 423 |
+
ndims=disp.ndim-2
|
| 424 |
+
if vol1 ==None or thresh==None:
|
| 425 |
+
label=1
|
| 426 |
+
else:
|
| 427 |
+
label=vol1>thresh
|
| 428 |
+
label=label*(spimg.laplace(label) < 0.1)
|
| 429 |
+
rescale_factor=2
|
| 430 |
+
label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
|
| 431 |
+
|
| 432 |
+
# disp = disp.permute([0, *range(2,ndims+2), 1])
|
| 433 |
+
# print(disp.shape)
|
| 434 |
+
disp = np.transpose(disp, [0, *range(2,ndims+2), 1])
|
| 435 |
+
# Jacob=np.stack(np.gradient(disp,axis=[-4,-3,-2]),-1)
|
| 436 |
+
Jacob=np.stack(np.gradient(disp,axis=[*range(1,ndims+1)]),-1)
|
| 437 |
+
for ii in range(ndims):
|
| 438 |
+
Jacob[..., ii, ii] = Jacob[..., ii, ii] + 1
|
| 439 |
+
# Jacob[..., 0, 0] = Jacob[..., 0, 0] + 1
|
| 440 |
+
# Jacob[..., 1, 1] = Jacob[..., 1, 1] + 1
|
| 441 |
+
# Jacob[..., 2, 2] = Jacob[..., 2, 2] + 1
|
| 442 |
+
return np.sum((np.linalg.det(Jacob)<0)*label)
|
| 443 |
+
|
| 444 |
+
def eval_def_mag(disp=None,vol1=None,vol2=None,thresh=0.5):
|
| 445 |
+
ndims=3
|
| 446 |
+
# if vol1 ==None or thresh==None:
|
| 447 |
+
# label=1
|
| 448 |
+
# else:
|
| 449 |
+
# label=vol1>thresh
|
| 450 |
+
# label=label*(spimg.laplace(label) < 0.1)
|
| 451 |
+
# rescale_factor=2
|
| 452 |
+
# label=label[...,::rescale_factor,::rescale_factor,::rescale_factor]
|
| 453 |
+
mag=np.sqrt(np.sum(np.square(disp),axis=1))
|
| 454 |
+
sz=mag.shape
|
| 455 |
+
max_mag=np.mean(np.max(np.reshape(mag,[sz[0],-1]),axis=-1))
|
| 456 |
+
avg_mag=np.mean(mag)
|
| 457 |
+
return [avg_mag,max_mag]
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def converet_to_nibabel(vol_tensor,ndims=3):
|
| 461 |
+
if isinstance(vol_tensor, np.ndarray):
|
| 462 |
+
vol_np=vol_tensor
|
| 463 |
+
else:
|
| 464 |
+
vol_np=vol_tensor.cpu().numpy()
|
| 465 |
+
vol_np=vol_np.squeeze(0)
|
| 466 |
+
if ndims==3:
|
| 467 |
+
map_eyes = np.eye(4)
|
| 468 |
+
elif ndims==2:
|
| 469 |
+
map_eyes = np.eye(4)
|
| 470 |
+
map_eyes[2,2]=0
|
| 471 |
+
|
| 472 |
+
if vol_np.shape[0]==1:
|
| 473 |
+
vol_np=vol_np.squeeze(0)
|
| 474 |
+
elif vol_np.shape[0]>1:
|
| 475 |
+
# save as 4D volumes
|
| 476 |
+
# print(vol_np.shape)
|
| 477 |
+
vol_np=np.transpose(vol_np,[1,2,3,0])
|
| 478 |
+
|
| 479 |
+
return nib.Nifti1Image(vol_np, affine=map_eyes)
|
| 480 |
+
|
| 481 |
+
def print_memory_usage(tag=""):
|
| 482 |
+
print(f"[{tag}] Allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB | Cached: {torch.cuda.memory_reserved() / 1e9:.2f} GB")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
if __name__ == "__main__":
|
| 486 |
+
vol_shape=[4,1,64,64]
|
| 487 |
+
|
| 488 |
+
vol=np.random.uniform(-1,1,vol_shape)
|
| 489 |
+
vol=Variable(torch.tensor(vol,dtype=torch.float32))
|
| 490 |
+
vol_res=random_resample(vol)
|
| 491 |
+
vol_crop=img_crop(vol_res)
|
| 492 |
+
|
| 493 |
+
mask = get_random_deformed_mask(vol.shape[2:])
|
| 494 |
+
|
| 495 |
+
print(mask)
|
| 496 |
+
|
| 497 |
+
# print(vol.tolist())
|
| 498 |
+
# print(vol_res.tolist())
|