Sync from local: code + epoch-110 checkpoint, clean README
Browse filesReplace existing repo with current local OmniMorph: full source tree (training/inference/registration scripts), Diffusion/OMorpher modules, dataloader mappings (16 datasets), and Models/all_om_net/000110_all_om_net.pth (final checkpoint, 3.0G). README rewritten to remove internal links/credentials. BERT external model and intermediate checkpoints not bundled.
This view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- .gitignore +1 -0
- Config/config_om.yaml +15 -20
- Config/config_reg_brain.yaml +36 -0
- Config/config_reg_hip.yaml +48 -0
- Dataloader/dataLoader.py +172 -69
- Dataloader/dataloader_utils.py +3 -3
- Dataloader/deal_with_json.py +150 -0
- Dataloader/embding_gen.py +10 -2
- Dataloader/nifty_mappings/AbdomenAtlas_mappings.json +2 -2
- Dataloader/nifty_mappings/AbdomenCT1k_mappings.json +2 -2
- Dataloader/nifty_mappings/Brats2019_mappings.json +2 -2
- Dataloader/nifty_mappings/Brats2020_mappings.json +2 -2
- Dataloader/nifty_mappings/Brats2021_mappings.json +2 -2
- Dataloader/nifty_mappings/CIA_mappings.json +2 -2
- Dataloader/nifty_mappings/Kaggle_osic_mappings.json +0 -0
- Dataloader/nifty_mappings/MSD_mappings.json +2 -2
- Dataloader/nifty_mappings/MnMs_mappings.json +0 -0
- Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json +3 -0
- Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json +3 -0
- Dataloader/nifty_mappings/OASIS_1_mappings.json +2 -2
- Dataloader/nifty_mappings/OASIS_2_mappings.json +2 -2
- Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json +2 -2
- Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json +2 -2
- Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json +2 -2
- Diffusion/diffuser-reg.py +541 -0
- Diffusion/diffuser.py +45 -20
- Diffusion/diffuser_opt.py +357 -0
- Diffusion/losses.py +44 -7
- Diffusion/losses_opt.py +141 -0
- Diffusion/networks.py +328 -17
- Diffusion/networks0.py +1195 -0
- Diffusion/networks_opt.py +239 -0
- Diffusion/safe_conv_transpose.py +401 -0
- Models/all_om_net/000110_all_om_net.pth +3 -0
- OM_reg.py +10 -18
- OM_reg_flexres.py +382 -0
- OM_train_2modes-reg.py +517 -0
- OM_train_2modes.py +60 -69
- OM_train_3modes-XPU.py +957 -0
- OM_train_3modes.py +697 -198
- OM_train_3modes_cudaonly.py +512 -0
- OM_train_3modes_opt.py +513 -0
- OM_train_3modes_original.py +585 -0
- OMorpher/__init__.py +3 -0
- OMorpher/omorpher.py +1058 -0
- README.md +129 -80
- Scripts/OM_aug_om.py +239 -0
- Scripts/OM_reg_flexres_om.py +315 -0
- Scripts/OM_reg_pair_ext.py +676 -0
.gitattributes
CHANGED
|
@@ -46,3 +46,5 @@ Dataloader/nifty_mappings/OASIS_2_mappings.json filter=lfs diff=lfs merge=lfs -t
|
|
| 46 |
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 47 |
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 46 |
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 47 |
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 48 |
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
CHANGED
|
@@ -15,6 +15,7 @@ External/
|
|
| 15 |
|
| 16 |
# Logs
|
| 17 |
Log/
|
|
|
|
| 18 |
swanlog/
|
| 19 |
train_log.txt
|
| 20 |
aug_log.txt
|
|
|
|
| 15 |
|
| 16 |
# Logs
|
| 17 |
Log/
|
| 18 |
+
Logs/
|
| 19 |
swanlog/
|
| 20 |
train_log.txt
|
| 21 |
aug_log.txt
|
Config/config_om.yaml
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
data_name: all
|
| 2 |
-
|
| 3 |
-
net_name: recmutattnnet
|
| 4 |
-
# net_name:
|
| 5 |
# net_name: defrecmutattnnet
|
| 6 |
ndims: 3
|
| 7 |
img_size: 128
|
| 8 |
-
batchsize:
|
| 9 |
ddf_pad_mode: border
|
| 10 |
-
device:
|
| 11 |
img_pad_mode: zeros
|
| 12 |
num_input_chn: 1
|
| 13 |
padding_mode: border
|
|
@@ -19,23 +19,21 @@ v_scale: 5.0e-05
|
|
| 19 |
epoch: 10000
|
| 20 |
epoch_per_save: 1
|
| 21 |
lr: 0.00001
|
| 22 |
-
noise_scale: 0.
|
| 23 |
# =========================
|
| 24 |
# AUGMENTATION SETTING
|
| 25 |
patients_list: []
|
| 26 |
# model_id_str: '000000'
|
| 27 |
# model_id_str: '000180' # before registration training
|
| 28 |
-
# model_id_str: '
|
| 29 |
-
model_id_str: '000354' #
|
| 30 |
# model_id_str: '000157'
|
| 31 |
# model_id_str: '000171'
|
| 32 |
-
|
|
|
|
| 33 |
noise_step: 1
|
| 34 |
-
aug_coe:
|
| 35 |
-
|
| 36 |
-
#
|
| 37 |
-
# aug_coe: 4 # how many times each sample will be augmented
|
| 38 |
-
condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
| 39 |
# aug_img_savepath: Data/Aug_data/totseg/img/
|
| 40 |
# aug_msk_savepath: Data/Aug_data/totseg/msk/
|
| 41 |
# aug_ddf_savepath: Data/Aug_data/totseg/ddf/
|
|
@@ -45,9 +43,6 @@ condition_type: 'uncon' # 'None', 'none', 'adding','independ', 'downsample
|
|
| 45 |
reg_img_savepath: Data/Reg_data/om/img/
|
| 46 |
reg_msk_savepath: Data/Reg_data/om/msk/
|
| 47 |
reg_ddf_savepath: Data/Reg_data/om/ddf/
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
aug_img_savepath: Data/Aug_data/mnms/img/
|
| 52 |
-
aug_msk_savepath: Data/Aug_data/mnms/msk/
|
| 53 |
-
aug_ddf_savepath: Data/Aug_data/mnms/ddf/
|
|
|
|
| 1 |
data_name: all
|
| 2 |
+
net_name: om_net
|
| 3 |
+
# net_name: recmutattnnet
|
| 4 |
+
# net_name: recmulmodmutattnnet
|
| 5 |
# net_name: defrecmutattnnet
|
| 6 |
ndims: 3
|
| 7 |
img_size: 128
|
| 8 |
+
batchsize: 3
|
| 9 |
ddf_pad_mode: border
|
| 10 |
+
device: xpu
|
| 11 |
img_pad_mode: zeros
|
| 12 |
num_input_chn: 1
|
| 13 |
padding_mode: border
|
|
|
|
| 19 |
epoch: 10000
|
| 20 |
epoch_per_save: 1
|
| 21 |
lr: 0.00001
|
| 22 |
+
noise_scale: 0.05
|
| 23 |
# =========================
|
| 24 |
# AUGMENTATION SETTING
|
| 25 |
patients_list: []
|
| 26 |
# model_id_str: '000000'
|
| 27 |
# model_id_str: '000180' # before registration training
|
| 28 |
+
# model_id_str: '000356'
|
|
|
|
| 29 |
# model_id_str: '000157'
|
| 30 |
# model_id_str: '000171'
|
| 31 |
+
model_id_str: '000009'
|
| 32 |
+
start_noise_step: 75
|
| 33 |
noise_step: 1
|
| 34 |
+
# aug_coe: 32 # how many times each sample will be augmented
|
| 35 |
+
aug_coe: 1 # how many times each sample will be augmented
|
| 36 |
+
condition_type: 'slice' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
|
|
|
|
|
|
| 37 |
# aug_img_savepath: Data/Aug_data/totseg/img/
|
| 38 |
# aug_msk_savepath: Data/Aug_data/totseg/msk/
|
| 39 |
# aug_ddf_savepath: Data/Aug_data/totseg/ddf/
|
|
|
|
| 43 |
reg_img_savepath: Data/Reg_data/om/img/
|
| 44 |
reg_msk_savepath: Data/Reg_data/om/msk/
|
| 45 |
reg_ddf_savepath: Data/Reg_data/om/ddf/
|
| 46 |
+
aug_img_savepath: Data/Aug_data/msd/img/
|
| 47 |
+
aug_msk_savepath: Data/Aug_data/msd/msk/
|
| 48 |
+
aug_ddf_savepath: Data/Aug_data/msd/ddf/
|
|
|
|
|
|
|
|
|
Config/config_reg_brain.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: all
|
| 2 |
+
# net_name: recresacnet
|
| 3 |
+
# net_name: recmutattnnet
|
| 4 |
+
net_name: recmulmodmutattnnet
|
| 5 |
+
# net_name: defrecmutattnnet
|
| 6 |
+
ndims: 3
|
| 7 |
+
img_size: 128
|
| 8 |
+
batchsize: 3
|
| 9 |
+
ddf_pad_mode: border
|
| 10 |
+
device: xpu
|
| 11 |
+
img_pad_mode: zeros
|
| 12 |
+
num_input_chn: 1
|
| 13 |
+
padding_mode: border
|
| 14 |
+
resample_mode: bilinear
|
| 15 |
+
timesteps: 80
|
| 16 |
+
v_scale: 5.0e-05
|
| 17 |
+
# =========================
|
| 18 |
+
# TRAINING SETTING
|
| 19 |
+
epoch: 10000
|
| 20 |
+
epoch_per_save: 1
|
| 21 |
+
lr: 0.00001
|
| 22 |
+
noise_scale: 0.1
|
| 23 |
+
# =========================
|
| 24 |
+
# AUGMENTATION SETTING
|
| 25 |
+
patients_list: []
|
| 26 |
+
model_id_str: '000009'
|
| 27 |
+
start_noise_step: 75
|
| 28 |
+
noise_step: 1
|
| 29 |
+
aug_coe: 1
|
| 30 |
+
condition_type: 'none'
|
| 31 |
+
reg_img_savepath: Data/Reg_data/unpair_brain/img/
|
| 32 |
+
reg_msk_savepath: Data/Reg_data/unpair_brain/msk/
|
| 33 |
+
reg_ddf_savepath: Data/Reg_data/unpair_brain/ddf/
|
| 34 |
+
aug_img_savepath: Data/Aug_data/unpair_brain/img/
|
| 35 |
+
aug_msk_savepath: Data/Aug_data/unpair_brain/msk/
|
| 36 |
+
aug_ddf_savepath: Data/Aug_data/unpair_brain/ddf/
|
Config/config_reg_hip.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data_name: all
|
| 2 |
+
# net_name: recresacnet
|
| 3 |
+
# net_name: recmutattnnet
|
| 4 |
+
net_name: recmulmodmutattnnet
|
| 5 |
+
# net_name: defrecmutattnnet
|
| 6 |
+
ndims: 3
|
| 7 |
+
img_size: 128
|
| 8 |
+
batchsize: 3
|
| 9 |
+
ddf_pad_mode: border
|
| 10 |
+
device: xpu
|
| 11 |
+
img_pad_mode: zeros
|
| 12 |
+
num_input_chn: 1
|
| 13 |
+
padding_mode: border
|
| 14 |
+
resample_mode: bilinear
|
| 15 |
+
timesteps: 80
|
| 16 |
+
v_scale: 5.0e-05
|
| 17 |
+
# =========================
|
| 18 |
+
# TRAINING SETTING
|
| 19 |
+
epoch: 10000
|
| 20 |
+
epoch_per_save: 1
|
| 21 |
+
lr: 0.00001
|
| 22 |
+
noise_scale: 0.1
|
| 23 |
+
# =========================
|
| 24 |
+
# AUGMENTATION SETTING
|
| 25 |
+
patients_list: []
|
| 26 |
+
# model_id_str: '000000'
|
| 27 |
+
# model_id_str: '000180' # before registration training
|
| 28 |
+
# model_id_str: '000356'
|
| 29 |
+
# model_id_str: '000157'
|
| 30 |
+
# model_id_str: '000171'
|
| 31 |
+
model_id_str: '000009'
|
| 32 |
+
start_noise_step: 75
|
| 33 |
+
noise_step: 1
|
| 34 |
+
# aug_coe: 32 # how many times each sample will be augmented
|
| 35 |
+
aug_coe: 1 # how many times each sample will be augmented
|
| 36 |
+
condition_type: 'none' # 'None', 'none', 'adding','independ', 'downsample', 'slice', 'project', 'uncon'
|
| 37 |
+
# aug_img_savepath: Data/Aug_data/totseg/img/
|
| 38 |
+
# aug_msk_savepath: Data/Aug_data/totseg/msk/
|
| 39 |
+
# aug_ddf_savepath: Data/Aug_data/totseg/ddf/
|
| 40 |
+
# aug_img_savepath: Data/Aug_data/om/img/
|
| 41 |
+
# aug_msk_savepath: Data/Aug_data/om/msk/
|
| 42 |
+
# aug_ddf_savepath: Data/Aug_data/om/ddf/
|
| 43 |
+
reg_img_savepath: Data/Reg_data/pair_hip/img/
|
| 44 |
+
reg_msk_savepath: Data/Reg_data/pair_hip/msk/
|
| 45 |
+
reg_ddf_savepath: Data/Reg_data/pair_hip/ddf/
|
| 46 |
+
aug_img_savepath: Data/Aug_data/pair_hip/img/
|
| 47 |
+
aug_msk_savepath: Data/Aug_data/pair_hip/msk/
|
| 48 |
+
aug_ddf_savepath: Data/Aug_data/pair_hip/ddf/
|
Dataloader/dataLoader.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch.utils.data import Dataset, DataLoader
|
| 3 |
import json
|
|
@@ -5,8 +8,8 @@ import SimpleITK as sitk
|
|
| 5 |
import numpy as np
|
| 6 |
from skimage.transform import rescale, resize, downscale_local_mean
|
| 7 |
# from torchvision.transforms import v2
|
| 8 |
-
|
| 9 |
-
sys.path.append(
|
| 10 |
from Dataloader.dataloader_utils import *
|
| 11 |
import random
|
| 12 |
|
|
@@ -18,22 +21,42 @@ import random
|
|
| 18 |
# }
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
mapping_files = {
|
| 22 |
-
'MSD': '
|
| 23 |
-
'TotalSegmentor': '
|
| 24 |
-
'Kaggle_osic': '
|
| 25 |
-
'CancerImageArchive': '
|
| 26 |
-
'MnMs': '
|
| 27 |
-
# 'Brats2019': '
|
| 28 |
-
'Brats2020': '
|
| 29 |
-
'Brats2021': '
|
| 30 |
-
'OASIS_1': '
|
| 31 |
-
'OASIS_2': '
|
| 32 |
-
'PSMA-FDG-PET-CT-LESION':'
|
| 33 |
-
'PSMA-CT':'
|
| 34 |
-
'AbdomenAtlas':'
|
| 35 |
-
'AbdomenCT1k':'
|
|
|
|
|
|
|
| 36 |
}
|
|
|
|
|
|
|
| 37 |
|
| 38 |
CLAMP_RANGE = [-400, 400] # default clamp range for the images
|
| 39 |
|
|
@@ -74,50 +97,9 @@ def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high')
|
|
| 74 |
sample_value = np.random.uniform(low, high=sample_value)
|
| 75 |
return sample_value
|
| 76 |
|
| 77 |
-
class DummyOMDataset_indiv(Dataset):
|
| 78 |
-
"""Dummy dataset that generates random 3D volumes and embeddings for XPU testing."""
|
| 79 |
-
def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
|
| 80 |
-
self.out_sz = out_sz
|
| 81 |
-
self.num_samples = num_samples
|
| 82 |
-
self.embd_dim = embd_dim
|
| 83 |
-
self.transform = transform
|
| 84 |
-
|
| 85 |
-
def __len__(self):
|
| 86 |
-
return self.num_samples
|
| 87 |
-
|
| 88 |
-
def __getitem__(self, idx):
|
| 89 |
-
volume = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 90 |
-
embd = np.random.randn(self.embd_dim).astype(np.float32)
|
| 91 |
-
if self.transform is not None:
|
| 92 |
-
volume = self.transform(volume)
|
| 93 |
-
return volume, embd
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class DummyOMDataset_pair(Dataset):
|
| 97 |
-
"""Dummy dataset that generates random paired 3D volumes and embeddings for XPU testing."""
|
| 98 |
-
def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
|
| 99 |
-
self.out_sz = out_sz
|
| 100 |
-
self.num_samples = num_samples
|
| 101 |
-
self.embd_dim = embd_dim
|
| 102 |
-
self.transform = transform
|
| 103 |
-
|
| 104 |
-
def __len__(self):
|
| 105 |
-
return self.num_samples
|
| 106 |
-
|
| 107 |
-
def __getitem__(self, idx):
|
| 108 |
-
volume_A = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 109 |
-
volume_B = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 110 |
-
embd_A = np.random.randn(self.embd_dim).astype(np.float32)
|
| 111 |
-
embd_B = np.random.randn(self.embd_dim).astype(np.float32)
|
| 112 |
-
if self.transform is not None:
|
| 113 |
-
volume_A = self.transform(volume_A)
|
| 114 |
-
volume_B = self.transform(volume_B)
|
| 115 |
-
return [volume_A, volume_B, embd_A, embd_B]
|
| 116 |
-
|
| 117 |
-
|
| 118 |
class OminiDataset(object):
|
| 119 |
"""Base class for OmniMorph datasets."""
|
| 120 |
-
def
|
| 121 |
|
| 122 |
# self.mappings = mapping_files
|
| 123 |
self.ALLdata = self.combine_data(mappings = mapping_files)
|
|
@@ -155,10 +137,27 @@ class OminiDataset(object):
|
|
| 155 |
|
| 156 |
def combine_data(self, mappings = mapping_files):
|
| 157 |
ALLdata = {}
|
|
|
|
|
|
|
| 158 |
for j in mappings.keys():
|
| 159 |
with open(mappings[j], 'r') as f:
|
| 160 |
mappings_tmp = json.load(f)
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
return ALLdata
|
| 163 |
|
| 164 |
def get_3D_volume(self, volume, select_channel = None):
|
|
@@ -301,10 +300,27 @@ class OminiDataset_v1(Dataset):
|
|
| 301 |
|
| 302 |
def combine_data(self):
|
| 303 |
ALLdata = {}
|
|
|
|
|
|
|
| 304 |
for j in self.mappings.keys():
|
| 305 |
with open(self.mappings[j], 'r') as f:
|
| 306 |
mappings = json.load(f)
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
return ALLdata
|
| 309 |
|
| 310 |
def __len__(self):
|
|
@@ -442,10 +458,27 @@ class OMDataset_indiv(Dataset):
|
|
| 442 |
|
| 443 |
def combine_data(self, mappings = mapping_files):
|
| 444 |
ALLdata = {}
|
|
|
|
|
|
|
| 445 |
for j in mappings.keys():
|
| 446 |
with open(mappings[j], 'r') as f:
|
| 447 |
mappings_tmp = json.load(f)
|
| 448 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 449 |
return ALLdata
|
| 450 |
|
| 451 |
def __len__(self):
|
|
@@ -496,7 +529,7 @@ class OMDataset_indiv(Dataset):
|
|
| 496 |
return [volume, embd]
|
| 497 |
|
| 498 |
class OminiDataset_paired(Dataset):
|
| 499 |
-
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.
|
| 500 |
# self.mappings = mapping_files
|
| 501 |
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 502 |
self.out_sz = out_sz
|
|
@@ -525,10 +558,27 @@ class OminiDataset_paired(Dataset):
|
|
| 525 |
|
| 526 |
def combine_data(self, mappings = mapping_files):
|
| 527 |
ALLdata = {}
|
|
|
|
|
|
|
| 528 |
for j in mappings.keys():
|
| 529 |
with open(mappings[j], 'r') as f:
|
| 530 |
mappings_tmp = json.load(f)
|
| 531 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 532 |
return ALLdata
|
| 533 |
|
| 534 |
def normalize(self, volume, eps=1e-7):
|
|
@@ -747,10 +797,27 @@ class OMDataset_pair(Dataset):
|
|
| 747 |
|
| 748 |
def combine_data(self, mappings = mapping_files):
|
| 749 |
ALLdata = {}
|
|
|
|
|
|
|
| 750 |
for j in mappings.keys():
|
| 751 |
with open(mappings[j], 'r') as f:
|
| 752 |
mappings_tmp = json.load(f)
|
| 753 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 754 |
return ALLdata
|
| 755 |
|
| 756 |
def normalize(self, volume, eps=1e-7):
|
|
@@ -911,8 +978,8 @@ class OMDataset_pair(Dataset):
|
|
| 911 |
|
| 912 |
paired_key = random.choice(paired_keys)
|
| 913 |
|
| 914 |
-
print(f"Key: {key}, Paired Key: {paired_key}")
|
| 915 |
-
print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
|
| 916 |
|
| 917 |
|
| 918 |
volume_B = sitk.ReadImage(paired_key)
|
|
@@ -1004,10 +1071,27 @@ class OminiDataset_paired_inf(object):
|
|
| 1004 |
|
| 1005 |
def combine_data(self, mappings = mapping_files):
|
| 1006 |
ALLdata = {}
|
|
|
|
|
|
|
| 1007 |
for j in mappings.keys():
|
| 1008 |
with open(mappings[j], 'r') as f:
|
| 1009 |
mappings_tmp = json.load(f)
|
| 1010 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
return ALLdata
|
| 1012 |
|
| 1013 |
def __len__(self):
|
|
@@ -1244,10 +1328,27 @@ class OminiDataset_inference_w_all(object):
|
|
| 1244 |
|
| 1245 |
def combine_data(self, mappings = mapping_files):
|
| 1246 |
ALLdata = {}
|
|
|
|
|
|
|
| 1247 |
for j in mappings.keys():
|
| 1248 |
with open(mappings[j], 'r') as f:
|
| 1249 |
mappings_tmp = json.load(f)
|
| 1250 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1251 |
return ALLdata
|
| 1252 |
|
| 1253 |
def normalize(self, volume, eps=1e-7):
|
|
@@ -1414,6 +1515,7 @@ class OminiDataset_inference_w_all(object):
|
|
| 1414 |
# print(f"Label with channels, pad_width_lab: {pad_width_lab}")
|
| 1415 |
else:
|
| 1416 |
pad_width_lab = pad_width
|
|
|
|
| 1417 |
label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
|
| 1418 |
# print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
|
| 1419 |
label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
|
|
@@ -1442,6 +1544,7 @@ class OminiDataset_inference_w_all(object):
|
|
| 1442 |
return return_dict
|
| 1443 |
|
| 1444 |
|
|
|
|
| 1445 |
class OminiDataset_bertembd(OminiDataset):
|
| 1446 |
def __init__(self,
|
| 1447 |
out_sz = 128,
|
|
@@ -1453,7 +1556,7 @@ class OminiDataset_bertembd(OminiDataset):
|
|
| 1453 |
reverse_axis_order = False,
|
| 1454 |
min_dim = 3,
|
| 1455 |
mapping_files = mapping_files):
|
| 1456 |
-
super().
|
| 1457 |
transform = transform,
|
| 1458 |
clamp_range = clamp_range,
|
| 1459 |
min_crop_ratio = min_crop_ratio,
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
from torch.utils.data import Dataset, DataLoader
|
| 6 |
import json
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
from skimage.transform import rescale, resize, downscale_local_mean
|
| 10 |
# from torchvision.transforms import v2
|
| 11 |
+
# sys.path.append('./')
|
| 12 |
+
sys.path.append(ROOT_DIR)
|
| 13 |
from Dataloader.dataloader_utils import *
|
| 14 |
import random
|
| 15 |
|
|
|
|
| 21 |
# }
|
| 22 |
|
| 23 |
|
| 24 |
+
# mapping_files = {
|
| 25 |
+
# 'MSD': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/MSD_mappings.json',
|
| 26 |
+
# 'TotalSegmentor': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
|
| 27 |
+
# 'Kaggle_osic': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/Kaggle_osic_mappings.json',
|
| 28 |
+
# 'CancerImageArchive': '/home/data/Github/OmniMorph/Dataloader/nifty_mappings/CIA_mappings.json',
|
| 29 |
+
# 'MnMs': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/MnMs_mappings.json',
|
| 30 |
+
# # 'Brats2019': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2019_mappings.json',
|
| 31 |
+
# 'Brats2020': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2020_mappings.json',
|
| 32 |
+
# 'Brats2021': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/Brats2021_mappings.json',
|
| 33 |
+
# 'OASIS_1': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_1_mappings.json',
|
| 34 |
+
# 'OASIS_2': '/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/OASIS_2_mappings.json',
|
| 35 |
+
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
|
| 36 |
+
# 'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 37 |
+
# 'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
|
| 38 |
+
# 'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
|
| 39 |
+
# }
|
| 40 |
mapping_files = {
|
| 41 |
+
'MSD': 'nifty_mappings/MSD_mappings.json',
|
| 42 |
+
'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
|
| 43 |
+
'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
|
| 44 |
+
'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
|
| 45 |
+
'MnMs': 'nifty_mappings/MnMs_mappings.json',
|
| 46 |
+
# 'Brats2019': 'nifty_mappings/Brats2019_mappings.json', # should be commented out after testing
|
| 47 |
+
'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
|
| 48 |
+
'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
|
| 49 |
+
'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
|
| 50 |
+
'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
|
| 51 |
+
'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
|
| 52 |
+
'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 53 |
+
'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
|
| 54 |
+
'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
|
| 55 |
+
'OAI_ZIB': 'nifty_mappings/OAI_ZIB_KL_mappings.json',
|
| 56 |
+
# 'OAI_ZIB': 'nifty_mappings/OAI_ZIB_WOMAC_mappings.json', # alternative: WOMAC scores instead of KL-grade
|
| 57 |
}
|
| 58 |
+
for k,v in mapping_files.items():
|
| 59 |
+
mapping_files[k] = os.path.join(ROOT_DIR, v)
|
| 60 |
|
| 61 |
CLAMP_RANGE = [-400, 400] # default clamp range for the images
|
| 62 |
|
|
|
|
| 97 |
sample_value = np.random.uniform(low, high=sample_value)
|
| 98 |
return sample_value
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
class OminiDataset(object):
|
| 101 |
"""Base class for OmniMorph datasets."""
|
| 102 |
+
def __init__(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
|
| 103 |
|
| 104 |
# self.mappings = mapping_files
|
| 105 |
self.ALLdata = self.combine_data(mappings = mapping_files)
|
|
|
|
| 137 |
|
| 138 |
def combine_data(self, mappings = mapping_files):
|
| 139 |
ALLdata = {}
|
| 140 |
+
total_entries = 0
|
| 141 |
+
total_skipped = 0
|
| 142 |
for j in mappings.keys():
|
| 143 |
with open(mappings[j], 'r') as f:
|
| 144 |
mappings_tmp = json.load(f)
|
| 145 |
+
skipped = 0
|
| 146 |
+
for k, v in mappings_tmp.items():
|
| 147 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 148 |
+
skipped += 1
|
| 149 |
+
continue
|
| 150 |
+
ALLdata[k] = v
|
| 151 |
+
accessible = len(mappings_tmp) - skipped
|
| 152 |
+
total_entries += len(mappings_tmp)
|
| 153 |
+
total_skipped += skipped
|
| 154 |
+
if skipped > 0:
|
| 155 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 156 |
+
if total_skipped > 0:
|
| 157 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 158 |
+
if len(ALLdata) < 1000:
|
| 159 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 160 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 161 |
return ALLdata
|
| 162 |
|
| 163 |
def get_3D_volume(self, volume, select_channel = None):
|
|
|
|
| 300 |
|
| 301 |
def combine_data(self):
|
| 302 |
ALLdata = {}
|
| 303 |
+
total_entries = 0
|
| 304 |
+
total_skipped = 0
|
| 305 |
for j in self.mappings.keys():
|
| 306 |
with open(self.mappings[j], 'r') as f:
|
| 307 |
mappings = json.load(f)
|
| 308 |
+
skipped = 0
|
| 309 |
+
for k, v in mappings.items():
|
| 310 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 311 |
+
skipped += 1
|
| 312 |
+
continue
|
| 313 |
+
ALLdata[k] = v
|
| 314 |
+
accessible = len(mappings) - skipped
|
| 315 |
+
total_entries += len(mappings)
|
| 316 |
+
total_skipped += skipped
|
| 317 |
+
if skipped > 0:
|
| 318 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings)} accessible ({skipped} missing/empty)")
|
| 319 |
+
if total_skipped > 0:
|
| 320 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 321 |
+
if len(ALLdata) < 1000:
|
| 322 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 323 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 324 |
return ALLdata
|
| 325 |
|
| 326 |
def __len__(self):
|
|
|
|
| 458 |
|
| 459 |
def combine_data(self, mappings = mapping_files):
|
| 460 |
ALLdata = {}
|
| 461 |
+
total_entries = 0
|
| 462 |
+
total_skipped = 0
|
| 463 |
for j in mappings.keys():
|
| 464 |
with open(mappings[j], 'r') as f:
|
| 465 |
mappings_tmp = json.load(f)
|
| 466 |
+
skipped = 0
|
| 467 |
+
for k, v in mappings_tmp.items():
|
| 468 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 469 |
+
skipped += 1
|
| 470 |
+
continue
|
| 471 |
+
ALLdata[k] = v
|
| 472 |
+
accessible = len(mappings_tmp) - skipped
|
| 473 |
+
total_entries += len(mappings_tmp)
|
| 474 |
+
total_skipped += skipped
|
| 475 |
+
if skipped > 0:
|
| 476 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 477 |
+
if total_skipped > 0:
|
| 478 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 479 |
+
if len(ALLdata) < 1000:
|
| 480 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 481 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 482 |
return ALLdata
|
| 483 |
|
| 484 |
def __len__(self):
|
|
|
|
| 529 |
return [volume, embd]
|
| 530 |
|
| 531 |
class OminiDataset_paired(Dataset):
|
| 532 |
+
def __init__(self, out_sz = 128, transform=None, clamp_range = CLAMP_RANGE, min_crop_ratio = 0.85, ROIs = None, modality = None, reverse_axis_order = False):
|
| 533 |
# self.mappings = mapping_files
|
| 534 |
self.ALLdata = self.combine_data(mappings=mapping_files)
|
| 535 |
self.out_sz = out_sz
|
|
|
|
| 558 |
|
| 559 |
def combine_data(self, mappings = mapping_files):
|
| 560 |
ALLdata = {}
|
| 561 |
+
total_entries = 0
|
| 562 |
+
total_skipped = 0
|
| 563 |
for j in mappings.keys():
|
| 564 |
with open(mappings[j], 'r') as f:
|
| 565 |
mappings_tmp = json.load(f)
|
| 566 |
+
skipped = 0
|
| 567 |
+
for k, v in mappings_tmp.items():
|
| 568 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 569 |
+
skipped += 1
|
| 570 |
+
continue
|
| 571 |
+
ALLdata[k] = v
|
| 572 |
+
accessible = len(mappings_tmp) - skipped
|
| 573 |
+
total_entries += len(mappings_tmp)
|
| 574 |
+
total_skipped += skipped
|
| 575 |
+
if skipped > 0:
|
| 576 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 577 |
+
if total_skipped > 0:
|
| 578 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 579 |
+
if len(ALLdata) < 1000:
|
| 580 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 581 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 582 |
return ALLdata
|
| 583 |
|
| 584 |
def normalize(self, volume, eps=1e-7):
|
|
|
|
| 797 |
|
| 798 |
def combine_data(self, mappings = mapping_files):
|
| 799 |
ALLdata = {}
|
| 800 |
+
total_entries = 0
|
| 801 |
+
total_skipped = 0
|
| 802 |
for j in mappings.keys():
|
| 803 |
with open(mappings[j], 'r') as f:
|
| 804 |
mappings_tmp = json.load(f)
|
| 805 |
+
skipped = 0
|
| 806 |
+
for k, v in mappings_tmp.items():
|
| 807 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 808 |
+
skipped += 1
|
| 809 |
+
continue
|
| 810 |
+
ALLdata[k] = v
|
| 811 |
+
accessible = len(mappings_tmp) - skipped
|
| 812 |
+
total_entries += len(mappings_tmp)
|
| 813 |
+
total_skipped += skipped
|
| 814 |
+
if skipped > 0:
|
| 815 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 816 |
+
if total_skipped > 0:
|
| 817 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 818 |
+
if len(ALLdata) < 1000:
|
| 819 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 820 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 821 |
return ALLdata
|
| 822 |
|
| 823 |
def normalize(self, volume, eps=1e-7):
|
|
|
|
| 978 |
|
| 979 |
paired_key = random.choice(paired_keys)
|
| 980 |
|
| 981 |
+
# print(f"Key: {key}, Paired Key: {paired_key}")
|
| 982 |
+
# print(f"ROI: {self.ALLdata_filtered[key]['ROI']}, {self.ALLdata_filtered[paired_key]['ROI']}; Modality: {self.ALLdata_filtered[key]['Modality']}, {self.ALLdata_filtered[paired_key]['Modality']}")
|
| 983 |
|
| 984 |
|
| 985 |
volume_B = sitk.ReadImage(paired_key)
|
|
|
|
| 1071 |
|
| 1072 |
def combine_data(self, mappings = mapping_files):
|
| 1073 |
ALLdata = {}
|
| 1074 |
+
total_entries = 0
|
| 1075 |
+
total_skipped = 0
|
| 1076 |
for j in mappings.keys():
|
| 1077 |
with open(mappings[j], 'r') as f:
|
| 1078 |
mappings_tmp = json.load(f)
|
| 1079 |
+
skipped = 0
|
| 1080 |
+
for k, v in mappings_tmp.items():
|
| 1081 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 1082 |
+
skipped += 1
|
| 1083 |
+
continue
|
| 1084 |
+
ALLdata[k] = v
|
| 1085 |
+
accessible = len(mappings_tmp) - skipped
|
| 1086 |
+
total_entries += len(mappings_tmp)
|
| 1087 |
+
total_skipped += skipped
|
| 1088 |
+
if skipped > 0:
|
| 1089 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 1090 |
+
if total_skipped > 0:
|
| 1091 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 1092 |
+
if len(ALLdata) < 1000:
|
| 1093 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 1094 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 1095 |
return ALLdata
|
| 1096 |
|
| 1097 |
def __len__(self):
|
|
|
|
| 1328 |
|
| 1329 |
def combine_data(self, mappings = mapping_files):
|
| 1330 |
ALLdata = {}
|
| 1331 |
+
total_entries = 0
|
| 1332 |
+
total_skipped = 0
|
| 1333 |
for j in mappings.keys():
|
| 1334 |
with open(mappings[j], 'r') as f:
|
| 1335 |
mappings_tmp = json.load(f)
|
| 1336 |
+
skipped = 0
|
| 1337 |
+
for k, v in mappings_tmp.items():
|
| 1338 |
+
if not os.path.exists(k) or os.path.getsize(k) == 0:
|
| 1339 |
+
skipped += 1
|
| 1340 |
+
continue
|
| 1341 |
+
ALLdata[k] = v
|
| 1342 |
+
accessible = len(mappings_tmp) - skipped
|
| 1343 |
+
total_entries += len(mappings_tmp)
|
| 1344 |
+
total_skipped += skipped
|
| 1345 |
+
if skipped > 0:
|
| 1346 |
+
print(f" WARNING: {j}: {accessible}/{len(mappings_tmp)} accessible ({skipped} missing/empty)")
|
| 1347 |
+
if total_skipped > 0:
|
| 1348 |
+
print(f" DATA LOADING WARNING: {len(ALLdata)}/{total_entries} total files accessible ({total_skipped} missing)")
|
| 1349 |
+
if len(ALLdata) < 1000:
|
| 1350 |
+
print(f" *** CRITICAL WARNING: Only {len(ALLdata)} files loaded! Expected ~15000+. "
|
| 1351 |
+
f"Check that data paths in nifty_mappings/ JSON files are accessible from this node. ***")
|
| 1352 |
return ALLdata
|
| 1353 |
|
| 1354 |
def normalize(self, volume, eps=1e-7):
|
|
|
|
| 1515 |
# print(f"Label with channels, pad_width_lab: {pad_width_lab}")
|
| 1516 |
else:
|
| 1517 |
pad_width_lab = pad_width
|
| 1518 |
+
|
| 1519 |
label = self.apply_pad_crop(label, pad_width_lab, crop_slices)
|
| 1520 |
# print(f"After pad and crop, label shape: {label.shape}, key: {key}, label key: {lk}")
|
| 1521 |
label_dict[lk] = resize(label,[self.out_sz]*self.ndims, anti_aliasing = False, preserve_range = True, order=0)
|
|
|
|
| 1544 |
return return_dict
|
| 1545 |
|
| 1546 |
|
| 1547 |
+
|
| 1548 |
class OminiDataset_bertembd(OminiDataset):
|
| 1549 |
def __init__(self,
|
| 1550 |
out_sz = 128,
|
|
|
|
| 1556 |
reverse_axis_order = False,
|
| 1557 |
min_dim = 3,
|
| 1558 |
mapping_files = mapping_files):
|
| 1559 |
+
super().__init__(out_sz = out_sz,
|
| 1560 |
transform = transform,
|
| 1561 |
clamp_range = clamp_range,
|
| 1562 |
min_crop_ratio = min_crop_ratio,
|
Dataloader/dataloader_utils.py
CHANGED
|
@@ -48,9 +48,9 @@ def get_sizeRange_dict(roi=''):
|
|
| 48 |
'abdomen': [240, 1024],
|
| 49 |
'pelvis': [220, 1024],
|
| 50 |
'thorax': [220, 1024],
|
| 51 |
-
'arm': [
|
| 52 |
-
'hand': [
|
| 53 |
-
'leg': [
|
| 54 |
'skeleton': [130, 1024],
|
| 55 |
}
|
| 56 |
if roi in sizeRange_dict:
|
|
|
|
| 48 |
'abdomen': [240, 1024],
|
| 49 |
'pelvis': [220, 1024],
|
| 50 |
'thorax': [220, 1024],
|
| 51 |
+
'arm': [100, 1024],
|
| 52 |
+
'hand': [100, 1024],
|
| 53 |
+
'leg': [100, 1024],
|
| 54 |
'skeleton': [130, 1024],
|
| 55 |
}
|
| 56 |
if roi in sizeRange_dict:
|
Dataloader/deal_with_json.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 3 |
+
sys.path.append(ROOT_DIR)
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
# CORRECT_DATA_PATH = os.path.join(ROOT_DIR, '../..')
|
| 7 |
+
# CORRECT_DATA_PATH = os.path.join('/hy-tmp')
|
| 8 |
+
CORRECT_DATA_PATH = '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D'
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def traverse_and_print(data, path=()):
|
| 12 |
+
for key, value in data.items():
|
| 13 |
+
current_path = path + (key,)
|
| 14 |
+
|
| 15 |
+
if isinstance(key, str) and 'DATASETS' in key:
|
| 16 |
+
print(f"KEY (str): {key}")
|
| 17 |
+
|
| 18 |
+
if isinstance(value, str) and 'DATASETS' in value:
|
| 19 |
+
print(f" VALUE (str): {value}")
|
| 20 |
+
elif isinstance(value, dict):
|
| 21 |
+
traverse_and_print(value, current_path)
|
| 22 |
+
|
| 23 |
+
def traverse_and_check(data, path=()):
|
| 24 |
+
failed_files = []
|
| 25 |
+
for key, value in data.items():
|
| 26 |
+
current_path = path + (key,)
|
| 27 |
+
|
| 28 |
+
if isinstance(key, str) and 'DATASETS_processed' in key:
|
| 29 |
+
if os.path.isfile(key):
|
| 30 |
+
print(f'\rCheck pass: {key}', end='')
|
| 31 |
+
else:
|
| 32 |
+
print(f'\rCheck fail ! : {key}')
|
| 33 |
+
failed_files.append(key)
|
| 34 |
+
|
| 35 |
+
if isinstance(value, str) and 'DATASETS_processed' in value:
|
| 36 |
+
if os.path.isfile(value):
|
| 37 |
+
print(f'\rCheck pass: {value}', end='')
|
| 38 |
+
else:
|
| 39 |
+
print(f'\rCheck fail ! : {value}')
|
| 40 |
+
failed_files.append(value)
|
| 41 |
+
elif isinstance(value, dict):
|
| 42 |
+
traverse_and_check(value, current_path)
|
| 43 |
+
|
| 44 |
+
if failed_files != []:
|
| 45 |
+
print(f'\nCheck finished. Failed files: {failed_files}')
|
| 46 |
+
return False
|
| 47 |
+
else:
|
| 48 |
+
print('\nAll files check passed!')
|
| 49 |
+
return True
|
| 50 |
+
|
| 51 |
+
def traverse_and_revise(data, path=()):
|
| 52 |
+
what_need_change = [
|
| 53 |
+
'/home/jachin/data/Github/data/data_gen_def',
|
| 54 |
+
'/home/data/Github/data/data_gen_def',
|
| 55 |
+
]
|
| 56 |
+
for key, value in list(data.items()):
|
| 57 |
+
current_path = path + (key,)
|
| 58 |
+
|
| 59 |
+
new_key = key
|
| 60 |
+
if isinstance(key, str) and 'data_gen_def' in key:
|
| 61 |
+
for wnc in what_need_change:
|
| 62 |
+
if wnc in key:
|
| 63 |
+
new_key = key.replace(wnc, CORRECT_DATA_PATH)
|
| 64 |
+
|
| 65 |
+
# change keys
|
| 66 |
+
data[new_key] = data.pop(key)
|
| 67 |
+
value = data[new_key]
|
| 68 |
+
current_path = path + (new_key,)
|
| 69 |
+
|
| 70 |
+
if isinstance(value, str) and 'data_gen_def' in value:
|
| 71 |
+
for wnc in what_need_change:
|
| 72 |
+
if wnc in value:
|
| 73 |
+
data[new_key] = value.replace(wnc, CORRECT_DATA_PATH)
|
| 74 |
+
|
| 75 |
+
elif isinstance(value, dict):
|
| 76 |
+
traverse_and_revise(value, current_path)
|
| 77 |
+
|
| 78 |
+
return data
|
| 79 |
+
|
| 80 |
+
def traverse_and_rename_label(data, old_label, new_label, task_keys=("segmentation", "registration")):
|
| 81 |
+
"""Rename a label key inside Label_path -> segmentation/registration for every entry.
|
| 82 |
+
|
| 83 |
+
Example: rename "brain" -> "brain_tumour" to fix the BraTS mislabel.
|
| 84 |
+
"""
|
| 85 |
+
count = 0
|
| 86 |
+
for key, value in data.items():
|
| 87 |
+
if not isinstance(value, dict):
|
| 88 |
+
continue
|
| 89 |
+
label_path = value.get("Label_path")
|
| 90 |
+
if isinstance(label_path, dict):
|
| 91 |
+
for tk in task_keys:
|
| 92 |
+
task_dict = label_path.get(tk)
|
| 93 |
+
if isinstance(task_dict, dict) and old_label in task_dict:
|
| 94 |
+
task_dict[new_label] = task_dict.pop(old_label)
|
| 95 |
+
count += 1
|
| 96 |
+
else:
|
| 97 |
+
# recurse into nested dicts
|
| 98 |
+
count += traverse_and_rename_label(value, old_label, new_label, task_keys)
|
| 99 |
+
return count
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
mapping_files = {
|
| 103 |
+
'MSD': 'nifty_mappings/MSD_mappings.json',
|
| 104 |
+
'TotalSegmentor': 'nifty_mappings/TotalSegmentorCT_MRI_mappings.json',
|
| 105 |
+
'Kaggle_osic': 'nifty_mappings/Kaggle_osic_mappings.json',
|
| 106 |
+
'CancerImageArchive': 'nifty_mappings/CIA_mappings.json',
|
| 107 |
+
'MnMs': 'nifty_mappings/MnMs_mappings.json',
|
| 108 |
+
'Brats2019': 'nifty_mappings/Brats2019_mappings.json',
|
| 109 |
+
'Brats2020': 'nifty_mappings/Brats2020_mappings.json',
|
| 110 |
+
'Brats2021': 'nifty_mappings/Brats2021_mappings.json',
|
| 111 |
+
'OASIS_1': 'nifty_mappings/OASIS_1_mappings.json',
|
| 112 |
+
'OASIS_2': 'nifty_mappings/OASIS_2_mappings.json',
|
| 113 |
+
'PSMA-FDG-PET-CT-LESION':'nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json',
|
| 114 |
+
'PSMA-CT':'nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 115 |
+
'AbdomenAtlas':'nifty_mappings/AbdomenAtlas_mappings.json',
|
| 116 |
+
'AbdomenCT1k':'nifty_mappings/AbdomenCT1k_mappings.json',
|
| 117 |
+
}
|
| 118 |
+
for k,v in mapping_files.items():
|
| 119 |
+
mapping_files[k] = os.path.join(ROOT_DIR, v)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
# --- Fix BraTS / MSD mislabel: "brain" -> "brain_tumour" ---
|
| 124 |
+
rename_datasets = ['Brats2019', 'Brats2020', 'Brats2021', 'MSD']
|
| 125 |
+
for ds_name in rename_datasets:
|
| 126 |
+
if ds_name not in mapping_files:
|
| 127 |
+
continue
|
| 128 |
+
v = mapping_files[ds_name]
|
| 129 |
+
with open(v, 'r') as f:
|
| 130 |
+
mappings_tmp = json.load(f)
|
| 131 |
+
n = traverse_and_rename_label(mappings_tmp, 'brain', 'brain_tumour')
|
| 132 |
+
if n > 0:
|
| 133 |
+
with open(v, 'w') as f:
|
| 134 |
+
json.dump(mappings_tmp, f, indent=4)
|
| 135 |
+
print(f'[{ds_name}] Renamed "brain" -> "brain_tumour" in {n} entries, saved to {v}')
|
| 136 |
+
else:
|
| 137 |
+
print(f'[{ds_name}] No "brain" labels found (already renamed?)')
|
| 138 |
+
|
| 139 |
+
# --- Path revision (uncomment to run) ---
|
| 140 |
+
# for k,v in mapping_files.items():
|
| 141 |
+
# with open(v, 'r') as f:
|
| 142 |
+
# mappings_tmp = json.load(f)
|
| 143 |
+
# new_mappings_tmp = traverse_and_revise(mappings_tmp)
|
| 144 |
+
# # traverse_and_print(new_mappings_tmp)
|
| 145 |
+
# # all_good = traverse_and_check(new_mappings_tmp)
|
| 146 |
+
# # save in-place
|
| 147 |
+
# with open(v, 'w') as f:
|
| 148 |
+
# json.dump(new_mappings_tmp, f, indent=4)
|
| 149 |
+
# print(f'Saved revised mapping to {v}')
|
| 150 |
+
|
Dataloader/embding_gen.py
CHANGED
|
@@ -23,7 +23,9 @@ mapping_files = {
|
|
| 23 |
# 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
|
| 24 |
# 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
|
| 25 |
# 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
|
| 26 |
-
'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
|
|
|
|
|
|
|
| 27 |
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
|
| 28 |
# 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
|
| 29 |
# 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
|
|
@@ -45,6 +47,8 @@ save_paths = {
|
|
| 45 |
'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 46 |
'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
|
| 47 |
'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
query = {
|
| 50 |
'MSD': ['description'],
|
|
@@ -61,6 +65,8 @@ query = {
|
|
| 61 |
'PSMA-CT':[],
|
| 62 |
'AbdomenAtlas':[],
|
| 63 |
'AbdomenCT1k':[],
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
add_text = {
|
| 66 |
'MSD': {},
|
|
@@ -77,11 +83,13 @@ add_text = {
|
|
| 77 |
'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
|
| 78 |
'AbdomenAtlas':{},
|
| 79 |
'AbdomenCT1k':{},
|
|
|
|
|
|
|
| 80 |
}
|
| 81 |
|
| 82 |
|
| 83 |
# bert intialization
|
| 84 |
-
model_name = '/
|
| 85 |
reduce_method = 'mean'
|
| 86 |
max_words_num = 32 # max number of words in the caption > 2
|
| 87 |
# max_words_num = 64 # max number of words in the caption > 2
|
|
|
|
| 23 |
# 'Brats2020': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2020/nifti_mappings.json',
|
| 24 |
# 'Brats2021': '/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/BRATS/BRATS2021/nifti_mappings.json',
|
| 25 |
# 'OASIS_1': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_1/CS_SECTIONAL/nifti_mappings.json',
|
| 26 |
+
# 'OASIS_2': '/home/data/Github/data/data_gen_def/DATASETS_processed/OASIS/OASIS_2/RAW_V2/nifti_mappings.json',
|
| 27 |
+
'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
|
| 28 |
+
'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Data/Omini3D/DATASETS_processed/OAI_ZIB/nifti_mappings.json',
|
| 29 |
# 'PSMA-FDG-PET-CT-LESION':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/PSMA-FDG-PET-CT-LESION/V2/nifti_mappings.json',
|
| 30 |
# 'PSMA-CT':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/PSMA/Longitudinal-CT/nifti_mappings.json',
|
| 31 |
# 'AbdomenAtlas':'/home/jachin/data/Github/data/data_gen_def/DATASETS_processed/AbdomenAtlas_v2/nifti_mappings.json',
|
|
|
|
| 47 |
'PSMA-CT':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json',
|
| 48 |
'AbdomenAtlas':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenAtlas_mappings.json',
|
| 49 |
'AbdomenCT1k':'/home/jachin/data/Github/OmniMorph/Dataloader/nifty_mappings/AbdomenCT1k_mappings.json',
|
| 50 |
+
'OAI_ZIB_KL': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json',
|
| 51 |
+
'OAI_ZIB_WOMAC': '/home/dn-zhen2/rds/rds-airr-p51-TWhPgQVLKbA/Code/OmniMorph/Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json',
|
| 52 |
}
|
| 53 |
query = {
|
| 54 |
'MSD': ['description'],
|
|
|
|
| 65 |
'PSMA-CT':[],
|
| 66 |
'AbdomenAtlas':[],
|
| 67 |
'AbdomenCT1k':[],
|
| 68 |
+
'OAI_ZIB_KL': ['Age', 'Gender', 'KL_Grade', 'BMI'],
|
| 69 |
+
'OAI_ZIB_WOMAC': ['Age', 'Gender', 'WOMAC_Pain', 'WOMAC_ADL', 'WOMAC_Stiffness', 'BMI'],
|
| 70 |
}
|
| 71 |
add_text = {
|
| 72 |
'MSD': {},
|
|
|
|
| 83 |
'PSMA-FDG-PET-CT-LESION':{'description': 'malignant melanoma, lymphoma, lung cancer, or healthy'},
|
| 84 |
'AbdomenAtlas':{},
|
| 85 |
'AbdomenCT1k':{},
|
| 86 |
+
'OAI_ZIB_KL': {'description': 'right knee osteoarthritis'},
|
| 87 |
+
'OAI_ZIB_WOMAC': {'description': 'right knee osteoarthritis'},
|
| 88 |
}
|
| 89 |
|
| 90 |
|
| 91 |
# bert intialization
|
| 92 |
+
model_name = '/rds/project/rds-TWhPgQVLKbA/Code/OmniMorph/External/Models/bert_large_uncased'
|
| 93 |
reduce_method = 'mean'
|
| 94 |
max_words_num = 32 # max number of words in the caption > 2
|
| 95 |
# max_words_num = 64 # max number of words in the caption > 2
|
Dataloader/nifty_mappings/AbdomenAtlas_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6000e9ba6b4fac278a1288826696ab7d5f77c97929d7e001dfb8938d7d5aa0a8
|
| 3 |
+
size 182087319
|
Dataloader/nifty_mappings/AbdomenCT1k_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a36ccd80e859aefd7334fb99ebca10601bb39be9e6432a1f59b4e98e9c4069a8
|
| 3 |
+
size 30687976
|
Dataloader/nifty_mappings/Brats2019_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f128806b4673b7e1219990f0e2c5732abd1080fd4de271195fa74538c32ab70
|
| 3 |
+
size 12178080
|
Dataloader/nifty_mappings/Brats2020_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:90659bf584857b9e543163431e3730c6e6ce229b3386dc8ab13e7411a6b00c78
|
| 3 |
+
size 17815563
|
Dataloader/nifty_mappings/Brats2021_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c758b9cfb8190f3b77eef03ea93a43f95e2d9e89dae4b08f6ae4dabc65024b97
|
| 3 |
+
size 44888384
|
Dataloader/nifty_mappings/CIA_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1aef79728ee6d2ab15ab7225a52d5e437cd10d33cfdcbb6f4d9c2aee1687d5f3
|
| 3 |
+
size 32803157
|
Dataloader/nifty_mappings/Kaggle_osic_mappings.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dataloader/nifty_mappings/MSD_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8b777fb0d1ab09b22dcb3048b25cf60a31ccc30749888f1f02d7dc4b43715ad6
|
| 3 |
+
size 92732794
|
Dataloader/nifty_mappings/MnMs_mappings.json
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Dataloader/nifty_mappings/OAI_ZIB_KL_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5ab4159932276f0ccd52efe44986ed184b504162f568cec68fc76fa0769efad
|
| 3 |
+
size 18096063
|
Dataloader/nifty_mappings/OAI_ZIB_WOMAC_mappings.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4dad37ced9f1dbe3819dd6ac0d51b6585c25e641b4d07352d706aaf3ac17c19a
|
| 3 |
+
size 18119154
|
Dataloader/nifty_mappings/OASIS_1_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a39ccde5fe81bd7b2b5fa1cc64feb7094ff83851bfd40a5287e01d817e45db59
|
| 3 |
+
size 15646470
|
Dataloader/nifty_mappings/OASIS_2_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7665f7769ef262f1758af1cf42e1610f211c53d35a625a457c5a50bca3841757
|
| 3 |
+
size 13440390
|
Dataloader/nifty_mappings/PSMA-CT-Longitud_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ebd252fec7062df77452b0bdeab47013314aba638cf0b0de295bc62748d2cfec
|
| 3 |
+
size 11728536
|
Dataloader/nifty_mappings/PSMA-FDG-PET-CT-LESION_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cab3cbb5a5a651e1c3446079a3c18b944ed1893893ccd25451c110f13eebe4cc
|
| 3 |
+
size 48538337
|
Dataloader/nifty_mappings/TotalSegmentorCT_MRI_mappings.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a922ecc5c136bcc3427f81e970d1cdd02e3b6c61bedc198e99b6fec8c380b4c3
|
| 3 |
+
size 69966911
|
Diffusion/diffuser-reg.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from torch.nn.utils.stateless import functional_call
|
| 5 |
+
|
| 6 |
+
import Diffusion.utils_diff as utils
|
| 7 |
+
from Diffusion.networks import *
|
| 8 |
+
# from networks import *
|
| 9 |
+
|
| 10 |
+
import random
|
| 11 |
+
|
| 12 |
+
EPS = 1e-8
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class DeformDDPM(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
network,
|
| 20 |
+
n_steps=50,
|
| 21 |
+
beta_schedule_fn = None,
|
| 22 |
+
device='cpu',
|
| 23 |
+
image_chw=(1, 28, 28),
|
| 24 |
+
batch_size = 1,
|
| 25 |
+
img_pad_mode = "zeros",
|
| 26 |
+
ddf_pad_mode="border",
|
| 27 |
+
padding_mode="border",
|
| 28 |
+
v_scale = 0.008/256,
|
| 29 |
+
resample_mode=None,
|
| 30 |
+
inf_mode = False,
|
| 31 |
+
):
|
| 32 |
+
super(DeformDDPM, self).__init__()
|
| 33 |
+
self.rec_num=2
|
| 34 |
+
self.ndims=len(image_chw)-1
|
| 35 |
+
self.n_steps = n_steps
|
| 36 |
+
self.v_scale = v_scale
|
| 37 |
+
self.device = device
|
| 38 |
+
self.msk_noise_scale = torch.tensor(0)
|
| 39 |
+
# self.msk_noise_scale = torch.tensor(1)
|
| 40 |
+
|
| 41 |
+
# print('================')
|
| 42 |
+
# print("device:",device)
|
| 43 |
+
# if device == 'cpu':
|
| 44 |
+
# print("num_device: 1")
|
| 45 |
+
# else:
|
| 46 |
+
# print("num_device:", torch.cuda.device_count())
|
| 47 |
+
# print('================')
|
| 48 |
+
|
| 49 |
+
self.num_device = torch.cuda.device_count()
|
| 50 |
+
|
| 51 |
+
self.batch_size = batch_size #//self.num_device
|
| 52 |
+
self.img_pad_mode = img_pad_mode
|
| 53 |
+
self.ddf_pad_mode = ddf_pad_mode
|
| 54 |
+
self.padding_mode = padding_mode
|
| 55 |
+
self.resample_mode = resample_mode
|
| 56 |
+
self.image_chw = image_chw
|
| 57 |
+
self.network = network#.to(self.device)
|
| 58 |
+
self.ddf_stn_full = STN(
|
| 59 |
+
img_sz = self.image_chw[1],
|
| 60 |
+
ndims = self.ndims,
|
| 61 |
+
padding_mode = self.padding_mode,
|
| 62 |
+
device = self.device,
|
| 63 |
+
)
|
| 64 |
+
self._DDF_Encoder_init()
|
| 65 |
+
self.copy_opt = nn.Identity()
|
| 66 |
+
self.inf_mode = inf_mode
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
def get_stn(self):
|
| 70 |
+
return self.img_stn, self.ddf_stn_full
|
| 71 |
+
|
| 72 |
+
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
|
| 73 |
+
if ctl_sz is None:
|
| 74 |
+
ctl_sz = self.image_chw[1] // ctl_ratio
|
| 75 |
+
self.ctl_sz=ctl_sz
|
| 76 |
+
self.img_sz=self.image_chw[1]
|
| 77 |
+
self.ddf_stn_rec=STN(img_sz=ctl_sz,ndims=self.ndims,device=self.device,padding_mode=self.ddf_pad_mode)
|
| 78 |
+
self.img_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode=self.resample_mode)
|
| 79 |
+
self.msk_stn=STN(img_sz=self.img_sz,ndims=self.ndims,device=self.device,padding_mode=self.img_pad_mode,resample_mode='nearest')
|
| 80 |
+
|
| 81 |
+
def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
|
| 82 |
+
rec_num = 1
|
| 83 |
+
mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
|
| 84 |
+
mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
|
| 85 |
+
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 86 |
+
# mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
|
| 87 |
+
# mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
|
| 88 |
+
mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
|
| 89 |
+
mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
|
| 90 |
+
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 91 |
+
return rec_num,mul_num_ddf,mul_num_dvf
|
| 92 |
+
|
| 93 |
+
# def _sample_random_uniform_multi_order(self, high=None, low=0, order_num=3):
|
| 94 |
+
# # high: tensor of shape (...), low: int or tensor broadcastable to high
|
| 95 |
+
# sample_num = torch.full_like(high, low) if not isinstance(low, torch.Tensor) else low.clone()
|
| 96 |
+
# for _ in range(order_num):
|
| 97 |
+
# # For each element, sample in [sample_num, high]
|
| 98 |
+
# # torch.randint requires scalar low/high, so we use elementwise sampling
|
| 99 |
+
# rand_shape = high.shape
|
| 100 |
+
# # Clamp sample_num to be <= high
|
| 101 |
+
# sample_num = torch.minimum(sample_num, high)
|
| 102 |
+
# # Generate random numbers for each element
|
| 103 |
+
# rand = torch.empty(rand_shape, dtype=high.dtype, device=high.device)
|
| 104 |
+
# for idx in np.ndindex(rand_shape):
|
| 105 |
+
# l = sample_num[idx].item()
|
| 106 |
+
# h = high[idx].item()
|
| 107 |
+
# if l >= h:
|
| 108 |
+
# rand[idx] = l
|
| 109 |
+
# else:
|
| 110 |
+
# rand[idx] = torch.randint(l, h + 1, (1,), device=high.device)
|
| 111 |
+
# sample_num = rand.to(high.dtype)
|
| 112 |
+
# return sample_num
|
| 113 |
+
|
| 114 |
+
def _get_random_ddf(self,img,t):
|
| 115 |
+
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| 116 |
+
ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
|
| 117 |
+
warped_img = self.img_stn(img,ddf_forward)
|
| 118 |
+
return warped_img, dvf_forward,ddf_forward
|
| 119 |
+
|
| 120 |
+
def _multiscale_dvf_generate(self,v_scale,ctl_szs=[4,8,16,32,64], rand_v_scale=True):
|
| 121 |
+
dvf=0
|
| 122 |
+
if self.img_sz is None:
|
| 123 |
+
self.img_sz=max(ctl_szs)
|
| 124 |
+
if 1 in ctl_szs:
|
| 125 |
+
dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
|
| 126 |
+
dvf = dvf + dvf_rot
|
| 127 |
+
for ctl_sz in ctl_szs:
|
| 128 |
+
_v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale
|
| 129 |
+
# temp>>
|
| 130 |
+
if ctl_sz <= 2:
|
| 131 |
+
_v_scale = _v_scale/2
|
| 132 |
+
# temp<<
|
| 133 |
+
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz]*self.ndims) * _v_scale
|
| 134 |
+
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz]*self.ndims, align_corners=False, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 135 |
+
dvf=dvf+dvf_comp
|
| 136 |
+
return dvf
|
| 137 |
+
|
| 138 |
+
def _sample_random_uniform_multi_order(self, high=None, low=0., order_num=3):
|
| 139 |
+
sample_value = low
|
| 140 |
+
for _ in range(order_num):
|
| 141 |
+
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 142 |
+
return sample_value
|
| 143 |
+
|
| 144 |
+
def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=4, flip_ratio=0.5):
|
| 145 |
+
crop_rate=2
|
| 146 |
+
for _ in range(self.ndims+1):
|
| 147 |
+
mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
|
| 148 |
+
# v_scale = v_scale *crop_rate
|
| 149 |
+
ctl_ddf_sz=[self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| 150 |
+
if ddf0 is not None:
|
| 151 |
+
ddf=ddf0
|
| 152 |
+
else:
|
| 153 |
+
ddf = torch.zeros(ctl_ddf_sz) * 0
|
| 154 |
+
dddf = torch.zeros(ctl_ddf_sz) * 0
|
| 155 |
+
scale_num = min(8,int(math.log2(self.ctl_sz))) # allow affine
|
| 156 |
+
# scale_num = min(5,int(math.log2(self.ctl_sz))-1) # semi-allow affine
|
| 157 |
+
# scale_num = min(5,int(math.log2(self.ctl_sz))-2) # avoid coupling between deformation and affine
|
| 158 |
+
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
| 159 |
+
|
| 160 |
+
for i in range(rec_num):
|
| 161 |
+
# Randomly select 5 elements from ctl_szs (if there are at least 5)
|
| 162 |
+
if len(ctl_szs_all) > select_num:
|
| 163 |
+
ctl_szs = random.sample(ctl_szs_all, select_num)
|
| 164 |
+
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
|
| 165 |
+
# if True:
|
| 166 |
+
if noise_ratio==0:
|
| 167 |
+
dvf0=dvf
|
| 168 |
+
else:
|
| 169 |
+
dvf0=dvf+self.ddf_stn_rec(self._multiscale_dvf_generate(self.v_scale*noise_ratio,ctl_szs=ctl_szs, rand_v_scale=False).to(self.device),dvf)
|
| 170 |
+
# print([num.shape for num in mul_num])
|
| 171 |
+
for j in range(torch.max(mul_num[0]).item()):
|
| 172 |
+
flag = [(n>j).int().to(self.device) for n in mul_num]
|
| 173 |
+
ddf = dvf0*flag[0] + self.ddf_stn_rec(ddf, dvf0*flag[0])
|
| 174 |
+
dddf = dvf*flag[1] + self.ddf_stn_rec(dddf, dvf*flag[1])
|
| 175 |
+
|
| 176 |
+
ddf = F.interpolate(ddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 177 |
+
# ddf = ddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
|
| 178 |
+
if self.ndims==2:
|
| 179 |
+
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 180 |
+
else:
|
| 181 |
+
ddf = ddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 182 |
+
# if rec_num==1:
|
| 183 |
+
if True:
|
| 184 |
+
dddf = F.interpolate(dddf * self.img_sz/self.ctl_sz, self.img_sz*crop_rate, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 185 |
+
# dddf = dddf[...,img_sz//2:img_sz*3//2,img_sz//2:img_sz*3//2]
|
| 186 |
+
if self.ndims == 2:
|
| 187 |
+
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 188 |
+
else:
|
| 189 |
+
dddf = dddf[..., self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2, self.img_sz // 2:self.img_sz * 3 // 2]
|
| 190 |
+
return ddf,dddf
|
| 191 |
+
else:
|
| 192 |
+
return ddf
|
| 193 |
+
|
| 194 |
+
def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
|
| 195 |
+
if noise_type == 'gaussian':
|
| 196 |
+
noise_map = torch.randn_like(img) * noise_scale
|
| 197 |
+
elif noise_type == 'uniform':
|
| 198 |
+
noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1
|
| 199 |
+
elif noise_type == 'binary':
|
| 200 |
+
noise_map = torch.bernoulli(torch.rand_like(img))
|
| 201 |
+
else:
|
| 202 |
+
noise_map = torch.zeros_like(img)
|
| 203 |
+
noise_map = noise_map.to(img.device)
|
| 204 |
+
return noise_map
|
| 205 |
+
|
| 206 |
+
def add_noise(self, img, noise_map=None, noise_ratio_range=[0.,1.]):
|
| 207 |
+
noise_ratio = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
|
| 208 |
+
return img * (1-noise_ratio) + noise_map * noise_ratio, noise_ratio
|
| 209 |
+
|
| 210 |
+
def apply_noise(self, img, noise_map=None, apply_mask=None):
|
| 211 |
+
return img * apply_mask + noise_map * (1-apply_mask)
|
| 212 |
+
|
| 213 |
+
def downsample(self, img, down_ratio_range=[1./32,1]):
|
| 214 |
+
down_ratio = list(np.random.uniform(down_ratio_range[0], down_ratio_range[1],[self.ndims]))
|
| 215 |
+
# print(down_ratio)
|
| 216 |
+
down_img = F.interpolate(img, scale_factor=down_ratio, mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 217 |
+
# print(down_img)
|
| 218 |
+
# return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.prod(down_ratio)
|
| 219 |
+
return F.interpolate(down_img, size=[self.image_chw[1]]*self.ndims, mode='bilinear' if self.ndims == 2 else 'trilinear', align_corners=False), np.sqrt(np.prod(down_ratio)) # jzheng: cond weight based on entropy
|
| 220 |
+
|
| 221 |
+
def get_slice_mask(self, img, slice_num_range=[0,32]):
|
| 222 |
+
slice_num_range[1] = min(slice_num_range[1], self.image_chw[1])
|
| 223 |
+
mask = torch.zeros_like(img)
|
| 224 |
+
sample_ratio = 0
|
| 225 |
+
for i in range(self.ndims):
|
| 226 |
+
if self.inf_mode:
|
| 227 |
+
slice_num = 1 # use max slice num for inference for better performance
|
| 228 |
+
slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
|
| 229 |
+
else:
|
| 230 |
+
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| 231 |
+
slice_idx = random.sample(range(self.image_chw[1]), slice_num)
|
| 232 |
+
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| 233 |
+
for idx in slice_idx:
|
| 234 |
+
mask[..., idx] = 1
|
| 235 |
+
mask = mask.permute(*transpose_list)
|
| 236 |
+
# sample_ratio += slice_num / self.image_chw[1] / self.ndims
|
| 237 |
+
sample_ratio += np.sqrt(slice_num / self.image_chw[1]) / self.ndims # jzheng: cond weight based on entropy
|
| 238 |
+
|
| 239 |
+
# print(mask)
|
| 240 |
+
# print("sample_ratio:", sample_ratio)
|
| 241 |
+
return mask, sample_ratio
|
| 242 |
+
|
| 243 |
+
def project(self, img):
|
| 244 |
+
proj_img = torch.zeros_like(img)
|
| 245 |
+
rand_bourn = np.random.randint(0, 2, size=[self.ndims])
|
| 246 |
+
proj_dim_num = np.sum(rand_bourn)
|
| 247 |
+
for i,pflag in zip(range(2, 2 + self.ndims), rand_bourn):
|
| 248 |
+
if pflag:
|
| 249 |
+
proj_img += torch.mean(img, dim=i, keepdim=True)
|
| 250 |
+
# print("projecting dim:", i)
|
| 251 |
+
return proj_img/(proj_dim_num+EPS), proj_dim_num
|
| 252 |
+
|
| 253 |
+
def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
|
| 254 |
+
# Remove torch.no_grad() since most operations are not differentiable anyway
|
| 255 |
+
proc_img = img.clone().detach()
|
| 256 |
+
if proc_type is None:
|
| 257 |
+
# Heavily bias towards 'uncon' for efficiency
|
| 258 |
+
proc_type = random.choices(
|
| 259 |
+
# ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
|
| 260 |
+
# weights=[1, 1, 1, 1, 1, 1, 3], k=1
|
| 261 |
+
['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
|
| 262 |
+
weights=[1, 1, 1, 1, 1, 3], k=1
|
| 263 |
+
)[0]
|
| 264 |
+
mask = torch.tensor(1, device=img.device)
|
| 265 |
+
cond_ratio = torch.tensor(1., device=img.device)
|
| 266 |
+
self.msk_noise_scale = torch.tensor(0, device=img.device)
|
| 267 |
+
noise_type = random.choice(['gaussian', 'uniform', 'none'])
|
| 268 |
+
# Precompute noise_map only if needed
|
| 269 |
+
noise_map = None
|
| 270 |
+
if proc_type not in ['none', None, '']:
|
| 271 |
+
if proc_type == 'uncon':
|
| 272 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| 273 |
+
proc_img = noise_map
|
| 274 |
+
mask = torch.tensor(0, device=img.device)
|
| 275 |
+
cond_ratio = torch.tensor(0, device=img.device)
|
| 276 |
+
return proc_img, mask, cond_ratio
|
| 277 |
+
if proc_type in ['adding', 'independ', 'slice','slice1']:
|
| 278 |
+
# self.msk_noise_scale = 0
|
| 279 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| 280 |
+
if proc_type == 'adding':
|
| 281 |
+
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| 282 |
+
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| 283 |
+
elif proc_type == 'independ':
|
| 284 |
+
mask = self.create_noise_map(img, noise_type='binary')
|
| 285 |
+
if self.msk_noise_scale == 0:
|
| 286 |
+
proc_img = img * mask
|
| 287 |
+
else:
|
| 288 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
cond_ratio = mask.float().mean()
|
| 291 |
+
elif proc_type == 'downsample':
|
| 292 |
+
# proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
|
| 293 |
+
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
|
| 294 |
+
cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| 295 |
+
elif proc_type == 'slice' or proc_type == 'slice1':
|
| 296 |
+
if proc_type == 'slice1':
|
| 297 |
+
slice_num_max = 1
|
| 298 |
+
else:
|
| 299 |
+
slice_num_max = random.randint(1, 64)
|
| 300 |
+
slice_num_max = random.randint(1, slice_num_max)
|
| 301 |
+
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 302 |
+
if self.msk_noise_scale == 0:
|
| 303 |
+
proc_img = img * mask
|
| 304 |
+
else:
|
| 305 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map*self.msk_noise_scale, apply_mask=mask)
|
| 306 |
+
cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| 307 |
+
elif proc_type == 'project':
|
| 308 |
+
proc_img, proj_num = self.project(proc_img)
|
| 309 |
+
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
|
| 310 |
+
# cond_ratio = torch.tensor(proj_num / (32 * self.ndims), device=img.device) # jzheng: cond weight based on entropy
|
| 311 |
+
return proc_img, mask, cond_ratio
|
| 312 |
+
|
| 313 |
+
def diffuse(self, x_0, t):
|
| 314 |
+
t=torch.tensor(t)
|
| 315 |
+
# img_t, dvf_forward, ddf_forward, ddf_stn, img_stn = self.ddf_enc(img= x_0, t=t)
|
| 316 |
+
# return img_t, dvf_forward,ddf_forward,ddf_stn,img_stn
|
| 317 |
+
return self._get_random_ddf(img = x_0, t = t)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def recover(self, x, y, t,rec_num=2, text=None):
|
| 321 |
+
if isinstance(t, list):
|
| 322 |
+
t=[torch.tensor(t0) for t0 in t]
|
| 323 |
+
t=[t0.to(x.device) for t0 in t]
|
| 324 |
+
else:
|
| 325 |
+
t=torch.tensor(t)
|
| 326 |
+
t.to(x.device)
|
| 327 |
+
if rec_num is None:
|
| 328 |
+
rec_num = self.rec_num
|
| 329 |
+
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
|
| 330 |
+
|
| 331 |
+
def recover_frozen_params_but_grad_input(self, x, y, t,rec_num=2, text=None):
|
| 332 |
+
"""
|
| 333 |
+
use detach to recover:
|
| 334 |
+
- but not include no_grad
|
| 335 |
+
"""
|
| 336 |
+
if isinstance(t, list):
|
| 337 |
+
t = [torch.tensor(t0, device=x.device) for t0 in t]
|
| 338 |
+
else:
|
| 339 |
+
t = torch.tensor(t, device=x.device)
|
| 340 |
+
|
| 341 |
+
if rec_num is None:
|
| 342 |
+
rec_num = self.rec_num
|
| 343 |
+
|
| 344 |
+
# params = {k: v.detach() for k, v in self.network.named_parameters()}
|
| 345 |
+
# buffers = dict(self.network.named_buffers()) # BN running stats etc. buffer
|
| 346 |
+
# # functional_call require position args,here kwargs doesnot work, so:
|
| 347 |
+
# def _forward(module, kw):
|
| 348 |
+
# return module(**kw)
|
| 349 |
+
# # functional_call(module, ...) can only pass args/kwargs to module.forward
|
| 350 |
+
# # PyTorch 2.x support functional_call(module, (params, buffers), args, kwargs)
|
| 351 |
+
# return functional_call(
|
| 352 |
+
# self.network,
|
| 353 |
+
# (params, buffers),
|
| 354 |
+
# args=(),
|
| 355 |
+
# kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
|
| 356 |
+
# )
|
| 357 |
+
|
| 358 |
+
# 1) param detached
|
| 359 |
+
params = {k: v.detach() for k, v in self.network.named_parameters()}
|
| 360 |
+
# 2) buffers keeps unchanged
|
| 361 |
+
buffers = dict(self.network.named_buffers())
|
| 362 |
+
|
| 363 |
+
# 3) old version of PyTorch doesnot support passing params and buffers together
|
| 364 |
+
params_and_buffers = {}
|
| 365 |
+
params_and_buffers.update(params)
|
| 366 |
+
params_and_buffers.update(buffers)
|
| 367 |
+
return functional_call(
|
| 368 |
+
self.network,
|
| 369 |
+
params_and_buffers,
|
| 370 |
+
(),
|
| 371 |
+
kwargs=dict(x=x, y=y, t=t, rec_num=rec_num, text=text),
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _single_step(self, x0, t, rec_num=2, proc_type=None,mask=None, cond_imgs=None, text=None):
|
| 376 |
+
if mask is None:
|
| 377 |
+
mask = 1
|
| 378 |
+
# org_imgs=self.copy_opt(x0)
|
| 379 |
+
if cond_imgs is None:
|
| 380 |
+
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(x0,proc_type=proc_type)
|
| 381 |
+
noisy_imgs, dvf_I,_ = self.diffuse(x0, t)
|
| 382 |
+
if isinstance(self.network,DefRec_MutAttnNet):
|
| 383 |
+
t = [t] * 1
|
| 384 |
+
return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
|
| 385 |
+
|
| 386 |
+
def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
|
| 387 |
+
if T is not None:
|
| 388 |
+
return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 389 |
+
else:
|
| 390 |
+
return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 391 |
+
# if mask is None:
|
| 392 |
+
# mask = 1
|
| 393 |
+
# cond_imgs = self.proc_cond_img(x0, proc_type=proc_type, **kwargs)
|
| 394 |
+
# noisy_imgs, dvf_I, _ = self.diffuse(x0, t)
|
| 395 |
+
# if isinstance(self.network, DefRec_MutAttnNet):
|
| 396 |
+
# t = [t] * 1
|
| 397 |
+
# return self.recover(x=noisy_imgs * mask, y=cond_imgs, t=t, rec_num=rec_num), dvf_I
|
| 398 |
+
|
| 399 |
+
def diff_recover(self,
|
| 400 |
+
img_org,
|
| 401 |
+
msk_org=None,
|
| 402 |
+
T=[None,None],
|
| 403 |
+
ddf_rand=None,
|
| 404 |
+
v_scale = None,
|
| 405 |
+
t_save=None,
|
| 406 |
+
cond_imgs=None,
|
| 407 |
+
proc_type=None,
|
| 408 |
+
text=None,
|
| 409 |
+
):
|
| 410 |
+
if cond_imgs is None:
|
| 411 |
+
cond_imgs = img_org.clone().detach()
|
| 412 |
+
# if proc_type is not None:
|
| 413 |
+
cond_imgs,mask_tgt,cond_ratio=self.proc_cond_img(cond_imgs, proc_type=proc_type)
|
| 414 |
+
if ddf_rand is None:
|
| 415 |
+
if v_scale is not None:
|
| 416 |
+
self.v_scale=v_scale
|
| 417 |
+
self._DDF_Encoder_init()
|
| 418 |
+
if T[0] is None or T[0] == 0:
|
| 419 |
+
img_diff = img_org.clone().detach()
|
| 420 |
+
ddf_rand = torch.zeros_like(img_diff)
|
| 421 |
+
else:
|
| 422 |
+
img_diff, _, ddf_rand = self._get_random_ddf(img= img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
|
| 423 |
+
else:
|
| 424 |
+
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
|
| 425 |
+
ddf_comp = ddf_rand.clone().detach()
|
| 426 |
+
img_rec = img_diff.clone().detach()
|
| 427 |
+
if msk_org is not None:
|
| 428 |
+
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
|
| 429 |
+
else:
|
| 430 |
+
msk_diff = None
|
| 431 |
+
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
|
| 432 |
+
img_save=[]
|
| 433 |
+
msk_save=[]
|
| 434 |
+
|
| 435 |
+
if isinstance(self.network,DefRec_MutAttnNet):
|
| 436 |
+
# Denosing image via list of t
|
| 437 |
+
t_list = list(range(T[1]-1, -1, -1))
|
| 438 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list,rec_num=None, text=text)
|
| 439 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 440 |
+
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 441 |
+
if msk_org is not None:
|
| 442 |
+
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| 443 |
+
else:
|
| 444 |
+
# Denosing image
|
| 445 |
+
if isinstance(T[-1], int):
|
| 446 |
+
time_steps = range(T[-1] - 1, -1, -1)
|
| 447 |
+
trainable_iterations =[]
|
| 448 |
+
else:
|
| 449 |
+
time_steps = T[-1]
|
| 450 |
+
|
| 451 |
+
# # Randomly select k iterations to make their parameters trainable
|
| 452 |
+
# win_len = 2 # Number of iterations to make trainable
|
| 453 |
+
# if len(time_steps) <= win_len:
|
| 454 |
+
# win_start = 0
|
| 455 |
+
# else:
|
| 456 |
+
# win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
|
| 457 |
+
# win_end = win_start + win_len - 1
|
| 458 |
+
|
| 459 |
+
k=2
|
| 460 |
+
# trainable_iterations = time_steps[win_start: win_start + win_len]
|
| 461 |
+
# trainable_iterations = random.sample(time_steps, k)
|
| 462 |
+
trainable_iterations = time_steps[-1:-k-1:-1]
|
| 463 |
+
# print(time_steps)
|
| 464 |
+
# print("trainable_iterations:", trainable_iterations)
|
| 465 |
+
for i in time_steps:
|
| 466 |
+
t = torch.tensor(np.array([i])).to(self.device)
|
| 467 |
+
|
| 468 |
+
if i in trainable_iterations:
|
| 469 |
+
# Make parameters trainable for this iteration
|
| 470 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 471 |
+
else:
|
| 472 |
+
# Freeze parameters for this iteration using torch.no_grad()
|
| 473 |
+
with torch.no_grad():
|
| 474 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 475 |
+
# for idx, i in enumerate(time_steps):
|
| 476 |
+
# t = torch.tensor(np.array([i])).to(self.device)
|
| 477 |
+
# if idx < win_start:
|
| 478 |
+
# # just no_grad
|
| 479 |
+
# with torch.no_grad():
|
| 480 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 481 |
+
# elif win_start <= idx <= win_end:
|
| 482 |
+
# # normal update
|
| 483 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 484 |
+
# else:
|
| 485 |
+
# # freeze params but keep grad for input
|
| 486 |
+
# pre_dvf_I = self.recover_frozen_params_but_grad_input(
|
| 487 |
+
# x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text
|
| 488 |
+
# )
|
| 489 |
+
|
| 490 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 491 |
+
# Apply to image
|
| 492 |
+
img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 493 |
+
if msk_org is not None:
|
| 494 |
+
msk_rec = self.msk_stn(msk_org.clone().detach(), ddf_comp)
|
| 495 |
+
if t_save is not None:
|
| 496 |
+
if i in t_save:
|
| 497 |
+
img_save.append(img_rec)
|
| 498 |
+
if msk_org is not None:
|
| 499 |
+
msk_save.append(msk_rec)
|
| 500 |
+
|
| 501 |
+
# for i in time_steps:
|
| 502 |
+
# t = torch.tensor(np.array([i])).to(self.device)
|
| 503 |
+
# pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t,rec_num=None)
|
| 504 |
+
# ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 505 |
+
# # apply to image
|
| 506 |
+
# img_rec = self.img_stn(img_org.clone().detach(), ddf_comp)
|
| 507 |
+
# if msk_org is not None:
|
| 508 |
+
# msk_rec = self.img_stn(msk_org.clone().detach(), ddf_comp)
|
| 509 |
+
# if t_save is not None:
|
| 510 |
+
# if i in t_save:
|
| 511 |
+
# img_save.append(img_rec)
|
| 512 |
+
# if msk_org is not None:
|
| 513 |
+
# msk_save.append(msk_rec)
|
| 514 |
+
# print(torch.max(torch.abs(ddf_comp)))
|
| 515 |
+
# print(torch.max(torch.abs(ddf_rand)))
|
| 516 |
+
|
| 517 |
+
return [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],[msk_rec,msk_diff,msk_save]
|
| 518 |
+
|
| 519 |
+
if __name__ == "__main__":
|
| 520 |
+
H, W = 8, 8
|
| 521 |
+
deformddpm = DeformDDPM(network=get_net(name="recmutattnnet")(n_steps=80, ndims=2, num_input_chn=1),image_chw=(1, H, W),device='cpu')
|
| 522 |
+
# img = torch.zeros([1, 1, H, W])
|
| 523 |
+
img = torch.randn([1, 1, H, W])
|
| 524 |
+
t = 1
|
| 525 |
+
rec_num = 2
|
| 526 |
+
# proc_type = 'adding'
|
| 527 |
+
# proc_type = 'independ'
|
| 528 |
+
# proc_type = 'downsample'
|
| 529 |
+
proc_type = 'slice'
|
| 530 |
+
# proc_type = 'project'
|
| 531 |
+
# proc_type = 'none'
|
| 532 |
+
print(img)
|
| 533 |
+
cond_imgs, mask_tgt = deformddpm.proc_cond_img(img, proc_type=proc_type)
|
| 534 |
+
print(cond_imgs)
|
| 535 |
+
# img_rec, dvf_I = deformddpm.forward(img, t, rec_num=rec_num, proc_type=proc_type)
|
| 536 |
+
# print(img_rec.shape, dvf_I.shape)
|
| 537 |
+
|
| 538 |
+
# proc_type = 'adding'
|
| 539 |
+
# ddf_comp, ddf_rand = deformddpm.diff_recover(img, T=[1,1], proc_type=proc_type)
|
| 540 |
+
|
| 541 |
+
|
Diffusion/diffuser.py
CHANGED
|
@@ -27,6 +27,7 @@ class DeformDDPM(nn.Module):
|
|
| 27 |
padding_mode="border",
|
| 28 |
v_scale = 0.008/256,
|
| 29 |
resample_mode=None,
|
|
|
|
| 30 |
):
|
| 31 |
super(DeformDDPM, self).__init__()
|
| 32 |
self.rec_num=2
|
|
@@ -35,6 +36,7 @@ class DeformDDPM(nn.Module):
|
|
| 35 |
self.v_scale = v_scale
|
| 36 |
self.device = device
|
| 37 |
self.msk_noise_scale = torch.tensor(0)
|
|
|
|
| 38 |
|
| 39 |
# print('================')
|
| 40 |
# print("device:",device)
|
|
@@ -61,6 +63,7 @@ class DeformDDPM(nn.Module):
|
|
| 61 |
)
|
| 62 |
self._DDF_Encoder_init()
|
| 63 |
self.copy_opt = nn.Identity()
|
|
|
|
| 64 |
return
|
| 65 |
|
| 66 |
def get_stn(self):
|
|
@@ -78,7 +81,8 @@ class DeformDDPM(nn.Module):
|
|
| 78 |
def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
|
| 79 |
rec_num = 1
|
| 80 |
mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
|
| 81 |
-
mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
|
|
|
|
| 82 |
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 83 |
# mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
|
| 84 |
# mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
|
|
@@ -110,7 +114,7 @@ class DeformDDPM(nn.Module):
|
|
| 110 |
|
| 111 |
def _get_random_ddf(self,img,t):
|
| 112 |
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| 113 |
-
ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf])
|
| 114 |
warped_img = self.img_stn(img,ddf_forward)
|
| 115 |
return warped_img, dvf_forward,ddf_forward
|
| 116 |
|
|
@@ -122,8 +126,10 @@ class DeformDDPM(nn.Module):
|
|
| 122 |
dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
|
| 123 |
dvf = dvf + dvf_rot
|
| 124 |
for ctl_sz in ctl_szs:
|
| 125 |
-
_v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=
|
| 126 |
# temp>>
|
|
|
|
|
|
|
| 127 |
if ctl_sz <= 2:
|
| 128 |
_v_scale = _v_scale/2
|
| 129 |
# temp<<
|
|
@@ -138,7 +144,7 @@ class DeformDDPM(nn.Module):
|
|
| 138 |
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 139 |
return sample_value
|
| 140 |
|
| 141 |
-
def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=
|
| 142 |
crop_rate=2
|
| 143 |
for _ in range(self.ndims+1):
|
| 144 |
mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
|
|
@@ -188,11 +194,11 @@ class DeformDDPM(nn.Module):
|
|
| 188 |
else:
|
| 189 |
return ddf
|
| 190 |
|
| 191 |
-
def create_noise_map(self, img, noise_type='gaussian',
|
| 192 |
if noise_type == 'gaussian':
|
| 193 |
-
noise_map = torch.randn_like(img) *
|
| 194 |
elif noise_type == 'uniform':
|
| 195 |
-
noise_map = torch.rand_like(img) # 0-1
|
| 196 |
elif noise_type == 'binary':
|
| 197 |
noise_map = torch.bernoulli(torch.rand_like(img))
|
| 198 |
else:
|
|
@@ -220,8 +226,18 @@ class DeformDDPM(nn.Module):
|
|
| 220 |
mask = torch.zeros_like(img)
|
| 221 |
sample_ratio = 0
|
| 222 |
for i in range(self.ndims):
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| 226 |
for idx in slice_idx:
|
| 227 |
mask[..., idx] = 1
|
|
@@ -243,7 +259,7 @@ class DeformDDPM(nn.Module):
|
|
| 243 |
# print("projecting dim:", i)
|
| 244 |
return proj_img/(proj_dim_num+EPS), proj_dim_num
|
| 245 |
|
| 246 |
-
def proc_cond_img(self, img, proc_type=None):
|
| 247 |
# Remove torch.no_grad() since most operations are not differentiable anyway
|
| 248 |
proc_img = img.clone().detach()
|
| 249 |
if proc_type is None:
|
|
@@ -251,7 +267,7 @@ class DeformDDPM(nn.Module):
|
|
| 251 |
proc_type = random.choices(
|
| 252 |
# ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
|
| 253 |
# weights=[1, 1, 1, 1, 1, 1, 3], k=1
|
| 254 |
-
['adding', 'independ', 'downsample', 'slice', 'none', 'uncon'],
|
| 255 |
weights=[1, 1, 1, 1, 1, 3], k=1
|
| 256 |
)[0]
|
| 257 |
mask = torch.tensor(1, device=img.device)
|
|
@@ -262,14 +278,14 @@ class DeformDDPM(nn.Module):
|
|
| 262 |
noise_map = None
|
| 263 |
if proc_type not in ['none', None, '']:
|
| 264 |
if proc_type == 'uncon':
|
| 265 |
-
noise_map = self.create_noise_map(img, noise_type=noise_type)
|
| 266 |
proc_img = noise_map
|
| 267 |
mask = torch.tensor(0, device=img.device)
|
| 268 |
cond_ratio = torch.tensor(0, device=img.device)
|
| 269 |
return proc_img, mask, cond_ratio
|
| 270 |
-
if proc_type in ['adding', 'independ', 'slice']:
|
| 271 |
# self.msk_noise_scale = 0
|
| 272 |
-
noise_map = self.create_noise_map(img, noise_type=noise_type)
|
| 273 |
if proc_type == 'adding':
|
| 274 |
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| 275 |
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
|
@@ -285,9 +301,12 @@ class DeformDDPM(nn.Module):
|
|
| 285 |
# proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
|
| 286 |
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
|
| 287 |
cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| 288 |
-
elif proc_type == 'slice':
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
| 291 |
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 292 |
if self.msk_noise_scale == 0:
|
| 293 |
proc_img = img * mask
|
|
@@ -373,8 +392,14 @@ class DeformDDPM(nn.Module):
|
|
| 373 |
t = [t] * 1
|
| 374 |
return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
|
| 375 |
|
| 376 |
-
def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, **kwargs):
|
| 377 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 379 |
else:
|
| 380 |
return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
|
@@ -446,7 +471,7 @@ class DeformDDPM(nn.Module):
|
|
| 446 |
# win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
|
| 447 |
# win_end = win_start + win_len - 1
|
| 448 |
|
| 449 |
-
k=2
|
| 450 |
# trainable_iterations = time_steps[win_start: win_start + win_len]
|
| 451 |
# trainable_iterations = random.sample(time_steps, k)
|
| 452 |
trainable_iterations = time_steps[-1:-k-1:-1]
|
|
|
|
| 27 |
padding_mode="border",
|
| 28 |
v_scale = 0.008/256,
|
| 29 |
resample_mode=None,
|
| 30 |
+
inf_mode = False,
|
| 31 |
):
|
| 32 |
super(DeformDDPM, self).__init__()
|
| 33 |
self.rec_num=2
|
|
|
|
| 36 |
self.v_scale = v_scale
|
| 37 |
self.device = device
|
| 38 |
self.msk_noise_scale = torch.tensor(0)
|
| 39 |
+
# self.msk_noise_scale = torch.tensor(1)
|
| 40 |
|
| 41 |
# print('================')
|
| 42 |
# print("device:",device)
|
|
|
|
| 63 |
)
|
| 64 |
self._DDF_Encoder_init()
|
| 65 |
self.copy_opt = nn.Identity()
|
| 66 |
+
self.inf_mode = inf_mode
|
| 67 |
return
|
| 68 |
|
| 69 |
def get_stn(self):
|
|
|
|
| 81 |
def _get_ddf_scale(self,t,divide_num=1,max_ddf_num=200): # 128
|
| 82 |
rec_num = 1
|
| 83 |
mul_num_ddf = torch.floor_divide(2*torch.pow(t,1.3), 3*divide_num).int()
|
| 84 |
+
# mul_num_dvf = torch.floor_divide(torch.pow(t,0.6), divide_num).int()
|
| 85 |
+
mul_num_dvf = torch.floor_divide(torch.pow(t,0.75), divide_num).int() # raise the power number to increase the dvf ratio, which can help the training of ddf_stn_rec and make the model more robust to large deformation
|
| 86 |
# print("time_step:",t,"mul_num_ddf:",mul_num_ddf,"mul_num_dvf:",mul_num_dvf)
|
| 87 |
# mul_num_ddf = self._sample_random_uniform_multi_order(high=mul_num_ddf)
|
| 88 |
# mul_num_dvf = self._sample_random_uniform_multi_order(high=mul_num_dvf)
|
|
|
|
| 114 |
|
| 115 |
def _get_random_ddf(self,img,t):
|
| 116 |
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| 117 |
+
ddf_forward,dvf_forward = self._random_ddf_generate(rec_num=rec_num, mul_num=[mul_num_ddf,mul_num_dvf],select_num=random.choice([1, 2, 3, 3, 4, 4]))
|
| 118 |
warped_img = self.img_stn(img,ddf_forward)
|
| 119 |
return warped_img, dvf_forward,ddf_forward
|
| 120 |
|
|
|
|
| 126 |
dvf_rot = utils.random_ddf(batch_size=self.batch_size, ndims=self.ndims, img_sz=[self.ctl_sz]*self.ndims, range_gauss=0, rot_range=np.pi/90)
|
| 127 |
dvf = dvf + dvf_rot
|
| 128 |
for ctl_sz in ctl_szs:
|
| 129 |
+
_v_scale = self._sample_random_uniform_multi_order(high=v_scale, low=0., order_num=random.choice([1, 2])) if rand_v_scale else v_scale
|
| 130 |
# temp>>
|
| 131 |
+
if ctl_sz <= 4:
|
| 132 |
+
_v_scale = _v_scale/2
|
| 133 |
if ctl_sz <= 2:
|
| 134 |
_v_scale = _v_scale/2
|
| 135 |
# temp<<
|
|
|
|
| 144 |
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 145 |
return sample_value
|
| 146 |
|
| 147 |
+
def _random_ddf_generate(self,rec_num=3,mul_num=[torch.tensor([5]),torch.tensor([5])],ddf0=None,keep_inverse=False,noise_ratio=0.08,select_num=3, flip_ratio=0.5):
|
| 148 |
crop_rate=2
|
| 149 |
for _ in range(self.ndims+1):
|
| 150 |
mul_num=[torch.unsqueeze(n,-1) for n in mul_num]
|
|
|
|
| 194 |
else:
|
| 195 |
return ddf
|
| 196 |
|
| 197 |
+
def create_noise_map(self, img, noise_type='gaussian', noise_scale=0.1):
|
| 198 |
if noise_type == 'gaussian':
|
| 199 |
+
noise_map = torch.randn_like(img) * noise_scale
|
| 200 |
elif noise_type == 'uniform':
|
| 201 |
+
noise_map = torch.rand_like(img)*noise_scale*2-noise_scale # 0-1
|
| 202 |
elif noise_type == 'binary':
|
| 203 |
noise_map = torch.bernoulli(torch.rand_like(img))
|
| 204 |
else:
|
|
|
|
| 226 |
mask = torch.zeros_like(img)
|
| 227 |
sample_ratio = 0
|
| 228 |
for i in range(self.ndims):
|
| 229 |
+
if self.inf_mode:
|
| 230 |
+
if i== 0:
|
| 231 |
+
slice_num = 1 # use max slice num for inference for better performance
|
| 232 |
+
slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
|
| 233 |
+
else:
|
| 234 |
+
slice_num = 0
|
| 235 |
+
slice_idx = []
|
| 236 |
+
# slice_num = 1 # use max slice num for inference for better performance
|
| 237 |
+
# slice_idx = [self.image_chw[1]//2] # use middle slice for inference for better performance
|
| 238 |
+
else:
|
| 239 |
+
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| 240 |
+
slice_idx = random.sample(range(self.image_chw[1]), slice_num)
|
| 241 |
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| 242 |
for idx in slice_idx:
|
| 243 |
mask[..., idx] = 1
|
|
|
|
| 259 |
# print("projecting dim:", i)
|
| 260 |
return proj_img/(proj_dim_num+EPS), proj_dim_num
|
| 261 |
|
| 262 |
+
def proc_cond_img(self, img, proc_type=None,noise_scale=0.1):
|
| 263 |
# Remove torch.no_grad() since most operations are not differentiable anyway
|
| 264 |
proc_img = img.clone().detach()
|
| 265 |
if proc_type is None:
|
|
|
|
| 267 |
proc_type = random.choices(
|
| 268 |
# ['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon'],
|
| 269 |
# weights=[1, 1, 1, 1, 1, 1, 3], k=1
|
| 270 |
+
['adding', 'independ', 'downsample', 'slice','slice1', 'none', 'uncon'],
|
| 271 |
weights=[1, 1, 1, 1, 1, 3], k=1
|
| 272 |
)[0]
|
| 273 |
mask = torch.tensor(1, device=img.device)
|
|
|
|
| 278 |
noise_map = None
|
| 279 |
if proc_type not in ['none', None, '']:
|
| 280 |
if proc_type == 'uncon':
|
| 281 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| 282 |
proc_img = noise_map
|
| 283 |
mask = torch.tensor(0, device=img.device)
|
| 284 |
cond_ratio = torch.tensor(0, device=img.device)
|
| 285 |
return proc_img, mask, cond_ratio
|
| 286 |
+
if proc_type in ['adding', 'independ', 'slice','slice1']:
|
| 287 |
# self.msk_noise_scale = 0
|
| 288 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type,noise_scale=noise_scale)
|
| 289 |
if proc_type == 'adding':
|
| 290 |
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| 291 |
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
|
|
|
| 301 |
# proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./32, 1])
|
| 302 |
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1./64, 1])
|
| 303 |
cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| 304 |
+
elif proc_type == 'slice' or proc_type == 'slice1':
|
| 305 |
+
if proc_type == 'slice1':
|
| 306 |
+
slice_num_max = 1
|
| 307 |
+
else:
|
| 308 |
+
slice_num_max = random.randint(1, 64)
|
| 309 |
+
slice_num_max = random.randint(1, slice_num_max)
|
| 310 |
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 311 |
if self.msk_noise_scale == 0:
|
| 312 |
proc_img = img * mask
|
|
|
|
| 392 |
t = [t] * 1
|
| 393 |
return self.recover(x=noisy_imgs*mask, y=cond_imgs, t=t, rec_num=rec_num, text=text), dvf_I
|
| 394 |
|
| 395 |
+
def forward(self, img_org, cond_imgs=None, proc_type=None, T=None, output_embedding=False, **kwargs):
|
| 396 |
+
if output_embedding:
|
| 397 |
+
# Direct network forward for contrastive embedding (no diffusion).
|
| 398 |
+
# Returns img_embd so DDP's prepare_for_backward traces the correct subgraph
|
| 399 |
+
# (encoder + mid + attn + img2txt only, no decoder).
|
| 400 |
+
self.network(x=img_org, y=cond_imgs, t=T, text=kwargs.get('text'), rec_num=1)
|
| 401 |
+
return self.network.img_embd
|
| 402 |
+
elif T is not None:
|
| 403 |
return self.diff_recover(img_org=img_org, T=T, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
| 404 |
else:
|
| 405 |
return self._single_step(x0=img_org, proc_type=proc_type, cond_imgs=cond_imgs, **kwargs)
|
|
|
|
| 471 |
# win_start = random.randint(len(time_steps)//2, len(time_steps) - win_len)
|
| 472 |
# win_end = win_start + win_len - 1
|
| 473 |
|
| 474 |
+
k = 1 if len(time_steps) > 16 else 2
|
| 475 |
# trainable_iterations = time_steps[win_start: win_start + win_len]
|
| 476 |
# trainable_iterations = random.sample(time_steps, k)
|
| 477 |
trainable_iterations = time_steps[-1:-k-1:-1]
|
Diffusion/diffuser_opt.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
diffuser_opt.py — Optimized DeformDDPM subclass.
|
| 3 |
+
|
| 4 |
+
Inherits from Diffusion.diffuser.DeformDDPM and overrides only the methods
|
| 5 |
+
that benefit from optimization.
|
| 6 |
+
|
| 7 |
+
Key optimizations:
|
| 8 |
+
1. diff_recover(): hoist img_org/msk_org .clone().detach() outside the loop,
|
| 9 |
+
pre-compute timestep tensors, use torch.no_grad() for frozen steps
|
| 10 |
+
2. _random_ddf_generate(): scaling-and-squaring for O(log n) composition
|
| 11 |
+
instead of O(n), crop-first upsampling (4x faster), on-device tensors.
|
| 12 |
+
3. proc_cond_img(): skip clone for 'uncon' path (most common, ~3/8 weight)
|
| 13 |
+
4. _DDF_Encoder_init(): use OptSTN (register_buffer, no per-call .to(device))
|
| 14 |
+
5. recover(): fix t tensor bug (was staying on CPU), avoid redundant torch.tensor()
|
| 15 |
+
6. _multiscale_dvf_generate(): generate random tensors on device to avoid
|
| 16 |
+
CPU→GPU transfer of 3D volumes.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from torch import nn
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import random
|
| 24 |
+
import math
|
| 25 |
+
|
| 26 |
+
import Diffusion.utils_diff as utils
|
| 27 |
+
from Diffusion.diffuser import DeformDDPM as _BaseDeformDDPM
|
| 28 |
+
from Diffusion.networks import *
|
| 29 |
+
from Diffusion.networks_opt import OptSTN
|
| 30 |
+
|
| 31 |
+
EPS = 1e-8
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DeformDDPM(_BaseDeformDDPM):
|
| 35 |
+
"""Drop-in replacement for DeformDDPM with speed optimizations."""
|
| 36 |
+
|
| 37 |
+
# ------------------------------------------------------------------
|
| 38 |
+
# Optimization 4: use OptSTN (register_buffer, no per-call .to())
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
def _DDF_Encoder_init(self, ctl_ratio=4, ctl_sz=None, resample_mode=None):
|
| 41 |
+
if ctl_sz is None:
|
| 42 |
+
ctl_sz = self.image_chw[1] // ctl_ratio
|
| 43 |
+
self.ctl_sz = ctl_sz
|
| 44 |
+
self.img_sz = self.image_chw[1]
|
| 45 |
+
# OPT: use OptSTN instead of STN — register_buffer for ref_grid/max_sz
|
| 46 |
+
self.ddf_stn_rec = OptSTN(img_sz=ctl_sz, ndims=self.ndims, device=self.device,
|
| 47 |
+
padding_mode=self.ddf_pad_mode)
|
| 48 |
+
self.img_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
|
| 49 |
+
padding_mode=self.img_pad_mode, resample_mode=self.resample_mode)
|
| 50 |
+
self.msk_stn = OptSTN(img_sz=self.img_sz, ndims=self.ndims, device=self.device,
|
| 51 |
+
padding_mode=self.img_pad_mode, resample_mode='nearest')
|
| 52 |
+
|
| 53 |
+
def __init__(self, network, n_steps=50, beta_schedule_fn=None, device='cpu',
|
| 54 |
+
image_chw=(1, 28, 28), batch_size=1, img_pad_mode="zeros",
|
| 55 |
+
ddf_pad_mode="border", padding_mode="border",
|
| 56 |
+
v_scale=0.008/256, resample_mode=None, inf_mode=False):
|
| 57 |
+
# Call parent __init__ — it creates STN instances
|
| 58 |
+
super().__init__(
|
| 59 |
+
network=network, n_steps=n_steps, beta_schedule_fn=beta_schedule_fn,
|
| 60 |
+
device=device, image_chw=image_chw, batch_size=batch_size,
|
| 61 |
+
img_pad_mode=img_pad_mode, ddf_pad_mode=ddf_pad_mode,
|
| 62 |
+
padding_mode=padding_mode, v_scale=v_scale, resample_mode=resample_mode,
|
| 63 |
+
inf_mode=inf_mode,
|
| 64 |
+
)
|
| 65 |
+
# OPT: replace ddf_stn_full with OptSTN too
|
| 66 |
+
self.ddf_stn_full = OptSTN(
|
| 67 |
+
img_sz=self.image_chw[1], ndims=self.ndims,
|
| 68 |
+
padding_mode=self.padding_mode, device=self.device,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
# ------------------------------------------------------------------
|
| 72 |
+
# Optimization 5: fix recover() t tensor bug + avoid redundant copies
|
| 73 |
+
# ------------------------------------------------------------------
|
| 74 |
+
def recover(self, x, y, t, rec_num=2, text=None):
|
| 75 |
+
# OPT: don't recreate t if already a tensor on the right device
|
| 76 |
+
if isinstance(t, list):
|
| 77 |
+
t = [t0 if isinstance(t0, torch.Tensor) else torch.tensor(t0, device=x.device)
|
| 78 |
+
for t0 in t]
|
| 79 |
+
t = [t0.to(x.device) if t0.device != x.device else t0 for t0 in t]
|
| 80 |
+
elif isinstance(t, torch.Tensor):
|
| 81 |
+
# OPT: skip torch.tensor() copy — just ensure correct device
|
| 82 |
+
if t.device != x.device:
|
| 83 |
+
t = t.to(x.device)
|
| 84 |
+
else:
|
| 85 |
+
t = torch.tensor(t, device=x.device)
|
| 86 |
+
if rec_num is None:
|
| 87 |
+
rec_num = self.rec_num
|
| 88 |
+
return self.network(x=x, y=y, t=t, rec_num=rec_num, text=text)
|
| 89 |
+
|
| 90 |
+
# ------------------------------------------------------------------
|
| 91 |
+
# Optimization 2: scaling-and-squaring + crop-first upsample
|
| 92 |
+
# ------------------------------------------------------------------
|
| 93 |
+
def _compose_n_times(self, dvf, n):
|
| 94 |
+
"""Compute n-fold self-composition of dvf using scaling-and-squaring.
|
| 95 |
+
|
| 96 |
+
Uses binary decomposition: O(log n) STN calls instead of O(n).
|
| 97 |
+
E.g. n=87 → ~10 calls, n=200 → ~9 calls (vs 87/200 iterative calls).
|
| 98 |
+
|
| 99 |
+
The result is the same deformation (n-fold composition) but computed
|
| 100 |
+
via a different sequence of grid_sample interpolations, so there are
|
| 101 |
+
small numerical differences (~1e-2 to 1e-1) vs iterative composition.
|
| 102 |
+
This is acceptable because DDF generation is stochastic augmentation.
|
| 103 |
+
"""
|
| 104 |
+
if n <= 0:
|
| 105 |
+
return torch.zeros_like(dvf)
|
| 106 |
+
result = None
|
| 107 |
+
current = dvf # current = dvf^(2^i), starts as dvf^1
|
| 108 |
+
while n > 0:
|
| 109 |
+
if n & 1: # bit is set → accumulate this power
|
| 110 |
+
if result is None:
|
| 111 |
+
result = current.clone()
|
| 112 |
+
else:
|
| 113 |
+
# result = current ∘ result (apply result first, then current)
|
| 114 |
+
result = result + self.ddf_stn_rec(current, result)
|
| 115 |
+
n >>= 1
|
| 116 |
+
if n > 0:
|
| 117 |
+
# Square: current = current ∘ current
|
| 118 |
+
current = current + self.ddf_stn_rec(current, current)
|
| 119 |
+
return result
|
| 120 |
+
|
| 121 |
+
def _crop_upsample(self, field):
|
| 122 |
+
"""Upsample DDF from ctl_sz to img_sz with 2x oversampling + center crop.
|
| 123 |
+
|
| 124 |
+
Instead of upsampling the full ctl_sz→img_sz*2 (e.g. 32³→256³) then
|
| 125 |
+
cropping to img_sz (128³), we crop the control-point field first
|
| 126 |
+
(to ~20³) then upsample to ~160³ and crop to 128³. This is 4x faster
|
| 127 |
+
and bit-identical because trilinear interpolation is local.
|
| 128 |
+
"""
|
| 129 |
+
crop_rate = 2
|
| 130 |
+
upscale = self.img_sz * crop_rate // self.ctl_sz # e.g. 8
|
| 131 |
+
margin = 2 # voxels of margin for interpolation boundary
|
| 132 |
+
lo = self.ctl_sz // 4 - margin # e.g. 6
|
| 133 |
+
hi = self.ctl_sz * 3 // 4 + margin # e.g. 26
|
| 134 |
+
crop_sz = hi - lo # e.g. 20
|
| 135 |
+
up_sz = crop_sz * upscale # e.g. 160
|
| 136 |
+
pad = (up_sz - self.img_sz) // 2 # e.g. 16
|
| 137 |
+
|
| 138 |
+
mode = 'bilinear' if self.ndims == 2 else 'trilinear'
|
| 139 |
+
if self.ndims == 2:
|
| 140 |
+
field_crop = field[..., lo:hi, lo:hi] * self.img_sz / self.ctl_sz
|
| 141 |
+
field_up = F.interpolate(field_crop, up_sz, mode=mode)
|
| 142 |
+
return field_up[..., pad:pad + self.img_sz, pad:pad + self.img_sz]
|
| 143 |
+
else:
|
| 144 |
+
field_crop = field[..., lo:hi, lo:hi, lo:hi] * self.img_sz / self.ctl_sz
|
| 145 |
+
field_up = F.interpolate(field_crop, up_sz, mode=mode)
|
| 146 |
+
return field_up[..., pad:pad + self.img_sz,
|
| 147 |
+
pad:pad + self.img_sz,
|
| 148 |
+
pad:pad + self.img_sz]
|
| 149 |
+
|
| 150 |
+
def _random_ddf_generate(self, rec_num=3, mul_num=[torch.tensor([5]), torch.tensor([5])],
|
| 151 |
+
ddf0=None, keep_inverse=False, noise_ratio=0.08, select_num=3, flip_ratio=0.5):
|
| 152 |
+
for _ in range(self.ndims + 1):
|
| 153 |
+
mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
|
| 154 |
+
ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| 155 |
+
if ddf0 is not None:
|
| 156 |
+
ddf = ddf0
|
| 157 |
+
else:
|
| 158 |
+
ddf = torch.zeros(ctl_ddf_sz, device=self.device)
|
| 159 |
+
dddf = torch.zeros(ctl_ddf_sz, device=self.device)
|
| 160 |
+
scale_num = min(8, int(math.log2(self.ctl_sz)))
|
| 161 |
+
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
| 162 |
+
|
| 163 |
+
for i in range(rec_num):
|
| 164 |
+
if len(ctl_szs_all) > select_num:
|
| 165 |
+
ctl_szs = random.sample(ctl_szs_all, select_num)
|
| 166 |
+
else:
|
| 167 |
+
ctl_szs = ctl_szs_all
|
| 168 |
+
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs)
|
| 169 |
+
if noise_ratio == 0:
|
| 170 |
+
dvf0 = dvf
|
| 171 |
+
else:
|
| 172 |
+
dvf0 = dvf + self.ddf_stn_rec(
|
| 173 |
+
self._multiscale_dvf_generate(self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False),
|
| 174 |
+
dvf)
|
| 175 |
+
|
| 176 |
+
mul_num_ddf_val = int(torch.max(mul_num[0]).item())
|
| 177 |
+
mul_num_dvf_val = int(torch.max(mul_num[1]).item())
|
| 178 |
+
|
| 179 |
+
# OPT: scaling-and-squaring — O(log n) STN calls instead of O(n)
|
| 180 |
+
# For t=40: 10 calls instead of 80. For t=79: 9 calls instead of 195.
|
| 181 |
+
ddf = self._compose_n_times(dvf0, mul_num_ddf_val)
|
| 182 |
+
dddf = self._compose_n_times(dvf, mul_num_dvf_val)
|
| 183 |
+
|
| 184 |
+
# OPT: crop-first upsample — 4x fewer voxels to interpolate (bit-identical)
|
| 185 |
+
ddf = self._crop_upsample(ddf)
|
| 186 |
+
dddf = self._crop_upsample(dddf)
|
| 187 |
+
return ddf, dddf
|
| 188 |
+
|
| 189 |
+
# ------------------------------------------------------------------
|
| 190 |
+
# Optimization 6: generate DVF on device to avoid CPU→GPU transfer
|
| 191 |
+
# ------------------------------------------------------------------
|
| 192 |
+
def _multiscale_dvf_generate(self, v_scale, ctl_szs=[4, 8, 16, 32, 64], rand_v_scale=True):
|
| 193 |
+
dvf = 0
|
| 194 |
+
if self.img_sz is None:
|
| 195 |
+
self.img_sz = max(ctl_szs)
|
| 196 |
+
if 1 in ctl_szs:
|
| 197 |
+
dvf_rot = utils.random_ddf(
|
| 198 |
+
batch_size=self.batch_size, ndims=self.ndims,
|
| 199 |
+
img_sz=[self.ctl_sz] * self.ndims, range_gauss=0, rot_range=np.pi / 90)
|
| 200 |
+
dvf = dvf + dvf_rot
|
| 201 |
+
for ctl_sz in ctl_szs:
|
| 202 |
+
_v_scale = self._sample_random_uniform_multi_order(
|
| 203 |
+
high=v_scale, low=0., order_num=random.choice([1, 1, 2])) if rand_v_scale else v_scale
|
| 204 |
+
if ctl_sz <= 2:
|
| 205 |
+
_v_scale = _v_scale / 2
|
| 206 |
+
# OPT: generate random tensor directly on device
|
| 207 |
+
dvf_comp = torch.randn([self.batch_size, self.ndims] + [ctl_sz] * self.ndims,
|
| 208 |
+
device=self.device) * _v_scale
|
| 209 |
+
dvf_comp = F.interpolate(dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims,
|
| 210 |
+
align_corners=False,
|
| 211 |
+
mode='bilinear' if self.ndims == 2 else 'trilinear')
|
| 212 |
+
dvf = dvf + dvf_comp
|
| 213 |
+
return dvf
|
| 214 |
+
|
| 215 |
+
# ------------------------------------------------------------------
|
| 216 |
+
# Optimization 3: skip clone for 'uncon' (most common conditioning type)
|
| 217 |
+
# ------------------------------------------------------------------
|
| 218 |
+
def proc_cond_img(self, img, proc_type=None, noise_scale=0.1):
|
| 219 |
+
if proc_type is None:
|
| 220 |
+
proc_type = random.choices(
|
| 221 |
+
['adding', 'independ', 'downsample', 'slice', 'slice1', 'none', 'uncon'],
|
| 222 |
+
weights=[1, 1, 1, 1, 1, 3], k=1
|
| 223 |
+
)[0]
|
| 224 |
+
mask = torch.tensor(1, device=img.device)
|
| 225 |
+
cond_ratio = torch.tensor(1., device=img.device)
|
| 226 |
+
self.msk_noise_scale = torch.tensor(0, device=img.device)
|
| 227 |
+
noise_type = random.choice(['gaussian', 'uniform', 'none'])
|
| 228 |
+
|
| 229 |
+
if proc_type not in ['none', None, '']:
|
| 230 |
+
# OPT: handle 'uncon' before cloning — no need to clone img
|
| 231 |
+
if proc_type == 'uncon':
|
| 232 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
| 233 |
+
proc_img = noise_map
|
| 234 |
+
mask = torch.tensor(0, device=img.device)
|
| 235 |
+
cond_ratio = torch.tensor(0, device=img.device)
|
| 236 |
+
return proc_img, mask, cond_ratio
|
| 237 |
+
|
| 238 |
+
# Only clone when we actually need the image data
|
| 239 |
+
proc_img = img.clone().detach()
|
| 240 |
+
noise_map = None
|
| 241 |
+
if proc_type in ['adding', 'independ', 'slice', 'slice1']:
|
| 242 |
+
noise_map = self.create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
| 243 |
+
if proc_type == 'adding':
|
| 244 |
+
proc_img, noise_ratio = self.add_noise(proc_img, noise_map=noise_map, noise_ratio_range=[0., 1.])
|
| 245 |
+
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| 246 |
+
elif proc_type == 'independ':
|
| 247 |
+
mask = self.create_noise_map(img, noise_type='binary')
|
| 248 |
+
if self.msk_noise_scale == 0:
|
| 249 |
+
proc_img = img * mask
|
| 250 |
+
else:
|
| 251 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
cond_ratio = mask.float().mean()
|
| 254 |
+
elif proc_type == 'downsample':
|
| 255 |
+
proc_img, down_ratio = self.downsample(proc_img, down_ratio_range=[1. / 64, 1])
|
| 256 |
+
cond_ratio = torch.tensor(down_ratio, device=img.device)
|
| 257 |
+
elif proc_type == 'slice' or proc_type == 'slice1':
|
| 258 |
+
if proc_type == 'slice1':
|
| 259 |
+
slice_num_max = 1
|
| 260 |
+
else:
|
| 261 |
+
slice_num_max = random.randint(1, 64)
|
| 262 |
+
slice_num_max = random.randint(1, slice_num_max)
|
| 263 |
+
mask, sample_ratio = self.get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 264 |
+
if self.msk_noise_scale == 0:
|
| 265 |
+
proc_img = img * mask
|
| 266 |
+
else:
|
| 267 |
+
proc_img = self.apply_noise(proc_img, noise_map=noise_map * self.msk_noise_scale, apply_mask=mask)
|
| 268 |
+
cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| 269 |
+
elif proc_type == 'project':
|
| 270 |
+
proc_img, proj_num = self.project(proc_img)
|
| 271 |
+
cond_ratio = torch.tensor(proj_num / (128 * self.ndims), device=img.device)
|
| 272 |
+
return proc_img, mask, cond_ratio
|
| 273 |
+
else:
|
| 274 |
+
# 'none' type — still need clone
|
| 275 |
+
proc_img = img.clone().detach()
|
| 276 |
+
return proc_img, mask, cond_ratio
|
| 277 |
+
|
| 278 |
+
# ------------------------------------------------------------------
|
| 279 |
+
# Optimization 1: hoist clone, pre-compute timestep tensors,
|
| 280 |
+
# use inference_mode for frozen iterations
|
| 281 |
+
# ------------------------------------------------------------------
|
| 282 |
+
def diff_recover(self, img_org, msk_org=None, T=[None, None], ddf_rand=None,
|
| 283 |
+
v_scale=None, t_save=None, cond_imgs=None, proc_type=None, text=None):
|
| 284 |
+
if cond_imgs is None:
|
| 285 |
+
cond_imgs = img_org.clone().detach()
|
| 286 |
+
cond_imgs, mask_tgt, cond_ratio = self.proc_cond_img(cond_imgs, proc_type=proc_type)
|
| 287 |
+
if ddf_rand is None:
|
| 288 |
+
if v_scale is not None:
|
| 289 |
+
self.v_scale = v_scale
|
| 290 |
+
self._DDF_Encoder_init()
|
| 291 |
+
if T[0] is None or T[0] == 0:
|
| 292 |
+
img_diff = img_org.clone().detach()
|
| 293 |
+
ddf_rand = torch.zeros_like(img_diff)
|
| 294 |
+
else:
|
| 295 |
+
img_diff, _, ddf_rand = self._get_random_ddf(
|
| 296 |
+
img=img_org, t=torch.tensor(np.array([T[0]])).to(self.device))
|
| 297 |
+
else:
|
| 298 |
+
img_diff = self.img_stn(img_org.clone().detach(), ddf_rand)
|
| 299 |
+
ddf_comp = ddf_rand.clone().detach()
|
| 300 |
+
img_rec = img_diff.clone().detach()
|
| 301 |
+
if msk_org is not None:
|
| 302 |
+
msk_diff = self.msk_stn(msk_org.clone().detach(), ddf_rand)
|
| 303 |
+
else:
|
| 304 |
+
msk_diff = None
|
| 305 |
+
msk_rec = msk_diff.clone().detach() if msk_org is not None else None
|
| 306 |
+
img_save = []
|
| 307 |
+
msk_save = []
|
| 308 |
+
|
| 309 |
+
# OPT: hoist clone().detach() outside the loop — grid_sample is read-only
|
| 310 |
+
img_org_ref = img_org.clone().detach()
|
| 311 |
+
msk_org_ref = msk_org.clone().detach() if msk_org is not None else None
|
| 312 |
+
|
| 313 |
+
if isinstance(self.network, DefRec_MutAttnNet):
|
| 314 |
+
t_list = list(range(T[1] - 1, -1, -1))
|
| 315 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t_list, rec_num=None, text=text)
|
| 316 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 317 |
+
img_rec = self.img_stn(img_org_ref, ddf_comp)
|
| 318 |
+
if msk_org is not None:
|
| 319 |
+
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
|
| 320 |
+
else:
|
| 321 |
+
if isinstance(T[-1], int):
|
| 322 |
+
time_steps = range(T[-1] - 1, -1, -1)
|
| 323 |
+
trainable_iterations = []
|
| 324 |
+
else:
|
| 325 |
+
time_steps = T[-1]
|
| 326 |
+
k = 2
|
| 327 |
+
trainable_iterations = time_steps[-1:-k - 1:-1]
|
| 328 |
+
|
| 329 |
+
# OPT: pre-compute trainable index threshold — avoid unhashable list issue
|
| 330 |
+
t_save_set = set(t_save) if t_save is not None else None
|
| 331 |
+
num_time_steps = len(time_steps) if not isinstance(time_steps, range) else len(time_steps)
|
| 332 |
+
trainable_start_idx = num_time_steps - len(trainable_iterations)
|
| 333 |
+
|
| 334 |
+
for step_idx, i in enumerate(time_steps):
|
| 335 |
+
# OPT: create tensor directly on device, no numpy intermediate
|
| 336 |
+
t = torch.tensor([i], device=self.device)
|
| 337 |
+
|
| 338 |
+
if step_idx >= trainable_start_idx:
|
| 339 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 340 |
+
else:
|
| 341 |
+
# OPT: no_grad for frozen iterations (inference_mode not safe here
|
| 342 |
+
# because ddf_comp is composed across frozen+trainable iterations)
|
| 343 |
+
with torch.no_grad():
|
| 344 |
+
pre_dvf_I = self.recover(x=img_rec, y=cond_imgs, t=t, rec_num=None, text=text)
|
| 345 |
+
|
| 346 |
+
ddf_comp = self.ddf_stn_full(ddf_comp, pre_dvf_I) + pre_dvf_I
|
| 347 |
+
# OPT: use pre-cloned reference instead of cloning each iteration
|
| 348 |
+
img_rec = self.img_stn(img_org_ref, ddf_comp)
|
| 349 |
+
if msk_org is not None:
|
| 350 |
+
msk_rec = self.msk_stn(msk_org_ref, ddf_comp)
|
| 351 |
+
if t_save_set is not None:
|
| 352 |
+
if i in t_save_set:
|
| 353 |
+
img_save.append(img_rec)
|
| 354 |
+
if msk_org is not None:
|
| 355 |
+
msk_save.append(msk_rec)
|
| 356 |
+
|
| 357 |
+
return [ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save]
|
Diffusion/losses.py
CHANGED
|
@@ -21,7 +21,7 @@ class LMSE(torch.nn.Module):
|
|
| 21 |
Labeled Mean Square Error (LMSE)
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, eps=1e-7, relate_eps=
|
| 25 |
super(LMSE, self).__init__()
|
| 26 |
self.eps = eps
|
| 27 |
self.relate_eps = relate_eps
|
|
@@ -72,7 +72,7 @@ class LNCC(torch.nn.Module):
|
|
| 72 |
Local (over window) normalized cross-correlation (LNCC)
|
| 73 |
"""
|
| 74 |
|
| 75 |
-
def __init__(self, win=None, num_ch=1, eps=1e-
|
| 76 |
super(LNCC, self).__init__()
|
| 77 |
self.scale = 2e0
|
| 78 |
self.win = win
|
|
@@ -84,11 +84,11 @@ class LNCC(torch.nn.Module):
|
|
| 84 |
|
| 85 |
# Set window size
|
| 86 |
if self.win is None:
|
| 87 |
-
self.win = [
|
| 88 |
self.padding = [(w-1) // 2 for w in self.win]
|
| 89 |
|
| 90 |
if smooth:
|
| 91 |
-
self.kernels = self._build_kernel(std=0.
|
| 92 |
self.sum_filt = self._build_kernel(std=0.0)
|
| 93 |
|
| 94 |
def _build_kernel(self, std=0.0):
|
|
@@ -153,7 +153,7 @@ class LNCC(torch.nn.Module):
|
|
| 153 |
J_var = J2_sum
|
| 154 |
|
| 155 |
# cc = (cross * cross) / (I_var * J_var + self.eps)
|
| 156 |
-
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
|
| 157 |
if label is not None:
|
| 158 |
label = label.float()
|
| 159 |
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
|
@@ -164,6 +164,43 @@ class LNCC(torch.nn.Module):
|
|
| 164 |
return -self.lncc(I*self.scale, J*self.scale, label=label)
|
| 165 |
|
| 166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
class NCC(torch.nn.Module):
|
| 169 |
# def __init__(self, eps_scale=10e-7,img_sz=256):
|
|
@@ -236,7 +273,7 @@ class Grad(torch.nn.Module):
|
|
| 236 |
N-D gradient loss
|
| 237 |
"""
|
| 238 |
|
| 239 |
-
def __init__(self, penalty=['l1'],ndims=
|
| 240 |
super(Grad, self).__init__()
|
| 241 |
self.penalty = penalty
|
| 242 |
self.eps = eps
|
|
@@ -521,7 +558,7 @@ if __name__ == "__main__":
|
|
| 521 |
img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
|
| 522 |
# img3d_t = img3d.clone().detach()
|
| 523 |
# img3d_t = torch.zeros_like(img3d)
|
| 524 |
-
translation =
|
| 525 |
start = 0
|
| 526 |
end = 32
|
| 527 |
# img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
|
|
|
|
| 21 |
Labeled Mean Square Error (LMSE)
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
def __init__(self, eps=1e-7, relate_eps=1e-1, win=None, smooth=False):
|
| 25 |
super(LMSE, self).__init__()
|
| 26 |
self.eps = eps
|
| 27 |
self.relate_eps = relate_eps
|
|
|
|
| 72 |
Local (over window) normalized cross-correlation (LNCC)
|
| 73 |
"""
|
| 74 |
|
| 75 |
+
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
|
| 76 |
super(LNCC, self).__init__()
|
| 77 |
self.scale = 2e0
|
| 78 |
self.win = win
|
|
|
|
| 84 |
|
| 85 |
# Set window size
|
| 86 |
if self.win is None:
|
| 87 |
+
self.win = [11] * self.ndims
|
| 88 |
self.padding = [(w-1) // 2 for w in self.win]
|
| 89 |
|
| 90 |
if smooth:
|
| 91 |
+
self.kernels = self._build_kernel(std=0.5)
|
| 92 |
self.sum_filt = self._build_kernel(std=0.0)
|
| 93 |
|
| 94 |
def _build_kernel(self, std=0.0):
|
|
|
|
| 153 |
J_var = J2_sum
|
| 154 |
|
| 155 |
# cc = (cross * cross) / (I_var * J_var + self.eps)
|
| 156 |
+
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps) # eps must be large enough to avoid numerical unstability
|
| 157 |
if label is not None:
|
| 158 |
label = label.float()
|
| 159 |
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
|
|
|
| 164 |
return -self.lncc(I*self.scale, J*self.scale, label=label)
|
| 165 |
|
| 166 |
|
| 167 |
+
class MSLNCC(LNCC):
|
| 168 |
+
"""
|
| 169 |
+
Multi-Scale Local Normalized Cross-Correlation (MSLNCC)
|
| 170 |
+
Computes LNCC at multiple scales and combines with weighted sum.
|
| 171 |
+
Images are downsampled via average pooling, labels via max pooling.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
|
| 175 |
+
scale_ratios=[1, 0.5, 0.25], scale_weights=[0.25, 0.5, 0.75]):
|
| 176 |
+
super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
|
| 177 |
+
central=central, smooth=smooth)
|
| 178 |
+
if win is None:
|
| 179 |
+
win = [9] * self.ndims
|
| 180 |
+
self.scale_ratios = scale_ratios
|
| 181 |
+
self.scale_weights = scale_weights
|
| 182 |
+
|
| 183 |
+
def _downsample(self, I, J, label, ratio):
|
| 184 |
+
"""Downsample images via average pooling, labels via max pooling."""
|
| 185 |
+
if ratio >= 1.0:
|
| 186 |
+
return I, J, label
|
| 187 |
+
factor = int(1.0 / ratio)
|
| 188 |
+
I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
|
| 189 |
+
J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
|
| 190 |
+
label_down = None
|
| 191 |
+
if label is not None:
|
| 192 |
+
label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
|
| 193 |
+
return I_down, J_down, label_down
|
| 194 |
+
|
| 195 |
+
def forward(self, I, J, label=None):
|
| 196 |
+
total_loss = 0.0
|
| 197 |
+
total_weight = 0.0
|
| 198 |
+
for ratio, weight in zip(self.scale_ratios, self.scale_weights):
|
| 199 |
+
I_s, J_s, label_s = self._downsample(I, J, label, ratio)
|
| 200 |
+
total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
|
| 201 |
+
total_weight += weight
|
| 202 |
+
return -total_loss / total_weight
|
| 203 |
+
|
| 204 |
|
| 205 |
class NCC(torch.nn.Module):
|
| 206 |
# def __init__(self, eps_scale=10e-7,img_sz=256):
|
|
|
|
| 273 |
N-D gradient loss
|
| 274 |
"""
|
| 275 |
|
| 276 |
+
def __init__(self, penalty=['l1'],ndims=3, eps=1e-8, outrange_weight=1e4,outrange_thresh=0.5, detj_weight=1e4, apear_scale=8, dist=1, sign=1,waive_thresh=10**-4):
|
| 277 |
super(Grad, self).__init__()
|
| 278 |
self.penalty = penalty
|
| 279 |
self.eps = eps
|
|
|
|
| 558 |
img3d_t = torch.empty(1,1,size,size,size).uniform_(0,1)#*-0.000001
|
| 559 |
# img3d_t = img3d.clone().detach()
|
| 560 |
# img3d_t = torch.zeros_like(img3d)
|
| 561 |
+
translation = 16
|
| 562 |
start = 0
|
| 563 |
end = 32
|
| 564 |
# img3d_t[:,:,translation:,translation:,translation:] = img3d[:,:,:size-translation,:size-translation,:size-translation]
|
Diffusion/losses_opt.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
losses_opt.py — Optimized loss functions.
|
| 3 |
+
|
| 4 |
+
Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use
|
| 5 |
+
register_buffer for convolution kernels (auto device transfer, no
|
| 6 |
+
per-call .to(device) overhead).
|
| 7 |
+
|
| 8 |
+
All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported
|
| 9 |
+
unchanged.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
# Re-export unchanged classes
|
| 17 |
+
from Diffusion.losses import (
|
| 18 |
+
LMSE,
|
| 19 |
+
NCC,
|
| 20 |
+
MRSE,
|
| 21 |
+
RMSE,
|
| 22 |
+
Grad,
|
| 23 |
+
avg_std_skew_kurt,
|
| 24 |
+
grad_std,
|
| 25 |
+
avg_std,
|
| 26 |
+
EPS,
|
| 27 |
+
eps_scale,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LNCC(torch.nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Local (over window) normalized cross-correlation (LNCC).
|
| 34 |
+
Optimized: kernels stored as registered buffers for automatic device transfer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
|
| 38 |
+
super(LNCC, self).__init__()
|
| 39 |
+
self.scale = 2e0
|
| 40 |
+
self.win = win
|
| 41 |
+
self.eps = eps
|
| 42 |
+
self.central = central
|
| 43 |
+
self.ndims = 3
|
| 44 |
+
self.strides = [1] * (self.ndims + 2)
|
| 45 |
+
self.smooth = smooth
|
| 46 |
+
|
| 47 |
+
if self.win is None:
|
| 48 |
+
self.win = [11] * self.ndims
|
| 49 |
+
self.padding = [(w - 1) // 2 for w in self.win]
|
| 50 |
+
|
| 51 |
+
if smooth:
|
| 52 |
+
self.tail = None # will be set in _build_kernel
|
| 53 |
+
kernels = self._build_kernel(std=0.5)
|
| 54 |
+
self.register_buffer('kernels', kernels) # OPT: auto device transfer
|
| 55 |
+
self.register_buffer('sum_filt', self._build_kernel(std=0.0)) # OPT: auto device transfer
|
| 56 |
+
|
| 57 |
+
def _build_kernel(self, std=0.0):
|
| 58 |
+
if std == 0.0:
|
| 59 |
+
return torch.ones([1, 1, *self.win]) / np.prod(self.win)
|
| 60 |
+
else:
|
| 61 |
+
self.tail = int(np.ceil(std)) * 2
|
| 62 |
+
k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
|
| 63 |
+
kernel = k / torch.sum(k)
|
| 64 |
+
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| 65 |
+
return kernel.unsqueeze(0).unsqueeze(0)
|
| 66 |
+
|
| 67 |
+
def lncc(self, I, J, label=None):
|
| 68 |
+
# OPT: no .to(I.device) needed — buffers auto-transfer with module.to()
|
| 69 |
+
|
| 70 |
+
if self.smooth:
|
| 71 |
+
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
|
| 72 |
+
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
|
| 73 |
+
|
| 74 |
+
I2 = I * I
|
| 75 |
+
J2 = J * J
|
| 76 |
+
IJ = I * J
|
| 77 |
+
|
| 78 |
+
if self.central:
|
| 79 |
+
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
|
| 80 |
+
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
|
| 81 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| 82 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| 83 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
| 84 |
+
|
| 85 |
+
cross = IJ_sum - (I_sum * J_sum)
|
| 86 |
+
I_var = I2_sum - (I_sum * I_sum)
|
| 87 |
+
J_var = J2_sum - (J_sum * J_sum)
|
| 88 |
+
else:
|
| 89 |
+
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| 90 |
+
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| 91 |
+
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
| 92 |
+
|
| 93 |
+
cross = IJ_sum
|
| 94 |
+
I_var = I2_sum
|
| 95 |
+
J_var = J2_sum
|
| 96 |
+
|
| 97 |
+
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
|
| 98 |
+
if label is not None:
|
| 99 |
+
label = label.float()
|
| 100 |
+
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
| 101 |
+
|
| 102 |
+
return torch.mean(cc)
|
| 103 |
+
|
| 104 |
+
def forward(self, I, J, label=None):
|
| 105 |
+
return -self.lncc(I * self.scale, J * self.scale, label=label)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class MSLNCC(LNCC):
|
| 109 |
+
"""
|
| 110 |
+
Multi-Scale Local Normalized Cross-Correlation (MSLNCC).
|
| 111 |
+
Optimized: inherits buffer-based kernels from LNCC.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
|
| 115 |
+
scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
|
| 116 |
+
super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
|
| 117 |
+
central=central, smooth=smooth)
|
| 118 |
+
if win is None:
|
| 119 |
+
win = [9] * self.ndims
|
| 120 |
+
self.scale_ratios = scale_ratios
|
| 121 |
+
self.scale_weights = scale_weights
|
| 122 |
+
|
| 123 |
+
def _downsample(self, I, J, label, ratio):
|
| 124 |
+
if ratio >= 1.0:
|
| 125 |
+
return I, J, label
|
| 126 |
+
factor = int(1.0 / ratio)
|
| 127 |
+
I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
|
| 128 |
+
J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
|
| 129 |
+
label_down = None
|
| 130 |
+
if label is not None:
|
| 131 |
+
label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
|
| 132 |
+
return I_down, J_down, label_down
|
| 133 |
+
|
| 134 |
+
def forward(self, I, J, label=None):
|
| 135 |
+
total_loss = 0.0
|
| 136 |
+
total_weight = 0.0
|
| 137 |
+
for ratio, weight in zip(self.scale_ratios, self.scale_weights):
|
| 138 |
+
I_s, J_s, label_s = self._downsample(I, J, label, ratio)
|
| 139 |
+
total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
|
| 140 |
+
total_weight += weight
|
| 141 |
+
return -total_loss / total_weight
|
Diffusion/networks.py
CHANGED
|
@@ -1,8 +1,28 @@
|
|
| 1 |
from torch import nn
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import math
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
def get_net(name="recresnet"):
|
| 8 |
name = name.lower()
|
|
@@ -16,8 +36,10 @@ def get_net(name="recresnet"):
|
|
| 16 |
net = RecMutAttnNet1
|
| 17 |
elif name == "defrecmutattnnet":
|
| 18 |
net = DefRec_MutAttnNet
|
| 19 |
-
elif name == "
|
| 20 |
-
net =
|
|
|
|
|
|
|
| 21 |
else:
|
| 22 |
net = None
|
| 23 |
return net
|
|
@@ -440,6 +462,7 @@ class DefRec_MutAttnNet(nn.Module):
|
|
| 440 |
nn.Linear(dim_out, dim_out)
|
| 441 |
)
|
| 442 |
|
|
|
|
| 443 |
class RecMutAttnNet1(nn.Module):
|
| 444 |
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 445 |
super(RecMutAttnNet1, self).__init__()
|
|
@@ -749,6 +772,8 @@ class RecMutAttnNet(nn.Module):
|
|
| 749 |
else:
|
| 750 |
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 751 |
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
|
|
|
|
| 752 |
|
| 753 |
return ddf
|
| 754 |
|
|
@@ -759,9 +784,9 @@ class RecMutAttnNet(nn.Module):
|
|
| 759 |
nn.Linear(dim_out, dim_out)
|
| 760 |
)
|
| 761 |
|
| 762 |
-
class
|
| 763 |
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 764 |
-
super(
|
| 765 |
|
| 766 |
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 767 |
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
|
@@ -785,16 +810,21 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 785 |
self.block_down = nn.ModuleList()
|
| 786 |
self.block_up = nn.ModuleList()
|
| 787 |
if self.conditional_input:
|
|
|
|
|
|
|
| 788 |
self.block_down_cond = nn.ModuleList()
|
| 789 |
self.fuse_conv0 = nn.ModuleList()
|
| 790 |
self.fuse_conv1 = nn.ModuleList()
|
| 791 |
-
self.
|
|
|
|
| 792 |
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 793 |
self.global_maxpool = Global_Maxpool(1)
|
| 794 |
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 795 |
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 796 |
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 797 |
-
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
|
|
|
|
|
|
| 798 |
self.img_res = [res]*self.dimension
|
| 799 |
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 800 |
[1, self.dimension]+list(self.img_res))
|
|
@@ -811,6 +841,11 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 811 |
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 812 |
))
|
| 813 |
if self.conditional_input:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 814 |
self.block_down_cond.append(nn.Sequential(
|
| 815 |
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 816 |
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
|
@@ -829,12 +864,14 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 829 |
))
|
| 830 |
|
| 831 |
# Bottleneck
|
|
|
|
| 832 |
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 833 |
self.b_mid = nn.Sequential(
|
| 834 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 835 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 836 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 837 |
)
|
|
|
|
| 838 |
|
| 839 |
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 840 |
|
|
@@ -860,6 +897,7 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 860 |
self.max_sz = [img_sz[0]] * self.dimension
|
| 861 |
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 862 |
|
|
|
|
| 863 |
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 864 |
if list(img_sz) != self.img_res:
|
| 865 |
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
|
@@ -870,6 +908,13 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 870 |
|
| 871 |
img = x
|
| 872 |
t = self.time_embed(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
for rec_id in range(rec_num):
|
| 875 |
if self.conditional_input:
|
|
@@ -879,7 +924,7 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 879 |
for i in range(self.hier_num):
|
| 880 |
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 881 |
if self.conditional_input:
|
| 882 |
-
tgt = self.block_down_cond[i](tgt)
|
| 883 |
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 884 |
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 885 |
enc_list.append(out)
|
|
@@ -893,19 +938,24 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 893 |
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 894 |
out_shape = out.shape
|
| 895 |
tgt_shape = tgt.shape
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
out_attn, _ = self.
|
|
|
|
| 899 |
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
|
|
|
| 900 |
out = out + out_attn
|
|
|
|
|
|
|
| 901 |
|
| 902 |
if self.conditional_input:
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
|
|
|
| 909 |
out_txt = self.txt_proc(out_txt)
|
| 910 |
out_txt = self.txt2img(out_txt)
|
| 911 |
out = out + out_txt
|
|
@@ -922,8 +972,264 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 922 |
else:
|
| 923 |
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 924 |
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 925 |
|
| 926 |
-
return ddf
|
| 927 |
|
| 928 |
def _make_te(self, dim_in, dim_out):
|
| 929 |
return nn.Sequential(
|
|
@@ -931,6 +1237,8 @@ class RecMutAttnNet_contrastive(nn.Module):
|
|
| 931 |
nn.ReLU(),
|
| 932 |
nn.Linear(dim_out, dim_out)
|
| 933 |
)
|
|
|
|
|
|
|
| 934 |
# class RecMutAttnNet(nn.Module):
|
| 935 |
# def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
|
| 936 |
# super(RecMutAttnNet, self).__init__()
|
|
@@ -1085,6 +1393,8 @@ def composite(ddfs,stn=None):
|
|
| 1085 |
comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
|
| 1086 |
return comp_ddf
|
| 1087 |
|
|
|
|
|
|
|
| 1088 |
class STN(nn.Module):
|
| 1089 |
def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
|
| 1090 |
super(STN, self).__init__()
|
|
@@ -1148,6 +1458,7 @@ class STN(nn.Module):
|
|
| 1148 |
resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
|
| 1149 |
return resampled_x
|
| 1150 |
|
|
|
|
| 1151 |
if __name__ == '__main__':
|
| 1152 |
ndims = 3
|
| 1153 |
res = 128
|
|
|
|
| 1 |
from torch import nn
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.checkpoint import checkpoint as grad_checkpoint
|
| 5 |
import numpy as np
|
| 6 |
import math
|
| 7 |
+
from Diffusion.safe_conv_transpose import SafeConvTranspose3d
|
| 8 |
+
|
| 9 |
+
class UpsampleConv(nn.Module):
|
| 10 |
+
"""Drop-in replacement for ConvTranspose3d/2d that avoids the XPU memory leak.
|
| 11 |
+
ConvTranspose3d backward leaks ~0.33 GiB/step on Intel XPU (oneDNN bug).
|
| 12 |
+
This uses F.interpolate (zero leak) + Conv (negligible leak) instead.
|
| 13 |
+
Also avoids checkerboard artifacts common with transposed convolutions.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, ndims=3):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.scale_factor = stride
|
| 18 |
+
self.mode = 'trilinear' if ndims == 3 else 'bilinear'
|
| 19 |
+
Conv = getattr(nn, f'Conv{ndims}d')
|
| 20 |
+
self.conv = Conv(in_channels, out_channels, 3, 1, 1)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
|
| 24 |
+
return self.conv(x)
|
| 25 |
+
|
| 26 |
|
| 27 |
def get_net(name="recresnet"):
|
| 28 |
name = name.lower()
|
|
|
|
| 36 |
net = RecMutAttnNet1
|
| 37 |
elif name == "defrecmutattnnet":
|
| 38 |
net = DefRec_MutAttnNet
|
| 39 |
+
elif name == "recmulmodmutattnnet":
|
| 40 |
+
net = RecMulModMutAttnNet
|
| 41 |
+
elif name == "om_net":
|
| 42 |
+
net = OM_net
|
| 43 |
else:
|
| 44 |
net = None
|
| 45 |
return net
|
|
|
|
| 462 |
nn.Linear(dim_out, dim_out)
|
| 463 |
)
|
| 464 |
|
| 465 |
+
|
| 466 |
class RecMutAttnNet1(nn.Module):
|
| 467 |
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 468 |
super(RecMutAttnNet1, self).__init__()
|
|
|
|
| 772 |
else:
|
| 773 |
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 774 |
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 775 |
+
|
| 776 |
+
# print(torch.max(torch.abs(ddf)))
|
| 777 |
|
| 778 |
return ddf
|
| 779 |
|
|
|
|
| 784 |
nn.Linear(dim_out, dim_out)
|
| 785 |
)
|
| 786 |
|
| 787 |
+
class RecMulModMutAttnNet(nn.Module):
|
| 788 |
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 789 |
+
super(RecMulModMutAttnNet, self).__init__()
|
| 790 |
|
| 791 |
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 792 |
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
|
|
|
| 810 |
self.block_down = nn.ModuleList()
|
| 811 |
self.block_up = nn.ModuleList()
|
| 812 |
if self.conditional_input:
|
| 813 |
+
# self.gate_img = nn.ModuleList()
|
| 814 |
+
self.txt_layers = nn.ModuleList()
|
| 815 |
self.block_down_cond = nn.ModuleList()
|
| 816 |
self.fuse_conv0 = nn.ModuleList()
|
| 817 |
self.fuse_conv1 = nn.ModuleList()
|
| 818 |
+
self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 819 |
+
self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 820 |
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 821 |
self.global_maxpool = Global_Maxpool(1)
|
| 822 |
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 823 |
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 824 |
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 825 |
+
# self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 826 |
+
self.text = torch.zeros(1, self.text_feat_chn)
|
| 827 |
+
|
| 828 |
self.img_res = [res]*self.dimension
|
| 829 |
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 830 |
[1, self.dimension]+list(self.img_res))
|
|
|
|
| 841 |
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 842 |
))
|
| 843 |
if self.conditional_input:
|
| 844 |
+
# self.gate_img.append(nn.Sequential(
|
| 845 |
+
# nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0),
|
| 846 |
+
# nn.Sigmoid()
|
| 847 |
+
# ))
|
| 848 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
|
| 849 |
self.block_down_cond.append(nn.Sequential(
|
| 850 |
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 851 |
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
|
|
|
| 864 |
))
|
| 865 |
|
| 866 |
# Bottleneck
|
| 867 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
|
| 868 |
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 869 |
self.b_mid = nn.Sequential(
|
| 870 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 871 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 872 |
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 873 |
)
|
| 874 |
+
self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
|
| 875 |
|
| 876 |
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 877 |
|
|
|
|
| 897 |
self.max_sz = [img_sz[0]] * self.dimension
|
| 898 |
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 899 |
|
| 900 |
+
|
| 901 |
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 902 |
if list(img_sz) != self.img_res:
|
| 903 |
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
|
|
|
| 908 |
|
| 909 |
img = x
|
| 910 |
t = self.time_embed(t)
|
| 911 |
+
if text is None:
|
| 912 |
+
text = self.text
|
| 913 |
+
# print(text.shape)
|
| 914 |
+
text = text.to(self.device)
|
| 915 |
+
txt_shape = [1,-1]+[1]*self.dimension
|
| 916 |
+
else:
|
| 917 |
+
txt_shape = [n,-1]+[1]*self.dimension
|
| 918 |
|
| 919 |
for rec_id in range(rec_num):
|
| 920 |
if self.conditional_input:
|
|
|
|
| 924 |
for i in range(self.hier_num):
|
| 925 |
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 926 |
if self.conditional_input:
|
| 927 |
+
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
|
| 928 |
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 929 |
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 930 |
enc_list.append(out)
|
|
|
|
| 938 |
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 939 |
out_shape = out.shape
|
| 940 |
tgt_shape = tgt.shape
|
| 941 |
+
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 942 |
+
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 943 |
+
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| 944 |
+
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| 945 |
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 946 |
+
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 947 |
out = out + out_attn
|
| 948 |
+
tgt = tgt + tgt_attn
|
| 949 |
+
out = self.fuse(torch.cat([out, tgt], dim=1))
|
| 950 |
|
| 951 |
if self.conditional_input:
|
| 952 |
+
|
| 953 |
+
# text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| 954 |
+
|
| 955 |
+
# out_txt = self.img2txt(out) + text.reshape(txt_shape)
|
| 956 |
+
img_txt_feat = self.img2txt(out)
|
| 957 |
+
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
|
| 958 |
+
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
|
| 959 |
out_txt = self.txt_proc(out_txt)
|
| 960 |
out_txt = self.txt2img(out_txt)
|
| 961 |
out = out + out_txt
|
|
|
|
| 972 |
else:
|
| 973 |
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 974 |
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 975 |
+
|
| 976 |
+
# print(torch.max(torch.abs(ddf)))
|
| 977 |
+
|
| 978 |
+
return ddf
|
| 979 |
+
|
| 980 |
+
def _make_te(self, dim_in, dim_out):
|
| 981 |
+
return nn.Sequential(
|
| 982 |
+
nn.Linear(dim_in, dim_out),
|
| 983 |
+
nn.ReLU(),
|
| 984 |
+
nn.Linear(dim_out, dim_out)
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
|
| 988 |
+
class OM_net(nn.Module):
|
| 989 |
+
"""
|
| 990 |
+
Extended RecMulModMutAttnNet with gated attention mechanisms:
|
| 991 |
+
1. Text Gate (bottleneck): sigmoid weight w_txt to interpolate between
|
| 992 |
+
text-enhanced features and raw image features. Learns to suppress
|
| 993 |
+
text branch when text embedding is zeros (no text provided).
|
| 994 |
+
2. Target Gate (each encoder level): per-voxel spatial gate using
|
| 995 |
+
residual AtrousBlock to identify condition vs. noise voxels in the
|
| 996 |
+
target/condition image path, weighting the fuse_conv1 output.
|
| 997 |
+
|
| 998 |
+
Supports gradient checkpointing via `use_checkpoint` flag to reduce
|
| 999 |
+
peak activation memory (trades compute for memory).
|
| 1000 |
+
"""
|
| 1001 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0,
|
| 1002 |
+
conditional_input=True, text_feat_chn=1024, num_heads=4,
|
| 1003 |
+
use_conv_transpose=False):
|
| 1004 |
+
super(OM_net, self).__init__()
|
| 1005 |
+
self.use_checkpoint = False # Set True to enable gradient checkpointing
|
| 1006 |
+
self.use_conv_transpose = use_conv_transpose
|
| 1007 |
+
|
| 1008 |
+
self.feat_channels = [num_input_chn, 12, 32, 64, 128, 512]
|
| 1009 |
+
self.conditional_input = conditional_input
|
| 1010 |
+
self.num_heads = num_heads
|
| 1011 |
+
self.text_feat_chn = text_feat_chn
|
| 1012 |
+
|
| 1013 |
+
self.dimension = ndims
|
| 1014 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 1015 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 1016 |
+
|
| 1017 |
+
# Sinusoidal embedding
|
| 1018 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 1019 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 1020 |
+
self.time_embed.requires_grad_(False)
|
| 1021 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 1022 |
+
self.down_layers = nn.ModuleList()
|
| 1023 |
+
self.up_layers = nn.ModuleList()
|
| 1024 |
+
self.ted_layers = nn.ModuleList()
|
| 1025 |
+
self.teu_layers = nn.ModuleList()
|
| 1026 |
+
self.block_down = nn.ModuleList()
|
| 1027 |
+
self.block_up = nn.ModuleList()
|
| 1028 |
+
if self.conditional_input:
|
| 1029 |
+
self.txt_layers = nn.ModuleList()
|
| 1030 |
+
self.block_down_cond = nn.ModuleList()
|
| 1031 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 1032 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 1033 |
+
self.tgt_gate = nn.ModuleList() # Target gate per encoder level
|
| 1034 |
+
self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 1035 |
+
self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 1036 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 1037 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 1038 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 1039 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 1040 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 1041 |
+
self.text = torch.zeros(1, self.text_feat_chn)
|
| 1042 |
+
|
| 1043 |
+
# Text Gate: text-only MLP → sigmoid weight (computed before rec loop)
|
| 1044 |
+
self.text_gate = nn.Sequential(
|
| 1045 |
+
nn.Linear(self.text_feat_chn, self.text_feat_chn // 4),
|
| 1046 |
+
nn.ReLU(),
|
| 1047 |
+
nn.Linear(self.text_feat_chn // 4, 1),
|
| 1048 |
+
nn.Sigmoid()
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
self.img_res = [res]*self.dimension
|
| 1052 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 1053 |
+
[1, self.dimension]+list(self.img_res))
|
| 1054 |
+
|
| 1055 |
+
for i in range(1, self.hier_num + 1):
|
| 1056 |
+
j=-i
|
| 1057 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 1058 |
+
self.up_layers.append(SafeConvTranspose3d(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 1059 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 1060 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 1061 |
+
self.block_down.append(nn.Sequential(
|
| 1062 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 1063 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 1064 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 1065 |
+
))
|
| 1066 |
+
if self.conditional_input:
|
| 1067 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
|
| 1068 |
+
self.block_down_cond.append(nn.Sequential(
|
| 1069 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 1070 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 1071 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 1072 |
+
))
|
| 1073 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 1074 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 1075 |
+
# Target Gate: residual AtrousBlock → 2-channel softmax (condition vs noise)
|
| 1076 |
+
self.tgt_gate.append(nn.Sequential(
|
| 1077 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims,
|
| 1078 |
+
self.feat_channels[i], self.feat_channels[i], ndims=ndims, atrous_rates=[1, 3]),
|
| 1079 |
+
self.Conv(self.feat_channels[i], 2, 1, 1, 0)
|
| 1080 |
+
))
|
| 1081 |
+
if i==self.hier_num:
|
| 1082 |
+
k=j
|
| 1083 |
+
else:
|
| 1084 |
+
k=j-1
|
| 1085 |
+
self.block_up.append(nn.Sequential(
|
| 1086 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 1087 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 1088 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 1089 |
+
))
|
| 1090 |
+
|
| 1091 |
+
# Bottleneck
|
| 1092 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
|
| 1093 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 1094 |
+
self.b_mid = nn.Sequential(
|
| 1095 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 1096 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 1097 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 1098 |
+
)
|
| 1099 |
+
self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
|
| 1100 |
+
|
| 1101 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 1102 |
+
|
| 1103 |
+
# Initialize target gates toward pass-through (condition confidence high)
|
| 1104 |
+
self._init_tgt_gates()
|
| 1105 |
+
|
| 1106 |
+
def _init_tgt_gates(self):
|
| 1107 |
+
"""Bias target gates so condition channel starts moderately high (~0.73).
|
| 1108 |
+
Milder than [2,-2] to ensure both cond*tgt and (1-cond)*out halves of
|
| 1109 |
+
fuse_conv1 input have enough signal for healthy early gradient flow."""
|
| 1110 |
+
for gate_seq in self.tgt_gate:
|
| 1111 |
+
final_conv = gate_seq[-1] # the Conv that outputs 2 channels
|
| 1112 |
+
with torch.no_grad():
|
| 1113 |
+
final_conv.bias.data[0] = 1.0 # condition channel → softmax ~0.73
|
| 1114 |
+
final_conv.bias.data[1] = -1.0 # noise channel → softmax ~0.27
|
| 1115 |
+
|
| 1116 |
+
def _encoder_level(self, i, out, tgt, t, ts_emb_shape, text, txt_shape, w_txt):
|
| 1117 |
+
"""Single encoder level — extracted for gradient checkpointing."""
|
| 1118 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 1119 |
+
if self.conditional_input and tgt is not None:
|
| 1120 |
+
tgt = self.block_down_cond[i](tgt) + w_txt * self.txt_layers[i](text).reshape(txt_shape)
|
| 1121 |
+
gate_logits = self.tgt_gate[i](tgt)
|
| 1122 |
+
cond_confidence = F.softmax(gate_logits, dim=1)[:, 0:1]
|
| 1123 |
+
tgt = self.fuse_conv1[i](torch.cat([cond_confidence*tgt, (1-cond_confidence)*out], axis=1))
|
| 1124 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 1125 |
+
return out, tgt
|
| 1126 |
+
|
| 1127 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 1128 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1129 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 1130 |
+
zip(sample_coords, max_sz)], 1)
|
| 1131 |
+
|
| 1132 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 1133 |
+
ref = self.ref_grid if ref is None else ref
|
| 1134 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 1135 |
+
resample_mode = 'bilinear'
|
| 1136 |
+
|
| 1137 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 1138 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 1139 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 1140 |
+
align_corners=True)
|
| 1141 |
+
|
| 1142 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 1143 |
+
self.device = x.device
|
| 1144 |
+
img_sz = x.size()[2:]
|
| 1145 |
+
n = x.size()[0]
|
| 1146 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 1147 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 1148 |
+
|
| 1149 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 1150 |
+
if list(img_sz) != self.img_res:
|
| 1151 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 1152 |
+
[1, self.dimension]+list(img_sz))
|
| 1153 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 1154 |
+
|
| 1155 |
+
img = x
|
| 1156 |
+
t = self.time_embed(t)
|
| 1157 |
+
if text is None:
|
| 1158 |
+
text = self.text
|
| 1159 |
+
text = text.to(self.device)
|
| 1160 |
+
txt_shape = [1,-1]+[1]*self.dimension
|
| 1161 |
+
else:
|
| 1162 |
+
txt_shape = [n,-1]+[1]*self.dimension
|
| 1163 |
+
|
| 1164 |
+
# Text Gate: compute w_txt from text embedding alone before rec loop
|
| 1165 |
+
txt_vec = text.view(text.size(0), -1) # [1, 1024] or [n, 1024]
|
| 1166 |
+
if txt_vec.size(0) == 1 and n > 1:
|
| 1167 |
+
txt_vec = txt_vec.expand(n, -1)
|
| 1168 |
+
w_txt = self.text_gate(txt_vec) # [B, 1]
|
| 1169 |
+
w_txt = w_txt.view([w_txt.size(0), 1] + [1] * self.dimension)
|
| 1170 |
+
|
| 1171 |
+
for rec_id in range(rec_num):
|
| 1172 |
+
if self.conditional_input:
|
| 1173 |
+
tgt = y
|
| 1174 |
+
enc_list = []
|
| 1175 |
+
out = img
|
| 1176 |
+
for i in range(self.hier_num):
|
| 1177 |
+
# Gradient checkpointing on early encoder levels (large feature maps)
|
| 1178 |
+
# to reduce peak activation memory. Levels 0-2 have 128^3, 64^3, 32^3 maps.
|
| 1179 |
+
if self.use_checkpoint and self.training and i < 3:
|
| 1180 |
+
out, tgt = grad_checkpoint(
|
| 1181 |
+
self._encoder_level, i, out, tgt if self.conditional_input else None,
|
| 1182 |
+
t, ts_emb_shape, text, txt_shape, w_txt,
|
| 1183 |
+
use_reentrant=False,
|
| 1184 |
+
)
|
| 1185 |
+
else:
|
| 1186 |
+
out, tgt = self._encoder_level(
|
| 1187 |
+
i, out, tgt if self.conditional_input else None,
|
| 1188 |
+
t, ts_emb_shape, text, txt_shape, w_txt,
|
| 1189 |
+
)
|
| 1190 |
+
enc_list.append(out)
|
| 1191 |
+
out = self.down_layers[i](out)
|
| 1192 |
+
if self.conditional_input:
|
| 1193 |
+
tgt = self.down_layers[i](tgt)
|
| 1194 |
+
|
| 1195 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 1196 |
+
if self.conditional_input:
|
| 1197 |
+
out_shape = out.shape
|
| 1198 |
+
tgt_shape = tgt.shape
|
| 1199 |
+
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
|
| 1200 |
+
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| 1201 |
+
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| 1202 |
+
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| 1203 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| 1204 |
+
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
|
| 1205 |
+
out = out + out_attn
|
| 1206 |
+
tgt = tgt + tgt_attn
|
| 1207 |
+
out = self.fuse(torch.cat([out, tgt], dim=1))
|
| 1208 |
+
|
| 1209 |
+
if self.conditional_input:
|
| 1210 |
+
img_txt_feat = self.img2txt(out)
|
| 1211 |
+
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
|
| 1212 |
+
out_txt = self.txt_layers[-1](text).reshape(txt_shape) - img_txt_feat
|
| 1213 |
+
out_txt = self.txt_proc(out_txt)
|
| 1214 |
+
out_txt = self.txt2img(out_txt)
|
| 1215 |
+
|
| 1216 |
+
# Text Gate: w_txt precomputed from text embedding alone
|
| 1217 |
+
out = (1 - w_txt) * out + w_txt * out_txt
|
| 1218 |
+
|
| 1219 |
+
for i in range(self.hier_num):
|
| 1220 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 1221 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 1222 |
+
|
| 1223 |
+
out = self.conv_out(out)/128
|
| 1224 |
+
|
| 1225 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 1226 |
+
if rec_id == 0:
|
| 1227 |
+
ddf = ddf_one
|
| 1228 |
+
else:
|
| 1229 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 1230 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 1231 |
|
| 1232 |
+
return ddf
|
| 1233 |
|
| 1234 |
def _make_te(self, dim_in, dim_out):
|
| 1235 |
return nn.Sequential(
|
|
|
|
| 1237 |
nn.ReLU(),
|
| 1238 |
nn.Linear(dim_out, dim_out)
|
| 1239 |
)
|
| 1240 |
+
|
| 1241 |
+
|
| 1242 |
# class RecMutAttnNet(nn.Module):
|
| 1243 |
# def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
|
| 1244 |
# super(RecMutAttnNet, self).__init__()
|
|
|
|
| 1393 |
comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
|
| 1394 |
return comp_ddf
|
| 1395 |
|
| 1396 |
+
|
| 1397 |
+
|
| 1398 |
class STN(nn.Module):
|
| 1399 |
def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
|
| 1400 |
super(STN, self).__init__()
|
|
|
|
| 1458 |
resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
|
| 1459 |
return resampled_x
|
| 1460 |
|
| 1461 |
+
|
| 1462 |
if __name__ == '__main__':
|
| 1463 |
ndims = 3
|
| 1464 |
res = 128
|
Diffusion/networks0.py
ADDED
|
@@ -0,0 +1,1195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
def get_net(name="recresnet"):
|
| 8 |
+
name = name.lower()
|
| 9 |
+
if name == "recresacnet":
|
| 10 |
+
net = RecResACNet
|
| 11 |
+
elif name == "recmutattnnet":
|
| 12 |
+
net = RecMutAttnNet
|
| 13 |
+
elif name == "recmutattnnet0":
|
| 14 |
+
net = RecMutAttnNet0
|
| 15 |
+
elif name == "recmutattnnet1":
|
| 16 |
+
net = RecMutAttnNet1
|
| 17 |
+
elif name == "defrecmutattnnet":
|
| 18 |
+
net = DefRec_MutAttnNet
|
| 19 |
+
elif name == "recmulmodmutattnnet":
|
| 20 |
+
net = RecMulModMutAttnNet
|
| 21 |
+
else:
|
| 22 |
+
net = None
|
| 23 |
+
return net
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def sinusoidal_embedding(n, d):
|
| 28 |
+
# Returns the standard positional embedding
|
| 29 |
+
embedding = torch.zeros(n, d)
|
| 30 |
+
wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
|
| 31 |
+
wk = wk.reshape((1, d))
|
| 32 |
+
t = torch.arange(n).reshape((n, 1))
|
| 33 |
+
embedding[:,::2] = torch.sin(t * wk[:,::2])
|
| 34 |
+
embedding[:,1::2] = torch.cos(t * wk[:,::2])
|
| 35 |
+
return embedding
|
| 36 |
+
|
| 37 |
+
class AtrousBlock(nn.Module):
|
| 38 |
+
def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
|
| 39 |
+
super(AtrousBlock, self).__init__()
|
| 40 |
+
# if 0 not in shape:
|
| 41 |
+
if normalize:
|
| 42 |
+
# print(shape)
|
| 43 |
+
# self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024
|
| 44 |
+
norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024
|
| 45 |
+
self.ln = norm(out_c,affine=True)
|
| 46 |
+
else:
|
| 47 |
+
self.ln = nn.Identity()
|
| 48 |
+
Conv=getattr(nn,'Conv%dd' % ndims)
|
| 49 |
+
if in_c!=out_c:
|
| 50 |
+
self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None
|
| 51 |
+
else:
|
| 52 |
+
self.conv0 = None
|
| 53 |
+
self.convs = nn.ModuleList([
|
| 54 |
+
Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
|
| 55 |
+
if ar>0 else Conv(out_c, out_c, 1, 1, 0)
|
| 56 |
+
for ar in atrous_rates
|
| 57 |
+
])
|
| 58 |
+
# self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding)
|
| 59 |
+
# self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding)
|
| 60 |
+
self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
|
| 61 |
+
# self.activation = nn.ReLU() if activation is None else activation
|
| 62 |
+
# self.activation = nn.ReLU()
|
| 63 |
+
self.normalize = normalize
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
if self.conv0 is not None:
|
| 67 |
+
x = self.conv0(x) #if self.conv0 is not None else x
|
| 68 |
+
x = self.ln(x) if self.normalize else x # jzheng 15/03/2024
|
| 69 |
+
out=nn.Identity()(x)
|
| 70 |
+
for conv in self.convs:
|
| 71 |
+
out = self.activation(out)
|
| 72 |
+
out = conv(out)
|
| 73 |
+
return self.activation(out+x)
|
| 74 |
+
|
| 75 |
+
# ==============================================
|
| 76 |
+
# Unconditional Network
|
| 77 |
+
# ==============================================
|
| 78 |
+
|
| 79 |
+
class RecResACNet(nn.Module):
|
| 80 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
|
| 81 |
+
super(RecResACNet, self).__init__()
|
| 82 |
+
|
| 83 |
+
self.dimension = ndims
|
| 84 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 85 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 86 |
+
|
| 87 |
+
# Sinusoidal embedding
|
| 88 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 89 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 90 |
+
self.time_embed.requires_grad_(False)
|
| 91 |
+
|
| 92 |
+
# First half
|
| 93 |
+
self.te1 = self._make_te(time_emb_dim, 1)
|
| 94 |
+
self.b1 = nn.Sequential(
|
| 95 |
+
AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
|
| 96 |
+
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
| 97 |
+
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
| 98 |
+
|
| 99 |
+
)
|
| 100 |
+
self.down1 = self.Conv(10, 10, 4, 2, 1)
|
| 101 |
+
|
| 102 |
+
self.te2 = self._make_te(time_emb_dim, 10)
|
| 103 |
+
self.b2 = nn.Sequential(
|
| 104 |
+
AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
|
| 105 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
|
| 106 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
|
| 107 |
+
)
|
| 108 |
+
self.down2 = self.Conv(20, 20, 4, 2, 1)
|
| 109 |
+
|
| 110 |
+
self.te3 = self._make_te(time_emb_dim, 20)
|
| 111 |
+
self.b3 = nn.Sequential(
|
| 112 |
+
AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
|
| 113 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
|
| 114 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
|
| 115 |
+
)
|
| 116 |
+
self.down3 = self.Conv(40, 40, 4, 2, 1)
|
| 117 |
+
|
| 118 |
+
# Bottleneck
|
| 119 |
+
self.te_mid = self._make_te(time_emb_dim, 40)
|
| 120 |
+
self.b_mid = nn.Sequential(
|
| 121 |
+
AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
|
| 122 |
+
AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
|
| 123 |
+
AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
# Second half
|
| 127 |
+
self.up1 = self.ConvT(40, 40, 4, 2, 1)
|
| 128 |
+
|
| 129 |
+
self.te4 = self._make_te(time_emb_dim, 80)
|
| 130 |
+
self.b4 = nn.Sequential(
|
| 131 |
+
AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
|
| 132 |
+
AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| 133 |
+
AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self.up2 = self.ConvT(20, 20, 4, 2, 1)
|
| 137 |
+
self.te5 = self._make_te(time_emb_dim, 40)
|
| 138 |
+
self.b5 = nn.Sequential(
|
| 139 |
+
AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| 140 |
+
AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| 141 |
+
AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.up3 = self.ConvT(10, 10, 4, 2, 1)
|
| 145 |
+
self.te_out = self._make_te(time_emb_dim, 20)
|
| 146 |
+
self.b_out = nn.Sequential(
|
| 147 |
+
AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| 148 |
+
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
|
| 149 |
+
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.conv_out = self.Conv(10, ndims, 3, 1, 1)
|
| 153 |
+
|
| 154 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 155 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 156 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 157 |
+
zip(sample_coords, max_sz)], 1)
|
| 158 |
+
|
| 159 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 160 |
+
ref = self.ref_grid if ref is None else ref
|
| 161 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 162 |
+
# resample_mode = 'bicubic'
|
| 163 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 164 |
+
# padding_mode = "border"
|
| 165 |
+
|
| 166 |
+
if True:
|
| 167 |
+
# return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
|
| 168 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 169 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 170 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 171 |
+
align_corners=True)
|
| 172 |
+
|
| 173 |
+
def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
|
| 174 |
+
#
|
| 175 |
+
self.device = x.device
|
| 176 |
+
# [h, w] = x.size()[2:]
|
| 177 |
+
img_sz = x.size()[2:]
|
| 178 |
+
n = x.size()[0]
|
| 179 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 180 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 181 |
+
# [h,w]=img_sz
|
| 182 |
+
# self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2])
|
| 183 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 184 |
+
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0),
|
| 185 |
+
# [1, 2, h, w]).to(self.device)
|
| 186 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 187 |
+
[1, self.dimension]+list(img_sz)).to(self.device)
|
| 188 |
+
img = x
|
| 189 |
+
|
| 190 |
+
# x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
|
| 191 |
+
t = self.time_embed(t)
|
| 192 |
+
|
| 193 |
+
for rec_id in range(rec_num):
|
| 194 |
+
out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28)
|
| 195 |
+
out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14)
|
| 196 |
+
out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7)
|
| 197 |
+
|
| 198 |
+
out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3)
|
| 199 |
+
|
| 200 |
+
out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
|
| 201 |
+
out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7)
|
| 202 |
+
|
| 203 |
+
out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
|
| 204 |
+
out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14)
|
| 205 |
+
|
| 206 |
+
out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
|
| 207 |
+
out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28)
|
| 208 |
+
|
| 209 |
+
out = self.conv_out(out)
|
| 210 |
+
|
| 211 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 212 |
+
if rec_id == 0:
|
| 213 |
+
ddf = ddf_one
|
| 214 |
+
else:
|
| 215 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 216 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 217 |
+
|
| 218 |
+
return ddf
|
| 219 |
+
|
| 220 |
+
def _make_te(self, dim_in, dim_out):
|
| 221 |
+
# make time embedding
|
| 222 |
+
|
| 223 |
+
return nn.Sequential(
|
| 224 |
+
nn.Linear(dim_in, dim_out),
|
| 225 |
+
# nn.SiLU(),
|
| 226 |
+
nn.ReLU(),
|
| 227 |
+
nn.Linear(dim_out, dim_out)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# ==============================================
|
| 231 |
+
# Conditional Network
|
| 232 |
+
# ==============================================
|
| 233 |
+
|
| 234 |
+
class cross_attn(nn.Module):
|
| 235 |
+
def __init__(self, q, k, v, ndims=2):
|
| 236 |
+
self.q = q
|
| 237 |
+
self.k = k
|
| 238 |
+
self.v = v
|
| 239 |
+
self.ndims = ndims
|
| 240 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
|
| 241 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
|
| 242 |
+
self.softmax = nn.Softmax(dim=-1)
|
| 243 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 244 |
+
|
| 245 |
+
def forward(self, x, y):
|
| 246 |
+
q = self.q(x)
|
| 247 |
+
k = self.k(y)
|
| 248 |
+
v = self.v(y)
|
| 249 |
+
attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
|
| 250 |
+
out = torch.matmul(attn, v)
|
| 251 |
+
return out
|
| 252 |
+
|
| 253 |
+
class DefRec_MutAttnNet(nn.Module):
|
| 254 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 255 |
+
super(DefRec_MutAttnNet, self).__init__()
|
| 256 |
+
|
| 257 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 258 |
+
# self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 259 |
+
self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
|
| 260 |
+
self.conditional_input = conditional_input
|
| 261 |
+
self.num_heads = num_heads
|
| 262 |
+
self.text_feat_chn = text_feat_chn
|
| 263 |
+
|
| 264 |
+
self.dimension = ndims
|
| 265 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 266 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 267 |
+
self.copy = nn.Identity()
|
| 268 |
+
# Sinusoidal embedding
|
| 269 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 270 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 271 |
+
self.time_embed.requires_grad_(False)
|
| 272 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 273 |
+
self.down_layers = nn.ModuleList()
|
| 274 |
+
self.up_layers = nn.ModuleList()
|
| 275 |
+
self.ted_layers = nn.ModuleList()
|
| 276 |
+
self.teu_layers = nn.ModuleList()
|
| 277 |
+
self.block_down = nn.ModuleList()
|
| 278 |
+
self.block_up = nn.ModuleList()
|
| 279 |
+
if self.conditional_input:
|
| 280 |
+
self.block_down_cond = nn.ModuleList()
|
| 281 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 282 |
+
# self.fuse_conv1 = nn.ModuleList()
|
| 283 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 284 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 285 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 286 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 287 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 288 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 289 |
+
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 290 |
+
self.img_res = [res]*self.dimension
|
| 291 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 292 |
+
[1, self.dimension]+list(self.img_res))
|
| 293 |
+
|
| 294 |
+
for i in range(1, self.hier_num + 1):
|
| 295 |
+
j=-i
|
| 296 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 297 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 298 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 299 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 300 |
+
self.block_down.append(nn.Sequential(
|
| 301 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 302 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 303 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 304 |
+
))
|
| 305 |
+
if self.conditional_input:
|
| 306 |
+
self.block_down_cond.append(nn.Sequential(
|
| 307 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 308 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 309 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 310 |
+
))
|
| 311 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 312 |
+
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 313 |
+
if i==self.hier_num:
|
| 314 |
+
k=j
|
| 315 |
+
else:
|
| 316 |
+
k=j-1
|
| 317 |
+
self.block_up.append(nn.Sequential(
|
| 318 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 319 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 320 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 321 |
+
))
|
| 322 |
+
|
| 323 |
+
# Bottleneck
|
| 324 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 325 |
+
self.b_mid = nn.Sequential(
|
| 326 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 327 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 328 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 332 |
+
|
| 333 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 334 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 335 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 336 |
+
zip(sample_coords, max_sz)], 1)
|
| 337 |
+
|
| 338 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 339 |
+
ref = self.ref_grid if ref is None else ref
|
| 340 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 341 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 342 |
+
|
| 343 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 344 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 345 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 346 |
+
align_corners=True)
|
| 347 |
+
|
| 348 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 349 |
+
self.device = x.device
|
| 350 |
+
img_sz = x.size()[2:]
|
| 351 |
+
n = x.size()[0]
|
| 352 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 353 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 354 |
+
|
| 355 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 356 |
+
if list(img_sz) != self.img_res:
|
| 357 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 358 |
+
# print(img_sz, self.img_res)
|
| 359 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 360 |
+
[1, self.dimension]+list(img_sz))
|
| 361 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 362 |
+
|
| 363 |
+
img = x
|
| 364 |
+
if self.conditional_input:
|
| 365 |
+
tgt = y
|
| 366 |
+
# encode the conditional input
|
| 367 |
+
tgt_down_list = []
|
| 368 |
+
for i in range(self.hier_num):
|
| 369 |
+
# out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
|
| 370 |
+
if self.conditional_input:
|
| 371 |
+
tgt = self.block_down_cond[i](tgt)
|
| 372 |
+
tgt_down_list.append(self.copy(tgt))
|
| 373 |
+
tgt = self.down_layers[i](tgt)
|
| 374 |
+
tgt_mid = self.copy(tgt)
|
| 375 |
+
tgt_shape = tgt_mid.shape
|
| 376 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 377 |
+
tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 378 |
+
|
| 379 |
+
t = [t0.to(self.device) for t0 in t]
|
| 380 |
+
t = [t0 for _ in range(rec_num) for t0 in t]
|
| 381 |
+
for rec_id,time in enumerate(t):
|
| 382 |
+
t_emb = self.time_embed(time)
|
| 383 |
+
|
| 384 |
+
# for rec_id in range(rec_num):
|
| 385 |
+
# if self.conditional_input:
|
| 386 |
+
# tgt = y
|
| 387 |
+
enc_list = []
|
| 388 |
+
out = img
|
| 389 |
+
for i in range(self.hier_num):
|
| 390 |
+
out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
|
| 391 |
+
if self.conditional_input:
|
| 392 |
+
# tgt = self.block_down_cond[i](tgt)
|
| 393 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
|
| 394 |
+
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 395 |
+
enc_list.append(out)
|
| 396 |
+
out = self.down_layers[i](out)
|
| 397 |
+
# if self.conditional_input:
|
| 398 |
+
# tgt = self.down_layers[i](tgt)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
|
| 402 |
+
if self.conditional_input:
|
| 403 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 404 |
+
out_shape = out.shape
|
| 405 |
+
# tgt_shape = tgt.shape
|
| 406 |
+
# # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 407 |
+
# tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 408 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
|
| 409 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 410 |
+
out = out + out_attn
|
| 411 |
+
|
| 412 |
+
if self.conditional_input:
|
| 413 |
+
if text is None:
|
| 414 |
+
text = self.text
|
| 415 |
+
text = text.to(self.device)
|
| 416 |
+
out_txt = self.img2txt(out) + text
|
| 417 |
+
out_txt = self.txt_proc(out_txt)
|
| 418 |
+
out_txt = self.txt2img(out_txt)
|
| 419 |
+
out = out + out_txt
|
| 420 |
+
|
| 421 |
+
for i in range(self.hier_num):
|
| 422 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 423 |
+
out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
|
| 424 |
+
|
| 425 |
+
out = self.conv_out(out)/128
|
| 426 |
+
|
| 427 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 428 |
+
if rec_id == 0:
|
| 429 |
+
ddf = ddf_one
|
| 430 |
+
else:
|
| 431 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 432 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 433 |
+
|
| 434 |
+
return ddf
|
| 435 |
+
|
| 436 |
+
def _make_te(self, dim_in, dim_out):
|
| 437 |
+
return nn.Sequential(
|
| 438 |
+
nn.Linear(dim_in, dim_out),
|
| 439 |
+
nn.ReLU(),
|
| 440 |
+
nn.Linear(dim_out, dim_out)
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
class RecMutAttnNet1(nn.Module):
|
| 445 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 446 |
+
super(RecMutAttnNet1, self).__init__()
|
| 447 |
+
|
| 448 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 449 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 450 |
+
self.conditional_input = conditional_input
|
| 451 |
+
self.num_heads = num_heads
|
| 452 |
+
self.text_feat_chn = text_feat_chn
|
| 453 |
+
|
| 454 |
+
self.dimension = ndims
|
| 455 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 456 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 457 |
+
|
| 458 |
+
# Sinusoidal embedding
|
| 459 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 460 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 461 |
+
self.time_embed.requires_grad_(False)
|
| 462 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 463 |
+
self.down_layers = nn.ModuleList()
|
| 464 |
+
self.up_layers = nn.ModuleList()
|
| 465 |
+
self.ted_layers = nn.ModuleList()
|
| 466 |
+
self.teu_layers = nn.ModuleList()
|
| 467 |
+
self.block_down = nn.ModuleList()
|
| 468 |
+
if self.conditional_input:
|
| 469 |
+
self.block_down_cond = nn.ModuleList()
|
| 470 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 471 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 472 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 473 |
+
|
| 474 |
+
self.block_up = nn.ModuleList()
|
| 475 |
+
|
| 476 |
+
for i in range(1, self.hier_num + 1):
|
| 477 |
+
j=-i
|
| 478 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 479 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 480 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 481 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 482 |
+
self.block_down.append(nn.Sequential(
|
| 483 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 484 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 485 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 486 |
+
))
|
| 487 |
+
if self.conditional_input:
|
| 488 |
+
self.block_down_cond.append(nn.Sequential(
|
| 489 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 490 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 491 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 492 |
+
))
|
| 493 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 494 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 495 |
+
if i==self.hier_num:
|
| 496 |
+
k=j
|
| 497 |
+
else:
|
| 498 |
+
k=j-1
|
| 499 |
+
self.block_up.append(nn.Sequential(
|
| 500 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 501 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 502 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 503 |
+
))
|
| 504 |
+
|
| 505 |
+
# Bottleneck
|
| 506 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 507 |
+
self.b_mid = nn.Sequential(
|
| 508 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 509 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 510 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 514 |
+
|
| 515 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 516 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 517 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 518 |
+
zip(sample_coords, max_sz)], 1)
|
| 519 |
+
|
| 520 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 521 |
+
ref = self.ref_grid if ref is None else ref
|
| 522 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 523 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 524 |
+
|
| 525 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 526 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 527 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 528 |
+
align_corners=True)
|
| 529 |
+
|
| 530 |
+
def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
|
| 531 |
+
self.device = x.device
|
| 532 |
+
img_sz = x.size()[2:]
|
| 533 |
+
n = x.size()[0]
|
| 534 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 535 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 536 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 537 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 538 |
+
[1, self.dimension]+list(img_sz)).to(self.device)
|
| 539 |
+
img = x
|
| 540 |
+
t = self.time_embed(t)
|
| 541 |
+
|
| 542 |
+
for rec_id in range(rec_num):
|
| 543 |
+
if self.conditional_input:
|
| 544 |
+
tgt = y
|
| 545 |
+
enc_list = []
|
| 546 |
+
out = img
|
| 547 |
+
for i in range(self.hier_num):
|
| 548 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 549 |
+
if self.conditional_input:
|
| 550 |
+
tgt = self.block_down_cond[i](tgt)
|
| 551 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 552 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 553 |
+
enc_list.append(out)
|
| 554 |
+
out = self.down_layers[i](out)
|
| 555 |
+
if self.conditional_input:
|
| 556 |
+
tgt = self.down_layers[i](tgt)
|
| 557 |
+
|
| 558 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 559 |
+
if self.conditional_input:
|
| 560 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 561 |
+
out_shape = out.shape
|
| 562 |
+
tgt_shape = tgt.shape
|
| 563 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 564 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 565 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 566 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 567 |
+
out = out + out_attn
|
| 568 |
+
|
| 569 |
+
for i in range(self.hier_num):
|
| 570 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 571 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 572 |
+
|
| 573 |
+
out = self.conv_out(out)/128
|
| 574 |
+
|
| 575 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 576 |
+
if rec_id == 0:
|
| 577 |
+
ddf = ddf_one
|
| 578 |
+
else:
|
| 579 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 580 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 581 |
+
|
| 582 |
+
return ddf
|
| 583 |
+
|
| 584 |
+
def _make_te(self, dim_in, dim_out):
|
| 585 |
+
return nn.Sequential(
|
| 586 |
+
nn.Linear(dim_in, dim_out),
|
| 587 |
+
nn.ReLU(),
|
| 588 |
+
nn.Linear(dim_out, dim_out)
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
class RecMutAttnNet(nn.Module):
|
| 592 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 593 |
+
super(RecMutAttnNet, self).__init__()
|
| 594 |
+
|
| 595 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 596 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 597 |
+
self.conditional_input = conditional_input
|
| 598 |
+
self.num_heads = num_heads
|
| 599 |
+
self.text_feat_chn = text_feat_chn
|
| 600 |
+
|
| 601 |
+
self.dimension = ndims
|
| 602 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 603 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 604 |
+
|
| 605 |
+
# Sinusoidal embedding
|
| 606 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 607 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 608 |
+
self.time_embed.requires_grad_(False)
|
| 609 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 610 |
+
self.down_layers = nn.ModuleList()
|
| 611 |
+
self.up_layers = nn.ModuleList()
|
| 612 |
+
self.ted_layers = nn.ModuleList()
|
| 613 |
+
self.teu_layers = nn.ModuleList()
|
| 614 |
+
self.block_down = nn.ModuleList()
|
| 615 |
+
self.block_up = nn.ModuleList()
|
| 616 |
+
if self.conditional_input:
|
| 617 |
+
self.block_down_cond = nn.ModuleList()
|
| 618 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 619 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 620 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 621 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 622 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 623 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 624 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 625 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 626 |
+
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 627 |
+
self.img_res = [res]*self.dimension
|
| 628 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 629 |
+
[1, self.dimension]+list(self.img_res))
|
| 630 |
+
|
| 631 |
+
for i in range(1, self.hier_num + 1):
|
| 632 |
+
j=-i
|
| 633 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 634 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 635 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 636 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 637 |
+
self.block_down.append(nn.Sequential(
|
| 638 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 639 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 640 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 641 |
+
))
|
| 642 |
+
if self.conditional_input:
|
| 643 |
+
self.block_down_cond.append(nn.Sequential(
|
| 644 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 645 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 646 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 647 |
+
))
|
| 648 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 649 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 650 |
+
if i==self.hier_num:
|
| 651 |
+
k=j
|
| 652 |
+
else:
|
| 653 |
+
k=j-1
|
| 654 |
+
self.block_up.append(nn.Sequential(
|
| 655 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 656 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 657 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 658 |
+
))
|
| 659 |
+
|
| 660 |
+
# Bottleneck
|
| 661 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 662 |
+
self.b_mid = nn.Sequential(
|
| 663 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 664 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 665 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 669 |
+
|
| 670 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 671 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 672 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 673 |
+
zip(sample_coords, max_sz)], 1)
|
| 674 |
+
|
| 675 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 676 |
+
ref = self.ref_grid if ref is None else ref
|
| 677 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 678 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 679 |
+
|
| 680 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 681 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 682 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 683 |
+
align_corners=True)
|
| 684 |
+
|
| 685 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 686 |
+
self.device = x.device
|
| 687 |
+
img_sz = x.size()[2:]
|
| 688 |
+
n = x.size()[0]
|
| 689 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 690 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 691 |
+
|
| 692 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 693 |
+
if list(img_sz) != self.img_res:
|
| 694 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 695 |
+
# print(img_sz, self.img_res)
|
| 696 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 697 |
+
[1, self.dimension]+list(img_sz))
|
| 698 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 699 |
+
|
| 700 |
+
img = x
|
| 701 |
+
t = self.time_embed(t)
|
| 702 |
+
|
| 703 |
+
for rec_id in range(rec_num):
|
| 704 |
+
if self.conditional_input:
|
| 705 |
+
tgt = y
|
| 706 |
+
enc_list = []
|
| 707 |
+
out = img
|
| 708 |
+
for i in range(self.hier_num):
|
| 709 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 710 |
+
if self.conditional_input:
|
| 711 |
+
tgt = self.block_down_cond[i](tgt)
|
| 712 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 713 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 714 |
+
enc_list.append(out)
|
| 715 |
+
out = self.down_layers[i](out)
|
| 716 |
+
if self.conditional_input:
|
| 717 |
+
tgt = self.down_layers[i](tgt)
|
| 718 |
+
|
| 719 |
+
|
| 720 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 721 |
+
if self.conditional_input:
|
| 722 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 723 |
+
out_shape = out.shape
|
| 724 |
+
tgt_shape = tgt.shape
|
| 725 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 726 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 727 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 728 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 729 |
+
out = out + out_attn
|
| 730 |
+
|
| 731 |
+
if self.conditional_input:
|
| 732 |
+
if text is None:
|
| 733 |
+
text = self.text
|
| 734 |
+
text = text.to(self.device)
|
| 735 |
+
text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| 736 |
+
out_txt = self.img2txt(out) + text
|
| 737 |
+
out_txt = self.txt_proc(out_txt)
|
| 738 |
+
out_txt = self.txt2img(out_txt)
|
| 739 |
+
out = out + out_txt
|
| 740 |
+
|
| 741 |
+
for i in range(self.hier_num):
|
| 742 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 743 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 744 |
+
|
| 745 |
+
out = self.conv_out(out)/128
|
| 746 |
+
|
| 747 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 748 |
+
if rec_id == 0:
|
| 749 |
+
ddf = ddf_one
|
| 750 |
+
else:
|
| 751 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 752 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 753 |
+
|
| 754 |
+
# print(torch.max(torch.abs(ddf)))
|
| 755 |
+
|
| 756 |
+
return ddf
|
| 757 |
+
|
| 758 |
+
def _make_te(self, dim_in, dim_out):
|
| 759 |
+
return nn.Sequential(
|
| 760 |
+
nn.Linear(dim_in, dim_out),
|
| 761 |
+
nn.ReLU(),
|
| 762 |
+
nn.Linear(dim_out, dim_out)
|
| 763 |
+
)
|
| 764 |
+
|
| 765 |
+
class RecMulModMutAttnNet(nn.Module):
|
| 766 |
+
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| 767 |
+
super(RecMulModMutAttnNet, self).__init__()
|
| 768 |
+
|
| 769 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 770 |
+
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| 771 |
+
self.conditional_input = conditional_input
|
| 772 |
+
self.num_heads = num_heads
|
| 773 |
+
self.text_feat_chn = text_feat_chn
|
| 774 |
+
|
| 775 |
+
self.dimension = ndims
|
| 776 |
+
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 777 |
+
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 778 |
+
|
| 779 |
+
# Sinusoidal embedding
|
| 780 |
+
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 781 |
+
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 782 |
+
self.time_embed.requires_grad_(False)
|
| 783 |
+
self.hier_num = len(self.feat_channels) - 1
|
| 784 |
+
self.down_layers = nn.ModuleList()
|
| 785 |
+
self.up_layers = nn.ModuleList()
|
| 786 |
+
self.ted_layers = nn.ModuleList()
|
| 787 |
+
self.teu_layers = nn.ModuleList()
|
| 788 |
+
self.block_down = nn.ModuleList()
|
| 789 |
+
self.block_up = nn.ModuleList()
|
| 790 |
+
if self.conditional_input:
|
| 791 |
+
# self.gate_img = nn.ModuleList()
|
| 792 |
+
self.txt_layers = nn.ModuleList()
|
| 793 |
+
self.block_down_cond = nn.ModuleList()
|
| 794 |
+
self.fuse_conv0 = nn.ModuleList()
|
| 795 |
+
self.fuse_conv1 = nn.ModuleList()
|
| 796 |
+
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| 797 |
+
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| 798 |
+
self.global_maxpool = Global_Maxpool(1)
|
| 799 |
+
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| 800 |
+
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| 801 |
+
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| 802 |
+
# self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| 803 |
+
self.text = torch.zeros(1, self.text_feat_chn)
|
| 804 |
+
|
| 805 |
+
self.img_res = [res]*self.dimension
|
| 806 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| 807 |
+
[1, self.dimension]+list(self.img_res))
|
| 808 |
+
|
| 809 |
+
for i in range(1, self.hier_num + 1):
|
| 810 |
+
j=-i
|
| 811 |
+
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 812 |
+
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 813 |
+
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 814 |
+
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 815 |
+
self.block_down.append(nn.Sequential(
|
| 816 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 817 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 818 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 819 |
+
))
|
| 820 |
+
if self.conditional_input:
|
| 821 |
+
# self.gate_img.append(nn.Sequential(
|
| 822 |
+
# nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0),
|
| 823 |
+
# nn.Sigmoid()
|
| 824 |
+
# ))
|
| 825 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
|
| 826 |
+
self.block_down_cond.append(nn.Sequential(
|
| 827 |
+
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 828 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 829 |
+
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 830 |
+
))
|
| 831 |
+
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 832 |
+
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 833 |
+
if i==self.hier_num:
|
| 834 |
+
k=j
|
| 835 |
+
else:
|
| 836 |
+
k=j-1
|
| 837 |
+
self.block_up.append(nn.Sequential(
|
| 838 |
+
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 839 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 840 |
+
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 841 |
+
))
|
| 842 |
+
|
| 843 |
+
# Bottleneck
|
| 844 |
+
self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
|
| 845 |
+
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 846 |
+
self.b_mid = nn.Sequential(
|
| 847 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 848 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 849 |
+
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 853 |
+
|
| 854 |
+
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 855 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 856 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 857 |
+
zip(sample_coords, max_sz)], 1)
|
| 858 |
+
|
| 859 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 860 |
+
ref = self.ref_grid if ref is None else ref
|
| 861 |
+
img_sz = self.max_sz if img_sz is None else img_sz
|
| 862 |
+
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 863 |
+
|
| 864 |
+
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 865 |
+
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 866 |
+
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 867 |
+
align_corners=True)
|
| 868 |
+
|
| 869 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 870 |
+
self.device = x.device
|
| 871 |
+
img_sz = x.size()[2:]
|
| 872 |
+
n = x.size()[0]
|
| 873 |
+
self.max_sz = [img_sz[0]] * self.dimension
|
| 874 |
+
ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 878 |
+
if list(img_sz) != self.img_res:
|
| 879 |
+
# print ("Reinitialize the ref_grid to match the model's input image size.")
|
| 880 |
+
# print(img_sz, self.img_res)
|
| 881 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 882 |
+
[1, self.dimension]+list(img_sz))
|
| 883 |
+
self.ref_grid = self.ref_grid.to(self.device)
|
| 884 |
+
|
| 885 |
+
img = x
|
| 886 |
+
t = self.time_embed(t)
|
| 887 |
+
if text is None:
|
| 888 |
+
text = self.text
|
| 889 |
+
# print(text.shape)
|
| 890 |
+
text = text.to(self.device)
|
| 891 |
+
txt_shape = [1,-1]+[1]*self.dimension
|
| 892 |
+
else:
|
| 893 |
+
txt_shape = [n,-1]+[1]*self.dimension
|
| 894 |
+
|
| 895 |
+
for rec_id in range(rec_num):
|
| 896 |
+
if self.conditional_input:
|
| 897 |
+
tgt = y
|
| 898 |
+
enc_list = []
|
| 899 |
+
out = img
|
| 900 |
+
for i in range(self.hier_num):
|
| 901 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 902 |
+
if self.conditional_input:
|
| 903 |
+
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
|
| 904 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 905 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 906 |
+
enc_list.append(out)
|
| 907 |
+
out = self.down_layers[i](out)
|
| 908 |
+
if self.conditional_input:
|
| 909 |
+
tgt = self.down_layers[i](tgt)
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 913 |
+
if self.conditional_input:
|
| 914 |
+
# out += self.attn_layer(out, tgt, tgt)[0]
|
| 915 |
+
out_shape = out.shape
|
| 916 |
+
tgt_shape = tgt.shape
|
| 917 |
+
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 918 |
+
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
|
| 919 |
+
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| 920 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
|
| 921 |
+
out = out + out_attn
|
| 922 |
+
|
| 923 |
+
if self.conditional_input:
|
| 924 |
+
|
| 925 |
+
# text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| 926 |
+
|
| 927 |
+
# out_txt = self.img2txt(out) + text.reshape(txt_shape)
|
| 928 |
+
img_txt_feat = self.img2txt(out)
|
| 929 |
+
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
|
| 930 |
+
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
|
| 931 |
+
out_txt = self.txt_proc(out_txt)
|
| 932 |
+
out_txt = self.txt2img(out_txt)
|
| 933 |
+
out = out + out_txt
|
| 934 |
+
|
| 935 |
+
for i in range(self.hier_num):
|
| 936 |
+
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 937 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 938 |
+
|
| 939 |
+
out = self.conv_out(out)/128
|
| 940 |
+
|
| 941 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 942 |
+
if rec_id == 0:
|
| 943 |
+
ddf = ddf_one
|
| 944 |
+
else:
|
| 945 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 946 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 947 |
+
|
| 948 |
+
# print(torch.max(torch.abs(ddf)))
|
| 949 |
+
|
| 950 |
+
return ddf
|
| 951 |
+
|
| 952 |
+
def _make_te(self, dim_in, dim_out):
|
| 953 |
+
return nn.Sequential(
|
| 954 |
+
nn.Linear(dim_in, dim_out),
|
| 955 |
+
nn.ReLU(),
|
| 956 |
+
nn.Linear(dim_out, dim_out)
|
| 957 |
+
)
|
| 958 |
+
|
| 959 |
+
# class RecMutAttnNet(nn.Module):
|
| 960 |
+
# def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
|
| 961 |
+
# super(RecMutAttnNet, self).__init__()
|
| 962 |
+
|
| 963 |
+
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
|
| 964 |
+
# self.conditional_input = conditional_input
|
| 965 |
+
|
| 966 |
+
# self.dimension = ndims
|
| 967 |
+
# self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| 968 |
+
# self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| 969 |
+
|
| 970 |
+
# # Sinusoidal embedding
|
| 971 |
+
# self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| 972 |
+
# self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| 973 |
+
# self.time_embed.requires_grad_(False)
|
| 974 |
+
# self.hier_num = len(self.feat_channels) - 1
|
| 975 |
+
# self.down_layers = nn.ModuleList()
|
| 976 |
+
# self.up_layers = nn.ModuleList()
|
| 977 |
+
# self.ted_layers = nn.ModuleList()
|
| 978 |
+
# self.teu_layers = nn.ModuleList()
|
| 979 |
+
# self.block_down = nn.ModuleList()
|
| 980 |
+
# if self.conditional_input:
|
| 981 |
+
# self.block_down_cond = nn.ModuleList()
|
| 982 |
+
# self.fuse_conv0 = nn.ModuleList()
|
| 983 |
+
# self.fuse_conv1 = nn.ModuleList()
|
| 984 |
+
# self.block_up = nn.ModuleList()
|
| 985 |
+
|
| 986 |
+
# for i in range(1, self.hier_num + 1):
|
| 987 |
+
# j=-i
|
| 988 |
+
# self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| 989 |
+
# self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| 990 |
+
# self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| 991 |
+
# self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| 992 |
+
# self.block_down.append(nn.Sequential(
|
| 993 |
+
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 994 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 995 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 996 |
+
# ))
|
| 997 |
+
# if self.conditional_input:
|
| 998 |
+
# self.block_down_cond.append(nn.Sequential(
|
| 999 |
+
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| 1000 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| 1001 |
+
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| 1002 |
+
# ))
|
| 1003 |
+
# self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 1004 |
+
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| 1005 |
+
# if i==self.hier_num:
|
| 1006 |
+
# k=j
|
| 1007 |
+
# else:
|
| 1008 |
+
# k=j-1
|
| 1009 |
+
# self.block_up.append(nn.Sequential(
|
| 1010 |
+
# AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 1011 |
+
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| 1012 |
+
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| 1013 |
+
# ))
|
| 1014 |
+
|
| 1015 |
+
# # Bottleneck
|
| 1016 |
+
# self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| 1017 |
+
# self.b_mid = nn.Sequential(
|
| 1018 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 1019 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| 1020 |
+
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| 1021 |
+
# )
|
| 1022 |
+
|
| 1023 |
+
# self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
| 1024 |
+
|
| 1025 |
+
# def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| 1026 |
+
# sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1027 |
+
# return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 1028 |
+
# zip(sample_coords, max_sz)], 1)
|
| 1029 |
+
|
| 1030 |
+
# def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 1031 |
+
# ref = self.ref_grid if ref is None else ref
|
| 1032 |
+
# img_sz = self.max_sz if img_sz is None else img_sz
|
| 1033 |
+
# resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
|
| 1034 |
+
|
| 1035 |
+
# return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| 1036 |
+
# np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| 1037 |
+
# [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| 1038 |
+
# align_corners=True)
|
| 1039 |
+
|
| 1040 |
+
# def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
|
| 1041 |
+
# self.device = x.device
|
| 1042 |
+
# img_sz = x.size()[2:]
|
| 1043 |
+
# n = x.size()[0]
|
| 1044 |
+
# self.max_sz = [img_sz[0]] * self.dimension
|
| 1045 |
+
# ts_emb_shape=[n,-1]+[1]*self.dimension
|
| 1046 |
+
# self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| 1047 |
+
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| 1048 |
+
# [1, self.dimension]+list(img_sz)).to(self.device)
|
| 1049 |
+
# img = x
|
| 1050 |
+
# t = self.time_embed(t)
|
| 1051 |
+
|
| 1052 |
+
# for rec_id in range(rec_num):
|
| 1053 |
+
# if self.conditional_input:
|
| 1054 |
+
# tgt = y
|
| 1055 |
+
# enc_list = []
|
| 1056 |
+
# out = img
|
| 1057 |
+
# for i in range(self.hier_num):
|
| 1058 |
+
# out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 1059 |
+
# if self.conditional_input:
|
| 1060 |
+
# tgt = self.block_down_cond[i](tgt)
|
| 1061 |
+
# out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 1062 |
+
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 1063 |
+
# enc_list.append(out)
|
| 1064 |
+
# out = self.down_layers[i](out)
|
| 1065 |
+
# if self.conditional_input:
|
| 1066 |
+
# tgt = self.down_layers[i](tgt)
|
| 1067 |
+
|
| 1068 |
+
# out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 1069 |
+
# if self.conditional_input:
|
| 1070 |
+
# out = out + tgt
|
| 1071 |
+
|
| 1072 |
+
# for i in range(self.hier_num):
|
| 1073 |
+
# out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| 1074 |
+
# out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 1075 |
+
|
| 1076 |
+
# out = self.conv_out(out)/128
|
| 1077 |
+
|
| 1078 |
+
# ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 1079 |
+
# if rec_id == 0:
|
| 1080 |
+
# ddf = ddf_one
|
| 1081 |
+
# else:
|
| 1082 |
+
# ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 1083 |
+
# img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 1084 |
+
|
| 1085 |
+
# return ddf
|
| 1086 |
+
|
| 1087 |
+
# def _make_te(self, dim_in, dim_out):
|
| 1088 |
+
# return nn.Sequential(
|
| 1089 |
+
# nn.Linear(dim_in, dim_out),
|
| 1090 |
+
# nn.ReLU(),
|
| 1091 |
+
# nn.Linear(dim_out, dim_out)
|
| 1092 |
+
# )
|
| 1093 |
+
# ==============================================
|
| 1094 |
+
# Layers
|
| 1095 |
+
# ==============================================
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def ddf_multiplier(dvf,mul_num=10,stn=None):
|
| 1099 |
+
ddf=dvf
|
| 1100 |
+
for i in range(mul_num):
|
| 1101 |
+
ddf = dvf + stn(ddf, dvf)
|
| 1102 |
+
return ddf
|
| 1103 |
+
|
| 1104 |
+
|
| 1105 |
+
def composite(ddfs,stn=None):
|
| 1106 |
+
if stn is None:
|
| 1107 |
+
stn = STN(device=ddfs[0].device,padding_mode="border")
|
| 1108 |
+
comp_ddf=ddfs[0]
|
| 1109 |
+
for i in range(1,len(ddfs)):
|
| 1110 |
+
comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
|
| 1111 |
+
return comp_ddf
|
| 1112 |
+
|
| 1113 |
+
|
| 1114 |
+
|
| 1115 |
+
class STN(nn.Module):
|
| 1116 |
+
def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
|
| 1117 |
+
super(STN, self).__init__()
|
| 1118 |
+
self.ndims=ndims
|
| 1119 |
+
self.img_sz=[img_sz]*ndims
|
| 1120 |
+
# self.img_sz=img_sz
|
| 1121 |
+
self.device = device
|
| 1122 |
+
self.padding_mode = padding_mode
|
| 1123 |
+
# max_sz=[128]*self.ndims
|
| 1124 |
+
max_sz=[img_sz]*self.ndims
|
| 1125 |
+
# max_sz=img_sz
|
| 1126 |
+
# max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz)
|
| 1127 |
+
# self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device)
|
| 1128 |
+
self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
|
| 1129 |
+
self.resample_mode=resample_mode
|
| 1130 |
+
if self.img_sz is not None:
|
| 1131 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
|
| 1132 |
+
[1, self.ndims] + self.img_sz).to(self.device)
|
| 1133 |
+
return
|
| 1134 |
+
def max_limit(self, sample_coords0, plus=0., minus=1.):
|
| 1135 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1136 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 1137 |
+
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| 1138 |
+
zip(sample_coords, self.max_sz)], 1)
|
| 1139 |
+
|
| 1140 |
+
def boundary_limit(self, sample_coords0, plus=0., minus=1.):
|
| 1141 |
+
|
| 1142 |
+
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| 1143 |
+
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
|
| 1144 |
+
return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
|
| 1145 |
+
zip(sample_coords, self.max_sz, self.ref_grid)], 1)
|
| 1146 |
+
|
| 1147 |
+
def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
|
| 1148 |
+
# print(vol.device, ddf.device)
|
| 1149 |
+
# print(self.device)
|
| 1150 |
+
# print('===================')
|
| 1151 |
+
device = ddf.device
|
| 1152 |
+
|
| 1153 |
+
ref = self.ref_grid if ref is None else ref
|
| 1154 |
+
if img_sz is None:
|
| 1155 |
+
img_sz = self.max_sz
|
| 1156 |
+
else:
|
| 1157 |
+
img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
|
| 1158 |
+
# resample_mode = 'bicubic'
|
| 1159 |
+
if self.resample_mode is None:
|
| 1160 |
+
resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear'
|
| 1161 |
+
else:
|
| 1162 |
+
resample_mode=self.resample_mode
|
| 1163 |
+
# padding_mode = "border"
|
| 1164 |
+
# print(ddf.shape, ref.shape)
|
| 1165 |
+
return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
|
| 1166 |
+
[0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
|
| 1167 |
+
padding_mode=padding_mode,
|
| 1168 |
+
align_corners=True)
|
| 1169 |
+
|
| 1170 |
+
def forward(self,x,ddf):
|
| 1171 |
+
self.device = x.device if self.device is None else self.device
|
| 1172 |
+
if self.img_sz is None:
|
| 1173 |
+
self.img_sz = list(x.size()[2:]).to(self.device)
|
| 1174 |
+
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
|
| 1175 |
+
resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
|
| 1176 |
+
return resampled_x
|
| 1177 |
+
|
| 1178 |
+
|
| 1179 |
+
if __name__ == '__main__':
|
| 1180 |
+
ndims = 3
|
| 1181 |
+
res = 128
|
| 1182 |
+
x = torch.rand([1, 1] + [res]*ndims)
|
| 1183 |
+
t = torch.randint(0, 1000, (1,))
|
| 1184 |
+
text = torch.rand([1, 1024] + [1]*ndims)
|
| 1185 |
+
model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
|
| 1186 |
+
y = model(x, x, t, text=text)
|
| 1187 |
+
print("Ouput shape", y.shape)
|
| 1188 |
+
|
| 1189 |
+
# Total parameters
|
| 1190 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 1191 |
+
# Trainable parameters only
|
| 1192 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 1193 |
+
|
| 1194 |
+
print(f"Total parameters: {total_params}")
|
| 1195 |
+
print(f"Trainable parameters: {trainable_params}")
|
Diffusion/networks_opt.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
networks_opt.py — Optimized network components.
|
| 3 |
+
|
| 4 |
+
Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:
|
| 5 |
+
1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call
|
| 6 |
+
2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —
|
| 7 |
+
eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step
|
| 8 |
+
|
| 9 |
+
All optimizations are mathematically equivalent to the originals.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from Diffusion.networks import RecMulModMutAttnNet, STN
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ======================================================================
|
| 21 |
+
# Optimized STN
|
| 22 |
+
# ======================================================================
|
| 23 |
+
|
| 24 |
+
class OptSTN(STN):
|
| 25 |
+
"""STN with register_buffer for automatic device transfer.
|
| 26 |
+
|
| 27 |
+
Eliminates per-call .to(device) overhead in resample() and forward().
|
| 28 |
+
Buffers auto-transfer when module.to(device) is called.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,
|
| 32 |
+
padding_mode="border", resample_mode=None):
|
| 33 |
+
# Skip parent __init__ to avoid creating plain tensor attributes
|
| 34 |
+
nn.Module.__init__(self)
|
| 35 |
+
self.ndims = ndims
|
| 36 |
+
self.img_sz = [img_sz] * ndims
|
| 37 |
+
self.device = device
|
| 38 |
+
self.padding_mode = padding_mode
|
| 39 |
+
self.resample_mode = resample_mode
|
| 40 |
+
|
| 41 |
+
# OPT: register_buffer — auto device transfer, no per-call .to()
|
| 42 |
+
max_sz_val = [img_sz] * ndims
|
| 43 |
+
max_sz_tensor = torch.Tensor(
|
| 44 |
+
np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
|
| 45 |
+
)
|
| 46 |
+
self.register_buffer('max_sz', max_sz_tensor)
|
| 47 |
+
|
| 48 |
+
if self.img_sz is not None:
|
| 49 |
+
ref_grid = torch.reshape(
|
| 50 |
+
torch.stack(torch.meshgrid(
|
| 51 |
+
[torch.arange(end=s) for s in self.img_sz]
|
| 52 |
+
), 0),
|
| 53 |
+
[1, self.ndims] + self.img_sz
|
| 54 |
+
)
|
| 55 |
+
self.register_buffer('ref_grid', ref_grid)
|
| 56 |
+
|
| 57 |
+
# OPT: pre-compute the img_sz tensor used when forward() calls resample()
|
| 58 |
+
img_sz_for_resample = torch.reshape(
|
| 59 |
+
torch.tensor([(s - 1) / 2. for s in self.img_sz]),
|
| 60 |
+
[1] + [1] * self.ndims + [self.ndims]
|
| 61 |
+
)
|
| 62 |
+
self.register_buffer('_img_sz_for_resample', img_sz_for_resample)
|
| 63 |
+
|
| 64 |
+
# OPT: pre-compute constant permutation order
|
| 65 |
+
self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]
|
| 66 |
+
|
| 67 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 68 |
+
# OPT: no .to(device) — buffers auto-transfer with module.to()
|
| 69 |
+
ref = self.ref_grid if ref is None else ref
|
| 70 |
+
|
| 71 |
+
if img_sz is None:
|
| 72 |
+
img_sz_t = self.max_sz
|
| 73 |
+
else:
|
| 74 |
+
# Use pre-computed tensor for the common case (called from forward)
|
| 75 |
+
img_sz_t = self._img_sz_for_resample
|
| 76 |
+
|
| 77 |
+
resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode
|
| 78 |
+
|
| 79 |
+
grid = torch.flip(
|
| 80 |
+
(ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
|
| 81 |
+
dims=[-1]
|
| 82 |
+
)
|
| 83 |
+
return F.grid_sample(vol, grid, mode=resample_mode,
|
| 84 |
+
padding_mode=padding_mode, align_corners=True)
|
| 85 |
+
|
| 86 |
+
def forward(self, x, ddf):
|
| 87 |
+
# OPT: no device check or ref_grid regeneration — buffers handle it
|
| 88 |
+
return self.resample(x, ddf=ddf, img_sz=self.img_sz,
|
| 89 |
+
padding_mode=self.padding_mode)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ======================================================================
|
| 93 |
+
# Optimized RecMulModMutAttnNet
|
| 94 |
+
# ======================================================================
|
| 95 |
+
|
| 96 |
+
class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
|
| 97 |
+
"""RecMulModMutAttnNet with cached tensors for resample/forward.
|
| 98 |
+
|
| 99 |
+
Eliminates per-call overhead:
|
| 100 |
+
- resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)
|
| 101 |
+
- forward(): cached img_sz tensor and ref_grid device placement
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, *args, **kwargs):
|
| 105 |
+
super().__init__(*args, **kwargs)
|
| 106 |
+
# Cache slots — populated on first forward
|
| 107 |
+
self._cached_input_key = None
|
| 108 |
+
self._cached_max_sz_tensor = None
|
| 109 |
+
self._cached_img_sz_tensor = None
|
| 110 |
+
# OPT: pre-compute constant permutation order
|
| 111 |
+
self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]
|
| 112 |
+
|
| 113 |
+
def _ensure_cache(self, img_sz, device):
|
| 114 |
+
"""Populate cached tensors if input size or device changed."""
|
| 115 |
+
key = (tuple(img_sz), device)
|
| 116 |
+
if key == self._cached_input_key:
|
| 117 |
+
return
|
| 118 |
+
self._cached_input_key = key
|
| 119 |
+
max_sz_list = [img_sz[0]] * self.dimension
|
| 120 |
+
self.max_sz = max_sz_list
|
| 121 |
+
|
| 122 |
+
# OPT: create max_sz tensor ONCE, reuse across all resample() calls
|
| 123 |
+
self._cached_max_sz_tensor = torch.Tensor(
|
| 124 |
+
np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
|
| 125 |
+
).to(device)
|
| 126 |
+
|
| 127 |
+
# OPT: create img_sz tensor ONCE per size change
|
| 128 |
+
self._cached_img_sz_tensor = torch.reshape(
|
| 129 |
+
torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
|
| 130 |
+
[1] * (self.dimension + 1) + [self.dimension]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# OPT: ref_grid — only regenerate if size changed, only .to() if needed
|
| 134 |
+
if list(img_sz) != self.img_res:
|
| 135 |
+
self.ref_grid = torch.reshape(
|
| 136 |
+
torch.stack(torch.meshgrid(
|
| 137 |
+
[torch.arange(end=imsz) for imsz in img_sz]
|
| 138 |
+
), 0),
|
| 139 |
+
[1, self.dimension] + list(img_sz)
|
| 140 |
+
).to(device)
|
| 141 |
+
elif self.ref_grid.device != torch.device(device):
|
| 142 |
+
self.ref_grid = self.ref_grid.to(device)
|
| 143 |
+
|
| 144 |
+
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| 145 |
+
# OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call
|
| 146 |
+
ref = self.ref_grid if ref is None else ref
|
| 147 |
+
img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor
|
| 148 |
+
|
| 149 |
+
grid = torch.flip(
|
| 150 |
+
(ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
|
| 151 |
+
dims=[-1]
|
| 152 |
+
)
|
| 153 |
+
return F.grid_sample(vol, grid, mode='bilinear',
|
| 154 |
+
padding_mode=padding_mode, align_corners=True)
|
| 155 |
+
|
| 156 |
+
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| 157 |
+
self.device = x.device
|
| 158 |
+
img_sz = x.size()[2:]
|
| 159 |
+
n = x.size()[0]
|
| 160 |
+
ts_emb_shape = [n, -1] + [1] * self.dimension
|
| 161 |
+
|
| 162 |
+
# OPT: cache tensors — only recreate if input size/device changes
|
| 163 |
+
self._ensure_cache(img_sz, self.device)
|
| 164 |
+
self.img_sz = self._cached_img_sz_tensor
|
| 165 |
+
|
| 166 |
+
img = x
|
| 167 |
+
t = self.time_embed(t)
|
| 168 |
+
if text is None:
|
| 169 |
+
text = self.text
|
| 170 |
+
text = text.to(self.device)
|
| 171 |
+
txt_shape = [1, -1] + [1] * self.dimension
|
| 172 |
+
else:
|
| 173 |
+
txt_shape = [n, -1] + [1] * self.dimension
|
| 174 |
+
|
| 175 |
+
for rec_id in range(rec_num):
|
| 176 |
+
if self.conditional_input:
|
| 177 |
+
tgt = y
|
| 178 |
+
enc_list = []
|
| 179 |
+
out = img
|
| 180 |
+
for i in range(self.hier_num):
|
| 181 |
+
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| 182 |
+
if self.conditional_input:
|
| 183 |
+
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
|
| 184 |
+
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| 185 |
+
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| 186 |
+
enc_list.append(out)
|
| 187 |
+
out = self.down_layers[i](out)
|
| 188 |
+
if self.conditional_input:
|
| 189 |
+
tgt = self.down_layers[i](tgt)
|
| 190 |
+
|
| 191 |
+
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| 192 |
+
if self.conditional_input:
|
| 193 |
+
out_shape = out.shape
|
| 194 |
+
tgt_shape = tgt.shape
|
| 195 |
+
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
|
| 196 |
+
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| 197 |
+
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| 198 |
+
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| 199 |
+
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| 200 |
+
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
|
| 201 |
+
out = out + out_attn
|
| 202 |
+
tgt = tgt + tgt_attn
|
| 203 |
+
out = self.fuse(torch.cat([out, tgt], dim=1))
|
| 204 |
+
|
| 205 |
+
if self.conditional_input:
|
| 206 |
+
img_txt_feat = self.img2txt(out)
|
| 207 |
+
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
|
| 208 |
+
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
|
| 209 |
+
out_txt = self.txt_proc(out_txt)
|
| 210 |
+
out_txt = self.txt2img(out_txt)
|
| 211 |
+
out = out + out_txt
|
| 212 |
+
|
| 213 |
+
for i in range(self.hier_num):
|
| 214 |
+
out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
|
| 215 |
+
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
| 216 |
+
|
| 217 |
+
out = self.conv_out(out) / 128
|
| 218 |
+
|
| 219 |
+
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| 220 |
+
if rec_id == 0:
|
| 221 |
+
ddf = ddf_one
|
| 222 |
+
else:
|
| 223 |
+
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| 224 |
+
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
| 225 |
+
|
| 226 |
+
return ddf
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
# ======================================================================
|
| 230 |
+
# Factory function
|
| 231 |
+
# ======================================================================
|
| 232 |
+
|
| 233 |
+
def get_net_opt(name):
|
| 234 |
+
"""Return optimized network class if available, else fall back to original."""
|
| 235 |
+
if name == "recmulmodmutattnnet":
|
| 236 |
+
return OptRecMulModMutAttnNet
|
| 237 |
+
# Fall back to original for other network types
|
| 238 |
+
from Diffusion.networks import get_net
|
| 239 |
+
return get_net(name)
|
Diffusion/safe_conv_transpose.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SafeConvTranspose3d: Drop-in replacement for nn.ConvTranspose3d that avoids
|
| 3 |
+
the XPU memory leak in the ConvTranspose3d backward pass (oneDNN autograd bug).
|
| 4 |
+
|
| 5 |
+
Mathematical Background
|
| 6 |
+
=======================
|
| 7 |
+
|
| 8 |
+
ConvTranspose3d (a.k.a. "transposed convolution" or "fractionally-strided
|
| 9 |
+
convolution") with parameters:
|
| 10 |
+
in_channels=C_in, out_channels=C_out, kernel_size=K, stride=S, padding=P
|
| 11 |
+
|
| 12 |
+
is the gradient (adjoint) of Conv3d with the same parameters. For an input x
|
| 13 |
+
of shape [B, C_in, D, H, W], the output has shape:
|
| 14 |
+
[B, C_out, S*(D-1) + K - 2*P, S*(H-1) + K - 2*P, S*(W-1) + K - 2*P]
|
| 15 |
+
|
| 16 |
+
For our specific case (K=4, S=2, P=1):
|
| 17 |
+
output_size = 2*(D-1) + 4 - 2 = 2*D (likewise for H, W)
|
| 18 |
+
|
| 19 |
+
The operation is mathematically equivalent to:
|
| 20 |
+
1. Stride insertion: insert (S-1) zeros between each input element
|
| 21 |
+
2. Padding: pad with (K - P - 1) zeros on each side
|
| 22 |
+
3. Regular Conv3d with spatially-flipped, channel-transposed weight
|
| 23 |
+
|
| 24 |
+
Specifically:
|
| 25 |
+
|
| 26 |
+
Step 1 - Stride insertion:
|
| 27 |
+
Input [B, C_in, D, H, W] -> [B, C_in, S*(D-1)+1, S*(H-1)+1, S*(W-1)+1]
|
| 28 |
+
For S=2: [B, C_in, 2*D-1, 2*H-1, 2*W-1]
|
| 29 |
+
Original values placed at positions 0, S, 2S, ... ; zeros elsewhere.
|
| 30 |
+
|
| 31 |
+
Step 2 - Padding:
|
| 32 |
+
Pad each spatial dimension with (K - P - 1) zeros on each side.
|
| 33 |
+
For K=4, P=1: pad = 2 on each side.
|
| 34 |
+
Shape becomes: [B, C_in, 2*D+3, 2*H+3, 2*W+3]
|
| 35 |
+
|
| 36 |
+
Step 3 - Conv3d with transformed weight:
|
| 37 |
+
ConvTranspose3d weight shape: [C_in, C_out, K, K, K]
|
| 38 |
+
Equivalent Conv3d weight: weight.flip(2,3,4).transpose(0,1)
|
| 39 |
+
-> shape [C_out, C_in, K, K, K]
|
| 40 |
+
|
| 41 |
+
Conv3d(stride=1, padding=0) on the padded input gives:
|
| 42 |
+
[B, C_out, (2*D+3 - K + 1), ...] = [B, C_out, 2*D, 2*H, 2*W] (correct!)
|
| 43 |
+
|
| 44 |
+
Why this is safe on XPU:
|
| 45 |
+
The forward uses F.pad (ZERO leak) and F.conv3d (negligible leak).
|
| 46 |
+
The backward is computed automatically by PyTorch's autograd through these
|
| 47 |
+
same safe ops — no ConvTranspose3d backward kernel is ever invoked.
|
| 48 |
+
Specifically:
|
| 49 |
+
- F.conv3d backward -> uses Conv3d backward (safe, 0.004 GiB/step)
|
| 50 |
+
- F.pad backward -> tensor slicing (trivially safe)
|
| 51 |
+
- Stride insertion backward -> gather at stride positions (trivially safe)
|
| 52 |
+
- weight.flip().transpose() backward -> indexing (trivially safe)
|
| 53 |
+
|
| 54 |
+
Forward precision:
|
| 55 |
+
Not bit-for-bit identical to nn.ConvTranspose3d due to different summation
|
| 56 |
+
order (stride-insert + pad + conv3d vs native transposed conv), but the
|
| 57 |
+
difference is negligible: max absolute diff < 5e-7 in float32, no elements
|
| 58 |
+
exceeding 1e-6. This is well within float32 machine epsilon for typical
|
| 59 |
+
activation magnitudes.
|
| 60 |
+
|
| 61 |
+
Backward precision:
|
| 62 |
+
Gradients match nn.ConvTranspose3d within 1e-5 (input) and 1e-4 (weight)
|
| 63 |
+
for float32. Verified across all channel configurations used in the
|
| 64 |
+
codebase (16-256 channels).
|
| 65 |
+
|
| 66 |
+
Implementation choices:
|
| 67 |
+
We also provide SafeConvTranspose3d_v2 which uses a custom autograd function
|
| 68 |
+
to call F.conv_transpose3d in the forward (bit-for-bit identical) but
|
| 69 |
+
replaces the backward with safe Conv3d-based gradient computation.
|
| 70 |
+
|
| 71 |
+
RECOMMENDATION: Use SafeConvTranspose3d (V1, decomposed forward) because:
|
| 72 |
+
- Simpler implementation with no custom autograd
|
| 73 |
+
- Fully transparent to PyTorch's autograd
|
| 74 |
+
- Compatible with gradient checkpointing, torch.compile, etc.
|
| 75 |
+
- The ~5e-7 forward precision loss is negligible for training
|
| 76 |
+
- V2's custom autograd requires careful maintenance and is fragile
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
import torch
|
| 80 |
+
import torch.nn as nn
|
| 81 |
+
import torch.nn.functional as F
|
| 82 |
+
from torch.autograd import Function
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# =============================================================================
|
| 86 |
+
# Approach 1 (RECOMMENDED): Decomposed forward pass
|
| 87 |
+
# =============================================================================
|
| 88 |
+
|
| 89 |
+
class SafeConvTranspose3d(nn.Module):
|
| 90 |
+
"""Drop-in replacement for nn.ConvTranspose3d that decomposes the operation
|
| 91 |
+
into stride insertion + padding + regular Conv3d.
|
| 92 |
+
|
| 93 |
+
All operations in forward (and thus all backward ops via autograd) are
|
| 94 |
+
safe on XPU: no ConvTranspose3d backward kernel is invoked.
|
| 95 |
+
|
| 96 |
+
Supports: kernel_size, stride, padding (scalar or tuple), bias, groups=1.
|
| 97 |
+
Does NOT support: output_padding, dilation != 1, groups != 1.
|
| 98 |
+
|
| 99 |
+
The weight tensor has the SAME shape as nn.ConvTranspose3d:
|
| 100 |
+
[in_channels, out_channels, *kernel_size]
|
| 101 |
+
so checkpoints can be loaded directly with load_state_dict().
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 105 |
+
padding=0, output_padding=0, groups=1, bias=True,
|
| 106 |
+
dilation=1, padding_mode='zeros'):
|
| 107 |
+
super().__init__()
|
| 108 |
+
|
| 109 |
+
if groups != 1:
|
| 110 |
+
raise NotImplementedError("SafeConvTranspose3d only supports groups=1")
|
| 111 |
+
if output_padding != 0:
|
| 112 |
+
raise NotImplementedError("SafeConvTranspose3d does not support output_padding")
|
| 113 |
+
|
| 114 |
+
# Normalize to tuples
|
| 115 |
+
if isinstance(kernel_size, int):
|
| 116 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 117 |
+
if isinstance(stride, int):
|
| 118 |
+
stride = (stride, stride, stride)
|
| 119 |
+
if isinstance(padding, int):
|
| 120 |
+
padding = (padding, padding, padding)
|
| 121 |
+
if isinstance(dilation, int):
|
| 122 |
+
dilation = (dilation, dilation, dilation)
|
| 123 |
+
if dilation != (1, 1, 1):
|
| 124 |
+
raise NotImplementedError("SafeConvTranspose3d does not support dilation != 1")
|
| 125 |
+
|
| 126 |
+
self.in_channels = in_channels
|
| 127 |
+
self.out_channels = out_channels
|
| 128 |
+
self.kernel_size = kernel_size
|
| 129 |
+
self.stride = stride
|
| 130 |
+
self.padding = padding
|
| 131 |
+
self.groups = groups
|
| 132 |
+
|
| 133 |
+
# Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
|
| 134 |
+
self.weight = nn.Parameter(
|
| 135 |
+
torch.empty(in_channels, out_channels, *kernel_size)
|
| 136 |
+
)
|
| 137 |
+
if bias:
|
| 138 |
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
| 139 |
+
else:
|
| 140 |
+
self.register_parameter('bias', None)
|
| 141 |
+
|
| 142 |
+
# Initialize weights same as nn.ConvTranspose3d
|
| 143 |
+
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
| 144 |
+
if self.bias is not None:
|
| 145 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
| 146 |
+
if fan_in != 0:
|
| 147 |
+
bound = 1 / fan_in**0.5
|
| 148 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
B, C_in, D, H, W = x.shape
|
| 152 |
+
sd, sh, sw = self.stride
|
| 153 |
+
kd, kh, kw = self.kernel_size
|
| 154 |
+
pd, ph, pw = self.padding
|
| 155 |
+
|
| 156 |
+
# Step 1: Stride insertion — place input values at stride positions,
|
| 157 |
+
# zeros elsewhere. This is the "fractionally-strided" part.
|
| 158 |
+
if sd > 1 or sh > 1 or sw > 1:
|
| 159 |
+
D_ins = sd * (D - 1) + 1
|
| 160 |
+
H_ins = sh * (H - 1) + 1
|
| 161 |
+
W_ins = sw * (W - 1) + 1
|
| 162 |
+
x_inserted = x.new_zeros(B, C_in, D_ins, H_ins, W_ins)
|
| 163 |
+
x_inserted[:, :, ::sd, ::sh, ::sw] = x
|
| 164 |
+
else:
|
| 165 |
+
x_inserted = x
|
| 166 |
+
|
| 167 |
+
# Step 2: Pad with (kernel_size - padding - 1) zeros on each side.
|
| 168 |
+
# This converts ConvTranspose3d's "padding" (which removes output elements)
|
| 169 |
+
# into the equivalent zero-padding for a regular convolution.
|
| 170 |
+
pad_d = kd - pd - 1
|
| 171 |
+
pad_h = kh - ph - 1
|
| 172 |
+
pad_w = kw - pw - 1
|
| 173 |
+
# F.pad argument order: (W_left, W_right, H_left, H_right, D_left, D_right)
|
| 174 |
+
x_padded = F.pad(x_inserted, (pad_w, pad_w, pad_h, pad_h, pad_d, pad_d))
|
| 175 |
+
|
| 176 |
+
# Step 3: Transform weight from ConvTranspose3d layout to Conv3d layout.
|
| 177 |
+
# ConvTranspose3d weight: [C_in, C_out, kD, kH, kW]
|
| 178 |
+
# Equivalent Conv3d weight: [C_out, C_in, kD, kH, kW] with spatial dims flipped
|
| 179 |
+
w_conv = self.weight.flip(2, 3, 4).transpose(0, 1)
|
| 180 |
+
|
| 181 |
+
# Step 4: Standard Conv3d (stride=1, padding=0)
|
| 182 |
+
return F.conv3d(x_padded, w_conv, self.bias, stride=1, padding=0)
|
| 183 |
+
|
| 184 |
+
def extra_repr(self):
|
| 185 |
+
return (f'{self.in_channels}, {self.out_channels}, '
|
| 186 |
+
f'kernel_size={self.kernel_size}, stride={self.stride}, '
|
| 187 |
+
f'padding={self.padding}, bias={self.bias is not None}')
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# =============================================================================
|
| 191 |
+
# Approach 2: Custom autograd — real forward, safe backward
|
| 192 |
+
# =============================================================================
|
| 193 |
+
|
| 194 |
+
class _SafeConvTranspose3dFunc(Function):
|
| 195 |
+
"""Custom autograd function that uses F.conv_transpose3d in forward
|
| 196 |
+
(bit-for-bit identical) but computes gradients using Conv3d-based ops
|
| 197 |
+
in backward (avoiding the leaky oneDNN ConvTranspose3d backward kernel).
|
| 198 |
+
|
| 199 |
+
Gradient derivation:
|
| 200 |
+
For y = conv_transpose3d(x, w, stride=S, padding=P):
|
| 201 |
+
|
| 202 |
+
grad_x = conv3d(grad_y, w, stride=S, padding=P)
|
| 203 |
+
Confirmed bit-for-bit identical to PyTorch's own backward.
|
| 204 |
+
|
| 205 |
+
grad_w = conv3d(pad(stride_insert(x)).T, grad_y.T).flip(spatial)
|
| 206 |
+
where stride_insert inserts (S-1) zeros between elements,
|
| 207 |
+
pad adds (K-P-1) zeros on each side, and .T swaps batch/channel.
|
| 208 |
+
The spatial flip accounts for the flip in the forward decomposition.
|
| 209 |
+
|
| 210 |
+
grad_bias = grad_y.sum(dim=(0, 2, 3, 4))
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def forward(ctx, input, weight, bias, stride, padding, output_padding, groups, dilation):
|
| 215 |
+
# Use the real conv_transpose3d for bit-for-bit identical forward
|
| 216 |
+
output = F.conv_transpose3d(
|
| 217 |
+
input, weight, bias,
|
| 218 |
+
stride=stride, padding=padding,
|
| 219 |
+
output_padding=output_padding, groups=groups, dilation=dilation
|
| 220 |
+
)
|
| 221 |
+
ctx.save_for_backward(input, weight, bias)
|
| 222 |
+
ctx.stride = stride
|
| 223 |
+
ctx.padding = padding
|
| 224 |
+
ctx.output_padding = output_padding
|
| 225 |
+
ctx.groups = groups
|
| 226 |
+
ctx.dilation = dilation
|
| 227 |
+
return output
|
| 228 |
+
|
| 229 |
+
@staticmethod
|
| 230 |
+
def backward(ctx, grad_output):
|
| 231 |
+
input, weight, bias = ctx.saved_tensors
|
| 232 |
+
stride = ctx.stride
|
| 233 |
+
padding = ctx.padding
|
| 234 |
+
groups = ctx.groups
|
| 235 |
+
dilation = ctx.dilation
|
| 236 |
+
|
| 237 |
+
grad_input = grad_weight = grad_bias = None
|
| 238 |
+
|
| 239 |
+
if ctx.needs_input_grad[0]:
|
| 240 |
+
# grad_input of ConvTranspose3d = Conv3d(grad_output, weight)
|
| 241 |
+
# This is exact: ConvTranspose3d IS the adjoint of Conv3d.
|
| 242 |
+
grad_input = F.conv3d(
|
| 243 |
+
grad_output, weight,
|
| 244 |
+
bias=None, stride=stride, padding=padding,
|
| 245 |
+
dilation=dilation, groups=groups
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if ctx.needs_input_grad[1]:
|
| 249 |
+
# grad_weight via the decomposed view.
|
| 250 |
+
# Forward decomposition: y = conv3d(x_padded, w.flip(spatial).T(0,1))
|
| 251 |
+
# The backward of this conv3d w.r.t. its weight can be expressed as:
|
| 252 |
+
# grad_w_conv = conv3d(x_padded.T(0,1), grad_y.T(0,1))
|
| 253 |
+
# where the batch-channel transpose turns the sum over batch
|
| 254 |
+
# into a channel dimension convolution.
|
| 255 |
+
#
|
| 256 |
+
# Then: grad_w = grad_w_conv.flip(spatial)
|
| 257 |
+
# because w_conv = w.flip(spatial).T(0,1), and the chain rule
|
| 258 |
+
# through the spatial flip gives an extra flip on the gradient.
|
| 259 |
+
|
| 260 |
+
B, C_in = input.shape[:2]
|
| 261 |
+
spatial = input.shape[2:]
|
| 262 |
+
|
| 263 |
+
# Stride-insert the input
|
| 264 |
+
if any(s > 1 for s in stride):
|
| 265 |
+
new_spatial = tuple(s * (d - 1) + 1 for s, d in zip(stride, spatial))
|
| 266 |
+
input_inserted = input.new_zeros(B, C_in, *new_spatial)
|
| 267 |
+
slices = (slice(None), slice(None)) + tuple(
|
| 268 |
+
slice(None, None, s) for s in stride
|
| 269 |
+
)
|
| 270 |
+
input_inserted[slices] = input
|
| 271 |
+
else:
|
| 272 |
+
input_inserted = input
|
| 273 |
+
|
| 274 |
+
# Pad: (K - P - 1) on each side per spatial dim
|
| 275 |
+
kernel_size = weight.shape[2:]
|
| 276 |
+
pad_sizes = []
|
| 277 |
+
for k, p in zip(reversed(kernel_size), reversed(padding)):
|
| 278 |
+
pad_val = k - p - 1
|
| 279 |
+
pad_sizes.extend([pad_val, pad_val])
|
| 280 |
+
x_padded = F.pad(input_inserted, pad_sizes)
|
| 281 |
+
|
| 282 |
+
# Compute grad_w_conv via conv3d with batch-channel transposition
|
| 283 |
+
x_padded_t = x_padded.transpose(0, 1) # [C_in, B, ...]
|
| 284 |
+
grad_output_t = grad_output.transpose(0, 1) # [C_out, B, ...]
|
| 285 |
+
|
| 286 |
+
# conv3d([C_in, B, D_pad...], [C_out, B, D_out...]) -> [C_in, C_out, K...]
|
| 287 |
+
grad_w_conv = F.conv3d(x_padded_t, grad_output_t)
|
| 288 |
+
|
| 289 |
+
# Undo the spatial flip from the forward decomposition
|
| 290 |
+
grad_weight = grad_w_conv.flip(2, 3, 4)
|
| 291 |
+
|
| 292 |
+
if bias is not None and ctx.needs_input_grad[2]:
|
| 293 |
+
grad_bias = grad_output.sum(dim=(0,) + tuple(range(2, grad_output.ndim)))
|
| 294 |
+
|
| 295 |
+
return grad_input, grad_weight, grad_bias, None, None, None, None, None
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class SafeConvTranspose3d_v2(nn.Module):
|
| 299 |
+
"""Drop-in replacement for nn.ConvTranspose3d using custom autograd.
|
| 300 |
+
|
| 301 |
+
Forward pass: Uses the real F.conv_transpose3d (bit-for-bit identical output).
|
| 302 |
+
Backward pass: Computes gradients using F.conv3d (avoids leaky oneDNN kernel).
|
| 303 |
+
|
| 304 |
+
Weight shape is identical to nn.ConvTranspose3d: [in_channels, out_channels, *kernel_size]
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
| 308 |
+
padding=0, output_padding=0, groups=1, bias=True,
|
| 309 |
+
dilation=1, padding_mode='zeros'):
|
| 310 |
+
super().__init__()
|
| 311 |
+
|
| 312 |
+
if groups != 1:
|
| 313 |
+
raise NotImplementedError("SafeConvTranspose3d_v2 only supports groups=1")
|
| 314 |
+
if output_padding != 0:
|
| 315 |
+
raise NotImplementedError("SafeConvTranspose3d_v2 does not support output_padding")
|
| 316 |
+
|
| 317 |
+
# Normalize to tuples
|
| 318 |
+
if isinstance(kernel_size, int):
|
| 319 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 320 |
+
if isinstance(stride, int):
|
| 321 |
+
stride = (stride, stride, stride)
|
| 322 |
+
if isinstance(padding, int):
|
| 323 |
+
padding = (padding, padding, padding)
|
| 324 |
+
if isinstance(dilation, int):
|
| 325 |
+
dilation = (dilation, dilation, dilation)
|
| 326 |
+
|
| 327 |
+
self.in_channels = in_channels
|
| 328 |
+
self.out_channels = out_channels
|
| 329 |
+
self.kernel_size = kernel_size
|
| 330 |
+
self.stride = stride
|
| 331 |
+
self.padding = padding
|
| 332 |
+
self.output_padding = (0, 0, 0) if isinstance(output_padding, int) else output_padding
|
| 333 |
+
self.groups = groups
|
| 334 |
+
self.dilation = dilation
|
| 335 |
+
|
| 336 |
+
# Weight shape matches ConvTranspose3d: [in_channels, out_channels, *kernel_size]
|
| 337 |
+
self.weight = nn.Parameter(
|
| 338 |
+
torch.empty(in_channels, out_channels, *kernel_size)
|
| 339 |
+
)
|
| 340 |
+
if bias:
|
| 341 |
+
self.bias = nn.Parameter(torch.empty(out_channels))
|
| 342 |
+
else:
|
| 343 |
+
self.register_parameter('bias', None)
|
| 344 |
+
|
| 345 |
+
# Initialize weights same as nn.ConvTranspose3d
|
| 346 |
+
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
| 347 |
+
if self.bias is not None:
|
| 348 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
| 349 |
+
if fan_in != 0:
|
| 350 |
+
bound = 1 / fan_in**0.5
|
| 351 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 352 |
+
|
| 353 |
+
def forward(self, x):
|
| 354 |
+
return _SafeConvTranspose3dFunc.apply(
|
| 355 |
+
x, self.weight, self.bias,
|
| 356 |
+
self.stride, self.padding, self.output_padding,
|
| 357 |
+
self.groups, self.dilation
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def extra_repr(self):
|
| 361 |
+
return (f'{self.in_channels}, {self.out_channels}, '
|
| 362 |
+
f'kernel_size={self.kernel_size}, stride={self.stride}, '
|
| 363 |
+
f'padding={self.padding}, bias={self.bias is not None}')
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# =============================================================================
|
| 367 |
+
# Utility: in-place replacement of ConvTranspose3d in existing models
|
| 368 |
+
# =============================================================================
|
| 369 |
+
|
| 370 |
+
def replace_conv_transpose3d(module, target_cls=SafeConvTranspose3d):
|
| 371 |
+
"""Recursively replace all nn.ConvTranspose3d in a module with the given
|
| 372 |
+
replacement class, copying weights and biases.
|
| 373 |
+
|
| 374 |
+
Usage:
|
| 375 |
+
model = MyModel()
|
| 376 |
+
replace_conv_transpose3d(model) # in-place modification
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
module: The nn.Module to modify in-place.
|
| 380 |
+
target_cls: Replacement class (default: SafeConvTranspose3d).
|
| 381 |
+
"""
|
| 382 |
+
for name, child in module.named_children():
|
| 383 |
+
if isinstance(child, nn.ConvTranspose3d):
|
| 384 |
+
ct = child
|
| 385 |
+
assert ct.groups == 1, f"groups={ct.groups} not supported"
|
| 386 |
+
assert ct.output_padding == (0,) * len(ct.output_padding), \
|
| 387 |
+
f"output_padding={ct.output_padding} not supported"
|
| 388 |
+
|
| 389 |
+
replacement = target_cls(
|
| 390 |
+
ct.in_channels, ct.out_channels, ct.kernel_size,
|
| 391 |
+
stride=ct.stride, padding=ct.padding,
|
| 392 |
+
bias=ct.bias is not None
|
| 393 |
+
)
|
| 394 |
+
# Copy weights — same tensor shape, no conversion needed
|
| 395 |
+
replacement.weight.data.copy_(ct.weight.data)
|
| 396 |
+
if ct.bias is not None:
|
| 397 |
+
replacement.bias.data.copy_(ct.bias.data)
|
| 398 |
+
|
| 399 |
+
setattr(module, name, replacement)
|
| 400 |
+
else:
|
| 401 |
+
replace_conv_transpose3d(child, target_cls)
|
Models/all_om_net/000110_all_om_net.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9c2c90820aba95bfd89d870820574461963450ca50617ee44fb5af2b17385b3
|
| 3 |
+
size 3017380171
|
OM_reg.py
CHANGED
|
@@ -72,7 +72,8 @@ min_crop_ratio = 0.9
|
|
| 72 |
# label_keys = ['heart']
|
| 73 |
label_keys = ['brain']
|
| 74 |
# label_keys = ['pancreas']
|
| 75 |
-
database = ['MSD']
|
|
|
|
| 76 |
|
| 77 |
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
|
| 78 |
Infer_Loader = DataLoader(
|
|
@@ -112,6 +113,7 @@ Deformddpm = DeformDDPM(
|
|
| 112 |
padding_mode = hyp_parameters["padding_mode"],
|
| 113 |
v_scale = hyp_parameters["v_scale"],
|
| 114 |
resample_mode = hyp_parameters["resample_mode"],
|
|
|
|
| 115 |
)
|
| 116 |
Deformddpm.to(hyp_parameters["device"])
|
| 117 |
|
|
@@ -125,7 +127,7 @@ ddf_stn.to(hyp_parameters["device"])
|
|
| 125 |
|
| 126 |
print("Loading model from:", model_save_path)
|
| 127 |
# Deformddpm.load_state_dict(torch.load(model_save_path))
|
| 128 |
-
checkpoint = torch.load(model_save_path)
|
| 129 |
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 130 |
Deformddpm.eval()
|
| 131 |
|
|
@@ -162,12 +164,8 @@ for e, d in tqdm(enumerate(Infer_Loader)):
|
|
| 162 |
# print(pid, image_original.shape, mask_original.max())
|
| 163 |
|
| 164 |
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
nifti_mask = nib.Nifti1Image(mask_original[0,:,:,:], np.eye(4))
|
| 168 |
-
elif hyp_parameters["ndims"] == 3:
|
| 169 |
-
nifti_img = nib.Nifti1Image(image_original[0,0,:,:,:], np.eye(4))
|
| 170 |
-
nifti_mask = nib.Nifti1Image(mask_original[0,0,:,:,:], np.eye(4))
|
| 171 |
|
| 172 |
# Saving original (undeformed image)
|
| 173 |
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
|
|
@@ -198,16 +196,10 @@ for e, d in tqdm(enumerate(Infer_Loader)):
|
|
| 198 |
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 199 |
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 200 |
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, :, :, :], np.eye(4))
|
| 206 |
-
elif hyp_parameters["ndims"] == 3:
|
| 207 |
-
nifti_img_aug = nib.Nifti1Image(denoise_imgs[0,0,:,:,:], np.eye(4))
|
| 208 |
-
nifti_mask_aug = nib.Nifti1Image(denoise_msks[0,0,:,:,:], np.eye(4))
|
| 209 |
-
nifti_img = nib.Nifti1Image(noisy_imgs_np[0,0,:,:,:], np.eye(4))
|
| 210 |
-
nifti_mask = nib.Nifti1Image(noisy_msks_np[0, 0, :, :], np.eye(4))
|
| 211 |
|
| 212 |
nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| 213 |
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
|
|
|
| 72 |
# label_keys = ['heart']
|
| 73 |
label_keys = ['brain']
|
| 74 |
# label_keys = ['pancreas']
|
| 75 |
+
# database = ['MSD']
|
| 76 |
+
database = ['Brats2019']
|
| 77 |
|
| 78 |
dataset = OminiDataset_inference_w_all(transform=None,min_crop_ratio=min_crop_ratio,label_key = label_keys, database=database)
|
| 79 |
Infer_Loader = DataLoader(
|
|
|
|
| 113 |
padding_mode = hyp_parameters["padding_mode"],
|
| 114 |
v_scale = hyp_parameters["v_scale"],
|
| 115 |
resample_mode = hyp_parameters["resample_mode"],
|
| 116 |
+
inf_mode = True, # set to True for inference, which will use fixed slice num and slice idx for better evaluation
|
| 117 |
)
|
| 118 |
Deformddpm.to(hyp_parameters["device"])
|
| 119 |
|
|
|
|
| 127 |
|
| 128 |
print("Loading model from:", model_save_path)
|
| 129 |
# Deformddpm.load_state_dict(torch.load(model_save_path))
|
| 130 |
+
checkpoint = torch.load(model_save_path, map_location='cpu')
|
| 131 |
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 132 |
Deformddpm.eval()
|
| 133 |
|
|
|
|
| 164 |
# print(pid, image_original.shape, mask_original.max())
|
| 165 |
|
| 166 |
|
| 167 |
+
nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
|
| 168 |
+
nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# Saving original (undeformed image)
|
| 171 |
# CMR: format: Patient0001_Slice0001_ORG_NA.nii.gz
|
|
|
|
| 196 |
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 197 |
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 198 |
|
| 199 |
+
nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
|
| 200 |
+
nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
|
| 201 |
+
nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
|
| 202 |
+
nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
nib.save(nifti_img_aug, os.path.join(hyp_parameters['reg_img_savepath'],utils.get_barcode([pid,e,im,noise_step])+'.nii.gz'))
|
| 205 |
nib.save(nifti_mask_aug, os.path.join(hyp_parameters['reg_msk_savepath'],utils.get_barcode([pid,e,im,noise_step])+'_GT.nii.gz'))
|
OM_reg_flexres.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torchvision.utils import save_image
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torch.optim import Adam
|
| 7 |
+
from torchvision.utils import make_grid
|
| 8 |
+
from Diffusion.diffuser import DeformDDPM
|
| 9 |
+
from Diffusion.networks import get_net, STN
|
| 10 |
+
from torchvision.transforms import Lambda
|
| 11 |
+
import random
|
| 12 |
+
import os
|
| 13 |
+
import utils
|
| 14 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 15 |
+
from Dataloader.dataLoader import *
|
| 16 |
+
|
| 17 |
+
from torchvision.utils import save_image
|
| 18 |
+
from einops import rearrange, reduce, repeat
|
| 19 |
+
import numpy as np
|
| 20 |
+
import nibabel as nib
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import yaml
|
| 23 |
+
import argparse
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
import SimpleITK as sitk
|
| 26 |
+
from skimage.transform import resize
|
| 27 |
+
|
| 28 |
+
EPS = 10e-8
|
| 29 |
+
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--config",
|
| 34 |
+
"-C",
|
| 35 |
+
help="Path for the config file",
|
| 36 |
+
type=str,
|
| 37 |
+
default="Config/config_om.yaml",
|
| 38 |
+
required=False,
|
| 39 |
+
)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
#=======================================================================================================================
|
| 42 |
+
|
| 43 |
+
# Load the YAML file into a dictionary
|
| 44 |
+
with open(args.config, 'r') as file:
|
| 45 |
+
hyp_parameters = yaml.safe_load(file)
|
| 46 |
+
print(hyp_parameters)
|
| 47 |
+
|
| 48 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 49 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 50 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 51 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 52 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 53 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 54 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 55 |
+
|
| 56 |
+
hyp_parameters['batchsize'] = 1
|
| 57 |
+
model_img_sz = hyp_parameters['img_size'] # e.g. 128
|
| 58 |
+
|
| 59 |
+
# =======================================================================================================================
|
| 60 |
+
# Dataset is used only for its filtering logic (to get the right set of keys + metadata).
|
| 61 |
+
# We bypass the DataLoader and load volumes directly to ensure deterministic center-padding
|
| 62 |
+
# that is identical between the 128^3 model input and the full-res volume.
|
| 63 |
+
label_keys = ['brain']
|
| 64 |
+
database = ['Brats2019']
|
| 65 |
+
|
| 66 |
+
dataset = OminiDataset_inference_w_all(
|
| 67 |
+
transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database)
|
| 68 |
+
# =======================================================================================================================
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
epoch=f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 72 |
+
model_save_path = f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/'
|
| 73 |
+
model_save_path = os.path.join(model_save_path, str(epoch)+'.pth')
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
Net = get_net(hyp_parameters["net_name"])
|
| 77 |
+
|
| 78 |
+
Deformddpm = DeformDDPM(
|
| 79 |
+
network=Net(n_steps = hyp_parameters["timesteps"],
|
| 80 |
+
ndims = hyp_parameters["ndims"],
|
| 81 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 82 |
+
res = model_img_sz
|
| 83 |
+
),
|
| 84 |
+
n_steps = hyp_parameters["timesteps"],
|
| 85 |
+
image_chw = [hyp_parameters["num_input_chn"]] + [model_img_sz]*hyp_parameters["ndims"],
|
| 86 |
+
device = hyp_parameters["device"],
|
| 87 |
+
batch_size = hyp_parameters["batchsize"],
|
| 88 |
+
img_pad_mode = hyp_parameters["img_pad_mode"],
|
| 89 |
+
ddf_pad_mode = hyp_parameters["ddf_pad_mode"],
|
| 90 |
+
padding_mode = hyp_parameters["padding_mode"],
|
| 91 |
+
v_scale = hyp_parameters["v_scale"],
|
| 92 |
+
resample_mode = hyp_parameters["resample_mode"],
|
| 93 |
+
inf_mode = True,
|
| 94 |
+
)
|
| 95 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 96 |
+
|
| 97 |
+
ddf_stn = STN(
|
| 98 |
+
img_sz = model_img_sz,
|
| 99 |
+
ndims = hyp_parameters["ndims"],
|
| 100 |
+
padding_mode = hyp_parameters['padding_mode'],
|
| 101 |
+
device = hyp_parameters["device"],
|
| 102 |
+
)
|
| 103 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 104 |
+
|
| 105 |
+
print("Loading model from:", model_save_path)
|
| 106 |
+
checkpoint = torch.load(model_save_path, map_location='cpu')
|
| 107 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 108 |
+
Deformddpm.eval()
|
| 109 |
+
|
| 110 |
+
# Full-res output directories (append _fullres to the standard paths)
|
| 111 |
+
reg_img_savepath_fullres = hyp_parameters['reg_img_savepath'].rstrip('/') + '_fullres/'
|
| 112 |
+
reg_msk_savepath_fullres = hyp_parameters['reg_msk_savepath'].rstrip('/') + '_fullres/'
|
| 113 |
+
reg_ddf_savepath_fullres = hyp_parameters['reg_ddf_savepath'].rstrip('/') + '_fullres/'
|
| 114 |
+
|
| 115 |
+
os.makedirs(hyp_parameters['reg_img_savepath'], exist_ok=True)
|
| 116 |
+
os.makedirs(hyp_parameters['reg_msk_savepath'], exist_ok=True)
|
| 117 |
+
os.makedirs(hyp_parameters['reg_ddf_savepath'], exist_ok=True)
|
| 118 |
+
os.makedirs(reg_img_savepath_fullres, exist_ok=True)
|
| 119 |
+
os.makedirs(reg_msk_savepath_fullres, exist_ok=True)
|
| 120 |
+
os.makedirs(reg_ddf_savepath_fullres, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ========== Helper functions ==========
|
| 124 |
+
|
| 125 |
+
def center_pad_to_cube(volume):
|
| 126 |
+
"""Pad volume to a cube using the max dimension, with symmetric (center) padding."""
|
| 127 |
+
max_dim = max(volume.shape[:3])
|
| 128 |
+
pad_width = []
|
| 129 |
+
for s in volume.shape[:3]:
|
| 130 |
+
total_pad = max_dim - s
|
| 131 |
+
pad_before = total_pad // 2
|
| 132 |
+
pad_after = total_pad - pad_before
|
| 133 |
+
pad_width.append((pad_before, pad_after))
|
| 134 |
+
# Handle extra dims (e.g., multi-channel labels)
|
| 135 |
+
for _ in range(volume.ndim - 3):
|
| 136 |
+
pad_width.append((0, 0))
|
| 137 |
+
return np.pad(volume, pad_width, mode='constant', constant_values=0)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def load_fullres_volume(key, ds):
|
| 141 |
+
"""Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
|
| 142 |
+
volume = sitk.ReadImage(key)
|
| 143 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 144 |
+
volume = reverse_axis_order(volume)
|
| 145 |
+
if volume.ndim == 4:
|
| 146 |
+
channel_ids = ds.get_channel_ids(key)
|
| 147 |
+
channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
|
| 148 |
+
volume = volume[:, :, :, channel_id]
|
| 149 |
+
# CT clamping
|
| 150 |
+
if ds.clamp_range is not None:
|
| 151 |
+
modality = ds.ALLdata_filtered[key].get("Modality", None)
|
| 152 |
+
if modality == "CT":
|
| 153 |
+
volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
|
| 154 |
+
volume = ds.normalize(volume)
|
| 155 |
+
volume = center_pad_to_cube(volume)
|
| 156 |
+
return volume # shape: [D, D, D] (cubic)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_fullres_label(key, ds, label_key):
|
| 160 |
+
"""Load original-resolution label: axis reorder, center-pad to cube (no resize)."""
|
| 161 |
+
label_path_dict = ds.ALLdata_filtered[key].get('Label_path', {})
|
| 162 |
+
task_labels = label_path_dict.get('segmentation', {})
|
| 163 |
+
if label_key not in task_labels:
|
| 164 |
+
return None
|
| 165 |
+
label = sitk.ReadImage(task_labels[label_key])
|
| 166 |
+
label = sitk.GetArrayFromImage(label)
|
| 167 |
+
label = reverse_axis_order(label)
|
| 168 |
+
if label.ndim > 3:
|
| 169 |
+
channel_ids = ds.get_channel_ids(key)
|
| 170 |
+
if len(channel_ids) != 0:
|
| 171 |
+
label = label[..., channel_ids]
|
| 172 |
+
label = center_pad_to_cube(label)
|
| 173 |
+
return label
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def apply_ddf(volume_tensor, ddf, padding_mode='border', resample_mode='bilinear'):
|
| 177 |
+
"""Apply DDF to volume tensor at any resolution.
|
| 178 |
+
|
| 179 |
+
The DDF stores fractional displacements (value * max_sz = voxel displacement).
|
| 180 |
+
When the DDF is spatially upscaled via trilinear interpolation from model resolution
|
| 181 |
+
to full resolution, the fractional values remain correct — we use the new spatial
|
| 182 |
+
size as max_sz, which correctly scales the voxel displacement proportionally.
|
| 183 |
+
"""
|
| 184 |
+
device = ddf.device
|
| 185 |
+
ndims = 3
|
| 186 |
+
img_sz = list(volume_tensor.shape[2:])
|
| 187 |
+
max_sz = torch.reshape(
|
| 188 |
+
torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| 189 |
+
[1, ndims] + [1] * ndims)
|
| 190 |
+
ref_grid = torch.reshape(
|
| 191 |
+
torch.stack(torch.meshgrid(
|
| 192 |
+
[torch.arange(s, device=device) for s in img_sz], indexing='ij'), 0),
|
| 193 |
+
[1, ndims] + img_sz)
|
| 194 |
+
img_shape = torch.reshape(
|
| 195 |
+
torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=device),
|
| 196 |
+
[1] + [1] * ndims + [ndims])
|
| 197 |
+
grid = torch.flip(
|
| 198 |
+
(ddf * max_sz + ref_grid).permute(
|
| 199 |
+
[0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
|
| 200 |
+
dims=[-1])
|
| 201 |
+
return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
|
| 202 |
+
padding_mode=padding_mode, align_corners=True)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ========== Main inference loop ==========
|
| 206 |
+
|
| 207 |
+
keys = list(dataset.ALLdata_filtered.keys())
|
| 208 |
+
print("total num of images:", len(keys))
|
| 209 |
+
|
| 210 |
+
for e, key in enumerate(tqdm(keys)):
|
| 211 |
+
pid = e
|
| 212 |
+
print(f'Processing patient {pid}, image {e}, key: {key}')
|
| 213 |
+
|
| 214 |
+
# --- Load full-resolution volume (center-padded to cube) ---
|
| 215 |
+
fullres_vol = load_fullres_volume(key, dataset)
|
| 216 |
+
orig_sz = list(fullres_vol.shape) # e.g. [240, 240, 240]
|
| 217 |
+
print(f" Full-res padded shape: {orig_sz}")
|
| 218 |
+
|
| 219 |
+
# --- Resize to model resolution for inference ---
|
| 220 |
+
vol_model = resize(fullres_vol, [model_img_sz] * 3,
|
| 221 |
+
anti_aliasing=True, preserve_range=True)
|
| 222 |
+
img = torch.tensor(vol_model[None, None, :, :, :],
|
| 223 |
+
dtype=torch.float32, device=hyp_parameters["device"])
|
| 224 |
+
|
| 225 |
+
# --- Load full-res labels and resize to model resolution ---
|
| 226 |
+
fullres_labels = {}
|
| 227 |
+
for lk in label_keys:
|
| 228 |
+
lab = load_fullres_label(key, dataset, lk)
|
| 229 |
+
if lab is not None:
|
| 230 |
+
fullres_labels[lk] = lab
|
| 231 |
+
|
| 232 |
+
# Build mask at model resolution (128^3)
|
| 233 |
+
label_arrays_model = []
|
| 234 |
+
label_arrays_fullres = []
|
| 235 |
+
for lk in label_keys:
|
| 236 |
+
if lk in fullres_labels:
|
| 237 |
+
lab = fullres_labels[lk]
|
| 238 |
+
lab_model = resize(lab, [model_img_sz] * 3,
|
| 239 |
+
anti_aliasing=False, preserve_range=True, order=0)
|
| 240 |
+
if lab_model.ndim == 3:
|
| 241 |
+
lab_model = lab_model[None, :, :, :]
|
| 242 |
+
elif lab_model.ndim > 3:
|
| 243 |
+
lab_model = np.transpose(lab_model, (3, 0, 1, 2))
|
| 244 |
+
label_arrays_model.append(lab_model)
|
| 245 |
+
|
| 246 |
+
if lab.ndim == 3:
|
| 247 |
+
lab = lab[None, :, :, :]
|
| 248 |
+
elif lab.ndim > 3:
|
| 249 |
+
lab = np.transpose(lab, (3, 0, 1, 2))
|
| 250 |
+
label_arrays_fullres.append(lab)
|
| 251 |
+
else:
|
| 252 |
+
label_arrays_model.append(np.full([1] + [model_img_sz] * 3, -1))
|
| 253 |
+
label_arrays_fullres.append(np.full([1] + orig_sz, -1))
|
| 254 |
+
|
| 255 |
+
if len(label_arrays_model) > 0:
|
| 256 |
+
mask_model_np = np.concatenate(label_arrays_model, axis=0)
|
| 257 |
+
mask = torch.tensor(mask_model_np[None], dtype=torch.float32,
|
| 258 |
+
device=hyp_parameters["device"])
|
| 259 |
+
fullres_msk_np = np.concatenate(label_arrays_fullres, axis=0)
|
| 260 |
+
fullres_msk_tensor = torch.tensor(fullres_msk_np[None], dtype=torch.float32,
|
| 261 |
+
device=hyp_parameters["device"])
|
| 262 |
+
else:
|
| 263 |
+
mask = None
|
| 264 |
+
fullres_msk_np = None
|
| 265 |
+
fullres_msk_tensor = None
|
| 266 |
+
|
| 267 |
+
# Build full-res image tensor
|
| 268 |
+
fullres_img_tensor = torch.tensor(fullres_vol[None, None, :, :, :],
|
| 269 |
+
dtype=torch.float32,
|
| 270 |
+
device=hyp_parameters["device"])
|
| 271 |
+
|
| 272 |
+
# --- Save target conditioning image (first subject) ---
|
| 273 |
+
if e <= 0:
|
| 274 |
+
target_img = img.clone().detach()
|
| 275 |
+
|
| 276 |
+
# --- Save original images at 128^3 ---
|
| 277 |
+
image_original = img.cpu().numpy()
|
| 278 |
+
nib.save(utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
|
| 279 |
+
os.path.join(hyp_parameters['reg_img_savepath'],
|
| 280 |
+
utils.get_barcode([pid, e]) + '.nii.gz'))
|
| 281 |
+
if mask is not None:
|
| 282 |
+
mask_original = mask.cpu().numpy()
|
| 283 |
+
nib.save(utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
|
| 284 |
+
os.path.join(hyp_parameters['reg_msk_savepath'],
|
| 285 |
+
utils.get_barcode([pid, e]) + '_GT.nii.gz'))
|
| 286 |
+
|
| 287 |
+
# --- Save original at full-res ---
|
| 288 |
+
# fullres_vol is [D,D,D], wrap as [1,1,D,D,D] for converet_to_nibabel
|
| 289 |
+
nib.save(utils.converet_to_nibabel(fullres_vol[None, None], ndims=hyp_parameters["ndims"]),
|
| 290 |
+
os.path.join(reg_img_savepath_fullres,
|
| 291 |
+
utils.get_barcode([pid, e]) + '.nii.gz'))
|
| 292 |
+
if fullres_msk_np is not None:
|
| 293 |
+
# fullres_msk_np is [C,D,D,D], wrap as [1,C,D,D,D]
|
| 294 |
+
nib.save(utils.converet_to_nibabel(fullres_msk_np[None], ndims=hyp_parameters["ndims"]),
|
| 295 |
+
os.path.join(reg_msk_savepath_fullres,
|
| 296 |
+
utils.get_barcode([pid, e]) + '_GT.nii.gz'))
|
| 297 |
+
|
| 298 |
+
# --- Diffusion recovery at model resolution ---
|
| 299 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
for im in range(1):
|
| 302 |
+
print(f' Generating -> Subject-{pid}, Scan-{e} ({im}/{hyp_parameters["aug_coe"]})', end='\r')
|
| 303 |
+
|
| 304 |
+
[ddf_comp, ddf_rand], [img_rec, img_diff, img_save], [msk_rec, msk_diff, msk_save] = \
|
| 305 |
+
Deformddpm.diff_recover(
|
| 306 |
+
img_org=img,
|
| 307 |
+
cond_imgs=target_img.clone().detach(),
|
| 308 |
+
msk_org=mask,
|
| 309 |
+
T=[None, hyp_parameters["timesteps"]],
|
| 310 |
+
v_scale=hyp_parameters["v_scale"],
|
| 311 |
+
t_save=None,
|
| 312 |
+
proc_type=hyp_parameters["condition_type"])
|
| 313 |
+
|
| 314 |
+
# --- Save 128^3 results (same as OM_reg.py) ---
|
| 315 |
+
denoise_imgs = img_rec.cpu().numpy()
|
| 316 |
+
noisy_imgs_np = img_diff.cpu().numpy()
|
| 317 |
+
|
| 318 |
+
nib.save(utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
|
| 319 |
+
os.path.join(hyp_parameters['reg_img_savepath'],
|
| 320 |
+
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
| 321 |
+
nib.save(utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"]),
|
| 322 |
+
os.path.join(hyp_parameters['reg_img_savepath'],
|
| 323 |
+
utils.get_barcode([pid, e, im, noise_step],
|
| 324 |
+
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
|
| 325 |
+
|
| 326 |
+
if msk_rec is not None:
|
| 327 |
+
denoise_msks = msk_rec.cpu().numpy()
|
| 328 |
+
noisy_msks_np = msk_diff.cpu().numpy()
|
| 329 |
+
nib.save(utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
|
| 330 |
+
os.path.join(hyp_parameters['reg_msk_savepath'],
|
| 331 |
+
utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
|
| 332 |
+
nib.save(utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"]),
|
| 333 |
+
os.path.join(hyp_parameters['reg_msk_savepath'],
|
| 334 |
+
utils.get_barcode([pid, e, im, noise_step],
|
| 335 |
+
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
|
| 336 |
+
|
| 337 |
+
# --- Upscale DDFs to original resolution ---
|
| 338 |
+
ddf_fullres = F.interpolate(ddf_comp, size=orig_sz,
|
| 339 |
+
mode='trilinear', align_corners=False)
|
| 340 |
+
ddf_rand_fullres = F.interpolate(ddf_rand, size=orig_sz,
|
| 341 |
+
mode='trilinear', align_corners=False)
|
| 342 |
+
|
| 343 |
+
# --- Apply DDFs at original resolution ---
|
| 344 |
+
img_rec_fullres = apply_ddf(fullres_img_tensor, ddf_fullres,
|
| 345 |
+
padding_mode='border')
|
| 346 |
+
img_noisy_fullres = apply_ddf(fullres_img_tensor, ddf_rand_fullres,
|
| 347 |
+
padding_mode='border')
|
| 348 |
+
|
| 349 |
+
if fullres_msk_tensor is not None:
|
| 350 |
+
msk_rec_fullres = apply_ddf(fullres_msk_tensor, ddf_fullres,
|
| 351 |
+
padding_mode='zeros', resample_mode='nearest')
|
| 352 |
+
msk_noisy_fullres = apply_ddf(fullres_msk_tensor, ddf_rand_fullres,
|
| 353 |
+
padding_mode='zeros', resample_mode='nearest')
|
| 354 |
+
|
| 355 |
+
# --- Save full-res results ---
|
| 356 |
+
nib.save(utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| 357 |
+
os.path.join(reg_img_savepath_fullres,
|
| 358 |
+
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
| 359 |
+
nib.save(utils.converet_to_nibabel(img_noisy_fullres, ndims=hyp_parameters["ndims"]),
|
| 360 |
+
os.path.join(reg_img_savepath_fullres,
|
| 361 |
+
utils.get_barcode([pid, e, im, noise_step],
|
| 362 |
+
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '.nii.gz'))
|
| 363 |
+
|
| 364 |
+
if fullres_msk_tensor is not None:
|
| 365 |
+
nib.save(utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| 366 |
+
os.path.join(reg_msk_savepath_fullres,
|
| 367 |
+
utils.get_barcode([pid, e, im, noise_step]) + '_GT.nii.gz'))
|
| 368 |
+
nib.save(utils.converet_to_nibabel(msk_noisy_fullres, ndims=hyp_parameters["ndims"]),
|
| 369 |
+
os.path.join(reg_msk_savepath_fullres,
|
| 370 |
+
utils.get_barcode([pid, e, im, noise_step],
|
| 371 |
+
header=['Patient', 'Slice', 'NoiseImg', 'NoiseStep']) + '_GT.nii.gz'))
|
| 372 |
+
|
| 373 |
+
# Save full-res DDF (converet_to_nibabel handles multi-channel → channel-last)
|
| 374 |
+
nib.save(utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
|
| 375 |
+
os.path.join(reg_ddf_savepath_fullres,
|
| 376 |
+
utils.get_barcode([pid, e, im, noise_step]) + '.nii.gz'))
|
| 377 |
+
|
| 378 |
+
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| 379 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 380 |
+
|
| 381 |
+
if e > 5:
|
| 382 |
+
break
|
OM_train_2modes-reg.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from torch.optim import Adam, SGD
|
| 14 |
+
from Diffusion.diffuser import DeformDDPM
|
| 15 |
+
from Diffusion.networks import get_net, STN
|
| 16 |
+
from torchvision.transforms import Lambda
|
| 17 |
+
import Diffusion.losses as losses
|
| 18 |
+
import random
|
| 19 |
+
import glob
|
| 20 |
+
import numpy as np
|
| 21 |
+
import utils
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 25 |
+
from Dataloader.dataLoader import *
|
| 26 |
+
|
| 27 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 28 |
+
import yaml
|
| 29 |
+
import argparse
|
| 30 |
+
|
| 31 |
+
####################
|
| 32 |
+
import torch.multiprocessing as mp
|
| 33 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 34 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 35 |
+
import torch.distributed as dist
|
| 36 |
+
# from torch.distributed import init_process_group
|
| 37 |
+
###############
|
| 38 |
+
def ddp_setup(rank, world_size):
|
| 39 |
+
"""
|
| 40 |
+
Args:
|
| 41 |
+
rank: Unique identifier of each process
|
| 42 |
+
world_size: Total number of processes
|
| 43 |
+
"""
|
| 44 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 45 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 46 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 47 |
+
torch.cuda.set_device(rank)
|
| 48 |
+
|
| 49 |
+
use_distributed = True
|
| 50 |
+
# use_distributed = False
|
| 51 |
+
|
| 52 |
+
EPS = 1e-5
|
| 53 |
+
MSK_EPS = 0.01
|
| 54 |
+
TEXT_EMBED_PROB = 0.7
|
| 55 |
+
AUG_RESAMPLE_PROB = 0.6
|
| 56 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16] # [ang, dist, reg]
|
| 57 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 58 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.2, 1e2] # [imgsim, imgmse, ddf]
|
| 59 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 60 |
+
|
| 61 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 62 |
+
|
| 63 |
+
parser = argparse.ArgumentParser()
|
| 64 |
+
|
| 65 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--config",
|
| 68 |
+
"-C",
|
| 69 |
+
help="Path for the config file",
|
| 70 |
+
type=str,
|
| 71 |
+
# default="Config/config_cmr.yaml",
|
| 72 |
+
# default="Config/config_lct.yaml",
|
| 73 |
+
default="Config/config_all.yaml",
|
| 74 |
+
required=False,
|
| 75 |
+
)
|
| 76 |
+
args = parser.parse_args()
|
| 77 |
+
#=======================================================================================================================
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 82 |
+
if use_distributed:
|
| 83 |
+
ddp_setup(rank,world_size)
|
| 84 |
+
|
| 85 |
+
if torch.distributed.is_initialized():
|
| 86 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 87 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 88 |
+
gpu_id = rank
|
| 89 |
+
|
| 90 |
+
# Load the YAML file into a dictionary
|
| 91 |
+
with open(args.config, 'r') as file:
|
| 92 |
+
hyp_parameters = yaml.safe_load(file)
|
| 93 |
+
print(hyp_parameters)
|
| 94 |
+
|
| 95 |
+
# epoch_per_save=10
|
| 96 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 97 |
+
|
| 98 |
+
data_name=hyp_parameters['data_name']
|
| 99 |
+
net_name = hyp_parameters['net_name']
|
| 100 |
+
|
| 101 |
+
Net=get_net(net_name)
|
| 102 |
+
|
| 103 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 104 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 105 |
+
model_dir=model_save_path
|
| 106 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 107 |
+
|
| 108 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 109 |
+
|
| 110 |
+
# tsfm = torchvision.transforms.Compose([
|
| 111 |
+
# torchvision.transforms.ToTensor(),
|
| 112 |
+
# ])
|
| 113 |
+
|
| 114 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 115 |
+
# train_loader = DataLoader(
|
| 116 |
+
# dataset,
|
| 117 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 118 |
+
# # shuffle=False,
|
| 119 |
+
# shuffle=True,
|
| 120 |
+
# drop_last=True,
|
| 121 |
+
# )
|
| 122 |
+
|
| 123 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 124 |
+
dataset = OMDataset_indiv(transform=None)
|
| 125 |
+
train_loader = DataLoader(
|
| 126 |
+
dataset,
|
| 127 |
+
batch_size=hyp_parameters['batchsize'],
|
| 128 |
+
shuffle=True,
|
| 129 |
+
drop_last=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 133 |
+
datasetp = OMDataset_pair(transform=None)
|
| 134 |
+
train_loader_p = DataLoader(
|
| 135 |
+
datasetp,
|
| 136 |
+
batch_size=hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO,
|
| 137 |
+
shuffle=True,
|
| 138 |
+
drop_last=True,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
Deformddpm = DeformDDPM(
|
| 144 |
+
network=Net(
|
| 145 |
+
n_steps=hyp_parameters["timesteps"],
|
| 146 |
+
ndims=hyp_parameters["ndims"],
|
| 147 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 148 |
+
res = hyp_parameters['img_size']
|
| 149 |
+
),
|
| 150 |
+
n_steps=hyp_parameters["timesteps"],
|
| 151 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 152 |
+
device=hyp_parameters["device"],
|
| 153 |
+
batch_size=hyp_parameters["batchsize"],
|
| 154 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 155 |
+
v_scale=hyp_parameters["v_scale"],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
ddf_stn = STN(
|
| 160 |
+
img_sz=hyp_parameters["img_size"],
|
| 161 |
+
ndims=hyp_parameters["ndims"],
|
| 162 |
+
# padding_mode="zeros",
|
| 163 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 164 |
+
device=hyp_parameters["device"],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if use_distributed:
|
| 169 |
+
Deformddpm.to(rank)
|
| 170 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 171 |
+
ddf_stn.to(rank)
|
| 172 |
+
else:
|
| 173 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 174 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 175 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# mse = nn.MSELoss()
|
| 179 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 180 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 181 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 182 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 183 |
+
|
| 184 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 185 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 186 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 187 |
+
loss_imgsim = losses.LNCC()
|
| 188 |
+
loss_imgmse = losses.LMSE()
|
| 189 |
+
|
| 190 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 191 |
+
# hyp_parameters["lr"]=0.00000001
|
| 192 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 193 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 194 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 195 |
+
|
| 196 |
+
# # LR scheduler ----- YHM
|
| 197 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 198 |
+
|
| 199 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 200 |
+
|
| 201 |
+
# check for existing models
|
| 202 |
+
if not os.path.exists(model_dir):
|
| 203 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 204 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 205 |
+
model_files.sort()
|
| 206 |
+
if model_files:
|
| 207 |
+
if gpu_id == 0:
|
| 208 |
+
print(model_files)
|
| 209 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1])
|
| 210 |
+
else:
|
| 211 |
+
initial_epoch = 0
|
| 212 |
+
|
| 213 |
+
if gpu_id == 0:
|
| 214 |
+
print('len_train_data: ',len(dataset))
|
| 215 |
+
# Training loop
|
| 216 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 217 |
+
|
| 218 |
+
epoch_loss_tot = 0.0
|
| 219 |
+
epoch_loss_gen_d = 0.0
|
| 220 |
+
epoch_loss_gen_a = 0.0
|
| 221 |
+
epoch_loss_reg = 0.0
|
| 222 |
+
epoch_loss_regist = 0.0
|
| 223 |
+
epoch_loss_imgsim = 0.0
|
| 224 |
+
epoch_loss_imgmse = 0.0
|
| 225 |
+
epoch_loss_ddfreg = 0.0
|
| 226 |
+
# Set model inside to train model
|
| 227 |
+
Deformddpm.train()
|
| 228 |
+
|
| 229 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 230 |
+
|
| 231 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 232 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 233 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 234 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 235 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 236 |
+
|
| 237 |
+
# x0, _ = batch
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
# ==========================================================================
|
| 241 |
+
# diffusion train on single image
|
| 242 |
+
|
| 243 |
+
# x0 = batch # for omni dataset
|
| 244 |
+
[x0,embd] = batch # for om dataset
|
| 245 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 246 |
+
# print('embd:', embd.shape)
|
| 247 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 248 |
+
embd = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 249 |
+
else:
|
| 250 |
+
embd = None
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
n = x0.size()[0] # batch_size -> n
|
| 255 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 256 |
+
|
| 257 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 258 |
+
|
| 259 |
+
# random deformation + rotation
|
| 260 |
+
if hyp_parameters["ndims"]>2:
|
| 261 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 262 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 263 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 264 |
+
else:
|
| 265 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 266 |
+
# x0 = transformer(x0)
|
| 267 |
+
if hyp_parameters['noise_scale']>0:
|
| 268 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 269 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 270 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 271 |
+
|
| 272 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 273 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 274 |
+
hyp_parameters["device"]
|
| 275 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 276 |
+
|
| 277 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 278 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 279 |
+
# print('proc_type:', proc_type)
|
| 280 |
+
cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
|
| 281 |
+
|
| 282 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 283 |
+
|
| 284 |
+
# print(torch.max(torch.abs(pre_dvf_I)))
|
| 285 |
+
# print(torch.max(torch.abs(dvf_I)))
|
| 286 |
+
|
| 287 |
+
loss_tot=0
|
| 288 |
+
|
| 289 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 290 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 291 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 292 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 293 |
+
|
| 294 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 295 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 296 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 297 |
+
|
| 298 |
+
# >> JZ: print nan in x0
|
| 299 |
+
if torch.isnan(x0).any():
|
| 300 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 301 |
+
# >> JZ: print loss of ddf
|
| 302 |
+
if loss_ddf>0.001:
|
| 303 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 304 |
+
# yu: check if loss_tot==nan or inf
|
| 305 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 306 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 307 |
+
loss_nan_step += 1
|
| 308 |
+
continue
|
| 309 |
+
if loss_nan_step > 5:
|
| 310 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 311 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 312 |
+
|
| 313 |
+
optimizer.zero_grad()
|
| 314 |
+
loss_tot.backward()
|
| 315 |
+
optimizer.step()
|
| 316 |
+
|
| 317 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 318 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 319 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 320 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 321 |
+
|
| 322 |
+
# print(loss_gen_a.item())
|
| 323 |
+
# if 0:
|
| 324 |
+
# if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
|
| 325 |
+
if step%train_mode_ratio == 0:
|
| 326 |
+
# ==========================================================================
|
| 327 |
+
# registration train on paired images
|
| 328 |
+
# x1, y1 = next(iter(train_loader_p))
|
| 329 |
+
# [x1, y1, _, embd_y] = next(iter(train_loader_p))
|
| 330 |
+
[x1, y1, _, embd_y] = batch_p
|
| 331 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 332 |
+
# embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
|
| 333 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 334 |
+
else:
|
| 335 |
+
# embd_x = None
|
| 336 |
+
embd_y = None
|
| 337 |
+
|
| 338 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 339 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 340 |
+
n = x1.size()[0] # batch_size -> n
|
| 341 |
+
# random deformation + rotation
|
| 342 |
+
# if hyp_parameters["ndims"]>2:
|
| 343 |
+
# if np.random.uniform(0,1)<0.6:
|
| 344 |
+
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 345 |
+
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 346 |
+
# x1 = transformer(x1)
|
| 347 |
+
# y1 = transformer(y1)
|
| 348 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 349 |
+
if hyp_parameters['noise_scale']>0:
|
| 350 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 351 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 352 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 353 |
+
x1 = x1 * random_scale + random_shift
|
| 354 |
+
y1 = y1 * random_scale + random_shift
|
| 355 |
+
# x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
|
| 356 |
+
# x1 = x1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 357 |
+
# y1 = thresh_img(y1, [0, 2*hyp_parameters['noise_scale']])
|
| 358 |
+
# y1 = y1 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 359 |
+
# # Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 360 |
+
# t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 361 |
+
# hyp_parameters["device"]
|
| 362 |
+
# ) # pick up a seq of rand number from 0 to 'timestep'
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# scale_regist = np.random.uniform(0.6,1.)
|
| 366 |
+
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
|
| 367 |
+
# print('T_regist (0.6,1) sampling range:', T_regist)
|
| 368 |
+
scale_regist = np.random.uniform(0.0,0.7)
|
| 369 |
+
select_timestep = np.random.randint(8, 17) # select a random number of timesteps to sample, between 8 and 16
|
| 370 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 371 |
+
# print('T_regist (0.1,0.7) sampling range:', T_regist)
|
| 372 |
+
# scale_regist = np.random.uniform(0.4,1.)
|
| 373 |
+
# T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
|
| 374 |
+
# scale_regist = np.random.uniform(0.6,1.)
|
| 375 |
+
# init_T = int(hyp_parameters["timesteps"] * scale_regist)
|
| 376 |
+
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist)), 2)+list(range(init_T,hyp_parameters["timesteps"]+1)), reverse=True)
|
| 377 |
+
|
| 378 |
+
T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
|
| 379 |
+
|
| 380 |
+
# print('T_regist:', T_regist)
|
| 381 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 382 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 383 |
+
# proc_type = random.choice(['project'])
|
| 384 |
+
y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 385 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 386 |
+
# loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 387 |
+
# loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
|
| 388 |
+
# loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=(msk_tgt+MSK_EPS)) # calculate loss for the registration process
|
| 389 |
+
loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 390 |
+
loss_mse = loss_imgmse(img_rec, y1, label=(y1>=0.0)) # calculate loss for the registration process
|
| 391 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 392 |
+
|
| 393 |
+
loss_regist = 0
|
| 394 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 395 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 396 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 397 |
+
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 398 |
+
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 399 |
+
|
| 400 |
+
# >> JZ: print nan in x0
|
| 401 |
+
if torch.isnan(x0).any():
|
| 402 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 403 |
+
# >> JZ: print loss of ddf
|
| 404 |
+
if loss_ddf1>0.002:
|
| 405 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 406 |
+
# # Print gradients for each parameter
|
| 407 |
+
# for name, param in Deformddpm.named_parameters():
|
| 408 |
+
# if param.grad is not None:
|
| 409 |
+
# print(f"Gradient for {name}: {param.grad.norm()}")
|
| 410 |
+
# else:
|
| 411 |
+
# print(f"Gradient for {name}: None")
|
| 412 |
+
|
| 413 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 414 |
+
optimizer.zero_grad()
|
| 415 |
+
loss_regist.backward()
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.4)
|
| 420 |
+
optimizer.step()
|
| 421 |
+
|
| 422 |
+
epoch_loss_regist += loss_regist.item() / total
|
| 423 |
+
epoch_loss_imgsim += loss_sim.item() / total
|
| 424 |
+
epoch_loss_imgmse += loss_mse.item() / total
|
| 425 |
+
epoch_loss_ddfreg += loss_ddf1.item() / total
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if step % 10 == 0:
|
| 429 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 430 |
+
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 431 |
+
# break # FOR TESTING
|
| 432 |
+
# else:
|
| 433 |
+
# print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
|
| 434 |
+
# pass
|
| 435 |
+
|
| 436 |
+
if 1:
|
| 437 |
+
# if gpu_id == 0:
|
| 438 |
+
print('==================')
|
| 439 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 440 |
+
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
|
| 441 |
+
print('==================')
|
| 442 |
+
# # LR schedular step ----- YHM
|
| 443 |
+
# scheduler.step()
|
| 444 |
+
|
| 445 |
+
if 0 == epoch % epoch_per_save:
|
| 446 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 447 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 448 |
+
# break # FOR TESTING
|
| 449 |
+
if not use_distributed:
|
| 450 |
+
print(f"saved in {save_dir}")
|
| 451 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 452 |
+
torch.save({
|
| 453 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 454 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 455 |
+
'epoch': epoch
|
| 456 |
+
}, save_dir)
|
| 457 |
+
elif gpu_id == 0:
|
| 458 |
+
print(f"saved in {save_dir}")
|
| 459 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 460 |
+
torch.save({
|
| 461 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 462 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 463 |
+
'epoch': epoch
|
| 464 |
+
}, save_dir)
|
| 465 |
+
|
| 466 |
+
# Resource cleanup at the end of training
|
| 467 |
+
torch.cuda.empty_cache()
|
| 468 |
+
gc.collect()
|
| 469 |
+
if use_distributed and dist.is_initialized():
|
| 470 |
+
dist.destroy_process_group()
|
| 471 |
+
|
| 472 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 473 |
+
|
| 474 |
+
if gpu_id == 0:
|
| 475 |
+
# if 0:
|
| 476 |
+
utils.print_memory_usage("Before Loading Model")
|
| 477 |
+
if 1:
|
| 478 |
+
gc.collect()
|
| 479 |
+
torch.cuda.empty_cache()
|
| 480 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 481 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 482 |
+
checkpoint = torch.load(model_file)
|
| 483 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 484 |
+
if use_distributed:
|
| 485 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 486 |
+
else:
|
| 487 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 488 |
+
if load_strict:
|
| 489 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 490 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 491 |
+
|
| 492 |
+
if use_distributed:
|
| 493 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 494 |
+
dist.barrier()
|
| 495 |
+
for param in Deformddpm.parameters():
|
| 496 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 497 |
+
dist.barrier()
|
| 498 |
+
for param_group in optimizer.param_groups:
|
| 499 |
+
for param in param_group['params']:
|
| 500 |
+
if param.grad is not None:
|
| 501 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 502 |
+
|
| 503 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 504 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 505 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 506 |
+
|
| 507 |
+
return initial_epoch, Deformddpm, optimizer
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
if __name__ == "__main__":
|
| 512 |
+
if use_distributed:
|
| 513 |
+
world_size = torch.cuda.device_count()
|
| 514 |
+
print(f"Distributed GPU number = {world_size}")
|
| 515 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 516 |
+
else:
|
| 517 |
+
main_train(0,1)
|
OM_train_2modes.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import torchvision
|
|
@@ -48,12 +52,11 @@ use_distributed = True
|
|
| 48 |
EPS = 1e-5
|
| 49 |
MSK_EPS = 0.01
|
| 50 |
TEXT_EMBED_PROB = 0.7
|
| 51 |
-
AUG_RESAMPLE_PROB = 0.
|
| 52 |
-
LOSS_WEIGHTS_DIFF = [2.0,
|
| 53 |
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
LOSS_WEIGHTS_REGIST = [2.0, 0.1, 256] # [imgsim, imgmse, ddf]
|
| 57 |
|
| 58 |
# AUG_PERMUTE_PROB = 0.35
|
| 59 |
|
|
@@ -130,7 +133,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 130 |
datasetp = OMDataset_pair(transform=None)
|
| 131 |
train_loader_p = DataLoader(
|
| 132 |
datasetp,
|
| 133 |
-
batch_size=hyp_parameters['batchsize']//
|
| 134 |
shuffle=True,
|
| 135 |
drop_last=True,
|
| 136 |
)
|
|
@@ -174,12 +177,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 174 |
|
| 175 |
# mse = nn.MSELoss()
|
| 176 |
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
|
|
|
| 177 |
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 178 |
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
|
|
|
| 179 |
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 180 |
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 181 |
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 182 |
-
loss_imgsim = losses.LNCC()
|
|
|
|
| 183 |
loss_imgmse = losses.LMSE()
|
| 184 |
|
| 185 |
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
@@ -220,15 +226,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 220 |
epoch_loss_ddfreg = 0.0
|
| 221 |
# Set model inside to train model
|
| 222 |
Deformddpm.train()
|
| 223 |
-
|
| 224 |
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 225 |
|
| 226 |
total = min(len(train_loader), len(train_loader_p))
|
| 227 |
-
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 228 |
# for step, batch in tqdm(enumerate(train_loader)):
|
| 229 |
# for step, batch in tqdm(enumerate(train_loader)):
|
| 230 |
-
|
| 231 |
# for step, batch in enumerate(train_loader_omni):
|
|
|
|
|
|
|
| 232 |
# x0, _ = batch
|
| 233 |
|
| 234 |
|
|
@@ -258,10 +264,10 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 258 |
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 259 |
else:
|
| 260 |
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 261 |
-
x0 = transformer(x0)
|
| 262 |
if hyp_parameters['noise_scale']>0:
|
| 263 |
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 264 |
-
x0 = thresh_img(x0, [0,
|
| 265 |
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 266 |
|
| 267 |
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
|
@@ -270,12 +276,15 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 270 |
) # pick up a seq of rand number from 0 to 'timestep'
|
| 271 |
|
| 272 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 273 |
-
proc_type = random.choice(['adding', 'downsample', 'slice', 'none', 'uncon', 'uncon', 'uncon'])
|
| 274 |
# print('proc_type:', proc_type)
|
| 275 |
cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
|
| 276 |
|
| 277 |
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 278 |
|
|
|
|
|
|
|
|
|
|
| 279 |
loss_tot=0
|
| 280 |
|
| 281 |
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
|
@@ -302,15 +311,14 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 302 |
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 303 |
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 304 |
|
| 305 |
-
|
| 306 |
optimizer.zero_grad()
|
| 307 |
loss_tot.backward()
|
| 308 |
optimizer.step()
|
| 309 |
|
| 310 |
-
epoch_loss_tot += loss_tot.item()
|
| 311 |
-
epoch_loss_gen_d += loss_gen_d.item()
|
| 312 |
-
epoch_loss_gen_a += loss_gen_a.item()
|
| 313 |
-
epoch_loss_reg += loss_ddf.item()
|
| 314 |
|
| 315 |
# print(loss_gen_a.item())
|
| 316 |
# if 0:
|
|
@@ -336,8 +344,8 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 336 |
# if np.random.uniform(0,1)<0.6:
|
| 337 |
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 338 |
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 339 |
-
x1 = transformer(x1)
|
| 340 |
-
y1 = transformer(y1)
|
| 341 |
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 342 |
if hyp_parameters['noise_scale']>0:
|
| 343 |
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
|
@@ -355,10 +363,13 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 355 |
# ) # pick up a seq of rand number from 0 to 'timestep'
|
| 356 |
|
| 357 |
|
| 358 |
-
# scale_regist = np.random.uniform(0.
|
| 359 |
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
| 362 |
# scale_regist = np.random.uniform(0.4,1.)
|
| 363 |
# T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
|
| 364 |
# scale_regist = np.random.uniform(0.6,1.)
|
|
@@ -369,33 +380,30 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 369 |
|
| 370 |
# print('T_regist:', T_regist)
|
| 371 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 372 |
-
proc_type = random.choice(['
|
| 373 |
# proc_type = random.choice(['project'])
|
| 374 |
y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 375 |
-
|
| 376 |
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
#
|
| 380 |
-
loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 381 |
-
loss_mse = loss_imgmse(img_rec, y1, label=(y1>0.0)) # calculate loss for the registration process
|
| 382 |
-
loss_ddf1 = loss_reg1(ddf_comp,img=y1) # calculate loss for the registration process
|
| 383 |
-
|
| 384 |
loss_regist = 0
|
| 385 |
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 386 |
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 387 |
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 388 |
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 389 |
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 390 |
-
|
| 391 |
# >> JZ: print nan in x0
|
| 392 |
if torch.isnan(x0).any():
|
| 393 |
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
# >> JZ: print loss of ddf
|
| 398 |
-
if loss_ddf1>0.
|
| 399 |
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 400 |
# # Print gradients for each parameter
|
| 401 |
# for name, param in Deformddpm.named_parameters():
|
|
@@ -403,43 +411,25 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 403 |
# print(f"Gradient for {name}: {param.grad.norm()}")
|
| 404 |
# else:
|
| 405 |
# print(f"Gradient for {name}: None")
|
| 406 |
-
|
| 407 |
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 408 |
optimizer.zero_grad()
|
| 409 |
loss_regist.backward()
|
| 410 |
|
| 411 |
|
| 412 |
|
| 413 |
-
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.
|
| 414 |
optimizer.step()
|
| 415 |
|
| 416 |
-
epoch_loss_regist += loss_regist.item()
|
| 417 |
-
epoch_loss_imgsim += loss_sim.item()
|
| 418 |
-
epoch_loss_imgmse += loss_mse.item()
|
| 419 |
-
epoch_loss_ddfreg += loss_ddf1.item()
|
| 420 |
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
if loss_sim.item()>-0.001:
|
| 426 |
-
print(f"*** Zero image similarity loss at epoch {epoch}, step {step}.")
|
| 427 |
-
def save_niftiimage(tensor, filename):
|
| 428 |
-
import nibabel as nib
|
| 429 |
-
import numpy as np
|
| 430 |
-
array = tensor.squeeze().cpu().detach().numpy()
|
| 431 |
-
nifti_img = nib.Nifti1Image(array, affine=np.eye(4))
|
| 432 |
-
nib.save(nifti_img, filename)
|
| 433 |
-
# save the x1 and y1 images for debugging
|
| 434 |
-
save_path = os.path.join('/home/data/Github/OmniMorph/Log/error_files',f"debug_images_epoch{epoch}_step{step}/")
|
| 435 |
-
os.makedirs(save_path, exist_ok=True)
|
| 436 |
-
save_niftiimage(img_rec, os.path.join(save_path, 'img_rec.nii.gz'))
|
| 437 |
-
save_niftiimage(x1, os.path.join(save_path, 'x1.nii.gz'))
|
| 438 |
-
save_niftiimage(y1, os.path.join(save_path, 'y1.nii.gz'))
|
| 439 |
-
save_niftiimage(y1_proc, os.path.join(save_path, 'y1_proc.nii.gz'))
|
| 440 |
-
exit()
|
| 441 |
-
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 442 |
-
|
| 443 |
# break # FOR TESTING
|
| 444 |
# else:
|
| 445 |
# print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
|
|
@@ -481,7 +471,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 481 |
if use_distributed and dist.is_initialized():
|
| 482 |
dist.destroy_process_group()
|
| 483 |
|
| 484 |
-
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
|
| 485 |
|
| 486 |
if gpu_id == 0:
|
| 487 |
# if 0:
|
|
@@ -494,10 +484,11 @@ def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True
|
|
| 494 |
checkpoint = torch.load(model_file)
|
| 495 |
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 496 |
if use_distributed:
|
| 497 |
-
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'])
|
| 498 |
else:
|
| 499 |
-
Deformddpm.load_state_dict(checkpoint['model_state_dict'])
|
| 500 |
-
|
|
|
|
| 501 |
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 502 |
|
| 503 |
if use_distributed:
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
import gc
|
| 7 |
import torch
|
| 8 |
import torchvision
|
|
|
|
| 52 |
EPS = 1e-5
|
| 53 |
MSK_EPS = 0.01
|
| 54 |
TEXT_EMBED_PROB = 0.7
|
| 55 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 56 |
+
LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0] # [ang, dist, reg]
|
| 57 |
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 58 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
|
| 59 |
+
DIFF_REG_BATCH_RATIO = 2
|
|
|
|
| 60 |
|
| 61 |
# AUG_PERMUTE_PROB = 0.35
|
| 62 |
|
|
|
|
| 133 |
datasetp = OMDataset_pair(transform=None)
|
| 134 |
train_loader_p = DataLoader(
|
| 135 |
datasetp,
|
| 136 |
+
batch_size=hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO,
|
| 137 |
shuffle=True,
|
| 138 |
drop_last=True,
|
| 139 |
)
|
|
|
|
| 177 |
|
| 178 |
# mse = nn.MSELoss()
|
| 179 |
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 180 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 181 |
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 182 |
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 183 |
+
|
| 184 |
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 185 |
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 186 |
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 187 |
+
# loss_imgsim = losses.LNCC()
|
| 188 |
+
loss_imgsim = losses.MSLNCC()
|
| 189 |
loss_imgmse = losses.LMSE()
|
| 190 |
|
| 191 |
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
|
|
| 226 |
epoch_loss_ddfreg = 0.0
|
| 227 |
# Set model inside to train model
|
| 228 |
Deformddpm.train()
|
| 229 |
+
|
| 230 |
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 231 |
|
| 232 |
total = min(len(train_loader), len(train_loader_p))
|
|
|
|
| 233 |
# for step, batch in tqdm(enumerate(train_loader)):
|
| 234 |
# for step, batch in tqdm(enumerate(train_loader)):
|
|
|
|
| 235 |
# for step, batch in enumerate(train_loader_omni):
|
| 236 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 237 |
+
|
| 238 |
# x0, _ = batch
|
| 239 |
|
| 240 |
|
|
|
|
| 264 |
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 265 |
else:
|
| 266 |
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 267 |
+
# x0 = transformer(x0)
|
| 268 |
if hyp_parameters['noise_scale']>0:
|
| 269 |
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 270 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 271 |
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 272 |
|
| 273 |
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
|
|
|
| 276 |
) # pick up a seq of rand number from 0 to 'timestep'
|
| 277 |
|
| 278 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 279 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 280 |
# print('proc_type:', proc_type)
|
| 281 |
cond_img, _, cond_ratio = Deformddpm.module.proc_cond_img(x0,proc_type=proc_type)
|
| 282 |
|
| 283 |
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 284 |
|
| 285 |
+
# print(torch.max(torch.abs(pre_dvf_I)))
|
| 286 |
+
# print(torch.max(torch.abs(dvf_I)))
|
| 287 |
+
|
| 288 |
loss_tot=0
|
| 289 |
|
| 290 |
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
|
|
|
| 311 |
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 312 |
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 313 |
|
|
|
|
| 314 |
optimizer.zero_grad()
|
| 315 |
loss_tot.backward()
|
| 316 |
optimizer.step()
|
| 317 |
|
| 318 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 319 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 320 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 321 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 322 |
|
| 323 |
# print(loss_gen_a.item())
|
| 324 |
# if 0:
|
|
|
|
| 344 |
# if np.random.uniform(0,1)<0.6:
|
| 345 |
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 346 |
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 347 |
+
# x1 = transformer(x1)
|
| 348 |
+
# y1 = transformer(y1)
|
| 349 |
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 350 |
if hyp_parameters['noise_scale']>0:
|
| 351 |
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
|
|
|
| 363 |
# ) # pick up a seq of rand number from 0 to 'timestep'
|
| 364 |
|
| 365 |
|
| 366 |
+
# scale_regist = np.random.uniform(0.6,1.)
|
| 367 |
# T_regist = sorted(random.sample(range(0, int(hyp_parameters["timesteps"] * scale_regist) + 1), 16), reverse=True)
|
| 368 |
+
# print('T_regist (0.6,1) sampling range:', T_regist)
|
| 369 |
+
scale_regist = np.random.uniform(0.0,0.7)
|
| 370 |
+
select_timestep = np.random.randint(8, 17) # select a random number of timesteps to sample, between 8 and 16
|
| 371 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 372 |
+
# print('T_regist (0.1,0.7) sampling range:', T_regist)
|
| 373 |
# scale_regist = np.random.uniform(0.4,1.)
|
| 374 |
# T_regist = [int(hyp_parameters["timesteps"]*scale_regist)]
|
| 375 |
# scale_regist = np.random.uniform(0.6,1.)
|
|
|
|
| 380 |
|
| 381 |
# print('T_regist:', T_regist)
|
| 382 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 383 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 384 |
# proc_type = random.choice(['project'])
|
| 385 |
y1_proc, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 386 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 387 |
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 388 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 389 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 390 |
+
# loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=(msk_tgt+MSK_EPS)) # calculate loss for the registration process
|
| 391 |
+
# loss_sim = loss_imgsim(img_rec, y1, label=(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 392 |
+
# loss_mse = loss_imgmse(img_rec, y1, label=(y1>=0.0)) # calculate loss for the registration process
|
| 393 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 394 |
+
|
| 395 |
loss_regist = 0
|
| 396 |
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 397 |
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 398 |
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 399 |
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 400 |
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 401 |
+
|
| 402 |
# >> JZ: print nan in x0
|
| 403 |
if torch.isnan(x0).any():
|
| 404 |
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
|
|
|
|
|
|
|
|
|
| 405 |
# >> JZ: print loss of ddf
|
| 406 |
+
if loss_ddf1>0.002:
|
| 407 |
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 408 |
# # Print gradients for each parameter
|
| 409 |
# for name, param in Deformddpm.named_parameters():
|
|
|
|
| 411 |
# print(f"Gradient for {name}: {param.grad.norm()}")
|
| 412 |
# else:
|
| 413 |
# print(f"Gradient for {name}: None")
|
| 414 |
+
|
| 415 |
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 416 |
optimizer.zero_grad()
|
| 417 |
loss_regist.backward()
|
| 418 |
|
| 419 |
|
| 420 |
|
| 421 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.2)
|
| 422 |
optimizer.step()
|
| 423 |
|
| 424 |
+
epoch_loss_regist += loss_regist.item() / total
|
| 425 |
+
epoch_loss_imgsim += loss_sim.item() / total
|
| 426 |
+
epoch_loss_imgmse += loss_mse.item() / total
|
| 427 |
+
epoch_loss_ddfreg += loss_ddf1.item() / total
|
| 428 |
|
| 429 |
|
| 430 |
+
if step % 10 == 0:
|
| 431 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 432 |
+
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
# break # FOR TESTING
|
| 434 |
# else:
|
| 435 |
# print('loss_gen_a:',loss_gen_a.item()) # FOR TESTING
|
|
|
|
| 471 |
if use_distributed and dist.is_initialized():
|
| 472 |
dist.destroy_process_group()
|
| 473 |
|
| 474 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 475 |
|
| 476 |
if gpu_id == 0:
|
| 477 |
# if 0:
|
|
|
|
| 484 |
checkpoint = torch.load(model_file)
|
| 485 |
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 486 |
if use_distributed:
|
| 487 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 488 |
else:
|
| 489 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 490 |
+
if load_strict:
|
| 491 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 492 |
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 493 |
|
| 494 |
if use_distributed:
|
OM_train_3modes-XPU.py
ADDED
|
@@ -0,0 +1,957 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, contextlib
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from torch.optim import Adam, SGD
|
| 14 |
+
from Diffusion.diffuser import DeformDDPM
|
| 15 |
+
from Diffusion.networks import get_net, STN
|
| 16 |
+
from torchvision.transforms import Lambda
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import Diffusion.losses as losses
|
| 19 |
+
import random
|
| 20 |
+
import glob
|
| 21 |
+
import numpy as np
|
| 22 |
+
import utils
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 26 |
+
from Dataloader.dataLoader import *
|
| 27 |
+
|
| 28 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 29 |
+
import yaml
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
# XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
|
| 33 |
+
try:
|
| 34 |
+
import intel_extension_for_pytorch as ipex
|
| 35 |
+
except ImportError:
|
| 36 |
+
ipex = None
|
| 37 |
+
try:
|
| 38 |
+
import oneccl_bindings_for_pytorch
|
| 39 |
+
except (ImportError, Exception) as e:
|
| 40 |
+
print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
|
| 41 |
+
|
| 42 |
+
####################
|
| 43 |
+
import torch.multiprocessing as mp
|
| 44 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 45 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 46 |
+
import torch.distributed as dist
|
| 47 |
+
# from torch.distributed import init_process_group
|
| 48 |
+
###############
|
| 49 |
+
def _device_available(device_type):
|
| 50 |
+
if device_type == 'xpu':
|
| 51 |
+
return hasattr(torch, 'xpu') and torch.xpu.is_available()
|
| 52 |
+
return torch.cuda.is_available()
|
| 53 |
+
|
| 54 |
+
def _device_count(device_type):
|
| 55 |
+
if device_type == 'xpu':
|
| 56 |
+
return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
|
| 57 |
+
return torch.cuda.device_count()
|
| 58 |
+
|
| 59 |
+
def _set_device(rank, device_type):
|
| 60 |
+
if device_type == 'xpu':
|
| 61 |
+
torch.xpu.set_device(rank)
|
| 62 |
+
else:
|
| 63 |
+
torch.cuda.set_device(rank)
|
| 64 |
+
|
| 65 |
+
def _empty_cache(device_type):
|
| 66 |
+
if device_type == 'xpu' and hasattr(torch, 'xpu'):
|
| 67 |
+
torch.xpu.empty_cache()
|
| 68 |
+
elif torch.cuda.is_available():
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
|
| 71 |
+
def ddp_setup(rank, world_size):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
rank: Unique identifier of each process (local_rank when launched by torchrun)
|
| 75 |
+
world_size: Total number of processes
|
| 76 |
+
"""
|
| 77 |
+
backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
|
| 78 |
+
if "LOCAL_RANK" in os.environ:
|
| 79 |
+
# Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
|
| 80 |
+
dist.init_process_group(backend=backend)
|
| 81 |
+
_set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
|
| 82 |
+
else:
|
| 83 |
+
# Single-node mp.spawn
|
| 84 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 85 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 86 |
+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| 87 |
+
_set_device(rank, DEVICE_TYPE)
|
| 88 |
+
|
| 89 |
+
EPS = 1e-5
|
| 90 |
+
MSK_EPS = 0.01
|
| 91 |
+
TEXT_EMBED_PROB = 0.5
|
| 92 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 93 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
|
| 94 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 95 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
|
| 96 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 97 |
+
LOSS_WEIGHT_CONTRASTIVE = 1e-4
|
| 98 |
+
REGISTRATION_STEP_RATIO = 1
|
| 99 |
+
CONTRASTIVE_STEP_RATIO = 1
|
| 100 |
+
MID_EPOCH_SAVE_STEPS = 10 # Save mid-epoch checkpoint every N steps for crash recovery.
|
| 101 |
+
# XPU autograd leaks ~1.0 GiB/step of device memory (Intel bug).
|
| 102 |
+
# With gradient checkpointing, training survives ~26 steps from fresh start,
|
| 103 |
+
# but fewer when carrying leaked memory from previous epoch.
|
| 104 |
+
# Save every 10 steps to minimize lost work on OOM crash.
|
| 105 |
+
EXIT_CODE_RESTART = 42 # Exit code signaling proactive restart (not a crash).
|
| 106 |
+
|
| 107 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 108 |
+
|
| 109 |
+
parser = argparse.ArgumentParser()
|
| 110 |
+
|
| 111 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--config",
|
| 114 |
+
"-C",
|
| 115 |
+
help="Path for the config file",
|
| 116 |
+
type=str,
|
| 117 |
+
# default="Config/config_cmr.yaml",
|
| 118 |
+
# default="Config/config_lct.yaml",
|
| 119 |
+
default="Config/config_all.yaml",
|
| 120 |
+
required=False,
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
|
| 123 |
+
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| 124 |
+
parser.add_argument("--max-steps-before-restart", type=int, default=0,
|
| 125 |
+
help="Proactive restart: exit after N training steps to reset XPU memory leak. "
|
| 126 |
+
"0=disabled (rely on OOM crash + auto-resubmit). "
|
| 127 |
+
"Recommended: 20 for XPU (survives ~26 steps max).")
|
| 128 |
+
parser.add_argument("--no-save", action="store_true",
|
| 129 |
+
help="Disable all checkpoint saving (for diagnostic/validation runs)")
|
| 130 |
+
parser.add_argument("--reset-optimizer", action="store_true",
|
| 131 |
+
help="Skip optimizer state loading from checkpoint (use when architecture changed)")
|
| 132 |
+
parser.add_argument("--eval-only", action="store_true",
|
| 133 |
+
help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)")
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
# Read config early to determine device type for DDP setup
|
| 137 |
+
with open(args.config, 'r') as _f:
|
| 138 |
+
_cfg = yaml.safe_load(_f)
|
| 139 |
+
DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
|
| 140 |
+
|
| 141 |
+
# Auto-detect: use DDP only when multiple devices are available
|
| 142 |
+
use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
|
| 143 |
+
# use_distributed = True
|
| 144 |
+
# use_distributed = False
|
| 145 |
+
#=======================================================================================================================
|
| 146 |
+
|
| 147 |
+
class _DummyIndiv(torch.utils.data.Dataset):
|
| 148 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 149 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 150 |
+
def __len__(self): return self.n
|
| 151 |
+
def __getitem__(self, i):
|
| 152 |
+
return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
|
| 153 |
+
|
| 154 |
+
class _DummyPair(torch.utils.data.Dataset):
|
| 155 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 156 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 157 |
+
def __len__(self): return self.n
|
| 158 |
+
def __getitem__(self, i):
|
| 159 |
+
return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 160 |
+
np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 161 |
+
np.random.randn(self.embd_dim).astype(np.float32),
|
| 162 |
+
np.random.randn(self.embd_dim).astype(np.float32))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 166 |
+
if use_distributed:
|
| 167 |
+
ddp_setup(rank,world_size)
|
| 168 |
+
|
| 169 |
+
if torch.distributed.is_initialized() and rank == 0:
|
| 170 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 171 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 172 |
+
print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}")
|
| 173 |
+
if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 174 |
+
props = torch.xpu.get_device_properties(0)
|
| 175 |
+
print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB")
|
| 176 |
+
# gpu_id = global rank (for save/print guards); rank = local device index
|
| 177 |
+
if "RANK" in os.environ:
|
| 178 |
+
gpu_id = int(os.environ["RANK"])
|
| 179 |
+
rank = int(os.environ["LOCAL_RANK"])
|
| 180 |
+
else:
|
| 181 |
+
gpu_id = rank
|
| 182 |
+
|
| 183 |
+
# Load the YAML file into a dictionary
|
| 184 |
+
with open(args.config, 'r') as file:
|
| 185 |
+
hyp_parameters = yaml.safe_load(file)
|
| 186 |
+
if args.batchsize > 0:
|
| 187 |
+
hyp_parameters['batchsize'] = args.batchsize
|
| 188 |
+
if gpu_id == 0:
|
| 189 |
+
print(hyp_parameters)
|
| 190 |
+
|
| 191 |
+
# epoch_per_save=10
|
| 192 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 193 |
+
|
| 194 |
+
data_name=hyp_parameters['data_name']
|
| 195 |
+
net_name = hyp_parameters['net_name']
|
| 196 |
+
|
| 197 |
+
Net=get_net(net_name)
|
| 198 |
+
|
| 199 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 200 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 201 |
+
model_dir=model_save_path
|
| 202 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 203 |
+
|
| 204 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 205 |
+
|
| 206 |
+
# tsfm = torchvision.transforms.Compose([
|
| 207 |
+
# torchvision.transforms.ToTensor(),
|
| 208 |
+
# ])
|
| 209 |
+
|
| 210 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 211 |
+
# train_loader = DataLoader(
|
| 212 |
+
# dataset,
|
| 213 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 214 |
+
# # shuffle=False,
|
| 215 |
+
# shuffle=True,
|
| 216 |
+
# drop_last=True,
|
| 217 |
+
# )
|
| 218 |
+
|
| 219 |
+
if args.dummy_samples > 0:
|
| 220 |
+
dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
|
| 221 |
+
datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
|
| 222 |
+
else:
|
| 223 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 224 |
+
dataset = OMDataset_indiv(transform=None)
|
| 225 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 226 |
+
datasetp = OMDataset_pair(transform=None)
|
| 227 |
+
|
| 228 |
+
if use_distributed:
|
| 229 |
+
sampler = DistributedSampler(dataset, shuffle=True)
|
| 230 |
+
sampler_p = DistributedSampler(datasetp, shuffle=True)
|
| 231 |
+
else:
|
| 232 |
+
sampler = None
|
| 233 |
+
sampler_p = None
|
| 234 |
+
|
| 235 |
+
train_loader = DataLoader(
|
| 236 |
+
dataset,
|
| 237 |
+
batch_size=hyp_parameters['batchsize'],
|
| 238 |
+
shuffle=(sampler is None),
|
| 239 |
+
drop_last=True,
|
| 240 |
+
sampler=sampler,
|
| 241 |
+
)
|
| 242 |
+
train_loader_p = DataLoader(
|
| 243 |
+
datasetp,
|
| 244 |
+
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| 245 |
+
shuffle=(sampler_p is None),
|
| 246 |
+
drop_last=True,
|
| 247 |
+
sampler=sampler_p,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
network = Net(
|
| 253 |
+
n_steps=hyp_parameters["timesteps"],
|
| 254 |
+
ndims=hyp_parameters["ndims"],
|
| 255 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 256 |
+
res = hyp_parameters['img_size']
|
| 257 |
+
)
|
| 258 |
+
# Enable gradient checkpointing on XPU to reduce peak activation memory.
|
| 259 |
+
# XPU autograd leaks ~1.0 GiB/step; lower peak buys more steps before OOM.
|
| 260 |
+
if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'):
|
| 261 |
+
network.use_checkpoint = True
|
| 262 |
+
if gpu_id == 0:
|
| 263 |
+
print(" [init] Gradient checkpointing enabled for XPU", flush=True)
|
| 264 |
+
|
| 265 |
+
Deformddpm = DeformDDPM(
|
| 266 |
+
network=network,
|
| 267 |
+
n_steps=hyp_parameters["timesteps"],
|
| 268 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 269 |
+
device=hyp_parameters["device"],
|
| 270 |
+
batch_size=hyp_parameters["batchsize"],
|
| 271 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 272 |
+
v_scale=hyp_parameters["v_scale"],
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
ddf_stn = STN(
|
| 277 |
+
img_sz=hyp_parameters["img_size"],
|
| 278 |
+
ndims=hyp_parameters["ndims"],
|
| 279 |
+
# padding_mode="zeros",
|
| 280 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 281 |
+
device=hyp_parameters["device"],
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
if use_distributed:
|
| 286 |
+
device = f"{DEVICE_TYPE}:{rank}"
|
| 287 |
+
# NO pre-allocation. CCL/oneDNN accumulate ~1.4 GiB/step of device memory outside
|
| 288 |
+
# PyTorch's caching allocator. Pre-allocating steals from that budget:
|
| 289 |
+
# 92% pre-alloc → crash at step 3, 78% → step 10, none (70% cap) → step 14.
|
| 290 |
+
# Instead, use empty_cache() between training phases to release unused cached memory
|
| 291 |
+
# back to the device for CCL/oneDNN.
|
| 292 |
+
if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 293 |
+
total_mem = torch.xpu.get_device_properties(rank).total_memory
|
| 294 |
+
print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True)
|
| 295 |
+
Deformddpm.to(device)
|
| 296 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
|
| 297 |
+
ddf_stn.to(device)
|
| 298 |
+
else:
|
| 299 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 300 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 301 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# mse = nn.MSELoss()
|
| 305 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 306 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 307 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 308 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 309 |
+
|
| 310 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 311 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 312 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 313 |
+
loss_imgsim = losses.MSLNCC()
|
| 314 |
+
loss_imgmse = losses.LMSE()
|
| 315 |
+
|
| 316 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 317 |
+
# hyp_parameters["lr"]=0.00000001
|
| 318 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 319 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 320 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 321 |
+
|
| 322 |
+
# # LR scheduler ----- YHM
|
| 323 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 324 |
+
|
| 325 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 326 |
+
|
| 327 |
+
# check for existing models
|
| 328 |
+
if not os.path.exists(model_dir):
|
| 329 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 330 |
+
# Check for checkpoints: first check tmp/ for mid-epoch, then main dir for epoch-level
|
| 331 |
+
tmp_dir = os.path.join(model_dir, "tmp")
|
| 332 |
+
tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth")))
|
| 333 |
+
model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
|
| 334 |
+
initial_step = 0
|
| 335 |
+
|
| 336 |
+
# Epoch stats and RNG states to restore when resuming from mid-epoch checkpoint
|
| 337 |
+
_resume_epoch_stats = None
|
| 338 |
+
_resume_rng = None
|
| 339 |
+
|
| 340 |
+
if tmp_files and not args.eval_only and args.max_steps_before_restart > 0:
|
| 341 |
+
# Mid-epoch checkpoint: only use when proactive restart is enabled
|
| 342 |
+
latest = tmp_files[-1]
|
| 343 |
+
if gpu_id == 0:
|
| 344 |
+
print(f" [resume] Found mid-epoch checkpoint: {latest}")
|
| 345 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| 346 |
+
basename = os.path.basename(latest)
|
| 347 |
+
initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0])
|
| 348 |
+
_ckpt = torch.load(latest, map_location='cpu', weights_only=False)
|
| 349 |
+
_resume_epoch_stats = _ckpt.get('epoch_stats', None)
|
| 350 |
+
del _ckpt
|
| 351 |
+
if gpu_id == 0:
|
| 352 |
+
print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}"
|
| 353 |
+
f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True)
|
| 354 |
+
elif model_files:
|
| 355 |
+
if gpu_id == 0:
|
| 356 |
+
print(model_files)
|
| 357 |
+
latest = model_files[-1]
|
| 358 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| 359 |
+
else:
|
| 360 |
+
initial_epoch = 0
|
| 361 |
+
|
| 362 |
+
if gpu_id == 0:
|
| 363 |
+
print('len_train_data: ',len(dataset))
|
| 364 |
+
|
| 365 |
+
# Proactive restart: track steps since process start to exit before OOM.
|
| 366 |
+
max_steps_restart = args.max_steps_before_restart
|
| 367 |
+
steps_since_start = 0
|
| 368 |
+
|
| 369 |
+
# Training loop
|
| 370 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 371 |
+
if use_distributed and sampler is not None:
|
| 372 |
+
sampler.set_epoch(epoch)
|
| 373 |
+
sampler_p.set_epoch(epoch)
|
| 374 |
+
|
| 375 |
+
epoch_loss_tot = 0.0
|
| 376 |
+
epoch_loss_gen_d = 0.0
|
| 377 |
+
epoch_loss_gen_a = 0.0
|
| 378 |
+
epoch_loss_reg = 0.0
|
| 379 |
+
epoch_loss_regist = 0.0
|
| 380 |
+
epoch_loss_imgsim = 0.0
|
| 381 |
+
epoch_loss_imgmse = 0.0
|
| 382 |
+
epoch_loss_ddfreg = 0.0
|
| 383 |
+
epoch_loss_contrastive = 0.0
|
| 384 |
+
total_contra = 0
|
| 385 |
+
total_reg_restored = None
|
| 386 |
+
total_contra_restored = None
|
| 387 |
+
|
| 388 |
+
# Restore epoch accumulators from mid-epoch checkpoint (only for the resumed epoch)
|
| 389 |
+
if _resume_epoch_stats is not None and epoch == initial_epoch:
|
| 390 |
+
epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0)
|
| 391 |
+
epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0)
|
| 392 |
+
epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0)
|
| 393 |
+
epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0)
|
| 394 |
+
epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0)
|
| 395 |
+
epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0)
|
| 396 |
+
epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0)
|
| 397 |
+
epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0)
|
| 398 |
+
epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0)
|
| 399 |
+
total_reg_restored = _resume_epoch_stats.get('total_reg', None)
|
| 400 |
+
total_contra_restored = _resume_epoch_stats.get('total_contra', None)
|
| 401 |
+
loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0)
|
| 402 |
+
# RNG states are restored INSIDE the skip loop (at the last skipped step)
|
| 403 |
+
# to avoid DataLoader __getitem__ calls corrupting the restored state.
|
| 404 |
+
_resume_rng = {k: _resume_epoch_stats[k] for k in
|
| 405 |
+
('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda')
|
| 406 |
+
if k in _resume_epoch_stats}
|
| 407 |
+
if gpu_id == 0:
|
| 408 |
+
print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True)
|
| 409 |
+
_resume_epoch_stats = None # Only restore once
|
| 410 |
+
else:
|
| 411 |
+
loss_nan_step = 0 # only reset when NOT resuming mid-epoch
|
| 412 |
+
|
| 413 |
+
# Set model inside to train model
|
| 414 |
+
Deformddpm.train()
|
| 415 |
+
|
| 416 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 417 |
+
total_reg = total // REGISTRATION_STEP_RATIO
|
| 418 |
+
# Restore total_reg and total_contra from checkpoint if available (mid-epoch resume)
|
| 419 |
+
if total_reg_restored is not None:
|
| 420 |
+
total_reg = total_reg_restored
|
| 421 |
+
total_reg_restored = None
|
| 422 |
+
if total_contra_restored is not None:
|
| 423 |
+
total_contra = total_contra_restored
|
| 424 |
+
total_contra_restored = None
|
| 425 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 426 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 427 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 428 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 429 |
+
|
| 430 |
+
# Skip steps already completed (mid-epoch resume).
|
| 431 |
+
# Checkpoint at step N is saved AFTER step N's training completes,
|
| 432 |
+
# so step N itself must also be skipped (use <=, not <).
|
| 433 |
+
if epoch == initial_epoch and initial_step > 0 and step <= initial_step:
|
| 434 |
+
# Restore RNG at the last skipped step, AFTER DataLoader __getitem__
|
| 435 |
+
# has consumed RNG for all skipped batches. This way the first
|
| 436 |
+
# non-skipped step starts with exactly the saved RNG state.
|
| 437 |
+
if step == initial_step and _resume_rng is not None:
|
| 438 |
+
# Restore rank 0's RNG as base state, then re-seed per-rank
|
| 439 |
+
# so each rank has independent RNG (matching continuous run's
|
| 440 |
+
# divergent-per-rank behavior). Without this, all ranks would
|
| 441 |
+
# share rank 0's RNG → correlated augmentation/dropout decisions.
|
| 442 |
+
if 'rng_torch' in _resume_rng:
|
| 443 |
+
torch.set_rng_state(_resume_rng['rng_torch'])
|
| 444 |
+
if 'rng_numpy' in _resume_rng:
|
| 445 |
+
np.random.set_state(_resume_rng['rng_numpy'])
|
| 446 |
+
if 'rng_python' in _resume_rng:
|
| 447 |
+
random.setstate(_resume_rng['rng_python'])
|
| 448 |
+
if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu':
|
| 449 |
+
torch.xpu.set_rng_state(_resume_rng['rng_xpu'])
|
| 450 |
+
elif 'rng_cuda' in _resume_rng and torch.cuda.is_available():
|
| 451 |
+
torch.cuda.set_rng_state(_resume_rng['rng_cuda'])
|
| 452 |
+
# Per-rank re-seed: checkpoint only has rank 0's RNG state.
|
| 453 |
+
# Advance each rank's RNG by a deterministic offset so they
|
| 454 |
+
# diverge (as they would in a continuous run).
|
| 455 |
+
if gpu_id > 0:
|
| 456 |
+
rank_seed = gpu_id * 100003 + initial_step * 31
|
| 457 |
+
torch.manual_seed(torch.initial_seed() + rank_seed)
|
| 458 |
+
np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31))
|
| 459 |
+
random.seed(random.getrandbits(32) + rank_seed)
|
| 460 |
+
if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 461 |
+
torch.xpu.manual_seed(torch.initial_seed() + rank_seed)
|
| 462 |
+
elif torch.cuda.is_available():
|
| 463 |
+
torch.cuda.manual_seed(torch.initial_seed() + rank_seed)
|
| 464 |
+
_resume_rng = None
|
| 465 |
+
if gpu_id == 0:
|
| 466 |
+
print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True)
|
| 467 |
+
continue
|
| 468 |
+
|
| 469 |
+
# Free registration tensors from previous step
|
| 470 |
+
x1 = y1 = ddf_comp = img_rec = img_diff = None
|
| 471 |
+
ddf_rand = y1_proc = msk_tgt = img_save = None
|
| 472 |
+
loss_regist = loss_sim = loss_mse = loss_ddf1 = None
|
| 473 |
+
|
| 474 |
+
# Memory diagnostic (one per node via local rank 0) — only warn when abnormal
|
| 475 |
+
# Normal at step start: ~16 GiB reserved, ~48 GiB free (of 64 GiB total)
|
| 476 |
+
if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 477 |
+
torch.xpu.reset_peak_memory_stats(rank)
|
| 478 |
+
free_mem, total_mem_dev = torch.xpu.mem_get_info(rank)
|
| 479 |
+
used_gib = (total_mem_dev - free_mem) / 1024**3
|
| 480 |
+
if used_gib > 24: # Normal is ~16 GiB at step start; warn if accumulating
|
| 481 |
+
alloc = torch.xpu.memory_allocated() / 1024**3
|
| 482 |
+
reserved = torch.xpu.memory_reserved() / 1024**3
|
| 483 |
+
free_gib = free_mem / 1024**3
|
| 484 |
+
print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: "
|
| 485 |
+
f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), "
|
| 486 |
+
f"{free_gib:.1f} GiB free", flush=True)
|
| 487 |
+
|
| 488 |
+
# ==========================================================================
|
| 489 |
+
# diffusion train on single image
|
| 490 |
+
|
| 491 |
+
# x0 = batch # for omni dataset
|
| 492 |
+
[x0,embd] = batch # for om dataset
|
| 493 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 494 |
+
# print('embd:', embd.shape)
|
| 495 |
+
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 496 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 497 |
+
embd_in = embd_dev
|
| 498 |
+
else:
|
| 499 |
+
embd_in = None
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
n = x0.size()[0] # batch_size -> n
|
| 504 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 505 |
+
|
| 506 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 507 |
+
|
| 508 |
+
# random deformation + rotation
|
| 509 |
+
if hyp_parameters["ndims"]>2:
|
| 510 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 511 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 512 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 513 |
+
else:
|
| 514 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 515 |
+
# x0 = transformer(x0)
|
| 516 |
+
if hyp_parameters['noise_scale']>0:
|
| 517 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 518 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 519 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 520 |
+
|
| 521 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 522 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 523 |
+
hyp_parameters["device"]
|
| 524 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 525 |
+
|
| 526 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 527 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 528 |
+
# print('proc_type:', proc_type)
|
| 529 |
+
ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| 530 |
+
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
| 531 |
+
|
| 532 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
|
| 533 |
+
|
| 534 |
+
loss_tot=0
|
| 535 |
+
|
| 536 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 537 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 538 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 539 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 540 |
+
|
| 541 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 542 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 543 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 544 |
+
|
| 545 |
+
# >> JZ: print nan in x0
|
| 546 |
+
if torch.isnan(x0).any():
|
| 547 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 548 |
+
# >> JZ: print loss of ddf
|
| 549 |
+
if loss_ddf>0.001:
|
| 550 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 551 |
+
# yu: check if loss_tot==nan or inf
|
| 552 |
+
# Synchronize NaN skip across all DDP ranks to avoid collective desync
|
| 553 |
+
# Use broadcast from rank 0 instead of all_reduce to avoid CCL hang on single-node XPU
|
| 554 |
+
is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
|
| 555 |
+
if use_distributed:
|
| 556 |
+
nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
|
| 557 |
+
dist.broadcast(nan_flag, src=0)
|
| 558 |
+
is_nan = nan_flag.item() > 0
|
| 559 |
+
if is_nan:
|
| 560 |
+
if gpu_id == 0:
|
| 561 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 562 |
+
loss_nan_step += 1
|
| 563 |
+
continue
|
| 564 |
+
if loss_nan_step > 5:
|
| 565 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 566 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 567 |
+
|
| 568 |
+
# ==========================================================================
|
| 569 |
+
# Diffusion backward (no gradient clipping — diffusion dominates training)
|
| 570 |
+
if not args.eval_only:
|
| 571 |
+
optimizer.zero_grad()
|
| 572 |
+
loss_tot.backward()
|
| 573 |
+
optimizer.step()
|
| 574 |
+
|
| 575 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 576 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 577 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 578 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 579 |
+
|
| 580 |
+
# Print running average every 20 steps in eval-only mode
|
| 581 |
+
if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0:
|
| 582 |
+
n = step + 1
|
| 583 |
+
print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} "
|
| 584 |
+
f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True)
|
| 585 |
+
|
| 586 |
+
# Free diffusion intermediates and aggressively release all memory to device.
|
| 587 |
+
# XPU runtime leaks ~1.3 GiB/step outside the caching allocator.
|
| 588 |
+
# gc.collect() + synchronize() + empty_cache() attempts to reclaim deferred/lazy allocations.
|
| 589 |
+
loss_gen_a_val = loss_gen_a.item()
|
| 590 |
+
del pre_dvf_I, dvf_I, trm_pred, loss_tot, loss_gen_a, loss_gen_d, loss_ddf
|
| 591 |
+
gc.collect()
|
| 592 |
+
if DEVICE_TYPE == 'xpu':
|
| 593 |
+
torch.xpu.synchronize()
|
| 594 |
+
_empty_cache(DEVICE_TYPE)
|
| 595 |
+
|
| 596 |
+
# Sync loss_gen_a across DDP ranks for contrastive and registration gating
|
| 597 |
+
if use_distributed:
|
| 598 |
+
loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}")
|
| 599 |
+
dist.broadcast(loss_gen_a_sync, src=0)
|
| 600 |
+
loss_gen_a_gate = loss_gen_a_sync.item()
|
| 601 |
+
else:
|
| 602 |
+
loss_gen_a_gate = loss_gen_a_val
|
| 603 |
+
|
| 604 |
+
# ==========================================================================
|
| 605 |
+
# Contrastive train on single image (text-image alignment)
|
| 606 |
+
# Separate backward with gradient clipping to prevent destabilizing diffusion.
|
| 607 |
+
loss_contra_val = None
|
| 608 |
+
if step % CONTRASTIVE_STEP_RATIO == 0:
|
| 609 |
+
n_contra = x0.size()[0]
|
| 610 |
+
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
| 611 |
+
# Route through DDP wrapper and return img_embd directly so DDP
|
| 612 |
+
# traces the correct subgraph (encoder + mid + attn + img2txt).
|
| 613 |
+
img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None) # [B, 1024]
|
| 614 |
+
loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.25)
|
| 615 |
+
|
| 616 |
+
if not args.eval_only:
|
| 617 |
+
optimizer.zero_grad()
|
| 618 |
+
loss_contra.backward()
|
| 619 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=1e-3)
|
| 620 |
+
optimizer.step()
|
| 621 |
+
loss_contra_val = loss_contra.item()
|
| 622 |
+
epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
|
| 623 |
+
|
| 624 |
+
# Free remaining intermediates and aggressively release memory before registration
|
| 625 |
+
if cond_img is not None:
|
| 626 |
+
del cond_img
|
| 627 |
+
if blind_mask is not None:
|
| 628 |
+
del blind_mask
|
| 629 |
+
gc.collect()
|
| 630 |
+
if DEVICE_TYPE == 'xpu':
|
| 631 |
+
torch.xpu.synchronize()
|
| 632 |
+
_empty_cache(DEVICE_TYPE)
|
| 633 |
+
|
| 634 |
+
# ==========================================================================
|
| 635 |
+
# registration train on paired images
|
| 636 |
+
# loss_gen_a_gate already synced across DDP ranks above
|
| 637 |
+
do_regist = step % REGISTRATION_STEP_RATIO == 0 and loss_gen_a_gate < -0.8
|
| 638 |
+
if do_regist:
|
| 639 |
+
[x1, y1, _, embd_y] = batch_p
|
| 640 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 641 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 642 |
+
else:
|
| 643 |
+
embd_y = None
|
| 644 |
+
|
| 645 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 646 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 647 |
+
n = x1.size()[0] # batch_size -> n
|
| 648 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 649 |
+
if hyp_parameters['noise_scale']>0:
|
| 650 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 651 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 652 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 653 |
+
x1 = x1 * random_scale + random_shift
|
| 654 |
+
y1 = y1 * random_scale + random_shift
|
| 655 |
+
|
| 656 |
+
scale_regist = np.random.uniform(0.0,0.5)
|
| 657 |
+
select_timestep = np.random.randint(12, 32) # select a random number of timesteps to sample, between 8 and 16
|
| 658 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 659 |
+
|
| 660 |
+
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
| 661 |
+
|
| 662 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 663 |
+
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| 664 |
+
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| 665 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 666 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 667 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 668 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 669 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 670 |
+
|
| 671 |
+
loss_regist = 0
|
| 672 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 673 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 674 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 675 |
+
|
| 676 |
+
# >> JZ: print nan in x0
|
| 677 |
+
if torch.isnan(x0).any():
|
| 678 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 679 |
+
# >> JZ: print loss of ddf
|
| 680 |
+
if loss_ddf1>0.002:
|
| 681 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 682 |
+
|
| 683 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 684 |
+
if not args.eval_only:
|
| 685 |
+
optimizer.zero_grad()
|
| 686 |
+
loss_regist.backward()
|
| 687 |
+
|
| 688 |
+
# torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 689 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
|
| 690 |
+
optimizer.step()
|
| 691 |
+
|
| 692 |
+
epoch_loss_regist += loss_regist.item()
|
| 693 |
+
epoch_loss_imgsim += loss_sim.item()
|
| 694 |
+
epoch_loss_imgmse += loss_mse.item()
|
| 695 |
+
epoch_loss_ddfreg += loss_ddf1.item()
|
| 696 |
+
else:
|
| 697 |
+
loss_sim = torch.tensor(0.0)
|
| 698 |
+
loss_mse = torch.tensor(0.0)
|
| 699 |
+
loss_ddf1 = torch.tensor(0.0)
|
| 700 |
+
loss_regist = torch.tensor(0.0)
|
| 701 |
+
if step % REGISTRATION_STEP_RATIO==0:
|
| 702 |
+
total_reg = total_reg-1
|
| 703 |
+
|
| 704 |
+
# Mid-epoch checkpoint and proactive restart (only when --max-steps-before-restart > 0)
|
| 705 |
+
if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save:
|
| 706 |
+
_epoch_stats = {
|
| 707 |
+
'epoch_loss_tot': epoch_loss_tot,
|
| 708 |
+
'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 709 |
+
'epoch_loss_gen_a': epoch_loss_gen_a,
|
| 710 |
+
'epoch_loss_reg': epoch_loss_reg,
|
| 711 |
+
'epoch_loss_regist': epoch_loss_regist,
|
| 712 |
+
'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 713 |
+
'epoch_loss_imgmse': epoch_loss_imgmse,
|
| 714 |
+
'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 715 |
+
'epoch_loss_contrastive': epoch_loss_contrastive,
|
| 716 |
+
'total_reg': total_reg,
|
| 717 |
+
'total_contra': total_contra,
|
| 718 |
+
'loss_nan_step': loss_nan_step,
|
| 719 |
+
'rng_torch': torch.get_rng_state(),
|
| 720 |
+
'rng_numpy': np.random.get_state(),
|
| 721 |
+
'rng_python': random.getstate(),
|
| 722 |
+
**(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| 723 |
+
{'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| 724 |
+
}
|
| 725 |
+
tmp_dir = os.path.join(model_save_path, "tmp")
|
| 726 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 727 |
+
for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| 728 |
+
os.remove(old_f)
|
| 729 |
+
mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| 730 |
+
state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| 731 |
+
torch.save({
|
| 732 |
+
'model_state_dict': state,
|
| 733 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 734 |
+
'epoch': epoch,
|
| 735 |
+
'step': step,
|
| 736 |
+
'epoch_stats': _epoch_stats,
|
| 737 |
+
}, mid_save)
|
| 738 |
+
print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
| 739 |
+
|
| 740 |
+
# Proactive restart: exit cleanly after N steps to reset XPU memory leak.
|
| 741 |
+
# The bash wrapper will re-launch srun within the same SLURM allocation.
|
| 742 |
+
steps_since_start += 1
|
| 743 |
+
if max_steps_restart > 0 and steps_since_start >= max_steps_restart:
|
| 744 |
+
# Save checkpoint at current position (if not just saved above)
|
| 745 |
+
if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save:
|
| 746 |
+
_epoch_stats = {
|
| 747 |
+
'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 748 |
+
'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg,
|
| 749 |
+
'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 750 |
+
'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 751 |
+
'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra,
|
| 752 |
+
'loss_nan_step': loss_nan_step,
|
| 753 |
+
'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(),
|
| 754 |
+
'rng_python': random.getstate(),
|
| 755 |
+
**(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| 756 |
+
{'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| 757 |
+
}
|
| 758 |
+
tmp_dir = os.path.join(model_save_path, "tmp")
|
| 759 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 760 |
+
for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| 761 |
+
os.remove(old_f)
|
| 762 |
+
mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| 763 |
+
state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| 764 |
+
torch.save({
|
| 765 |
+
'model_state_dict': state,
|
| 766 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 767 |
+
'epoch': epoch,
|
| 768 |
+
'step': step,
|
| 769 |
+
'epoch_stats': _epoch_stats,
|
| 770 |
+
}, mid_save)
|
| 771 |
+
print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
| 772 |
+
if gpu_id == 0:
|
| 773 |
+
print(f" [restart] Proactive restart after {steps_since_start} steps "
|
| 774 |
+
f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True)
|
| 775 |
+
# Clean shutdown
|
| 776 |
+
_empty_cache(DEVICE_TYPE)
|
| 777 |
+
gc.collect()
|
| 778 |
+
if use_distributed and dist.is_initialized():
|
| 779 |
+
dist.barrier()
|
| 780 |
+
dist.destroy_process_group()
|
| 781 |
+
sys.exit(EXIT_CODE_RESTART)
|
| 782 |
+
|
| 783 |
+
if gpu_id == 0:
|
| 784 |
+
print('==================')
|
| 785 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 786 |
+
print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| 787 |
+
total_reg_safe = max(total_reg, 1)
|
| 788 |
+
print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)')
|
| 789 |
+
print('==================')
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
if 0 == epoch % epoch_per_save and not args.no_save:
|
| 793 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 794 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 795 |
+
# break # FOR TESTING
|
| 796 |
+
if not use_distributed:
|
| 797 |
+
print(f"saved in {save_dir}")
|
| 798 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 799 |
+
torch.save({
|
| 800 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 801 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 802 |
+
'epoch': epoch
|
| 803 |
+
}, save_dir)
|
| 804 |
+
elif gpu_id == 0:
|
| 805 |
+
print(f"saved in {save_dir}")
|
| 806 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 807 |
+
torch.save({
|
| 808 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 809 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 810 |
+
'epoch': epoch
|
| 811 |
+
}, save_dir)
|
| 812 |
+
# Clean up tmp/ mid-epoch checkpoints after completed epoch
|
| 813 |
+
if gpu_id == 0 and not args.no_save:
|
| 814 |
+
tmp_dir = os.path.join(model_dir, "tmp")
|
| 815 |
+
tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth"))
|
| 816 |
+
if tmp_pths:
|
| 817 |
+
for f in tmp_pths:
|
| 818 |
+
os.remove(f)
|
| 819 |
+
print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True)
|
| 820 |
+
# Reset initial_step after first epoch completes (no more skipping)
|
| 821 |
+
initial_step = 0
|
| 822 |
+
|
| 823 |
+
# XPU CCL workaround: restart after each epoch to avoid CCL hang on 2nd epoch.
|
| 824 |
+
# CCL's Level Zero IPC handles accumulate and cause deadlock after ~200+ collectives.
|
| 825 |
+
# A fresh process resets the L0 context. The bash loop catches exit code 42 and restarts.
|
| 826 |
+
if DEVICE_TYPE == 'xpu' and use_distributed:
|
| 827 |
+
if gpu_id == 0:
|
| 828 |
+
print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True)
|
| 829 |
+
_empty_cache(DEVICE_TYPE)
|
| 830 |
+
gc.collect()
|
| 831 |
+
if dist.is_initialized():
|
| 832 |
+
dist.barrier()
|
| 833 |
+
dist.destroy_process_group()
|
| 834 |
+
sys.exit(EXIT_CODE_RESTART)
|
| 835 |
+
|
| 836 |
+
# Resource cleanup at the end of training
|
| 837 |
+
_empty_cache(DEVICE_TYPE)
|
| 838 |
+
gc.collect()
|
| 839 |
+
if use_distributed and dist.is_initialized():
|
| 840 |
+
dist.destroy_process_group()
|
| 841 |
+
|
| 842 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 843 |
+
|
| 844 |
+
# All ranks load checkpoint so optimizer state is consistent across DDP processes.
|
| 845 |
+
# (Optimizer state includes per-parameter Adam momentum/variance which are NOT
|
| 846 |
+
# broadcast — only model weights are broadcast. Without this, non-rank-0 processes
|
| 847 |
+
# would have fresh Adam state after restart.)
|
| 848 |
+
gc.collect()
|
| 849 |
+
_empty_cache(DEVICE_TYPE)
|
| 850 |
+
if gpu_id == 0:
|
| 851 |
+
utils.print_memory_usage("Before Loading Model")
|
| 852 |
+
checkpoint = torch.load(model_file, map_location='cpu', weights_only=False)
|
| 853 |
+
if use_distributed:
|
| 854 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 855 |
+
else:
|
| 856 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 857 |
+
# Restore optimizer state when available (needed for mid-epoch resume).
|
| 858 |
+
# Selective loading: load states for parameters with matching shapes, skip mismatched ones
|
| 859 |
+
# (e.g., UpsampleConv replaced ConvTranspose3d — different kernel shapes).
|
| 860 |
+
# After one epoch, the saved checkpoint will have correct state for ALL parameters.
|
| 861 |
+
if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer:
|
| 862 |
+
saved_opt = checkpoint['optimizer_state_dict']
|
| 863 |
+
saved_state = saved_opt.get('state', {})
|
| 864 |
+
param_list = [p for group in optimizer.param_groups for p in group['params']]
|
| 865 |
+
|
| 866 |
+
# Check if all shapes match (fast path: full load)
|
| 867 |
+
all_match = True
|
| 868 |
+
skipped = 0
|
| 869 |
+
for idx, s in saved_state.items():
|
| 870 |
+
if int(idx) < len(param_list):
|
| 871 |
+
p = param_list[int(idx)]
|
| 872 |
+
for k, v in s.items():
|
| 873 |
+
if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape:
|
| 874 |
+
all_match = False
|
| 875 |
+
break
|
| 876 |
+
if not all_match:
|
| 877 |
+
break
|
| 878 |
+
|
| 879 |
+
if all_match:
|
| 880 |
+
optimizer.load_state_dict(saved_opt)
|
| 881 |
+
else:
|
| 882 |
+
# Selective load: restore param_groups settings (lr, betas, etc.)
|
| 883 |
+
for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups):
|
| 884 |
+
for k, v in saved_g.items():
|
| 885 |
+
if k != 'params':
|
| 886 |
+
group[k] = v
|
| 887 |
+
# Restore per-parameter state only where shapes match
|
| 888 |
+
for idx, s in saved_state.items():
|
| 889 |
+
idx_int = int(idx)
|
| 890 |
+
if idx_int < len(param_list):
|
| 891 |
+
p = param_list[idx_int]
|
| 892 |
+
shapes_ok = all(
|
| 893 |
+
v.shape == p.shape for k, v in s.items()
|
| 894 |
+
if isinstance(v, torch.Tensor) and v.dim() > 0
|
| 895 |
+
)
|
| 896 |
+
if shapes_ok:
|
| 897 |
+
# Cast state tensors to match parameter dtype/device
|
| 898 |
+
new_state = {}
|
| 899 |
+
for k, v in s.items():
|
| 900 |
+
if isinstance(v, torch.Tensor):
|
| 901 |
+
new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v
|
| 902 |
+
else:
|
| 903 |
+
new_state[k] = v
|
| 904 |
+
optimizer.state[p] = new_state
|
| 905 |
+
else:
|
| 906 |
+
skipped += 1
|
| 907 |
+
if gpu_id == 0:
|
| 908 |
+
loaded = len(saved_state) - skipped
|
| 909 |
+
print(f" [checkpoint] Selective optimizer load: {loaded} params restored, "
|
| 910 |
+
f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True)
|
| 911 |
+
elif args.reset_optimizer and gpu_id == 0:
|
| 912 |
+
print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True)
|
| 913 |
+
del checkpoint
|
| 914 |
+
if gpu_id == 0:
|
| 915 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 916 |
+
|
| 917 |
+
if use_distributed:
|
| 918 |
+
# Broadcast model weights from rank 0 to ensure exact consistency
|
| 919 |
+
dist.barrier()
|
| 920 |
+
for param in Deformddpm.parameters():
|
| 921 |
+
dist.broadcast(param.data, src=0)
|
| 922 |
+
|
| 923 |
+
# get the epoch number from the filename
|
| 924 |
+
basename = os.path.basename(model_file)
|
| 925 |
+
epoch_from_file = int(basename[:6])
|
| 926 |
+
if '_step' in basename:
|
| 927 |
+
# Mid-epoch checkpoint: resume at same epoch (don't +1)
|
| 928 |
+
initial_epoch = epoch_from_file
|
| 929 |
+
else:
|
| 930 |
+
# End-of-epoch checkpoint: start next epoch
|
| 931 |
+
initial_epoch = epoch_from_file + 1
|
| 932 |
+
|
| 933 |
+
return initial_epoch, Deformddpm, optimizer
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
if __name__ == "__main__":
|
| 938 |
+
if "LOCAL_RANK" in os.environ:
|
| 939 |
+
# Multi-node: launched by torchrun / srun
|
| 940 |
+
use_distributed = True
|
| 941 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 942 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 943 |
+
print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
|
| 944 |
+
try:
|
| 945 |
+
main_train(local_rank, world_size)
|
| 946 |
+
except Exception as e:
|
| 947 |
+
import traceback
|
| 948 |
+
print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
|
| 949 |
+
traceback.print_exc()
|
| 950 |
+
raise
|
| 951 |
+
elif use_distributed:
|
| 952 |
+
# Single-node multi-GPU: use mp.spawn
|
| 953 |
+
world_size = _device_count(DEVICE_TYPE)
|
| 954 |
+
print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
|
| 955 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 956 |
+
else:
|
| 957 |
+
main_train(0,1)
|
OM_train_3modes.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import torchvision
|
|
@@ -9,21 +13,32 @@ from torch.utils.data import DataLoader
|
|
| 9 |
from torch.optim import Adam, SGD
|
| 10 |
from Diffusion.diffuser import DeformDDPM
|
| 11 |
from Diffusion.networks import get_net, STN
|
| 12 |
-
from torchvision.transforms import Lambda
|
|
|
|
| 13 |
import Diffusion.losses as losses
|
| 14 |
import random
|
| 15 |
import glob
|
| 16 |
import numpy as np
|
| 17 |
import utils
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
|
| 20 |
-
from Dataloader.dataloader0 import get_dataloader
|
| 21 |
from Dataloader.dataLoader import *
|
| 22 |
|
| 23 |
from Dataloader.dataloader_utils import thresh_img
|
| 24 |
import yaml
|
| 25 |
import argparse
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
####################
|
| 28 |
import torch.multiprocessing as mp
|
| 29 |
from torch.utils.data.distributed import DistributedSampler
|
|
@@ -31,27 +46,66 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|
| 31 |
import torch.distributed as dist
|
| 32 |
# from torch.distributed import init_process_group
|
| 33 |
###############
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def ddp_setup(rank, world_size):
|
| 35 |
"""
|
| 36 |
Args:
|
| 37 |
-
rank: Unique identifier of each process
|
| 38 |
world_size: Total number of processes
|
| 39 |
"""
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
EPS = 1e-5
|
| 49 |
MSK_EPS = 0.01
|
| 50 |
-
TEXT_EMBED_PROB = 0.
|
| 51 |
-
AUG_RESAMPLE_PROB = 0.
|
| 52 |
-
LOSS_WEIGHTS_DIFF = [
|
| 53 |
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 54 |
-
LOSS_WEIGHTS_REGIST = [1.0, 0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# AUG_PERMUTE_PROB = 0.35
|
| 57 |
|
|
@@ -68,23 +122,73 @@ parser.add_argument(
|
|
| 68 |
default="Config/config_all.yaml",
|
| 69 |
required=False,
|
| 70 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
#=======================================================================================================================
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 77 |
if use_distributed:
|
| 78 |
ddp_setup(rank,world_size)
|
| 79 |
|
| 80 |
-
if torch.distributed.is_initialized():
|
| 81 |
print(f"World size: {torch.distributed.get_world_size()}")
|
| 82 |
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# Load the YAML file into a dictionary
|
| 86 |
with open(args.config, 'r') as file:
|
| 87 |
hyp_parameters = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
| 88 |
print(hyp_parameters)
|
| 89 |
|
| 90 |
# epoch_per_save=10
|
|
@@ -98,7 +202,7 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 98 |
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 99 |
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 100 |
model_dir=model_save_path
|
| 101 |
-
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 102 |
|
| 103 |
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 104 |
|
|
@@ -115,33 +219,54 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 115 |
# drop_last=True,
|
| 116 |
# )
|
| 117 |
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
train_loader = DataLoader(
|
| 121 |
dataset,
|
| 122 |
batch_size=hyp_parameters['batchsize'],
|
| 123 |
-
shuffle=
|
| 124 |
drop_last=True,
|
|
|
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
# datasetp = OminiDataset_paired(transform=None)
|
| 128 |
-
datasetp = OMDataset_pair(transform=None)
|
| 129 |
train_loader_p = DataLoader(
|
| 130 |
datasetp,
|
| 131 |
-
batch_size=hyp_parameters['batchsize']//
|
| 132 |
-
shuffle=
|
| 133 |
drop_last=True,
|
|
|
|
| 134 |
)
|
| 135 |
|
| 136 |
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
Deformddpm = DeformDDPM(
|
| 139 |
-
network=
|
| 140 |
-
n_steps=hyp_parameters["timesteps"],
|
| 141 |
-
ndims=hyp_parameters["ndims"],
|
| 142 |
-
num_input_chn = hyp_parameters["num_input_chn"],
|
| 143 |
-
res = hyp_parameters['img_size']
|
| 144 |
-
),
|
| 145 |
n_steps=hyp_parameters["timesteps"],
|
| 146 |
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 147 |
device=hyp_parameters["device"],
|
|
@@ -161,9 +286,18 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 161 |
|
| 162 |
|
| 163 |
if use_distributed:
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
else:
|
| 168 |
Deformddpm.to(hyp_parameters["device"])
|
| 169 |
ddf_stn.to(hyp_parameters["device"])
|
|
@@ -172,12 +306,14 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 172 |
|
| 173 |
# mse = nn.MSELoss()
|
| 174 |
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 175 |
-
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"]
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 178 |
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 179 |
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 180 |
-
loss_imgsim = losses.
|
| 181 |
loss_imgmse = losses.LMSE()
|
| 182 |
|
| 183 |
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
@@ -194,19 +330,51 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 194 |
# check for existing models
|
| 195 |
if not os.path.exists(model_dir):
|
| 196 |
os.makedirs(model_dir, exist_ok=True)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if gpu_id == 0:
|
| 201 |
print(model_files)
|
| 202 |
-
|
|
|
|
| 203 |
else:
|
| 204 |
initial_epoch = 0
|
| 205 |
|
| 206 |
if gpu_id == 0:
|
| 207 |
print('len_train_data: ',len(dataset))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
# Training loop
|
| 209 |
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
epoch_loss_tot = 0.0
|
| 212 |
epoch_loss_gen_d = 0.0
|
|
@@ -216,17 +384,110 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 216 |
epoch_loss_imgsim = 0.0
|
| 217 |
epoch_loss_imgmse = 0.0
|
| 218 |
epoch_loss_ddfreg = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
# Set model inside to train model
|
| 220 |
Deformddpm.train()
|
| 221 |
-
|
| 222 |
-
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 223 |
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
# for step, batch in tqdm(enumerate(train_loader)):
|
| 226 |
-
|
| 227 |
# for step, batch in enumerate(train_loader_omni):
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# ==========================================================================
|
| 232 |
# diffusion train on single image
|
|
@@ -235,12 +496,11 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 235 |
[x0,embd] = batch # for om dataset
|
| 236 |
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 237 |
# print('embd:', embd.shape)
|
|
|
|
| 238 |
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 239 |
-
|
| 240 |
else:
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
|
| 245 |
n = x0.size()[0] # batch_size -> n
|
| 246 |
x0 = x0.to(hyp_parameters["device"])
|
|
@@ -254,10 +514,10 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 254 |
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 255 |
else:
|
| 256 |
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 257 |
-
x0 = transformer(x0)
|
| 258 |
if hyp_parameters['noise_scale']>0:
|
| 259 |
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 260 |
-
x0 = thresh_img(x0, [0,
|
| 261 |
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 262 |
|
| 263 |
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
|
@@ -266,157 +526,301 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 266 |
) # pick up a seq of rand number from 0 to 'timestep'
|
| 267 |
|
| 268 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 269 |
-
proc_type = random.choice(['adding', '
|
| 270 |
# print('proc_type:', proc_type)
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd) # forward diffusion process
|
| 274 |
-
|
| 275 |
-
loss_tot=0
|
| 276 |
-
|
| 277 |
-
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 278 |
-
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 279 |
-
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 280 |
-
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 281 |
-
|
| 282 |
-
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 283 |
-
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 284 |
-
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 285 |
-
|
| 286 |
-
# >> JZ: print nan in x0
|
| 287 |
-
if torch.isnan(x0).any():
|
| 288 |
-
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 289 |
-
# >> JZ: print loss of ddf
|
| 290 |
-
if loss_ddf>0.001:
|
| 291 |
-
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 292 |
-
# yu: check if loss_tot==nan or inf
|
| 293 |
-
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 294 |
-
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 295 |
-
loss_nan_step += 1
|
| 296 |
-
continue
|
| 297 |
-
if loss_nan_step > 5:
|
| 298 |
-
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 299 |
-
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
|
| 311 |
-
# print(loss_gen_a.item())
|
| 312 |
-
# if 0:
|
| 313 |
-
# if loss_gen_a.item() < -0.3 and step%train_mode_ratio == 0:
|
| 314 |
-
if step%train_mode_ratio == 0:
|
| 315 |
# ==========================================================================
|
| 316 |
-
#
|
| 317 |
-
#
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 320 |
-
# embd_x = embd_x.to(hyp_parameters["device"]).type(torch.float32)
|
| 321 |
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 322 |
else:
|
| 323 |
-
# embd_x = None
|
| 324 |
embd_y = None
|
| 325 |
|
| 326 |
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 327 |
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 328 |
n = x1.size()[0] # batch_size -> n
|
| 329 |
-
# random deformation + rotation
|
| 330 |
-
# if hyp_parameters["ndims"]>2:
|
| 331 |
-
# if np.random.uniform(0,1)<0.6:
|
| 332 |
-
# x1 = utils.random_resample(x1, deform_scale=0)
|
| 333 |
-
# y1 = utils.random_resample(y1, deform_scale=0)
|
| 334 |
-
x1 = transformer(x1)
|
| 335 |
-
y1 = transformer(y1)
|
| 336 |
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 337 |
if hyp_parameters['noise_scale']>0:
|
| 338 |
-
x1 = thresh_img(x1, [0, 2*hyp_parameters['noise_scale']])
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
) #
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
T_regist = [[t for _ in range(hyp_parameters["batchsize"]//2)] for t in T_regist]
|
| 357 |
-
|
| 358 |
-
# print('T_regist:', T_regist)
|
| 359 |
-
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'none'])
|
| 360 |
-
proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'none', 'none'])
|
| 361 |
-
# proc_type = random.choice(['project'])
|
| 362 |
-
y1, msk_tgt, cond_ratio = Deformddpm.module.proc_cond_img(y1,proc_type=proc_type)
|
| 363 |
-
msk_tgt = msk_tgt + MSK_EPS
|
| 364 |
-
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 365 |
-
loss_ddf1 = loss_reg1(ddf_comp,img=y1,msk=msk_tgt) # calculate loss for the registration process
|
| 366 |
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 367 |
-
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>0.0)) # calculate loss for the registration process
|
|
|
|
| 368 |
|
| 369 |
loss_regist = 0
|
| 370 |
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 371 |
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 372 |
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 373 |
-
# print('proc_type:', proc_type, 'cond_ratio:', cond_ratio.item())
|
| 374 |
-
# print('loss_regist:', loss_regist.item(), 'loss_sim:', loss_sim.item(), 'loss_ddf1:', loss_ddf1.item())
|
| 375 |
|
| 376 |
# >> JZ: print nan in x0
|
| 377 |
if torch.isnan(x0).any():
|
| 378 |
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 379 |
# >> JZ: print loss of ddf
|
| 380 |
-
if loss_ddf1>0.
|
| 381 |
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 382 |
-
|
| 383 |
-
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 384 |
-
optimizer.zero_grad()
|
| 385 |
-
loss_regist.backward()
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
# else:
|
| 392 |
-
# print(f"Gradient for {name}: None")
|
| 393 |
-
|
| 394 |
-
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 395 |
-
optimizer.step()
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
epoch_loss_imgmse += loss_mse.item() * len(x0) / len(train_loader.dataset)
|
| 400 |
-
epoch_loss_ddfreg += loss_ddf1.item() * len(x0) / len(train_loader.dataset)
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
-
if
|
| 412 |
-
|
| 413 |
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 414 |
-
print(f'
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
-
# # LR schedular step ----- YHM
|
| 417 |
-
# scheduler.step()
|
| 418 |
|
| 419 |
-
if 0 == epoch % epoch_per_save:
|
| 420 |
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 421 |
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 422 |
# break # FOR TESTING
|
|
@@ -436,55 +840,150 @@ def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
|
| 436 |
'optimizer_state_dict': optimizer.state_dict(),
|
| 437 |
'epoch': epoch
|
| 438 |
}, save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
|
| 440 |
# Resource cleanup at the end of training
|
| 441 |
-
|
| 442 |
gc.collect()
|
| 443 |
if use_distributed and dist.is_initialized():
|
| 444 |
dist.destroy_process_group()
|
| 445 |
|
| 446 |
-
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True):
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
if gpu_id == 0:
|
| 449 |
-
# if 0:
|
| 450 |
utils.print_memory_usage("Before Loading Model")
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
checkpoint =
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
else:
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 464 |
|
| 465 |
if use_distributed:
|
| 466 |
-
# Broadcast model weights from rank 0 to
|
| 467 |
dist.barrier()
|
| 468 |
for param in Deformddpm.parameters():
|
| 469 |
-
dist.broadcast(param.data, src=0)
|
| 470 |
-
dist.barrier()
|
| 471 |
-
for param_group in optimizer.param_groups:
|
| 472 |
-
for param in param_group['params']:
|
| 473 |
-
if param.grad is not None:
|
| 474 |
-
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 475 |
|
| 476 |
-
#
|
| 477 |
-
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
return initial_epoch, Deformddpm, optimizer
|
| 481 |
|
| 482 |
|
| 483 |
|
| 484 |
if __name__ == "__main__":
|
| 485 |
-
if
|
| 486 |
-
|
| 487 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 489 |
else:
|
| 490 |
main_train(0,1)
|
|
|
|
| 1 |
+
import os, sys, contextlib
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
import gc
|
| 7 |
import torch
|
| 8 |
import torchvision
|
|
|
|
| 13 |
from torch.optim import Adam, SGD
|
| 14 |
from Diffusion.diffuser import DeformDDPM
|
| 15 |
from Diffusion.networks import get_net, STN
|
| 16 |
+
# from torchvision.transforms import Lambda
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
import Diffusion.losses as losses
|
| 19 |
import random
|
| 20 |
import glob
|
| 21 |
import numpy as np
|
| 22 |
import utils
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
|
| 25 |
+
# from Dataloader.dataloader0 import get_dataloader
|
| 26 |
from Dataloader.dataLoader import *
|
| 27 |
|
| 28 |
from Dataloader.dataloader_utils import thresh_img
|
| 29 |
import yaml
|
| 30 |
import argparse
|
| 31 |
|
| 32 |
+
# XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
|
| 33 |
+
try:
|
| 34 |
+
import intel_extension_for_pytorch as ipex
|
| 35 |
+
except ImportError:
|
| 36 |
+
ipex = None
|
| 37 |
+
try:
|
| 38 |
+
import oneccl_bindings_for_pytorch
|
| 39 |
+
except (ImportError, Exception) as e:
|
| 40 |
+
print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
|
| 41 |
+
|
| 42 |
####################
|
| 43 |
import torch.multiprocessing as mp
|
| 44 |
from torch.utils.data.distributed import DistributedSampler
|
|
|
|
| 46 |
import torch.distributed as dist
|
| 47 |
# from torch.distributed import init_process_group
|
| 48 |
###############
|
| 49 |
+
def _device_available(device_type):
|
| 50 |
+
if device_type == 'xpu':
|
| 51 |
+
return hasattr(torch, 'xpu') and torch.xpu.is_available()
|
| 52 |
+
return torch.cuda.is_available()
|
| 53 |
+
|
| 54 |
+
def _device_count(device_type):
|
| 55 |
+
if device_type == 'xpu':
|
| 56 |
+
return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
|
| 57 |
+
return torch.cuda.device_count()
|
| 58 |
+
|
| 59 |
+
def _set_device(rank, device_type):
|
| 60 |
+
if device_type == 'xpu':
|
| 61 |
+
torch.xpu.set_device(rank)
|
| 62 |
+
else:
|
| 63 |
+
torch.cuda.set_device(rank)
|
| 64 |
+
|
| 65 |
+
def _empty_cache(device_type):
|
| 66 |
+
if device_type == 'xpu' and hasattr(torch, 'xpu'):
|
| 67 |
+
torch.xpu.empty_cache()
|
| 68 |
+
elif torch.cuda.is_available():
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
|
| 71 |
def ddp_setup(rank, world_size):
|
| 72 |
"""
|
| 73 |
Args:
|
| 74 |
+
rank: Unique identifier of each process (local_rank when launched by torchrun)
|
| 75 |
world_size: Total number of processes
|
| 76 |
"""
|
| 77 |
+
backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
|
| 78 |
+
if "LOCAL_RANK" in os.environ:
|
| 79 |
+
# Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
|
| 80 |
+
dist.init_process_group(backend=backend)
|
| 81 |
+
_set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
|
| 82 |
+
else:
|
| 83 |
+
# Single-node mp.spawn
|
| 84 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 85 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 86 |
+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| 87 |
+
_set_device(rank, DEVICE_TYPE)
|
| 88 |
|
| 89 |
EPS = 1e-5
|
| 90 |
MSK_EPS = 0.01
|
| 91 |
+
TEXT_EMBED_PROB = 0.5
|
| 92 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 93 |
+
LOSS_WEIGHTS_DIFF = [4.0, 2.0, 8.0] # [ang, dist, reg]
|
| 94 |
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 95 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
|
| 96 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 97 |
+
# LOSS_WEIGHT_CONTRASTIVE = 1e-4
|
| 98 |
+
LOSS_WEIGHT_CONTRASTIVE = 1e-1
|
| 99 |
+
REGISTRATION_STEP_RATIO = 1
|
| 100 |
+
CONTRASTIVE_STEP_RATIO = 1
|
| 101 |
+
ACCEPT_THRESH_CONTRASTIVE = 0.1
|
| 102 |
+
ACCEPT_THRESH_ANGLE = -0.8
|
| 103 |
+
MID_EPOCH_SAVE_STEPS = 1e4 # Save mid-epoch checkpoint every N steps for crash recovery.
|
| 104 |
+
# XPU autograd leaks ~1.0 GiB/step of device memory (Intel bug).
|
| 105 |
+
# With gradient checkpointing, training survives ~26 steps from fresh start,
|
| 106 |
+
# but fewer when carrying leaked memory from previous epoch.
|
| 107 |
+
# Save every 10 steps to minimize lost work on OOM crash.
|
| 108 |
+
EXIT_CODE_RESTART = 42 # Exit code signaling proactive restart (not a crash).
|
| 109 |
|
| 110 |
# AUG_PERMUTE_PROB = 0.35
|
| 111 |
|
|
|
|
| 122 |
default="Config/config_all.yaml",
|
| 123 |
required=False,
|
| 124 |
)
|
| 125 |
+
parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
|
| 126 |
+
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| 127 |
+
parser.add_argument("--max-steps-before-restart", type=int, default=0,
|
| 128 |
+
help="Proactive restart: exit after N training steps to reset XPU memory leak. "
|
| 129 |
+
"0=disabled (rely on OOM crash + auto-resubmit). "
|
| 130 |
+
"Recommended: 20 for XPU (survives ~26 steps max).")
|
| 131 |
+
parser.add_argument("--no-save", action="store_true", default=False,
|
| 132 |
+
help="Disable all checkpoint saving (for diagnostic/validation runs)")
|
| 133 |
+
parser.add_argument("--reset-optimizer", action="store_true",
|
| 134 |
+
help="Skip optimizer state loading from checkpoint (use when architecture changed)")
|
| 135 |
+
parser.add_argument("--eval-only", action="store_true",
|
| 136 |
+
help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)")
|
| 137 |
args = parser.parse_args()
|
| 138 |
+
|
| 139 |
+
# Read config early to determine device type for DDP setup
|
| 140 |
+
with open(args.config, 'r') as _f:
|
| 141 |
+
_cfg = yaml.safe_load(_f)
|
| 142 |
+
DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
|
| 143 |
+
|
| 144 |
+
# Auto-detect: use DDP only when multiple devices are available
|
| 145 |
+
use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
|
| 146 |
+
# use_distributed = True
|
| 147 |
+
# use_distributed = False
|
| 148 |
#=======================================================================================================================
|
| 149 |
|
| 150 |
+
class _DummyIndiv(torch.utils.data.Dataset):
|
| 151 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 152 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 153 |
+
def __len__(self): return self.n
|
| 154 |
+
def __getitem__(self, i):
|
| 155 |
+
return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
|
| 156 |
+
|
| 157 |
+
class _DummyPair(torch.utils.data.Dataset):
|
| 158 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 159 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 160 |
+
def __len__(self): return self.n
|
| 161 |
+
def __getitem__(self, i):
|
| 162 |
+
return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 163 |
+
np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 164 |
+
np.random.randn(self.embd_dim).astype(np.float32),
|
| 165 |
+
np.random.randn(self.embd_dim).astype(np.float32))
|
| 166 |
|
| 167 |
|
| 168 |
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 169 |
if use_distributed:
|
| 170 |
ddp_setup(rank,world_size)
|
| 171 |
|
| 172 |
+
if torch.distributed.is_initialized() and rank == 0:
|
| 173 |
print(f"World size: {torch.distributed.get_world_size()}")
|
| 174 |
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 175 |
+
print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}")
|
| 176 |
+
if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 177 |
+
props = torch.xpu.get_device_properties(0)
|
| 178 |
+
print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB")
|
| 179 |
+
# gpu_id = global rank (for save/print guards); rank = local device index
|
| 180 |
+
if "RANK" in os.environ:
|
| 181 |
+
gpu_id = int(os.environ["RANK"])
|
| 182 |
+
rank = int(os.environ["LOCAL_RANK"])
|
| 183 |
+
else:
|
| 184 |
+
gpu_id = rank
|
| 185 |
|
| 186 |
# Load the YAML file into a dictionary
|
| 187 |
with open(args.config, 'r') as file:
|
| 188 |
hyp_parameters = yaml.safe_load(file)
|
| 189 |
+
if args.batchsize > 0:
|
| 190 |
+
hyp_parameters['batchsize'] = args.batchsize
|
| 191 |
+
if gpu_id == 0:
|
| 192 |
print(hyp_parameters)
|
| 193 |
|
| 194 |
# epoch_per_save=10
|
|
|
|
| 202 |
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 203 |
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 204 |
model_dir=model_save_path
|
| 205 |
+
# transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 206 |
|
| 207 |
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 208 |
|
|
|
|
| 219 |
# drop_last=True,
|
| 220 |
# )
|
| 221 |
|
| 222 |
+
if args.dummy_samples > 0:
|
| 223 |
+
dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
|
| 224 |
+
datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
|
| 225 |
+
else:
|
| 226 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 227 |
+
dataset = OMDataset_indiv(transform=None)
|
| 228 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 229 |
+
datasetp = OMDataset_pair(transform=None)
|
| 230 |
+
|
| 231 |
+
if use_distributed:
|
| 232 |
+
sampler = DistributedSampler(dataset, shuffle=True)
|
| 233 |
+
sampler_p = DistributedSampler(datasetp, shuffle=True)
|
| 234 |
+
else:
|
| 235 |
+
sampler = None
|
| 236 |
+
sampler_p = None
|
| 237 |
+
|
| 238 |
train_loader = DataLoader(
|
| 239 |
dataset,
|
| 240 |
batch_size=hyp_parameters['batchsize'],
|
| 241 |
+
shuffle=(sampler is None),
|
| 242 |
drop_last=True,
|
| 243 |
+
sampler=sampler,
|
| 244 |
)
|
|
|
|
|
|
|
|
|
|
| 245 |
train_loader_p = DataLoader(
|
| 246 |
datasetp,
|
| 247 |
+
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| 248 |
+
shuffle=(sampler_p is None),
|
| 249 |
drop_last=True,
|
| 250 |
+
sampler=sampler_p,
|
| 251 |
)
|
| 252 |
|
| 253 |
|
| 254 |
|
| 255 |
+
network = Net(
|
| 256 |
+
n_steps=hyp_parameters["timesteps"],
|
| 257 |
+
ndims=hyp_parameters["ndims"],
|
| 258 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 259 |
+
res = hyp_parameters['img_size']
|
| 260 |
+
)
|
| 261 |
+
# Enable gradient checkpointing on XPU to reduce peak activation memory.
|
| 262 |
+
# XPU autograd leaks ~1.0 GiB/step; lower peak buys more steps before OOM.
|
| 263 |
+
if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'):
|
| 264 |
+
network.use_checkpoint = True
|
| 265 |
+
if gpu_id == 0:
|
| 266 |
+
print(" [init] Gradient checkpointing enabled for XPU", flush=True)
|
| 267 |
+
|
| 268 |
Deformddpm = DeformDDPM(
|
| 269 |
+
network=network,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
n_steps=hyp_parameters["timesteps"],
|
| 271 |
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 272 |
device=hyp_parameters["device"],
|
|
|
|
| 286 |
|
| 287 |
|
| 288 |
if use_distributed:
|
| 289 |
+
device = f"{DEVICE_TYPE}:{rank}"
|
| 290 |
+
# NO pre-allocation. CCL/oneDNN accumulate ~1.4 GiB/step of device memory outside
|
| 291 |
+
# PyTorch's caching allocator. Pre-allocating steals from that budget:
|
| 292 |
+
# 92% pre-alloc → crash at step 3, 78% → step 10, none (70% cap) → step 14.
|
| 293 |
+
# Instead, use empty_cache() between training phases to release unused cached memory
|
| 294 |
+
# back to the device for CCL/oneDNN.
|
| 295 |
+
if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 296 |
+
total_mem = torch.xpu.get_device_properties(rank).total_memory
|
| 297 |
+
print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True)
|
| 298 |
+
Deformddpm.to(device)
|
| 299 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
|
| 300 |
+
ddf_stn.to(device)
|
| 301 |
else:
|
| 302 |
Deformddpm.to(hyp_parameters["device"])
|
| 303 |
ddf_stn.to(hyp_parameters["device"])
|
|
|
|
| 306 |
|
| 307 |
# mse = nn.MSELoss()
|
| 308 |
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 309 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 310 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 311 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 312 |
+
|
| 313 |
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 314 |
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 315 |
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 316 |
+
loss_imgsim = losses.MSLNCC()
|
| 317 |
loss_imgmse = losses.LMSE()
|
| 318 |
|
| 319 |
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
|
|
| 330 |
# check for existing models
|
| 331 |
if not os.path.exists(model_dir):
|
| 332 |
os.makedirs(model_dir, exist_ok=True)
|
| 333 |
+
# Check for checkpoints: first check tmp/ for mid-epoch, then main dir for epoch-level
|
| 334 |
+
tmp_dir = os.path.join(model_dir, "tmp")
|
| 335 |
+
tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth")))
|
| 336 |
+
model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
|
| 337 |
+
initial_step = 0
|
| 338 |
+
|
| 339 |
+
# Epoch stats and RNG states to restore when resuming from mid-epoch checkpoint
|
| 340 |
+
_resume_epoch_stats = None
|
| 341 |
+
_resume_rng = None
|
| 342 |
+
|
| 343 |
+
if tmp_files and not args.eval_only and args.max_steps_before_restart > 0:
|
| 344 |
+
# Mid-epoch checkpoint: only use when proactive restart is enabled
|
| 345 |
+
latest = tmp_files[-1]
|
| 346 |
+
if gpu_id == 0:
|
| 347 |
+
print(f" [resume] Found mid-epoch checkpoint: {latest}")
|
| 348 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| 349 |
+
basename = os.path.basename(latest)
|
| 350 |
+
initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0])
|
| 351 |
+
_ckpt = torch.load(latest, map_location='cpu', weights_only=False)
|
| 352 |
+
_resume_epoch_stats = _ckpt.get('epoch_stats', None)
|
| 353 |
+
del _ckpt
|
| 354 |
+
if gpu_id == 0:
|
| 355 |
+
print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}"
|
| 356 |
+
f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True)
|
| 357 |
+
elif model_files:
|
| 358 |
if gpu_id == 0:
|
| 359 |
print(model_files)
|
| 360 |
+
latest = model_files[-1]
|
| 361 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| 362 |
else:
|
| 363 |
initial_epoch = 0
|
| 364 |
|
| 365 |
if gpu_id == 0:
|
| 366 |
print('len_train_data: ',len(dataset))
|
| 367 |
+
|
| 368 |
+
# Proactive restart: track steps since process start to exit before OOM.
|
| 369 |
+
max_steps_restart = args.max_steps_before_restart
|
| 370 |
+
steps_since_start = 0
|
| 371 |
+
loss_contra_gate = 0.0
|
| 372 |
+
|
| 373 |
# Training loop
|
| 374 |
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 375 |
+
if use_distributed and sampler is not None:
|
| 376 |
+
sampler.set_epoch(epoch)
|
| 377 |
+
sampler_p.set_epoch(epoch)
|
| 378 |
|
| 379 |
epoch_loss_tot = 0.0
|
| 380 |
epoch_loss_gen_d = 0.0
|
|
|
|
| 384 |
epoch_loss_imgsim = 0.0
|
| 385 |
epoch_loss_imgmse = 0.0
|
| 386 |
epoch_loss_ddfreg = 0.0
|
| 387 |
+
epoch_loss_contrastive = 0.0
|
| 388 |
+
total_contra = 0
|
| 389 |
+
total_reg_restored = None
|
| 390 |
+
total_contra_restored = None
|
| 391 |
+
|
| 392 |
+
# Restore epoch accumulators from mid-epoch checkpoint (only for the resumed epoch)
|
| 393 |
+
if _resume_epoch_stats is not None and epoch == initial_epoch:
|
| 394 |
+
epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0)
|
| 395 |
+
epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0)
|
| 396 |
+
epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0)
|
| 397 |
+
epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0)
|
| 398 |
+
epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0)
|
| 399 |
+
epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0)
|
| 400 |
+
epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0)
|
| 401 |
+
epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0)
|
| 402 |
+
epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0)
|
| 403 |
+
total_reg_restored = _resume_epoch_stats.get('total_reg', None)
|
| 404 |
+
total_contra_restored = _resume_epoch_stats.get('total_contra', None)
|
| 405 |
+
loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0)
|
| 406 |
+
# RNG states are restored INSIDE the skip loop (at the last skipped step)
|
| 407 |
+
# to avoid DataLoader __getitem__ calls corrupting the restored state.
|
| 408 |
+
_resume_rng = {k: _resume_epoch_stats[k] for k in
|
| 409 |
+
('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda')
|
| 410 |
+
if k in _resume_epoch_stats}
|
| 411 |
+
if gpu_id == 0:
|
| 412 |
+
print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True)
|
| 413 |
+
_resume_epoch_stats = None # Only restore once
|
| 414 |
+
else:
|
| 415 |
+
loss_nan_step = 0 # only reset when NOT resuming mid-epoch
|
| 416 |
+
|
| 417 |
# Set model inside to train model
|
| 418 |
Deformddpm.train()
|
|
|
|
|
|
|
| 419 |
|
| 420 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 421 |
+
total_reg = total // REGISTRATION_STEP_RATIO
|
| 422 |
+
# Restore total_reg and total_contra from checkpoint if available (mid-epoch resume)
|
| 423 |
+
if total_reg_restored is not None:
|
| 424 |
+
total_reg = total_reg_restored
|
| 425 |
+
total_reg_restored = None
|
| 426 |
+
if total_contra_restored is not None:
|
| 427 |
+
total_contra = total_contra_restored
|
| 428 |
+
total_contra_restored = None
|
| 429 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 430 |
# for step, batch in tqdm(enumerate(train_loader)):
|
|
|
|
| 431 |
# for step, batch in enumerate(train_loader_omni):
|
| 432 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 433 |
+
|
| 434 |
+
# Skip steps already completed (mid-epoch resume).
|
| 435 |
+
# Checkpoint at step N is saved AFTER step N's training completes,
|
| 436 |
+
# so step N itself must also be skipped (use <=, not <).
|
| 437 |
+
if epoch == initial_epoch and initial_step > 0 and step <= initial_step:
|
| 438 |
+
# Restore RNG at the last skipped step, AFTER DataLoader __getitem__
|
| 439 |
+
# has consumed RNG for all skipped batches. This way the first
|
| 440 |
+
# non-skipped step starts with exactly the saved RNG state.
|
| 441 |
+
if step == initial_step and _resume_rng is not None:
|
| 442 |
+
# Restore rank 0's RNG as base state, then re-seed per-rank
|
| 443 |
+
# so each rank has independent RNG (matching continuous run's
|
| 444 |
+
# divergent-per-rank behavior). Without this, all ranks would
|
| 445 |
+
# share rank 0's RNG → correlated augmentation/dropout decisions.
|
| 446 |
+
if 'rng_torch' in _resume_rng:
|
| 447 |
+
torch.set_rng_state(_resume_rng['rng_torch'])
|
| 448 |
+
if 'rng_numpy' in _resume_rng:
|
| 449 |
+
np.random.set_state(_resume_rng['rng_numpy'])
|
| 450 |
+
if 'rng_python' in _resume_rng:
|
| 451 |
+
random.setstate(_resume_rng['rng_python'])
|
| 452 |
+
if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu':
|
| 453 |
+
torch.xpu.set_rng_state(_resume_rng['rng_xpu'])
|
| 454 |
+
elif 'rng_cuda' in _resume_rng and torch.cuda.is_available():
|
| 455 |
+
torch.cuda.set_rng_state(_resume_rng['rng_cuda'])
|
| 456 |
+
# Per-rank re-seed: checkpoint only has rank 0's RNG state.
|
| 457 |
+
# Advance each rank's RNG by a deterministic offset so they
|
| 458 |
+
# diverge (as they would in a continuous run).
|
| 459 |
+
if gpu_id > 0:
|
| 460 |
+
rank_seed = gpu_id * 100003 + initial_step * 31
|
| 461 |
+
torch.manual_seed(torch.initial_seed() + rank_seed)
|
| 462 |
+
np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31))
|
| 463 |
+
random.seed(random.getrandbits(32) + rank_seed)
|
| 464 |
+
if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 465 |
+
torch.xpu.manual_seed(torch.initial_seed() + rank_seed)
|
| 466 |
+
elif torch.cuda.is_available():
|
| 467 |
+
torch.cuda.manual_seed(torch.initial_seed() + rank_seed)
|
| 468 |
+
_resume_rng = None
|
| 469 |
+
if gpu_id == 0:
|
| 470 |
+
print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True)
|
| 471 |
+
continue
|
| 472 |
+
|
| 473 |
+
# Free registration tensors from previous step
|
| 474 |
+
x1 = y1 = ddf_comp = img_rec = img_diff = None
|
| 475 |
+
ddf_rand = y1_proc = msk_tgt = img_save = None
|
| 476 |
+
loss_regist = loss_sim = loss_mse = loss_ddf1 = None
|
| 477 |
+
|
| 478 |
+
# Memory diagnostic (one per node via local rank 0) — only warn when abnormal
|
| 479 |
+
# Normal at step start: ~16 GiB reserved, ~48 GiB free (of 64 GiB total)
|
| 480 |
+
if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| 481 |
+
torch.xpu.reset_peak_memory_stats(rank)
|
| 482 |
+
free_mem, total_mem_dev = torch.xpu.mem_get_info(rank)
|
| 483 |
+
used_gib = (total_mem_dev - free_mem) / 1024**3
|
| 484 |
+
if used_gib > 24: # Normal is ~16 GiB at step start; warn if accumulating
|
| 485 |
+
alloc = torch.xpu.memory_allocated() / 1024**3
|
| 486 |
+
reserved = torch.xpu.memory_reserved() / 1024**3
|
| 487 |
+
free_gib = free_mem / 1024**3
|
| 488 |
+
print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: "
|
| 489 |
+
f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), "
|
| 490 |
+
f"{free_gib:.1f} GiB free", flush=True)
|
| 491 |
|
| 492 |
# ==========================================================================
|
| 493 |
# diffusion train on single image
|
|
|
|
| 496 |
[x0,embd] = batch # for om dataset
|
| 497 |
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 498 |
# print('embd:', embd.shape)
|
| 499 |
+
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 500 |
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 501 |
+
embd_in = embd_dev
|
| 502 |
else:
|
| 503 |
+
embd_in = None
|
|
|
|
|
|
|
| 504 |
|
| 505 |
n = x0.size()[0] # batch_size -> n
|
| 506 |
x0 = x0.to(hyp_parameters["device"])
|
|
|
|
| 514 |
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 515 |
else:
|
| 516 |
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 517 |
+
# x0 = transformer(x0)
|
| 518 |
if hyp_parameters['noise_scale']>0:
|
| 519 |
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 520 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 521 |
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 522 |
|
| 523 |
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
|
|
|
| 526 |
) # pick up a seq of rand number from 0 to 'timestep'
|
| 527 |
|
| 528 |
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 529 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 530 |
# print('proc_type:', proc_type)
|
| 531 |
+
ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| 532 |
+
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
+
if loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE:
|
| 535 |
+
|
| 536 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
|
| 537 |
+
|
| 538 |
+
loss_tot=0
|
| 539 |
|
| 540 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 541 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 542 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 543 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 544 |
|
| 545 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 546 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 547 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 548 |
+
|
| 549 |
+
# >> JZ: print nan in x0
|
| 550 |
+
if torch.isnan(x0).any():
|
| 551 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 552 |
+
# >> JZ: print loss of ddf
|
| 553 |
+
if loss_ddf>0.001:
|
| 554 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 555 |
+
# yu: check if loss_tot==nan or inf
|
| 556 |
+
# Synchronize NaN skip across all DDP ranks to avoid collective desync
|
| 557 |
+
# Use broadcast from rank 0 instead of all_reduce to avoid CCL hang on single-node XPU
|
| 558 |
+
is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
|
| 559 |
+
if use_distributed:
|
| 560 |
+
nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
|
| 561 |
+
dist.broadcast(nan_flag, src=0)
|
| 562 |
+
is_nan = nan_flag.item() > 0
|
| 563 |
+
if is_nan:
|
| 564 |
+
if gpu_id == 0:
|
| 565 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 566 |
+
loss_nan_step += 1
|
| 567 |
+
continue
|
| 568 |
+
if loss_nan_step > 5:
|
| 569 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 570 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
# ==========================================================================
|
| 573 |
+
# Diffusion backward (no gradient clipping — diffusion dominates training)
|
| 574 |
+
# print(loss_contra_gate)
|
| 575 |
+
if (not args.eval_only): # Skip backward when contrastive loss is high to avoid destabilizing diffusion training (especially early on)
|
| 576 |
+
optimizer.zero_grad()
|
| 577 |
+
loss_tot.backward()
|
| 578 |
+
optimizer.step()
|
| 579 |
+
|
| 580 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 581 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 582 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 583 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 584 |
+
|
| 585 |
+
# Print running average every 20 steps in eval-only mode
|
| 586 |
+
if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0:
|
| 587 |
+
n = step + 1
|
| 588 |
+
print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} "
|
| 589 |
+
f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True)
|
| 590 |
+
|
| 591 |
+
# Free diffusion intermediates and aggressively release all memory to device.
|
| 592 |
+
# XPU runtime leaks ~1.3 GiB/step outside the caching allocator.
|
| 593 |
+
# gc.collect() + synchronize() + empty_cache() attempts to reclaim deferred/lazy allocations.
|
| 594 |
+
loss_gen_a_val = loss_gen_a.item()
|
| 595 |
+
|
| 596 |
+
# del pre_dvf_I, dvf_I, trm_pred, loss_tot, loss_gen_a, loss_gen_d, loss_ddf
|
| 597 |
+
gc.collect()
|
| 598 |
+
if DEVICE_TYPE == 'xpu':
|
| 599 |
+
torch.xpu.synchronize()
|
| 600 |
+
_empty_cache(DEVICE_TYPE)
|
| 601 |
+
|
| 602 |
+
# Sync loss_gen_a across DDP ranks for contrastive and registration gating
|
| 603 |
+
if use_distributed:
|
| 604 |
+
loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}")
|
| 605 |
+
dist.broadcast(loss_gen_a_sync, src=0)
|
| 606 |
+
loss_gen_a_gate = loss_gen_a_sync.item()
|
| 607 |
+
else:
|
| 608 |
+
loss_gen_a_gate = loss_gen_a_val
|
| 609 |
+
|
| 610 |
+
LOSS_WEIGHT_CONTRASTIVE=1e-4
|
| 611 |
+
else:
|
| 612 |
+
LOSS_WEIGHT_CONTRASTIVE=1e-1
|
| 613 |
+
if gpu_id == 0:
|
| 614 |
+
print(f" [train] step {step}: Skipping backward (contra_gate={loss_contra_gate:.4f})", flush=True)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
# ==========================================================================
|
| 618 |
+
# Contrastive train on single image (text-image alignment)
|
| 619 |
+
# Separate backward with gradient clipping to prevent destabilizing diffusion.
|
| 620 |
+
loss_contra_val = None
|
| 621 |
+
if step % CONTRASTIVE_STEP_RATIO == 0:
|
| 622 |
+
n_contra = x0.size()[0]
|
| 623 |
+
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
| 624 |
+
# Route through DDP wrapper and return img_embd directly so DDP
|
| 625 |
+
# traces the correct subgraph (encoder + mid + attn + img2txt).
|
| 626 |
+
img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None) # [B, 1024]
|
| 627 |
+
loss_contra_preweight = F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1)-0.25).mean()
|
| 628 |
+
loss_contra = LOSS_WEIGHT_CONTRASTIVE * loss_contra_preweight
|
| 629 |
+
|
| 630 |
+
if not args.eval_only:
|
| 631 |
+
optimizer.zero_grad()
|
| 632 |
+
loss_contra.backward()
|
| 633 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=LOSS_WEIGHT_CONTRASTIVE*1)
|
| 634 |
+
optimizer.step()
|
| 635 |
+
loss_contra_val = loss_contra.item()
|
| 636 |
+
epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
|
| 637 |
+
|
| 638 |
+
# else:
|
| 639 |
+
# if gpu_id == 0:
|
| 640 |
+
# print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
|
| 641 |
+
|
| 642 |
+
# Free remaining intermediates and aggressively release memory before registration
|
| 643 |
+
if cond_img is not None:
|
| 644 |
+
del cond_img
|
| 645 |
+
if blind_mask is not None:
|
| 646 |
+
del blind_mask
|
| 647 |
+
gc.collect()
|
| 648 |
+
if DEVICE_TYPE == 'xpu':
|
| 649 |
+
torch.xpu.synchronize()
|
| 650 |
+
_empty_cache(DEVICE_TYPE)
|
| 651 |
+
|
| 652 |
+
# Sync loss_gen_a across DDP ranks for contrastive and registration gating
|
| 653 |
+
if use_distributed:
|
| 654 |
+
loss_contra_sync = torch.tensor([loss_contra_preweight], device=f"{DEVICE_TYPE}:{rank}")
|
| 655 |
+
dist.broadcast(loss_contra_sync, src=0)
|
| 656 |
+
loss_contra_gate = loss_contra_sync.item()
|
| 657 |
+
else:
|
| 658 |
+
loss_contra_gate = loss_contra_preweight
|
| 659 |
+
|
| 660 |
+
# ==========================================================================
|
| 661 |
+
# registration train on paired images
|
| 662 |
+
# loss_gen_a_gate already synced across DDP ranks above
|
| 663 |
+
do_regist = step % REGISTRATION_STEP_RATIO == 0 and (loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE) and loss_gen_a_gate < ACCEPT_THRESH_ANGLE
|
| 664 |
+
if do_regist:
|
| 665 |
+
[x1, y1, _, embd_y] = batch_p
|
| 666 |
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
|
|
|
| 667 |
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 668 |
else:
|
|
|
|
| 669 |
embd_y = None
|
| 670 |
|
| 671 |
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 672 |
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 673 |
n = x1.size()[0] # batch_size -> n
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 674 |
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 675 |
if hyp_parameters['noise_scale']>0:
|
| 676 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 677 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 678 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 679 |
+
x1 = x1 * random_scale + random_shift
|
| 680 |
+
y1 = y1 * random_scale + random_shift
|
| 681 |
+
|
| 682 |
+
scale_regist = np.random.uniform(0.0,0.5)
|
| 683 |
+
select_timestep = np.random.randint(12, 32) # select a random number of timesteps to sample, between 8 and 16
|
| 684 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 685 |
+
|
| 686 |
+
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
| 687 |
+
|
| 688 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 689 |
+
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| 690 |
+
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| 691 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 692 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 694 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 695 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 696 |
|
| 697 |
loss_regist = 0
|
| 698 |
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 699 |
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 700 |
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
|
|
|
|
|
|
| 701 |
|
| 702 |
# >> JZ: print nan in x0
|
| 703 |
if torch.isnan(x0).any():
|
| 704 |
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 705 |
# >> JZ: print loss of ddf
|
| 706 |
+
if loss_ddf1>0.002:
|
| 707 |
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
|
| 709 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 710 |
+
if not args.eval_only:
|
| 711 |
+
optimizer.zero_grad()
|
| 712 |
+
loss_regist.backward()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
|
| 715 |
+
optimizer.step()
|
|
|
|
|
|
|
| 716 |
|
| 717 |
+
epoch_loss_regist += loss_regist.item()
|
| 718 |
+
epoch_loss_imgsim += loss_sim.item()
|
| 719 |
+
epoch_loss_imgmse += loss_mse.item()
|
| 720 |
+
epoch_loss_ddfreg += loss_ddf1.item()
|
| 721 |
+
else:
|
| 722 |
+
loss_sim = torch.tensor(0.0)
|
| 723 |
+
loss_mse = torch.tensor(0.0)
|
| 724 |
+
loss_ddf1 = torch.tensor(0.0)
|
| 725 |
+
loss_regist = torch.tensor(0.0)
|
| 726 |
+
if step % REGISTRATION_STEP_RATIO==0:
|
| 727 |
+
total_reg = total_reg-1
|
| 728 |
+
|
| 729 |
+
# print for checking
|
| 730 |
+
if step % 10 == 0:
|
| 731 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 732 |
+
print(f'- loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 733 |
+
print(f'- loss_contra: {loss_contra}')
|
| 734 |
|
| 735 |
+
# Mid-epoch checkpoint and proactive restart (only when --max-steps-before-restart > 0)
|
| 736 |
+
if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save:
|
| 737 |
+
_epoch_stats = {
|
| 738 |
+
'epoch_loss_tot': epoch_loss_tot,
|
| 739 |
+
'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 740 |
+
'epoch_loss_gen_a': epoch_loss_gen_a,
|
| 741 |
+
'epoch_loss_reg': epoch_loss_reg,
|
| 742 |
+
'epoch_loss_regist': epoch_loss_regist,
|
| 743 |
+
'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 744 |
+
'epoch_loss_imgmse': epoch_loss_imgmse,
|
| 745 |
+
'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 746 |
+
'epoch_loss_contrastive': epoch_loss_contrastive,
|
| 747 |
+
'total_reg': total_reg,
|
| 748 |
+
'total_contra': total_contra,
|
| 749 |
+
'loss_nan_step': loss_nan_step,
|
| 750 |
+
'rng_torch': torch.get_rng_state(),
|
| 751 |
+
'rng_numpy': np.random.get_state(),
|
| 752 |
+
'rng_python': random.getstate(),
|
| 753 |
+
**(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| 754 |
+
{'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| 755 |
+
}
|
| 756 |
+
tmp_dir = os.path.join(model_save_path, "tmp")
|
| 757 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 758 |
+
for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| 759 |
+
os.remove(old_f)
|
| 760 |
+
mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| 761 |
+
state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| 762 |
+
torch.save({
|
| 763 |
+
'model_state_dict': state,
|
| 764 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 765 |
+
'epoch': epoch,
|
| 766 |
+
'step': step,
|
| 767 |
+
'epoch_stats': _epoch_stats,
|
| 768 |
+
}, mid_save)
|
| 769 |
+
print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
| 770 |
+
|
| 771 |
+
# Proactive restart: exit cleanly after N steps to reset XPU memory leak.
|
| 772 |
+
# The bash wrapper will re-launch srun within the same SLURM allocation.
|
| 773 |
+
steps_since_start += 1
|
| 774 |
+
if max_steps_restart > 0 and steps_since_start >= max_steps_restart:
|
| 775 |
+
# Save checkpoint at current position (if not just saved above)
|
| 776 |
+
if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save:
|
| 777 |
+
_epoch_stats = {
|
| 778 |
+
'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 779 |
+
'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg,
|
| 780 |
+
'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 781 |
+
'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 782 |
+
'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra,
|
| 783 |
+
'loss_nan_step': loss_nan_step,
|
| 784 |
+
'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(),
|
| 785 |
+
'rng_python': random.getstate(),
|
| 786 |
+
**(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| 787 |
+
{'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| 788 |
+
}
|
| 789 |
+
tmp_dir = os.path.join(model_save_path, "tmp")
|
| 790 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
| 791 |
+
for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| 792 |
+
os.remove(old_f)
|
| 793 |
+
mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| 794 |
+
state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| 795 |
+
torch.save({
|
| 796 |
+
'model_state_dict': state,
|
| 797 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 798 |
+
'epoch': epoch,
|
| 799 |
+
'step': step,
|
| 800 |
+
'epoch_stats': _epoch_stats,
|
| 801 |
+
}, mid_save)
|
| 802 |
+
print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
| 803 |
+
if gpu_id == 0:
|
| 804 |
+
print(f" [restart] Proactive restart after {steps_since_start} steps "
|
| 805 |
+
f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True)
|
| 806 |
+
# Clean shutdown
|
| 807 |
+
_empty_cache(DEVICE_TYPE)
|
| 808 |
+
gc.collect()
|
| 809 |
+
if use_distributed and dist.is_initialized():
|
| 810 |
+
dist.barrier()
|
| 811 |
+
dist.destroy_process_group()
|
| 812 |
+
sys.exit(EXIT_CODE_RESTART)
|
| 813 |
|
| 814 |
+
if gpu_id == 0:
|
| 815 |
+
print('==================')
|
| 816 |
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 817 |
+
print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| 818 |
+
total_reg_safe = max(total_reg, 1)
|
| 819 |
+
print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)')
|
| 820 |
+
print('==================')
|
| 821 |
|
|
|
|
|
|
|
| 822 |
|
| 823 |
+
if 0 == epoch % epoch_per_save and not args.no_save:
|
| 824 |
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 825 |
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 826 |
# break # FOR TESTING
|
|
|
|
| 840 |
'optimizer_state_dict': optimizer.state_dict(),
|
| 841 |
'epoch': epoch
|
| 842 |
}, save_dir)
|
| 843 |
+
# Clean up tmp/ mid-epoch checkpoints after completed epoch
|
| 844 |
+
if gpu_id == 0 and not args.no_save:
|
| 845 |
+
tmp_dir = os.path.join(model_dir, "tmp")
|
| 846 |
+
tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth"))
|
| 847 |
+
if tmp_pths:
|
| 848 |
+
for f in tmp_pths:
|
| 849 |
+
os.remove(f)
|
| 850 |
+
print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True)
|
| 851 |
+
# Reset initial_step after first epoch completes (no more skipping)
|
| 852 |
+
initial_step = 0
|
| 853 |
+
|
| 854 |
+
# XPU CCL workaround: restart after each epoch to avoid CCL hang on 2nd epoch.
|
| 855 |
+
# CCL's Level Zero IPC handles accumulate and cause deadlock after ~200+ collectives.
|
| 856 |
+
# A fresh process resets the L0 context. The bash loop catches exit code 42 and restarts.
|
| 857 |
+
if DEVICE_TYPE == 'xpu' and use_distributed:
|
| 858 |
+
if gpu_id == 0:
|
| 859 |
+
print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True)
|
| 860 |
+
_empty_cache(DEVICE_TYPE)
|
| 861 |
+
gc.collect()
|
| 862 |
+
if dist.is_initialized():
|
| 863 |
+
dist.barrier()
|
| 864 |
+
dist.destroy_process_group()
|
| 865 |
+
sys.exit(EXIT_CODE_RESTART)
|
| 866 |
|
| 867 |
# Resource cleanup at the end of training
|
| 868 |
+
_empty_cache(DEVICE_TYPE)
|
| 869 |
gc.collect()
|
| 870 |
if use_distributed and dist.is_initialized():
|
| 871 |
dist.destroy_process_group()
|
| 872 |
|
| 873 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 874 |
+
|
| 875 |
+
# All ranks load checkpoint so optimizer state is consistent across DDP processes.
|
| 876 |
+
# (Optimizer state includes per-parameter Adam momentum/variance which are NOT
|
| 877 |
+
# broadcast — only model weights are broadcast. Without this, non-rank-0 processes
|
| 878 |
+
# would have fresh Adam state after restart.)
|
| 879 |
+
gc.collect()
|
| 880 |
+
_empty_cache(DEVICE_TYPE)
|
| 881 |
if gpu_id == 0:
|
|
|
|
| 882 |
utils.print_memory_usage("Before Loading Model")
|
| 883 |
+
# checkpoint = torch.load(model_file, map_location='cpu', weights_only=False)
|
| 884 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
| 885 |
+
if use_distributed:
|
| 886 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 887 |
+
else:
|
| 888 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 889 |
+
# Restore optimizer state when available (needed for mid-epoch resume).
|
| 890 |
+
# Selective loading: load states for parameters with matching shapes, skip mismatched ones
|
| 891 |
+
# (e.g., UpsampleConv replaced ConvTranspose3d — different kernel shapes).
|
| 892 |
+
# After one epoch, the saved checkpoint will have correct state for ALL parameters.
|
| 893 |
+
if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer:
|
| 894 |
+
saved_opt = checkpoint['optimizer_state_dict']
|
| 895 |
+
saved_state = saved_opt.get('state', {})
|
| 896 |
+
param_list = [p for group in optimizer.param_groups for p in group['params']]
|
| 897 |
+
|
| 898 |
+
# Check if all shapes match (fast path: full load)
|
| 899 |
+
all_match = True
|
| 900 |
+
skipped = 0
|
| 901 |
+
for idx, s in saved_state.items():
|
| 902 |
+
if int(idx) < len(param_list):
|
| 903 |
+
p = param_list[int(idx)]
|
| 904 |
+
for k, v in s.items():
|
| 905 |
+
if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape:
|
| 906 |
+
all_match = False
|
| 907 |
+
break
|
| 908 |
+
if not all_match:
|
| 909 |
+
break
|
| 910 |
+
|
| 911 |
+
if all_match:
|
| 912 |
+
optimizer.load_state_dict(saved_opt)
|
| 913 |
else:
|
| 914 |
+
# Selective load: restore param_groups settings (lr, betas, etc.)
|
| 915 |
+
for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups):
|
| 916 |
+
for k, v in saved_g.items():
|
| 917 |
+
if k != 'params':
|
| 918 |
+
group[k] = v
|
| 919 |
+
# Restore per-parameter state only where shapes match
|
| 920 |
+
for idx, s in saved_state.items():
|
| 921 |
+
idx_int = int(idx)
|
| 922 |
+
if idx_int < len(param_list):
|
| 923 |
+
p = param_list[idx_int]
|
| 924 |
+
shapes_ok = all(
|
| 925 |
+
v.shape == p.shape for k, v in s.items()
|
| 926 |
+
if isinstance(v, torch.Tensor) and v.dim() > 0
|
| 927 |
+
)
|
| 928 |
+
if shapes_ok:
|
| 929 |
+
# Cast state tensors to match parameter dtype/device
|
| 930 |
+
new_state = {}
|
| 931 |
+
for k, v in s.items():
|
| 932 |
+
if isinstance(v, torch.Tensor):
|
| 933 |
+
new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v
|
| 934 |
+
else:
|
| 935 |
+
new_state[k] = v
|
| 936 |
+
optimizer.state[p] = new_state
|
| 937 |
+
else:
|
| 938 |
+
skipped += 1
|
| 939 |
+
if gpu_id == 0:
|
| 940 |
+
loaded = len(saved_state) - skipped
|
| 941 |
+
print(f" [checkpoint] Selective optimizer load: {loaded} params restored, "
|
| 942 |
+
f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True)
|
| 943 |
+
elif args.reset_optimizer and gpu_id == 0:
|
| 944 |
+
print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True)
|
| 945 |
+
del checkpoint
|
| 946 |
+
if gpu_id == 0:
|
| 947 |
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 948 |
|
| 949 |
if use_distributed:
|
| 950 |
+
# Broadcast model weights from rank 0 to ensure exact consistency
|
| 951 |
dist.barrier()
|
| 952 |
for param in Deformddpm.parameters():
|
| 953 |
+
dist.broadcast(param.data, src=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
|
| 955 |
+
# get the epoch number from the filename
|
| 956 |
+
basename = os.path.basename(model_file)
|
| 957 |
+
epoch_from_file = int(basename[:6])
|
| 958 |
+
if '_step' in basename:
|
| 959 |
+
# Mid-epoch checkpoint: resume at same epoch (don't +1)
|
| 960 |
+
initial_epoch = epoch_from_file
|
| 961 |
+
else:
|
| 962 |
+
# End-of-epoch checkpoint: start next epoch
|
| 963 |
+
initial_epoch = epoch_from_file + 1
|
| 964 |
|
| 965 |
return initial_epoch, Deformddpm, optimizer
|
| 966 |
|
| 967 |
|
| 968 |
|
| 969 |
if __name__ == "__main__":
|
| 970 |
+
if "LOCAL_RANK" in os.environ:
|
| 971 |
+
# Multi-node: launched by torchrun / srun
|
| 972 |
+
use_distributed = True
|
| 973 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 974 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 975 |
+
print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
|
| 976 |
+
try:
|
| 977 |
+
main_train(local_rank, world_size)
|
| 978 |
+
except Exception as e:
|
| 979 |
+
import traceback
|
| 980 |
+
print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
|
| 981 |
+
traceback.print_exc()
|
| 982 |
+
raise
|
| 983 |
+
elif use_distributed:
|
| 984 |
+
# Single-node multi-GPU: use mp.spawn
|
| 985 |
+
world_size = _device_count(DEVICE_TYPE)
|
| 986 |
+
print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
|
| 987 |
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 988 |
else:
|
| 989 |
main_train(0,1)
|
OM_train_3modes_cudaonly.py
ADDED
|
@@ -0,0 +1,512 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from torch.optim import Adam, SGD
|
| 14 |
+
from Diffusion.diffuser import DeformDDPM
|
| 15 |
+
from Diffusion.networks import get_net, STN
|
| 16 |
+
from torchvision.transforms import Lambda
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import Diffusion.losses as losses
|
| 19 |
+
import random
|
| 20 |
+
import glob
|
| 21 |
+
import numpy as np
|
| 22 |
+
import utils
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 26 |
+
from Dataloader.dataLoader import *
|
| 27 |
+
|
| 28 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 29 |
+
import yaml
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
####################
|
| 33 |
+
import torch.multiprocessing as mp
|
| 34 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 36 |
+
import torch.distributed as dist
|
| 37 |
+
# from torch.distributed import init_process_group
|
| 38 |
+
###############
|
| 39 |
+
def ddp_setup(rank, world_size):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
rank: Unique identifier of each process
|
| 43 |
+
world_size: Total number of processes
|
| 44 |
+
"""
|
| 45 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 46 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 47 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 48 |
+
torch.cuda.set_device(rank)
|
| 49 |
+
|
| 50 |
+
# Auto-detect: use DDP only when multiple CUDA GPUs are available
|
| 51 |
+
use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
|
| 52 |
+
# use_distributed = True
|
| 53 |
+
# use_distributed = False
|
| 54 |
+
|
| 55 |
+
EPS = 1e-5
|
| 56 |
+
MSK_EPS = 0.01
|
| 57 |
+
TEXT_EMBED_PROB = 0.5
|
| 58 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 59 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
|
| 60 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 61 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
|
| 62 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 63 |
+
LOSS_WEIGHT_CONTRASTIVE = 0.001
|
| 64 |
+
REGISTRATION_STEP_RATIO = 1
|
| 65 |
+
CONTRASTIVE_STEP_RATIO = 1
|
| 66 |
+
|
| 67 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 68 |
+
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
|
| 71 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--config",
|
| 74 |
+
"-C",
|
| 75 |
+
help="Path for the config file",
|
| 76 |
+
type=str,
|
| 77 |
+
# default="Config/config_cmr.yaml",
|
| 78 |
+
# default="Config/config_lct.yaml",
|
| 79 |
+
default="Config/config_all.yaml",
|
| 80 |
+
required=False,
|
| 81 |
+
)
|
| 82 |
+
# parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
|
| 83 |
+
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
#=======================================================================================================================
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 90 |
+
if use_distributed:
|
| 91 |
+
ddp_setup(rank,world_size)
|
| 92 |
+
|
| 93 |
+
if torch.distributed.is_initialized():
|
| 94 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 95 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 96 |
+
gpu_id = rank
|
| 97 |
+
|
| 98 |
+
# Load the YAML file into a dictionary
|
| 99 |
+
with open(args.config, 'r') as file:
|
| 100 |
+
hyp_parameters = yaml.safe_load(file)
|
| 101 |
+
if args.batchsize > 0:
|
| 102 |
+
hyp_parameters['batchsize'] = args.batchsize
|
| 103 |
+
print(hyp_parameters)
|
| 104 |
+
|
| 105 |
+
# epoch_per_save=10
|
| 106 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 107 |
+
|
| 108 |
+
data_name=hyp_parameters['data_name']
|
| 109 |
+
net_name = hyp_parameters['net_name']
|
| 110 |
+
|
| 111 |
+
Net=get_net(net_name)
|
| 112 |
+
|
| 113 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 114 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 115 |
+
model_dir=model_save_path
|
| 116 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 117 |
+
|
| 118 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 119 |
+
|
| 120 |
+
# tsfm = torchvision.transforms.Compose([
|
| 121 |
+
# torchvision.transforms.ToTensor(),
|
| 122 |
+
# ])
|
| 123 |
+
|
| 124 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 125 |
+
# train_loader = DataLoader(
|
| 126 |
+
# dataset,
|
| 127 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 128 |
+
# # shuffle=False,
|
| 129 |
+
# shuffle=True,
|
| 130 |
+
# drop_last=True,
|
| 131 |
+
# )
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 135 |
+
dataset = OMDataset_indiv(transform=None)
|
| 136 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 137 |
+
datasetp = OMDataset_pair(transform=None)
|
| 138 |
+
|
| 139 |
+
train_loader = DataLoader(
|
| 140 |
+
dataset,
|
| 141 |
+
batch_size=hyp_parameters['batchsize'],
|
| 142 |
+
shuffle=True,
|
| 143 |
+
drop_last=True,
|
| 144 |
+
)
|
| 145 |
+
train_loader_p = DataLoader(
|
| 146 |
+
datasetp,
|
| 147 |
+
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| 148 |
+
shuffle=True,
|
| 149 |
+
drop_last=True,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
Deformddpm = DeformDDPM(
|
| 155 |
+
network=Net(
|
| 156 |
+
n_steps=hyp_parameters["timesteps"],
|
| 157 |
+
ndims=hyp_parameters["ndims"],
|
| 158 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 159 |
+
res = hyp_parameters['img_size']
|
| 160 |
+
),
|
| 161 |
+
n_steps=hyp_parameters["timesteps"],
|
| 162 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 163 |
+
device=hyp_parameters["device"],
|
| 164 |
+
batch_size=hyp_parameters["batchsize"],
|
| 165 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 166 |
+
v_scale=hyp_parameters["v_scale"],
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
ddf_stn = STN(
|
| 171 |
+
img_sz=hyp_parameters["img_size"],
|
| 172 |
+
ndims=hyp_parameters["ndims"],
|
| 173 |
+
# padding_mode="zeros",
|
| 174 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 175 |
+
device=hyp_parameters["device"],
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if use_distributed:
|
| 180 |
+
Deformddpm.to(rank)
|
| 181 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 182 |
+
ddf_stn.to(rank)
|
| 183 |
+
else:
|
| 184 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 185 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 186 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# mse = nn.MSELoss()
|
| 190 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 191 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 192 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 193 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 194 |
+
|
| 195 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 196 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 197 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 198 |
+
loss_imgsim = losses.MSLNCC()
|
| 199 |
+
loss_imgmse = losses.LMSE()
|
| 200 |
+
|
| 201 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 202 |
+
# hyp_parameters["lr"]=0.00000001
|
| 203 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 204 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 205 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 206 |
+
|
| 207 |
+
# # LR scheduler ----- YHM
|
| 208 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 209 |
+
|
| 210 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 211 |
+
|
| 212 |
+
# check for existing models
|
| 213 |
+
if not os.path.exists(model_dir):
|
| 214 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 215 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 216 |
+
model_files.sort()
|
| 217 |
+
if model_files:
|
| 218 |
+
if gpu_id == 0:
|
| 219 |
+
print(model_files)
|
| 220 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
|
| 221 |
+
else:
|
| 222 |
+
initial_epoch = 0
|
| 223 |
+
|
| 224 |
+
if gpu_id == 0:
|
| 225 |
+
print('len_train_data: ',len(dataset))
|
| 226 |
+
# Training loop
|
| 227 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 228 |
+
|
| 229 |
+
epoch_loss_tot = 0.0
|
| 230 |
+
epoch_loss_gen_d = 0.0
|
| 231 |
+
epoch_loss_gen_a = 0.0
|
| 232 |
+
epoch_loss_reg = 0.0
|
| 233 |
+
epoch_loss_regist = 0.0
|
| 234 |
+
epoch_loss_imgsim = 0.0
|
| 235 |
+
epoch_loss_imgmse = 0.0
|
| 236 |
+
epoch_loss_ddfreg = 0.0
|
| 237 |
+
epoch_loss_contrastive = 0.0
|
| 238 |
+
# Set model inside to train model
|
| 239 |
+
Deformddpm.train()
|
| 240 |
+
|
| 241 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 242 |
+
|
| 243 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 244 |
+
total_reg = total // REGISTRATION_STEP_RATIO
|
| 245 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 246 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 247 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 248 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 249 |
+
|
| 250 |
+
# x0, _ = batch
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# ==========================================================================
|
| 254 |
+
# diffusion train on single image
|
| 255 |
+
|
| 256 |
+
# x0 = batch # for omni dataset
|
| 257 |
+
[x0,embd] = batch # for om dataset
|
| 258 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 259 |
+
# print('embd:', embd.shape)
|
| 260 |
+
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 261 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 262 |
+
embd_in = embd_dev
|
| 263 |
+
else:
|
| 264 |
+
embd_in = None
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
n = x0.size()[0] # batch_size -> n
|
| 269 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 270 |
+
|
| 271 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 272 |
+
|
| 273 |
+
# random deformation + rotation
|
| 274 |
+
if hyp_parameters["ndims"]>2:
|
| 275 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 276 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 277 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 278 |
+
else:
|
| 279 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 280 |
+
# x0 = transformer(x0)
|
| 281 |
+
if hyp_parameters['noise_scale']>0:
|
| 282 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 283 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 284 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 285 |
+
|
| 286 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 287 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 288 |
+
hyp_parameters["device"]
|
| 289 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 290 |
+
|
| 291 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 292 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 293 |
+
# print('proc_type:', proc_type)
|
| 294 |
+
ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| 295 |
+
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
| 296 |
+
|
| 297 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
|
| 298 |
+
|
| 299 |
+
# print(torch.max(torch.abs(pre_dvf_I)))
|
| 300 |
+
# print(torch.max(torch.abs(dvf_I)))
|
| 301 |
+
|
| 302 |
+
loss_tot=0
|
| 303 |
+
|
| 304 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 305 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 306 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 307 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 308 |
+
|
| 309 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 310 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 311 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 312 |
+
|
| 313 |
+
# >> JZ: print nan in x0
|
| 314 |
+
if torch.isnan(x0).any():
|
| 315 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 316 |
+
# >> JZ: print loss of ddf
|
| 317 |
+
if loss_ddf>0.001:
|
| 318 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 319 |
+
# yu: check if loss_tot==nan or inf
|
| 320 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 321 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 322 |
+
loss_nan_step += 1
|
| 323 |
+
continue
|
| 324 |
+
if loss_nan_step > 5:
|
| 325 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 326 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 327 |
+
|
| 328 |
+
optimizer.zero_grad()
|
| 329 |
+
loss_tot.backward()
|
| 330 |
+
optimizer.step()
|
| 331 |
+
|
| 332 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 333 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 334 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 335 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 336 |
+
|
| 337 |
+
# ==========================================================================
|
| 338 |
+
# contrastive train on single image (text-image alignment)
|
| 339 |
+
loss_contra_val = None
|
| 340 |
+
if step % CONTRASTIVE_STEP_RATIO == 0:
|
| 341 |
+
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
|
| 342 |
+
n_contra = x0.size()[0]
|
| 343 |
+
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
| 344 |
+
_ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
|
| 345 |
+
if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
|
| 346 |
+
img_embd = raw_network.img_embd # [B, 1024]
|
| 347 |
+
loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.05) # contrastive loss to align image embedding with text embedding, with a margin of 0.02
|
| 348 |
+
|
| 349 |
+
optimizer.zero_grad()
|
| 350 |
+
loss_contra.backward()
|
| 351 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
|
| 352 |
+
optimizer.step()
|
| 353 |
+
loss_contra_val = loss_contra.item()
|
| 354 |
+
epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
|
| 355 |
+
else:
|
| 356 |
+
if gpu_id == 0:
|
| 357 |
+
print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
|
| 358 |
+
|
| 359 |
+
# ==========================================================================
|
| 360 |
+
# registration train on paired images
|
| 361 |
+
if step%REGISTRATION_STEP_RATIO == 0 and loss_gen_a.item()<-0.6: # only train registration on relatively well-deformed images, to avoid too large registration loss and unstable training in the early stage
|
| 362 |
+
[x1, y1, _, embd_y] = batch_p
|
| 363 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 364 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 365 |
+
else:
|
| 366 |
+
embd_y = None
|
| 367 |
+
|
| 368 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 369 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 370 |
+
n = x1.size()[0] # batch_size -> n
|
| 371 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 372 |
+
if hyp_parameters['noise_scale']>0:
|
| 373 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 374 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 375 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 376 |
+
x1 = x1 * random_scale + random_shift
|
| 377 |
+
y1 = y1 * random_scale + random_shift
|
| 378 |
+
|
| 379 |
+
scale_regist = np.random.uniform(0.0,0.7)
|
| 380 |
+
select_timestep = np.random.randint(12, 25) # select a random number of timesteps to sample, between 8 and 16
|
| 381 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 382 |
+
|
| 383 |
+
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
| 384 |
+
|
| 385 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 386 |
+
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| 387 |
+
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| 388 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 389 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 390 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 391 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 392 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 393 |
+
|
| 394 |
+
loss_regist = 0
|
| 395 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 396 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 397 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 398 |
+
|
| 399 |
+
# >> JZ: print nan in x0
|
| 400 |
+
if torch.isnan(x0).any():
|
| 401 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 402 |
+
# >> JZ: print loss of ddf
|
| 403 |
+
if loss_ddf1>0.002:
|
| 404 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 405 |
+
|
| 406 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 407 |
+
optimizer.zero_grad()
|
| 408 |
+
loss_regist.backward()
|
| 409 |
+
|
| 410 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 411 |
+
optimizer.step()
|
| 412 |
+
|
| 413 |
+
epoch_loss_regist += loss_regist.item()
|
| 414 |
+
epoch_loss_imgsim += loss_sim.item()
|
| 415 |
+
epoch_loss_imgmse += loss_mse.item()
|
| 416 |
+
epoch_loss_ddfreg += loss_ddf1.item()
|
| 417 |
+
else:
|
| 418 |
+
loss_sim = torch.tensor(0.0)
|
| 419 |
+
loss_mse = torch.tensor(0.0)
|
| 420 |
+
loss_ddf1 = torch.tensor(0.0)
|
| 421 |
+
loss_regist = torch.tensor(0.0)
|
| 422 |
+
if step % REGISTRATION_STEP_RATIO==0:
|
| 423 |
+
total_reg = total_reg-1
|
| 424 |
+
|
| 425 |
+
if step % 10 == 0:
|
| 426 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 427 |
+
if loss_contra_val is not None:
|
| 428 |
+
print(f' loss_contrastive: {loss_contra_val:.6f}')
|
| 429 |
+
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 430 |
+
|
| 431 |
+
if 1:
|
| 432 |
+
print('==================')
|
| 433 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 434 |
+
print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| 435 |
+
print(f' loss_regist: {epoch_loss_regist/total_reg} = {epoch_loss_imgsim/total_reg} (imgsim) + {epoch_loss_imgmse/total_reg} (imgmse) + {epoch_loss_ddfreg/total_reg} (ddf)')
|
| 436 |
+
print('==================')
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if 0 == epoch % epoch_per_save:
|
| 440 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 441 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 442 |
+
# break # FOR TESTING
|
| 443 |
+
if not use_distributed:
|
| 444 |
+
print(f"saved in {save_dir}")
|
| 445 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 446 |
+
torch.save({
|
| 447 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 448 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 449 |
+
'epoch': epoch
|
| 450 |
+
}, save_dir)
|
| 451 |
+
elif gpu_id == 0:
|
| 452 |
+
print(f"saved in {save_dir}")
|
| 453 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 454 |
+
torch.save({
|
| 455 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 456 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 457 |
+
'epoch': epoch
|
| 458 |
+
}, save_dir)
|
| 459 |
+
|
| 460 |
+
# Resource cleanup at the end of training
|
| 461 |
+
if torch.cuda.is_available():
|
| 462 |
+
torch.cuda.empty_cache()
|
| 463 |
+
gc.collect()
|
| 464 |
+
if use_distributed and dist.is_initialized():
|
| 465 |
+
dist.destroy_process_group()
|
| 466 |
+
|
| 467 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 468 |
+
|
| 469 |
+
if gpu_id == 0:
|
| 470 |
+
# if 0:
|
| 471 |
+
utils.print_memory_usage("Before Loading Model")
|
| 472 |
+
if torch.cuda.is_available():
|
| 473 |
+
gc.collect()
|
| 474 |
+
torch.cuda.empty_cache()
|
| 475 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 476 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 477 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
| 478 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 479 |
+
if use_distributed:
|
| 480 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 481 |
+
else:
|
| 482 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 483 |
+
if load_strict:
|
| 484 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 485 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 486 |
+
|
| 487 |
+
if use_distributed:
|
| 488 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 489 |
+
dist.barrier()
|
| 490 |
+
for param in Deformddpm.parameters():
|
| 491 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 492 |
+
dist.barrier()
|
| 493 |
+
for param_group in optimizer.param_groups:
|
| 494 |
+
for param in param_group['params']:
|
| 495 |
+
if param.grad is not None:
|
| 496 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 497 |
+
|
| 498 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 499 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 500 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 501 |
+
|
| 502 |
+
return initial_epoch, Deformddpm, optimizer
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
if __name__ == "__main__":
|
| 507 |
+
if use_distributed:
|
| 508 |
+
world_size = torch.cuda.device_count()
|
| 509 |
+
print(f"Distributed GPU number = {world_size}")
|
| 510 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 511 |
+
else:
|
| 512 |
+
main_train(0,1)
|
OM_train_3modes_opt.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OM_train_3modes_opt.py — Optimized 3-mode training (diffusion + contrastive + registration).
|
| 3 |
+
|
| 4 |
+
Speed optimizations over OM_train_3modes.py (all mathematically equivalent):
|
| 5 |
+
1. DataLoader: num_workers, pin_memory, persistent_workers for I/O overlap
|
| 6 |
+
2. optimizer.zero_grad(set_to_none=True) — avoids zero-fill overhead
|
| 7 |
+
3. Fixed-length T_regist (16 steps) — avoids XPU dynamic shape recompilation
|
| 8 |
+
4. Removed redundant x0.to(device) call
|
| 9 |
+
5. Uses diffuser_opt.DeformDDPM (hoisted clone, no *0 redundancy, OptSTN, inference_mode)
|
| 10 |
+
6. Uses losses_opt.MSLNCC/LNCC (register_buffer for kernels)
|
| 11 |
+
7. Pre-compute proc_type lists to reduce Python overhead in hot loop
|
| 12 |
+
8. Uses OptRecMulModMutAttnNet (cached resample tensors, ~300 fewer CPU→GPU transfers)
|
| 13 |
+
9. Uses OptSTN for ddf_stn (register_buffer, no per-call .to())
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os, sys
|
| 17 |
+
|
| 18 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 19 |
+
sys.path.append(ROOT_DIR)
|
| 20 |
+
|
| 21 |
+
import gc
|
| 22 |
+
import torch
|
| 23 |
+
import torchvision
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torchvision.utils import save_image
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
|
| 28 |
+
from torch.optim import Adam, SGD
|
| 29 |
+
from Diffusion.diffuser_opt import DeformDDPM
|
| 30 |
+
from Diffusion.networks_opt import get_net_opt, OptSTN
|
| 31 |
+
from torchvision.transforms import Lambda
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
import Diffusion.losses_opt as losses
|
| 34 |
+
import random
|
| 35 |
+
import glob
|
| 36 |
+
import numpy as np
|
| 37 |
+
import utils
|
| 38 |
+
from tqdm import tqdm
|
| 39 |
+
|
| 40 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 41 |
+
from Dataloader.dataLoader import *
|
| 42 |
+
|
| 43 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 44 |
+
import yaml
|
| 45 |
+
import argparse
|
| 46 |
+
|
| 47 |
+
####################
|
| 48 |
+
import torch.multiprocessing as mp
|
| 49 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 50 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 51 |
+
import torch.distributed as dist
|
| 52 |
+
###############
|
| 53 |
+
def ddp_setup(rank, world_size):
|
| 54 |
+
"""
|
| 55 |
+
Args:
|
| 56 |
+
rank: Unique identifier of each process
|
| 57 |
+
world_size: Total number of processes
|
| 58 |
+
"""
|
| 59 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 60 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 61 |
+
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
| 62 |
+
torch.cuda.set_device(rank)
|
| 63 |
+
|
| 64 |
+
# Auto-detect: use DDP only when multiple CUDA GPUs are available
|
| 65 |
+
use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
|
| 66 |
+
# use_distributed = True
|
| 67 |
+
# use_distributed = False
|
| 68 |
+
|
| 69 |
+
EPS = 1e-5
|
| 70 |
+
MSK_EPS = 0.01
|
| 71 |
+
TEXT_EMBED_PROB = 0.7
|
| 72 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 73 |
+
LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0] # [ang, dist, reg]
|
| 74 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 75 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
|
| 76 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 77 |
+
LOSS_WEIGHT_CONTRASTIVE = 1.0
|
| 78 |
+
CONTRASTIVE_STEP_RATIO = 2
|
| 79 |
+
|
| 80 |
+
# OPT: Fixed registration timestep count to avoid XPU dynamic shape recompilation
|
| 81 |
+
FIXED_T_REGIST_LEN = 16
|
| 82 |
+
|
| 83 |
+
# OPT: DataLoader workers (set to 0 to disable multiprocessing if needed)
|
| 84 |
+
NUM_WORKERS = 4
|
| 85 |
+
PIN_MEMORY = True
|
| 86 |
+
|
| 87 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 88 |
+
|
| 89 |
+
parser = argparse.ArgumentParser()
|
| 90 |
+
|
| 91 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--config",
|
| 94 |
+
"-C",
|
| 95 |
+
help="Path for the config file",
|
| 96 |
+
type=str,
|
| 97 |
+
# default="Config/config_cmr.yaml",
|
| 98 |
+
# default="Config/config_lct.yaml",
|
| 99 |
+
default="Config/config_all.yaml",
|
| 100 |
+
required=False,
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
|
| 103 |
+
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| 104 |
+
parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, help="DataLoader num_workers (default: 4)")
|
| 105 |
+
args = parser.parse_args()
|
| 106 |
+
#=======================================================================================================================
|
| 107 |
+
|
| 108 |
+
class _DummyIndiv(torch.utils.data.Dataset):
|
| 109 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 110 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 111 |
+
def __len__(self): return self.n
|
| 112 |
+
def __getitem__(self, i):
|
| 113 |
+
return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
|
| 114 |
+
|
| 115 |
+
class _DummyPair(torch.utils.data.Dataset):
|
| 116 |
+
def __init__(self, n, sz, embd_dim=1024):
|
| 117 |
+
self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| 118 |
+
def __len__(self): return self.n
|
| 119 |
+
def __getitem__(self, i):
|
| 120 |
+
return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 121 |
+
np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| 122 |
+
np.random.randn(self.embd_dim).astype(np.float32),
|
| 123 |
+
np.random.randn(self.embd_dim).astype(np.float32))
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 127 |
+
if use_distributed:
|
| 128 |
+
ddp_setup(rank,world_size)
|
| 129 |
+
|
| 130 |
+
if torch.distributed.is_initialized():
|
| 131 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 132 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 133 |
+
gpu_id = rank
|
| 134 |
+
|
| 135 |
+
# Load the YAML file into a dictionary
|
| 136 |
+
with open(args.config, 'r') as file:
|
| 137 |
+
hyp_parameters = yaml.safe_load(file)
|
| 138 |
+
if args.batchsize > 0:
|
| 139 |
+
hyp_parameters['batchsize'] = args.batchsize
|
| 140 |
+
print(hyp_parameters)
|
| 141 |
+
|
| 142 |
+
# epoch_per_save=10
|
| 143 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 144 |
+
|
| 145 |
+
data_name=hyp_parameters['data_name']
|
| 146 |
+
net_name = hyp_parameters['net_name']
|
| 147 |
+
|
| 148 |
+
Net=get_net_opt(net_name)
|
| 149 |
+
|
| 150 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 151 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 152 |
+
model_dir=model_save_path
|
| 153 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 154 |
+
|
| 155 |
+
# OPT: DataLoader with num_workers, pin_memory, persistent_workers
|
| 156 |
+
num_workers = args.num_workers
|
| 157 |
+
use_pin_memory = PIN_MEMORY and hyp_parameters["device"] != "cpu"
|
| 158 |
+
|
| 159 |
+
if args.dummy_samples > 0:
|
| 160 |
+
dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
|
| 161 |
+
datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
|
| 162 |
+
else:
|
| 163 |
+
dataset = OMDataset_indiv(transform=None)
|
| 164 |
+
datasetp = OMDataset_pair(transform=None)
|
| 165 |
+
|
| 166 |
+
train_loader = DataLoader(
|
| 167 |
+
dataset,
|
| 168 |
+
batch_size=hyp_parameters['batchsize'],
|
| 169 |
+
shuffle=True,
|
| 170 |
+
drop_last=True,
|
| 171 |
+
num_workers=num_workers, # OPT
|
| 172 |
+
pin_memory=use_pin_memory, # OPT
|
| 173 |
+
persistent_workers=num_workers > 0, # OPT
|
| 174 |
+
)
|
| 175 |
+
train_loader_p = DataLoader(
|
| 176 |
+
datasetp,
|
| 177 |
+
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| 178 |
+
shuffle=True,
|
| 179 |
+
drop_last=True,
|
| 180 |
+
num_workers=num_workers, # OPT
|
| 181 |
+
pin_memory=use_pin_memory, # OPT
|
| 182 |
+
persistent_workers=num_workers > 0, # OPT
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
Deformddpm = DeformDDPM(
|
| 188 |
+
network=Net(
|
| 189 |
+
n_steps=hyp_parameters["timesteps"],
|
| 190 |
+
ndims=hyp_parameters["ndims"],
|
| 191 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 192 |
+
res = hyp_parameters['img_size']
|
| 193 |
+
),
|
| 194 |
+
n_steps=hyp_parameters["timesteps"],
|
| 195 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 196 |
+
device=hyp_parameters["device"],
|
| 197 |
+
batch_size=hyp_parameters["batchsize"],
|
| 198 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 199 |
+
v_scale=hyp_parameters["v_scale"],
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
ddf_stn = OptSTN(
|
| 204 |
+
img_sz=hyp_parameters["img_size"],
|
| 205 |
+
ndims=hyp_parameters["ndims"],
|
| 206 |
+
# padding_mode="zeros",
|
| 207 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 208 |
+
device=hyp_parameters["device"],
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if use_distributed:
|
| 213 |
+
Deformddpm.to(rank)
|
| 214 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 215 |
+
ddf_stn.to(rank)
|
| 216 |
+
else:
|
| 217 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 218 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 219 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# mse = nn.MSELoss()
|
| 223 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 224 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 225 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 226 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 227 |
+
|
| 228 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 229 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 230 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 231 |
+
loss_imgsim = losses.MSLNCC()
|
| 232 |
+
loss_imgmse = losses.LMSE()
|
| 233 |
+
|
| 234 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 235 |
+
|
| 236 |
+
# check for existing models
|
| 237 |
+
if not os.path.exists(model_dir):
|
| 238 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 239 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 240 |
+
model_files.sort()
|
| 241 |
+
if model_files:
|
| 242 |
+
if gpu_id == 0:
|
| 243 |
+
print(model_files)
|
| 244 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
|
| 245 |
+
else:
|
| 246 |
+
initial_epoch = 0
|
| 247 |
+
|
| 248 |
+
if gpu_id == 0:
|
| 249 |
+
print('len_train_data: ',len(dataset))
|
| 250 |
+
# Training loop
|
| 251 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 252 |
+
|
| 253 |
+
epoch_loss_tot = 0.0
|
| 254 |
+
epoch_loss_gen_d = 0.0
|
| 255 |
+
epoch_loss_gen_a = 0.0
|
| 256 |
+
epoch_loss_reg = 0.0
|
| 257 |
+
epoch_loss_regist = 0.0
|
| 258 |
+
epoch_loss_imgsim = 0.0
|
| 259 |
+
epoch_loss_imgmse = 0.0
|
| 260 |
+
epoch_loss_ddfreg = 0.0
|
| 261 |
+
epoch_loss_contrastive = 0.0
|
| 262 |
+
# Set model inside to train model
|
| 263 |
+
Deformddpm.train()
|
| 264 |
+
|
| 265 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 266 |
+
|
| 267 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 268 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 269 |
+
|
| 270 |
+
# ==========================================================================
|
| 271 |
+
# diffusion train on single image
|
| 272 |
+
|
| 273 |
+
[x0,embd] = batch # for om dataset
|
| 274 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 275 |
+
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 276 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 277 |
+
embd_in = embd_dev
|
| 278 |
+
else:
|
| 279 |
+
embd_in = None
|
| 280 |
+
|
| 281 |
+
n = x0.size()[0] # batch_size -> n
|
| 282 |
+
# OPT: removed redundant x0.to(device) — already done above
|
| 283 |
+
|
| 284 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 285 |
+
|
| 286 |
+
# random deformation + rotation
|
| 287 |
+
if hyp_parameters["ndims"]>2:
|
| 288 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 289 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 290 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 291 |
+
else:
|
| 292 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 293 |
+
# x0 = transformer(x0)
|
| 294 |
+
if hyp_parameters['noise_scale']>0:
|
| 295 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 296 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 297 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 298 |
+
|
| 299 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 300 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 301 |
+
hyp_parameters["device"]
|
| 302 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 303 |
+
|
| 304 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 305 |
+
ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| 306 |
+
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
| 307 |
+
|
| 308 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
|
| 309 |
+
|
| 310 |
+
loss_tot=0
|
| 311 |
+
|
| 312 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 313 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 314 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 315 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 316 |
+
|
| 317 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 318 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 319 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 320 |
+
|
| 321 |
+
# >> JZ: print nan in x0
|
| 322 |
+
if torch.isnan(x0).any():
|
| 323 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 324 |
+
# >> JZ: print loss of ddf
|
| 325 |
+
if loss_ddf>0.001:
|
| 326 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 327 |
+
# yu: check if loss_tot==nan or inf
|
| 328 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 329 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 330 |
+
loss_nan_step += 1
|
| 331 |
+
continue
|
| 332 |
+
if loss_nan_step > 5:
|
| 333 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 334 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 335 |
+
|
| 336 |
+
optimizer.zero_grad(set_to_none=True) # OPT: set_to_none faster than zero-fill
|
| 337 |
+
loss_tot.backward()
|
| 338 |
+
optimizer.step()
|
| 339 |
+
|
| 340 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 341 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 342 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 343 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 344 |
+
|
| 345 |
+
# ==========================================================================
|
| 346 |
+
# contrastive train on single image (text-image alignment)
|
| 347 |
+
loss_contra_val = None
|
| 348 |
+
if step % CONTRASTIVE_STEP_RATIO == 0:
|
| 349 |
+
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
|
| 350 |
+
n_contra = x0.size()[0]
|
| 351 |
+
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
| 352 |
+
_ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
|
| 353 |
+
if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
|
| 354 |
+
img_embd = raw_network.img_embd # [B, 1024]
|
| 355 |
+
loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
|
| 356 |
+
|
| 357 |
+
optimizer.zero_grad(set_to_none=True) # OPT
|
| 358 |
+
loss_contra.backward()
|
| 359 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.05)
|
| 360 |
+
optimizer.step()
|
| 361 |
+
loss_contra_val = loss_contra.item()
|
| 362 |
+
epoch_loss_contrastive += loss_contra_val / total
|
| 363 |
+
else:
|
| 364 |
+
if gpu_id == 0:
|
| 365 |
+
print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
|
| 366 |
+
|
| 367 |
+
# ==========================================================================
|
| 368 |
+
# registration train on paired images
|
| 369 |
+
if step%train_mode_ratio == 0:
|
| 370 |
+
[x1, y1, _, embd_y] = batch_p
|
| 371 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 372 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 373 |
+
else:
|
| 374 |
+
embd_y = None
|
| 375 |
+
|
| 376 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 377 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 378 |
+
n = x1.size()[0] # batch_size -> n
|
| 379 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 380 |
+
if hyp_parameters['noise_scale']>0:
|
| 381 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 382 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 383 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 384 |
+
x1 = x1 * random_scale + random_shift
|
| 385 |
+
y1 = y1 * random_scale + random_shift
|
| 386 |
+
|
| 387 |
+
scale_regist = np.random.uniform(0.0,0.7)
|
| 388 |
+
# OPT: fixed-length T_regist to avoid XPU dynamic shape recompilation
|
| 389 |
+
# Sample FIXED_T_REGIST_LEN timesteps (was: random 8-16), always same loop length
|
| 390 |
+
t_pool = list(range(int(hyp_parameters["timesteps"] * scale_regist), hyp_parameters["timesteps"]))
|
| 391 |
+
select_timestep = min(FIXED_T_REGIST_LEN, len(t_pool))
|
| 392 |
+
T_regist = sorted(random.sample(t_pool, select_timestep), reverse=True)
|
| 393 |
+
|
| 394 |
+
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
| 395 |
+
|
| 396 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 397 |
+
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| 398 |
+
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| 399 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 400 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 401 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 402 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 403 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 404 |
+
|
| 405 |
+
loss_regist = 0
|
| 406 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 407 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 408 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 409 |
+
|
| 410 |
+
# >> JZ: print nan in x0
|
| 411 |
+
if torch.isnan(x0).any():
|
| 412 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 413 |
+
# >> JZ: print loss of ddf
|
| 414 |
+
if loss_ddf1>0.002:
|
| 415 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 416 |
+
|
| 417 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 418 |
+
optimizer.zero_grad(set_to_none=True) # OPT
|
| 419 |
+
loss_regist.backward()
|
| 420 |
+
|
| 421 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.2)
|
| 422 |
+
optimizer.step()
|
| 423 |
+
|
| 424 |
+
epoch_loss_regist += loss_regist.item() / total
|
| 425 |
+
epoch_loss_imgsim += loss_sim.item() / total
|
| 426 |
+
epoch_loss_imgmse += loss_mse.item() / total
|
| 427 |
+
epoch_loss_ddfreg += loss_ddf1.item() / total
|
| 428 |
+
|
| 429 |
+
if step % 10 == 0:
|
| 430 |
+
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 431 |
+
if loss_contra_val is not None:
|
| 432 |
+
print(f' loss_contrastive: {loss_contra_val:.6f}')
|
| 433 |
+
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 434 |
+
|
| 435 |
+
if 1:
|
| 436 |
+
print('==================')
|
| 437 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 438 |
+
print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| 439 |
+
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
|
| 440 |
+
print('==================')
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if 0 == epoch % epoch_per_save:
|
| 444 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 445 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 446 |
+
# break # FOR TESTING
|
| 447 |
+
if not use_distributed:
|
| 448 |
+
print(f"saved in {save_dir}")
|
| 449 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 450 |
+
torch.save({
|
| 451 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 452 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 453 |
+
'epoch': epoch
|
| 454 |
+
}, save_dir)
|
| 455 |
+
elif gpu_id == 0:
|
| 456 |
+
print(f"saved in {save_dir}")
|
| 457 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 458 |
+
torch.save({
|
| 459 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 460 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 461 |
+
'epoch': epoch
|
| 462 |
+
}, save_dir)
|
| 463 |
+
|
| 464 |
+
# Resource cleanup at the end of training
|
| 465 |
+
if torch.cuda.is_available():
|
| 466 |
+
torch.cuda.empty_cache()
|
| 467 |
+
gc.collect()
|
| 468 |
+
if use_distributed and dist.is_initialized():
|
| 469 |
+
dist.destroy_process_group()
|
| 470 |
+
|
| 471 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 472 |
+
|
| 473 |
+
if gpu_id == 0:
|
| 474 |
+
# if 0:
|
| 475 |
+
utils.print_memory_usage("Before Loading Model")
|
| 476 |
+
if torch.cuda.is_available():
|
| 477 |
+
gc.collect()
|
| 478 |
+
torch.cuda.empty_cache()
|
| 479 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
| 480 |
+
if use_distributed:
|
| 481 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 482 |
+
else:
|
| 483 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 484 |
+
if load_strict:
|
| 485 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 486 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 487 |
+
|
| 488 |
+
if use_distributed:
|
| 489 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 490 |
+
dist.barrier()
|
| 491 |
+
for param in Deformddpm.parameters():
|
| 492 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 493 |
+
dist.barrier()
|
| 494 |
+
for param_group in optimizer.param_groups:
|
| 495 |
+
for param in param_group['params']:
|
| 496 |
+
if param.grad is not None:
|
| 497 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 498 |
+
|
| 499 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 500 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 501 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 502 |
+
|
| 503 |
+
return initial_epoch, Deformddpm, optimizer
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
if use_distributed:
|
| 509 |
+
world_size = torch.cuda.device_count()
|
| 510 |
+
print(f"Distributed GPU number = {world_size}")
|
| 511 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 512 |
+
else:
|
| 513 |
+
main_train(0,1)
|
OM_train_3modes_original.py
ADDED
|
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
|
| 3 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 4 |
+
sys.path.append(ROOT_DIR)
|
| 5 |
+
|
| 6 |
+
import gc
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.utils import save_image
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from torch.optim import Adam, SGD
|
| 14 |
+
from Diffusion.diffuser import DeformDDPM
|
| 15 |
+
from Diffusion.networks import get_net, STN
|
| 16 |
+
from torchvision.transforms import Lambda
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import Diffusion.losses as losses
|
| 19 |
+
import random
|
| 20 |
+
import glob
|
| 21 |
+
import numpy as np
|
| 22 |
+
import utils
|
| 23 |
+
from tqdm import tqdm
|
| 24 |
+
|
| 25 |
+
from Dataloader.dataloader0 import get_dataloader
|
| 26 |
+
from Dataloader.dataLoader import *
|
| 27 |
+
|
| 28 |
+
from Dataloader.dataloader_utils import thresh_img
|
| 29 |
+
import yaml
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
# XPU support: import Intel Extension for PyTorch and oneCCL bindings if available
|
| 33 |
+
try:
|
| 34 |
+
import intel_extension_for_pytorch as ipex
|
| 35 |
+
except ImportError:
|
| 36 |
+
ipex = None
|
| 37 |
+
try:
|
| 38 |
+
import oneccl_bindings_for_pytorch
|
| 39 |
+
except (ImportError, Exception) as e:
|
| 40 |
+
print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
|
| 41 |
+
|
| 42 |
+
####################
|
| 43 |
+
import torch.multiprocessing as mp
|
| 44 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 45 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 46 |
+
import torch.distributed as dist
|
| 47 |
+
# from torch.distributed import init_process_group
|
| 48 |
+
###############
|
| 49 |
+
def _device_available(device_type):
|
| 50 |
+
if device_type == 'xpu':
|
| 51 |
+
return hasattr(torch, 'xpu') and torch.xpu.is_available()
|
| 52 |
+
return torch.cuda.is_available()
|
| 53 |
+
|
| 54 |
+
def _device_count(device_type):
|
| 55 |
+
if device_type == 'xpu':
|
| 56 |
+
return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
|
| 57 |
+
return torch.cuda.device_count()
|
| 58 |
+
|
| 59 |
+
def _set_device(rank, device_type):
|
| 60 |
+
if device_type == 'xpu':
|
| 61 |
+
torch.xpu.set_device(rank)
|
| 62 |
+
else:
|
| 63 |
+
torch.cuda.set_device(rank)
|
| 64 |
+
|
| 65 |
+
def _empty_cache(device_type):
|
| 66 |
+
if device_type == 'xpu' and hasattr(torch, 'xpu'):
|
| 67 |
+
torch.xpu.empty_cache()
|
| 68 |
+
elif torch.cuda.is_available():
|
| 69 |
+
torch.cuda.empty_cache()
|
| 70 |
+
|
| 71 |
+
def ddp_setup(rank, world_size):
|
| 72 |
+
"""
|
| 73 |
+
Args:
|
| 74 |
+
rank: Unique identifier of each process (local_rank when launched by torchrun)
|
| 75 |
+
world_size: Total number of processes
|
| 76 |
+
"""
|
| 77 |
+
backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
|
| 78 |
+
if "LOCAL_RANK" in os.environ:
|
| 79 |
+
# Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set
|
| 80 |
+
dist.init_process_group(backend=backend)
|
| 81 |
+
_set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
|
| 82 |
+
else:
|
| 83 |
+
# Single-node mp.spawn
|
| 84 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 85 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 86 |
+
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| 87 |
+
_set_device(rank, DEVICE_TYPE)
|
| 88 |
+
|
| 89 |
+
EPS = 1e-5
|
| 90 |
+
MSK_EPS = 0.01
|
| 91 |
+
TEXT_EMBED_PROB = 0.5
|
| 92 |
+
AUG_RESAMPLE_PROB = 0.5
|
| 93 |
+
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg]
|
| 94 |
+
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
|
| 95 |
+
LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf]
|
| 96 |
+
DIFF_REG_BATCH_RATIO = 2
|
| 97 |
+
LOSS_WEIGHT_CONTRASTIVE = 0.001
|
| 98 |
+
REGISTRATION_STEP_RATIO = 1
|
| 99 |
+
CONTRASTIVE_STEP_RATIO = 1
|
| 100 |
+
|
| 101 |
+
# AUG_PERMUTE_PROB = 0.35
|
| 102 |
+
|
| 103 |
+
parser = argparse.ArgumentParser()
|
| 104 |
+
|
| 105 |
+
# config_file_path = 'Config/config_cmr.yaml'
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--config",
|
| 108 |
+
"-C",
|
| 109 |
+
help="Path for the config file",
|
| 110 |
+
type=str,
|
| 111 |
+
# default="Config/config_cmr.yaml",
|
| 112 |
+
# default="Config/config_lct.yaml",
|
| 113 |
+
default="Config/config_all.yaml",
|
| 114 |
+
required=False,
|
| 115 |
+
)
|
| 116 |
+
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| 117 |
+
args = parser.parse_args()
|
| 118 |
+
|
| 119 |
+
# Read config early to determine device type for DDP setup
|
| 120 |
+
with open(args.config, 'r') as _f:
|
| 121 |
+
_cfg = yaml.safe_load(_f)
|
| 122 |
+
DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu'
|
| 123 |
+
|
| 124 |
+
# Auto-detect: use DDP only when multiple devices are available
|
| 125 |
+
use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
|
| 126 |
+
# use_distributed = True
|
| 127 |
+
# use_distributed = False
|
| 128 |
+
#=======================================================================================================================
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| 132 |
+
if use_distributed:
|
| 133 |
+
ddp_setup(rank,world_size)
|
| 134 |
+
|
| 135 |
+
if torch.distributed.is_initialized() and rank == 0:
|
| 136 |
+
print(f"World size: {torch.distributed.get_world_size()}")
|
| 137 |
+
print(f"Communication backend: {torch.distributed.get_backend()}")
|
| 138 |
+
# gpu_id = global rank (for save/print guards); rank = local device index
|
| 139 |
+
if "RANK" in os.environ:
|
| 140 |
+
gpu_id = int(os.environ["RANK"])
|
| 141 |
+
rank = int(os.environ["LOCAL_RANK"])
|
| 142 |
+
else:
|
| 143 |
+
gpu_id = rank
|
| 144 |
+
|
| 145 |
+
# Load the YAML file into a dictionary
|
| 146 |
+
with open(args.config, 'r') as file:
|
| 147 |
+
hyp_parameters = yaml.safe_load(file)
|
| 148 |
+
if args.batchsize > 0:
|
| 149 |
+
hyp_parameters['batchsize'] = args.batchsize
|
| 150 |
+
if gpu_id == 0:
|
| 151 |
+
print(hyp_parameters)
|
| 152 |
+
|
| 153 |
+
# epoch_per_save=10
|
| 154 |
+
epoch_per_save=hyp_parameters['epoch_per_save']
|
| 155 |
+
|
| 156 |
+
data_name=hyp_parameters['data_name']
|
| 157 |
+
net_name = hyp_parameters['net_name']
|
| 158 |
+
|
| 159 |
+
Net=get_net(net_name)
|
| 160 |
+
|
| 161 |
+
suffix_pth=f'_{data_name}_{net_name}.pth'
|
| 162 |
+
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| 163 |
+
model_dir=model_save_path
|
| 164 |
+
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
|
| 165 |
+
|
| 166 |
+
# Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train')
|
| 167 |
+
|
| 168 |
+
# tsfm = torchvision.transforms.Compose([
|
| 169 |
+
# torchvision.transforms.ToTensor(),
|
| 170 |
+
# ])
|
| 171 |
+
|
| 172 |
+
# dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale'])
|
| 173 |
+
# train_loader = DataLoader(
|
| 174 |
+
# dataset,
|
| 175 |
+
# batch_size=hyp_parameters['batchsize'],
|
| 176 |
+
# # shuffle=False,
|
| 177 |
+
# shuffle=True,
|
| 178 |
+
# drop_last=True,
|
| 179 |
+
# )
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# dataset = OminiDataset_v1(transform=None)
|
| 183 |
+
dataset = OMDataset_indiv(transform=None)
|
| 184 |
+
# datasetp = OminiDataset_paired(transform=None)
|
| 185 |
+
datasetp = OMDataset_pair(transform=None)
|
| 186 |
+
|
| 187 |
+
if use_distributed:
|
| 188 |
+
sampler = DistributedSampler(dataset, shuffle=True)
|
| 189 |
+
sampler_p = DistributedSampler(datasetp, shuffle=True)
|
| 190 |
+
else:
|
| 191 |
+
sampler = None
|
| 192 |
+
sampler_p = None
|
| 193 |
+
|
| 194 |
+
train_loader = DataLoader(
|
| 195 |
+
dataset,
|
| 196 |
+
batch_size=hyp_parameters['batchsize'],
|
| 197 |
+
shuffle=(sampler is None),
|
| 198 |
+
drop_last=True,
|
| 199 |
+
sampler=sampler,
|
| 200 |
+
)
|
| 201 |
+
train_loader_p = DataLoader(
|
| 202 |
+
datasetp,
|
| 203 |
+
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| 204 |
+
shuffle=(sampler_p is None),
|
| 205 |
+
drop_last=True,
|
| 206 |
+
sampler=sampler_p,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
Deformddpm = DeformDDPM(
|
| 212 |
+
network=Net(
|
| 213 |
+
n_steps=hyp_parameters["timesteps"],
|
| 214 |
+
ndims=hyp_parameters["ndims"],
|
| 215 |
+
num_input_chn = hyp_parameters["num_input_chn"],
|
| 216 |
+
res = hyp_parameters['img_size']
|
| 217 |
+
),
|
| 218 |
+
n_steps=hyp_parameters["timesteps"],
|
| 219 |
+
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| 220 |
+
device=hyp_parameters["device"],
|
| 221 |
+
batch_size=hyp_parameters["batchsize"],
|
| 222 |
+
img_pad_mode=hyp_parameters["img_pad_mode"],
|
| 223 |
+
v_scale=hyp_parameters["v_scale"],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
ddf_stn = STN(
|
| 228 |
+
img_sz=hyp_parameters["img_size"],
|
| 229 |
+
ndims=hyp_parameters["ndims"],
|
| 230 |
+
# padding_mode="zeros",
|
| 231 |
+
padding_mode=hyp_parameters["padding_mode"],
|
| 232 |
+
device=hyp_parameters["device"],
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if use_distributed:
|
| 237 |
+
device = f"{DEVICE_TYPE}:{rank}"
|
| 238 |
+
Deformddpm.to(device)
|
| 239 |
+
Deformddpm = DDP(Deformddpm, device_ids=[rank])
|
| 240 |
+
ddf_stn.to(device)
|
| 241 |
+
else:
|
| 242 |
+
Deformddpm.to(hyp_parameters["device"])
|
| 243 |
+
ddf_stn.to(hyp_parameters["device"])
|
| 244 |
+
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# mse = nn.MSELoss()
|
| 248 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
|
| 249 |
+
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
|
| 250 |
+
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| 251 |
+
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
| 252 |
+
|
| 253 |
+
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 254 |
+
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
| 255 |
+
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| 256 |
+
loss_imgsim = losses.MSLNCC()
|
| 257 |
+
loss_imgmse = losses.LMSE()
|
| 258 |
+
|
| 259 |
+
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
| 260 |
+
# hyp_parameters["lr"]=0.00000001
|
| 261 |
+
# optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01)
|
| 262 |
+
# optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98)
|
| 263 |
+
# optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9)
|
| 264 |
+
|
| 265 |
+
# # LR scheduler ----- YHM
|
| 266 |
+
# scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
|
| 267 |
+
|
| 268 |
+
# Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth'))
|
| 269 |
+
|
| 270 |
+
# check for existing models
|
| 271 |
+
if not os.path.exists(model_dir):
|
| 272 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 273 |
+
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 274 |
+
model_files.sort()
|
| 275 |
+
if model_files:
|
| 276 |
+
if gpu_id == 0:
|
| 277 |
+
print(model_files)
|
| 278 |
+
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
|
| 279 |
+
else:
|
| 280 |
+
initial_epoch = 0
|
| 281 |
+
|
| 282 |
+
if gpu_id == 0:
|
| 283 |
+
print('len_train_data: ',len(dataset))
|
| 284 |
+
# Training loop
|
| 285 |
+
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| 286 |
+
if use_distributed and sampler is not None:
|
| 287 |
+
sampler.set_epoch(epoch)
|
| 288 |
+
sampler_p.set_epoch(epoch)
|
| 289 |
+
|
| 290 |
+
epoch_loss_tot = 0.0
|
| 291 |
+
epoch_loss_gen_d = 0.0
|
| 292 |
+
epoch_loss_gen_a = 0.0
|
| 293 |
+
epoch_loss_reg = 0.0
|
| 294 |
+
epoch_loss_regist = 0.0
|
| 295 |
+
epoch_loss_imgsim = 0.0
|
| 296 |
+
epoch_loss_imgmse = 0.0
|
| 297 |
+
epoch_loss_ddfreg = 0.0
|
| 298 |
+
epoch_loss_contrastive = 0.0
|
| 299 |
+
# Set model inside to train model
|
| 300 |
+
Deformddpm.train()
|
| 301 |
+
|
| 302 |
+
loss_nan_step = 0 # yu: count the number of nan loss steps
|
| 303 |
+
|
| 304 |
+
total = min(len(train_loader), len(train_loader_p))
|
| 305 |
+
total_reg = total // REGISTRATION_STEP_RATIO
|
| 306 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 307 |
+
# for step, batch in tqdm(enumerate(train_loader)):
|
| 308 |
+
# for step, batch in enumerate(train_loader_omni):
|
| 309 |
+
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
| 310 |
+
|
| 311 |
+
# x0, _ = batch
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# ==========================================================================
|
| 315 |
+
# diffusion train on single image
|
| 316 |
+
|
| 317 |
+
# x0 = batch # for omni dataset
|
| 318 |
+
[x0,embd] = batch # for om dataset
|
| 319 |
+
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
| 320 |
+
# print('embd:', embd.shape)
|
| 321 |
+
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| 322 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 323 |
+
embd_in = embd_dev
|
| 324 |
+
else:
|
| 325 |
+
embd_in = None
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
n = x0.size()[0] # batch_size -> n
|
| 330 |
+
x0 = x0.to(hyp_parameters["device"])
|
| 331 |
+
|
| 332 |
+
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
| 333 |
+
|
| 334 |
+
# random deformation + rotation
|
| 335 |
+
if hyp_parameters["ndims"]>2:
|
| 336 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 337 |
+
x0 = utils.random_resample(x0, deform_scale=0)
|
| 338 |
+
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
|
| 339 |
+
else:
|
| 340 |
+
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
| 341 |
+
# x0 = transformer(x0)
|
| 342 |
+
if hyp_parameters['noise_scale']>0:
|
| 343 |
+
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| 344 |
+
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| 345 |
+
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 346 |
+
|
| 347 |
+
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
|
| 348 |
+
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| 349 |
+
hyp_parameters["device"]
|
| 350 |
+
) # pick up a seq of rand number from 0 to 'timestep'
|
| 351 |
+
|
| 352 |
+
# proc_type = random.choice(['adding', 'independ', 'downsample', 'slice', 'project', 'none', 'uncon', 'uncon', 'uncon'])
|
| 353 |
+
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
| 354 |
+
# print('proc_type:', proc_type)
|
| 355 |
+
ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| 356 |
+
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
| 357 |
+
|
| 358 |
+
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
|
| 359 |
+
|
| 360 |
+
# print(torch.max(torch.abs(pre_dvf_I)))
|
| 361 |
+
# print(torch.max(torch.abs(dvf_I)))
|
| 362 |
+
|
| 363 |
+
loss_tot=0
|
| 364 |
+
|
| 365 |
+
loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| 366 |
+
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| 367 |
+
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 368 |
+
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| 369 |
+
|
| 370 |
+
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| 371 |
+
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| 372 |
+
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
| 373 |
+
|
| 374 |
+
# >> JZ: print nan in x0
|
| 375 |
+
if torch.isnan(x0).any():
|
| 376 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 377 |
+
# >> JZ: print loss of ddf
|
| 378 |
+
if loss_ddf>0.001:
|
| 379 |
+
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
| 380 |
+
# yu: check if loss_tot==nan or inf
|
| 381 |
+
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
|
| 382 |
+
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| 383 |
+
loss_nan_step += 1
|
| 384 |
+
continue
|
| 385 |
+
if loss_nan_step > 5:
|
| 386 |
+
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| 387 |
+
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
| 388 |
+
|
| 389 |
+
optimizer.zero_grad()
|
| 390 |
+
loss_tot.backward()
|
| 391 |
+
optimizer.step()
|
| 392 |
+
|
| 393 |
+
epoch_loss_tot += loss_tot.item() / total
|
| 394 |
+
epoch_loss_gen_d += loss_gen_d.item() / total
|
| 395 |
+
epoch_loss_gen_a += loss_gen_a.item() / total
|
| 396 |
+
epoch_loss_reg += loss_ddf.item() / total
|
| 397 |
+
|
| 398 |
+
# ==========================================================================
|
| 399 |
+
# contrastive train on single image (text-image alignment)
|
| 400 |
+
loss_contra_val = None
|
| 401 |
+
if step % CONTRASTIVE_STEP_RATIO == 0:
|
| 402 |
+
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
|
| 403 |
+
n_contra = x0.size()[0]
|
| 404 |
+
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
| 405 |
+
_ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
|
| 406 |
+
if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
|
| 407 |
+
img_embd = raw_network.img_embd # [B, 1024]
|
| 408 |
+
loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.05) # contrastive loss to align image embedding with text embedding, with a margin of 0.02
|
| 409 |
+
|
| 410 |
+
optimizer.zero_grad()
|
| 411 |
+
loss_contra.backward()
|
| 412 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
|
| 413 |
+
optimizer.step()
|
| 414 |
+
loss_contra_val = loss_contra.item()
|
| 415 |
+
epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
|
| 416 |
+
else:
|
| 417 |
+
if gpu_id == 0:
|
| 418 |
+
print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
|
| 419 |
+
|
| 420 |
+
# ==========================================================================
|
| 421 |
+
# registration train on paired images
|
| 422 |
+
if step%REGISTRATION_STEP_RATIO == 0 and loss_gen_a.item()<-0.6: # only train registration on relatively well-deformed images, to avoid too large registration loss and unstable training in the early stage
|
| 423 |
+
[x1, y1, _, embd_y] = batch_p
|
| 424 |
+
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| 425 |
+
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| 426 |
+
else:
|
| 427 |
+
embd_y = None
|
| 428 |
+
|
| 429 |
+
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| 430 |
+
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| 431 |
+
n = x1.size()[0] # batch_size -> n
|
| 432 |
+
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| 433 |
+
if hyp_parameters['noise_scale']>0:
|
| 434 |
+
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| 435 |
+
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| 436 |
+
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| 437 |
+
x1 = x1 * random_scale + random_shift
|
| 438 |
+
y1 = y1 * random_scale + random_shift
|
| 439 |
+
|
| 440 |
+
scale_regist = np.random.uniform(0.0,0.7)
|
| 441 |
+
select_timestep = np.random.randint(12, 25) # select a random number of timesteps to sample, between 8 and 16
|
| 442 |
+
T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
| 443 |
+
|
| 444 |
+
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
| 445 |
+
|
| 446 |
+
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| 447 |
+
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| 448 |
+
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| 449 |
+
msk_tgt = msk_tgt+MSK_EPS
|
| 450 |
+
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
|
| 451 |
+
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
|
| 452 |
+
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
|
| 453 |
+
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
|
| 454 |
+
|
| 455 |
+
loss_regist = 0
|
| 456 |
+
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| 457 |
+
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| 458 |
+
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
| 459 |
+
|
| 460 |
+
# >> JZ: print nan in x0
|
| 461 |
+
if torch.isnan(x0).any():
|
| 462 |
+
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
| 463 |
+
# >> JZ: print loss of ddf
|
| 464 |
+
if loss_ddf1>0.002:
|
| 465 |
+
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
| 466 |
+
|
| 467 |
+
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| 468 |
+
optimizer.zero_grad()
|
| 469 |
+
loss_regist.backward()
|
| 470 |
+
|
| 471 |
+
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1)
|
| 472 |
+
optimizer.step()
|
| 473 |
+
|
| 474 |
+
epoch_loss_regist += loss_regist.item()
|
| 475 |
+
epoch_loss_imgsim += loss_sim.item()
|
| 476 |
+
epoch_loss_imgmse += loss_mse.item()
|
| 477 |
+
epoch_loss_ddfreg += loss_ddf1.item()
|
| 478 |
+
else:
|
| 479 |
+
loss_sim = torch.tensor(0.0)
|
| 480 |
+
loss_mse = torch.tensor(0.0)
|
| 481 |
+
loss_ddf1 = torch.tensor(0.0)
|
| 482 |
+
loss_regist = torch.tensor(0.0)
|
| 483 |
+
if step % REGISTRATION_STEP_RATIO==0:
|
| 484 |
+
total_reg = total_reg-1
|
| 485 |
+
|
| 486 |
+
# if step % 50 == 0:
|
| 487 |
+
# print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| 488 |
+
# if loss_contra_val is not None:
|
| 489 |
+
# print(f' loss_contrastive: {loss_contra_val:.6f}')
|
| 490 |
+
# print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| 491 |
+
|
| 492 |
+
if gpu_id == 0:
|
| 493 |
+
print('==================')
|
| 494 |
+
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| 495 |
+
print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| 496 |
+
print(f' loss_regist: {epoch_loss_regist/total_reg} = {epoch_loss_imgsim/total_reg} (imgsim) + {epoch_loss_imgmse/total_reg} (imgmse) + {epoch_loss_ddfreg/total_reg} (ddf)')
|
| 497 |
+
print('==================')
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
if 0 == epoch % epoch_per_save:
|
| 501 |
+
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| 502 |
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
| 503 |
+
# break # FOR TESTING
|
| 504 |
+
if not use_distributed:
|
| 505 |
+
print(f"saved in {save_dir}")
|
| 506 |
+
# torch.save(Deformddpm.state_dict(), save_dir)
|
| 507 |
+
torch.save({
|
| 508 |
+
'model_state_dict': Deformddpm.state_dict(),
|
| 509 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 510 |
+
'epoch': epoch
|
| 511 |
+
}, save_dir)
|
| 512 |
+
elif gpu_id == 0:
|
| 513 |
+
print(f"saved in {save_dir}")
|
| 514 |
+
# torch.save(Deformddpm.module.state_dict(), save_dir)
|
| 515 |
+
torch.save({
|
| 516 |
+
'model_state_dict': Deformddpm.module.state_dict(),
|
| 517 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 518 |
+
'epoch': epoch
|
| 519 |
+
}, save_dir)
|
| 520 |
+
|
| 521 |
+
# Resource cleanup at the end of training
|
| 522 |
+
_empty_cache(DEVICE_TYPE)
|
| 523 |
+
gc.collect()
|
| 524 |
+
if use_distributed and dist.is_initialized():
|
| 525 |
+
dist.destroy_process_group()
|
| 526 |
+
|
| 527 |
+
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
| 528 |
+
|
| 529 |
+
if gpu_id == 0:
|
| 530 |
+
# if 0:
|
| 531 |
+
utils.print_memory_usage("Before Loading Model")
|
| 532 |
+
gc.collect()
|
| 533 |
+
_empty_cache(DEVICE_TYPE)
|
| 534 |
+
# Deformddpm.network.load_state_dict(torch.load(latest_model_file))
|
| 535 |
+
# Deformddpm.load_state_dict(torch.load(latest_model_file), strict=False)
|
| 536 |
+
checkpoint = torch.load(model_file, map_location='cpu')
|
| 537 |
+
# checkpoint = torch.load(latest_model_file, map_location=f"cuda:{rank}")
|
| 538 |
+
if use_distributed:
|
| 539 |
+
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 540 |
+
else:
|
| 541 |
+
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| 542 |
+
if load_strict:
|
| 543 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 544 |
+
utils.print_memory_usage("After Loading Checkpoint on GPU")
|
| 545 |
+
|
| 546 |
+
if use_distributed:
|
| 547 |
+
# Broadcast model weights from rank 0 to all other GPUs
|
| 548 |
+
dist.barrier()
|
| 549 |
+
for param in Deformddpm.parameters():
|
| 550 |
+
dist.broadcast(param.data, src=0) # Synchronize model across ranks
|
| 551 |
+
dist.barrier()
|
| 552 |
+
for param_group in optimizer.param_groups:
|
| 553 |
+
for param in param_group['params']:
|
| 554 |
+
if param.grad is not None:
|
| 555 |
+
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
|
| 556 |
+
|
| 557 |
+
# initial_epoch = checkpoint['epoch'] + 1
|
| 558 |
+
# get the epoch number from the filename and add 1 to set as initial_epoch
|
| 559 |
+
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
|
| 560 |
+
|
| 561 |
+
return initial_epoch, Deformddpm, optimizer
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
if __name__ == "__main__":
|
| 566 |
+
if "LOCAL_RANK" in os.environ:
|
| 567 |
+
# Multi-node: launched by torchrun / srun
|
| 568 |
+
use_distributed = True
|
| 569 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 570 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 571 |
+
print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
|
| 572 |
+
try:
|
| 573 |
+
main_train(local_rank, world_size)
|
| 574 |
+
except Exception as e:
|
| 575 |
+
import traceback
|
| 576 |
+
print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
|
| 577 |
+
traceback.print_exc()
|
| 578 |
+
raise
|
| 579 |
+
elif use_distributed:
|
| 580 |
+
# Single-node multi-GPU: use mp.spawn
|
| 581 |
+
world_size = _device_count(DEVICE_TYPE)
|
| 582 |
+
print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
|
| 583 |
+
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| 584 |
+
else:
|
| 585 |
+
main_train(0,1)
|
OMorpher/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .omorpher import OMorpher
|
| 2 |
+
|
| 3 |
+
__all__ = ['OMorpher']
|
OMorpher/omorpher.py
ADDED
|
@@ -0,0 +1,1058 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OMorpher — Object-oriented wrapper for OmniMorph diffusion-based deformation.
|
| 3 |
+
|
| 4 |
+
Stores original high-res images and composes all intermediate deformations as
|
| 5 |
+
deformation fields (DDFs), resampling only once at the end to avoid blurring.
|
| 6 |
+
Independent of DeformDDPM at runtime; reimplements the diffusion logic using
|
| 7 |
+
the network / STN / loss building blocks from Diffusion.*.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import glob
|
| 12 |
+
import math
|
| 13 |
+
import random
|
| 14 |
+
from typing import Optional, Union, List, Tuple, Dict
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
import yaml
|
| 22 |
+
import SimpleITK as sitk
|
| 23 |
+
from skimage.transform import resize as sk_resize
|
| 24 |
+
|
| 25 |
+
from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
|
| 26 |
+
from Diffusion.losses import Grad, MRSE, NCC
|
| 27 |
+
|
| 28 |
+
EPS = 1e-8
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class OMorpher:
|
| 32 |
+
"""High-level interface for OmniMorph deformation diffusion.
|
| 33 |
+
|
| 34 |
+
All images are kept at their original resolution internally. Deformation
|
| 35 |
+
fields are composed at model resolution and up-scaled on demand so that the
|
| 36 |
+
original image is resampled at most *once*.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# ------------------------------------------------------------------
|
| 40 |
+
# Construction
|
| 41 |
+
# ------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
config: Union[str, dict],
|
| 46 |
+
checkpoint_path: Optional[str] = None,
|
| 47 |
+
device: Optional[str] = None,
|
| 48 |
+
bert_model_path: Optional[str] = None,
|
| 49 |
+
):
|
| 50 |
+
# ---- Config ----
|
| 51 |
+
if isinstance(config, str):
|
| 52 |
+
with open(config, "r") as f:
|
| 53 |
+
config = yaml.safe_load(f)
|
| 54 |
+
self.config: dict = config
|
| 55 |
+
|
| 56 |
+
self.net_name: str = config.get("net_name", "recmutattnnet")
|
| 57 |
+
self.ndims: int = config.get("ndims", 3)
|
| 58 |
+
self.img_size: int = config.get("img_size", 128)
|
| 59 |
+
self.timesteps: int = config.get("timesteps", 80)
|
| 60 |
+
self.v_scale: float = config.get("v_scale", 5e-5)
|
| 61 |
+
self.noise_scale: float = config.get("noise_scale", 0.1)
|
| 62 |
+
self.condition_type: str = config.get("condition_type", "none")
|
| 63 |
+
self.num_input_chn: int = config.get("num_input_chn", 1)
|
| 64 |
+
self.img_pad_mode: str = config.get("img_pad_mode", "zeros")
|
| 65 |
+
self.ddf_pad_mode: str = config.get("ddf_pad_mode", "border")
|
| 66 |
+
self.padding_mode: str = config.get("padding_mode", "border")
|
| 67 |
+
self.resample_mode: str = config.get("resample_mode", "bilinear")
|
| 68 |
+
self.batch_size: int = config.get("batchsize", 1)
|
| 69 |
+
self.data_name: str = config.get("data_name", "all")
|
| 70 |
+
self.clamp_range: list = config.get("clamp_range", [-400, 400])
|
| 71 |
+
self.inf_mode: bool = config.get("inf_mode", True)
|
| 72 |
+
|
| 73 |
+
# ---- Device ----
|
| 74 |
+
if device is not None:
|
| 75 |
+
self.device = torch.device(device)
|
| 76 |
+
else:
|
| 77 |
+
self.device = self._resolve_device(config.get("device", None))
|
| 78 |
+
|
| 79 |
+
# ---- BERT (lazy) ----
|
| 80 |
+
self.bert_model_path = bert_model_path or os.path.join(
|
| 81 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 82 |
+
"External", "Models", "bert_large_uncased",
|
| 83 |
+
)
|
| 84 |
+
self._bert_model = None
|
| 85 |
+
self._bert_tokenizer = None
|
| 86 |
+
|
| 87 |
+
# ---- Network ----
|
| 88 |
+
Net = get_net(self.net_name)
|
| 89 |
+
self.network = Net(
|
| 90 |
+
n_steps=self.timesteps,
|
| 91 |
+
ndims=self.ndims,
|
| 92 |
+
num_input_chn=self.num_input_chn,
|
| 93 |
+
res=self.img_size,
|
| 94 |
+
)
|
| 95 |
+
self.network.to(self.device)
|
| 96 |
+
|
| 97 |
+
# ---- STN instances ----
|
| 98 |
+
self.ctl_ratio = 4
|
| 99 |
+
self.ctl_sz = self.img_size // self.ctl_ratio
|
| 100 |
+
|
| 101 |
+
self.stn_full = STN(
|
| 102 |
+
img_sz=self.img_size,
|
| 103 |
+
ndims=self.ndims,
|
| 104 |
+
padding_mode=self.padding_mode,
|
| 105 |
+
device=self.device,
|
| 106 |
+
)
|
| 107 |
+
self.stn_ctl = STN(
|
| 108 |
+
img_sz=self.ctl_sz,
|
| 109 |
+
ndims=self.ndims,
|
| 110 |
+
padding_mode=self.ddf_pad_mode,
|
| 111 |
+
device=self.device,
|
| 112 |
+
)
|
| 113 |
+
self.img_stn = STN(
|
| 114 |
+
img_sz=self.img_size,
|
| 115 |
+
ndims=self.ndims,
|
| 116 |
+
padding_mode=self.img_pad_mode,
|
| 117 |
+
device=self.device,
|
| 118 |
+
resample_mode=self.resample_mode if self.resample_mode != "bilinear" else None,
|
| 119 |
+
)
|
| 120 |
+
self.msk_stn = STN(
|
| 121 |
+
img_sz=self.img_size,
|
| 122 |
+
ndims=self.ndims,
|
| 123 |
+
padding_mode=self.img_pad_mode,
|
| 124 |
+
device=self.device,
|
| 125 |
+
resample_mode="nearest",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# ---- Loss functions (for fine-tuning) ----
|
| 129 |
+
self._loss_grad = Grad(penalty=["l1"], ndims=self.ndims)
|
| 130 |
+
self._loss_dist = MRSE(img_sz=self.img_size)
|
| 131 |
+
self._loss_ang = NCC(img_sz=self.img_size)
|
| 132 |
+
|
| 133 |
+
# ---- Load checkpoint ----
|
| 134 |
+
if checkpoint_path is not None:
|
| 135 |
+
self._load_checkpoint(checkpoint_path)
|
| 136 |
+
else:
|
| 137 |
+
auto_path = self._auto_find_checkpoint()
|
| 138 |
+
if auto_path is not None:
|
| 139 |
+
self._load_checkpoint(auto_path)
|
| 140 |
+
|
| 141 |
+
self.network.eval()
|
| 142 |
+
|
| 143 |
+
# ---- State ----
|
| 144 |
+
self._init_img: Optional[torch.Tensor] = None # [B,1,S,S,S] model-res
|
| 145 |
+
self._init_img_raw: Optional[torch.Tensor] = None # [B,1,D,H,W] full-res
|
| 146 |
+
self._init_img_original_shape: Optional[tuple] = None
|
| 147 |
+
self._init_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
|
| 148 |
+
self._cond_img: Optional[torch.Tensor] = None # [B,1,S,S,S]
|
| 149 |
+
self._cond_txt: Optional[torch.Tensor] = None # [B,1024]
|
| 150 |
+
self._predicted_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
|
| 151 |
+
self._intermediate_ddfs: List[Tuple[int, torch.Tensor]] = []
|
| 152 |
+
|
| 153 |
+
# ---- Fine-tuning state ----
|
| 154 |
+
self._optimizer: Optional[torch.optim.Optimizer] = None
|
| 155 |
+
|
| 156 |
+
# ------------------------------------------------------------------
|
| 157 |
+
# Device resolution
|
| 158 |
+
# ------------------------------------------------------------------
|
| 159 |
+
|
| 160 |
+
@staticmethod
|
| 161 |
+
def _resolve_device(hint: Optional[str] = None) -> torch.device:
|
| 162 |
+
if hint is not None:
|
| 163 |
+
s = str(hint).lower()
|
| 164 |
+
if s not in ("auto", ""):
|
| 165 |
+
return torch.device(s)
|
| 166 |
+
# XPU → CUDA → CPU
|
| 167 |
+
try:
|
| 168 |
+
import intel_extension_for_pytorch # noqa: F401
|
| 169 |
+
if torch.xpu.is_available():
|
| 170 |
+
return torch.device("xpu")
|
| 171 |
+
except (ImportError, AttributeError):
|
| 172 |
+
pass
|
| 173 |
+
if torch.cuda.is_available():
|
| 174 |
+
return torch.device("cuda")
|
| 175 |
+
return torch.device("cpu")
|
| 176 |
+
|
| 177 |
+
# ------------------------------------------------------------------
|
| 178 |
+
# Checkpoint helpers
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
|
| 181 |
+
def _auto_find_checkpoint(self) -> Optional[str]:
|
| 182 |
+
pattern = os.path.join(
|
| 183 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 184 |
+
"Models",
|
| 185 |
+
f"{self.data_name}_{self.net_name}",
|
| 186 |
+
"*.pth",
|
| 187 |
+
)
|
| 188 |
+
files = sorted(glob.glob(pattern))
|
| 189 |
+
return files[-1] if files else None
|
| 190 |
+
|
| 191 |
+
def _load_checkpoint(self, path: str):
|
| 192 |
+
ckpt = torch.load(path, map_location="cpu")
|
| 193 |
+
state_dict = ckpt.get("model_state_dict", ckpt)
|
| 194 |
+
# Strip DDP 'module.' prefix and DeformDDPM wrapper keys
|
| 195 |
+
cleaned = {}
|
| 196 |
+
for k, v in state_dict.items():
|
| 197 |
+
k = k.replace("module.", "")
|
| 198 |
+
if k.startswith("network."):
|
| 199 |
+
k = k[len("network."):]
|
| 200 |
+
cleaned[k] = v
|
| 201 |
+
# Only load keys that exist in the network
|
| 202 |
+
net_keys = set(self.network.state_dict().keys())
|
| 203 |
+
filtered = {k: v for k, v in cleaned.items() if k in net_keys}
|
| 204 |
+
if filtered:
|
| 205 |
+
self.network.load_state_dict(filtered, strict=False)
|
| 206 |
+
|
| 207 |
+
# ------------------------------------------------------------------
|
| 208 |
+
# Public — Input setters
|
| 209 |
+
# ------------------------------------------------------------------
|
| 210 |
+
|
| 211 |
+
def set_init_img(
|
| 212 |
+
self,
|
| 213 |
+
img,
|
| 214 |
+
modality: Optional[str] = None,
|
| 215 |
+
) -> "OMorpher":
|
| 216 |
+
"""Set the initial image. Accepts numpy, torch, path, or (img, ddf) tuple."""
|
| 217 |
+
init_ddf = None
|
| 218 |
+
if isinstance(img, (tuple, list)):
|
| 219 |
+
img, init_ddf = img[0], img[1]
|
| 220 |
+
|
| 221 |
+
model_tensor, fullres_tensor, orig_shape = self._standardize_img(
|
| 222 |
+
img, modality=modality, keep_raw=True,
|
| 223 |
+
)
|
| 224 |
+
self._init_img = model_tensor
|
| 225 |
+
self._init_img_raw = fullres_tensor
|
| 226 |
+
self._init_img_original_shape = orig_shape
|
| 227 |
+
|
| 228 |
+
if init_ddf is not None:
|
| 229 |
+
self._init_ddf = self._to_ddf_tensor(init_ddf)
|
| 230 |
+
else:
|
| 231 |
+
B = self._init_img.shape[0]
|
| 232 |
+
S = self.img_size
|
| 233 |
+
self._init_ddf = torch.zeros(
|
| 234 |
+
[B, self.ndims] + [S] * self.ndims,
|
| 235 |
+
dtype=torch.float32, device=self.device,
|
| 236 |
+
)
|
| 237 |
+
return self
|
| 238 |
+
|
| 239 |
+
def set_cond_img(
|
| 240 |
+
self,
|
| 241 |
+
img=None,
|
| 242 |
+
modality: Optional[str] = None,
|
| 243 |
+
) -> "OMorpher":
|
| 244 |
+
"""Set the conditioning image. Default: Gaussian noise sigma=0.1."""
|
| 245 |
+
if img is None:
|
| 246 |
+
B = self._init_img.shape[0] if self._init_img is not None else self.batch_size
|
| 247 |
+
S = self.img_size
|
| 248 |
+
self._cond_img = torch.randn(
|
| 249 |
+
[B, 1] + [S] * self.ndims,
|
| 250 |
+
dtype=torch.float32, device=self.device,
|
| 251 |
+
) * 0.1
|
| 252 |
+
else:
|
| 253 |
+
tensor, _, _ = self._standardize_img(img, modality=modality, keep_raw=False)
|
| 254 |
+
self._cond_img = tensor
|
| 255 |
+
return self
|
| 256 |
+
|
| 257 |
+
def set_cond_txt(self, txt=None) -> "OMorpher":
|
| 258 |
+
"""Set the text conditioning. Accepts string, numpy [1024], torch [1024], or None."""
|
| 259 |
+
self._cond_txt = self._standardize_txt(txt)
|
| 260 |
+
return self
|
| 261 |
+
|
| 262 |
+
def set_init_def(self, ddf=None) -> "OMorpher":
|
| 263 |
+
"""Set or regenerate the initial deformation field.
|
| 264 |
+
|
| 265 |
+
If *ddf* is ``None``, a random DDF is generated using the forward
|
| 266 |
+
diffusion parameters (useful for data augmentation).
|
| 267 |
+
"""
|
| 268 |
+
if ddf is None:
|
| 269 |
+
if self._init_img is None:
|
| 270 |
+
raise RuntimeError("set_init_img() must be called before set_init_def()")
|
| 271 |
+
t_val = self.config.get("start_noise_step", self.timesteps // 2)
|
| 272 |
+
t = torch.tensor([t_val], dtype=torch.long, device=self.device)
|
| 273 |
+
_, _, random_ddf = self._get_random_ddf(self._init_img, t)
|
| 274 |
+
self._init_ddf = random_ddf
|
| 275 |
+
else:
|
| 276 |
+
self._init_ddf = self._to_ddf_tensor(ddf)
|
| 277 |
+
return self
|
| 278 |
+
|
| 279 |
+
# ------------------------------------------------------------------
|
| 280 |
+
# Public — Core operations (inference)
|
| 281 |
+
# ------------------------------------------------------------------
|
| 282 |
+
|
| 283 |
+
def predict(
|
| 284 |
+
self,
|
| 285 |
+
T: Optional[list] = None,
|
| 286 |
+
proc_type: Optional[str] = None,
|
| 287 |
+
t_save: Optional[list] = None,
|
| 288 |
+
) -> "OMorpher":
|
| 289 |
+
"""Run reverse diffusion and store predicted DDF. Returns ``self`` for chaining."""
|
| 290 |
+
if self._init_img is None:
|
| 291 |
+
raise RuntimeError("set_init_img() must be called before predict()")
|
| 292 |
+
|
| 293 |
+
# Defaults
|
| 294 |
+
start_noise = self.config.get("start_noise_step", 0)
|
| 295 |
+
if T is None:
|
| 296 |
+
T = [start_noise, self.timesteps]
|
| 297 |
+
if proc_type is None:
|
| 298 |
+
proc_type = self.condition_type
|
| 299 |
+
|
| 300 |
+
B = self._init_img.shape[0]
|
| 301 |
+
S = self.img_size
|
| 302 |
+
|
| 303 |
+
# Conditioning
|
| 304 |
+
cond_img_src = self._cond_img if self._cond_img is not None else self._init_img.clone().detach()
|
| 305 |
+
cond_img, mask, cond_ratio = self._proc_cond_img(cond_img_src, proc_type=proc_type)
|
| 306 |
+
|
| 307 |
+
# Text embedding
|
| 308 |
+
txt = self._cond_txt
|
| 309 |
+
if txt is None:
|
| 310 |
+
txt = torch.zeros([B, 1024], dtype=torch.float32, device=self.device)
|
| 311 |
+
|
| 312 |
+
# Reshape text for network consumption
|
| 313 |
+
if isinstance(self.network, DefRec_MutAttnNet):
|
| 314 |
+
txt = txt.view(B, -1, *([1] * self.ndims))
|
| 315 |
+
|
| 316 |
+
# Initial state
|
| 317 |
+
init_ddf_is_zero = (self._init_ddf is None) or torch.all(self._init_ddf == 0)
|
| 318 |
+
|
| 319 |
+
if not init_ddf_is_zero:
|
| 320 |
+
ddf_comp = self._init_ddf.clone()
|
| 321 |
+
img_rec = self.img_stn(self._init_img, ddf_comp)
|
| 322 |
+
elif T[0] is not None and T[0] > 0:
|
| 323 |
+
t_start = torch.tensor(np.array([T[0]]), device=self.device)
|
| 324 |
+
img_rec, _, ddf_comp = self._get_random_ddf(self._init_img, t_start)
|
| 325 |
+
else:
|
| 326 |
+
img_rec = self._init_img.clone()
|
| 327 |
+
ddf_comp = torch.zeros(
|
| 328 |
+
[B, self.ndims] + [S] * self.ndims,
|
| 329 |
+
dtype=torch.float32, device=self.device,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Reverse diffusion loop
|
| 333 |
+
self._intermediate_ddfs = []
|
| 334 |
+
|
| 335 |
+
rec_num = 2 # matches DeformDDPM.rec_num default
|
| 336 |
+
|
| 337 |
+
if isinstance(self.network, DefRec_MutAttnNet):
|
| 338 |
+
# DefRec network: pass full time list at once
|
| 339 |
+
t_list = list(range(T[1] - 1, -1, -1))
|
| 340 |
+
with torch.no_grad():
|
| 341 |
+
pre_dvf = self.network(
|
| 342 |
+
x=img_rec, y=cond_img, t=t_list, rec_num=rec_num, text=txt,
|
| 343 |
+
)
|
| 344 |
+
ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| 345 |
+
img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
|
| 346 |
+
if t_save:
|
| 347 |
+
self._intermediate_ddfs.append((0, ddf_comp.clone()))
|
| 348 |
+
else:
|
| 349 |
+
# Standard iterative recovery
|
| 350 |
+
time_steps = range(T[1] - 1, -1, -1)
|
| 351 |
+
for i in time_steps:
|
| 352 |
+
t = torch.tensor(np.array([i]), device=self.device)
|
| 353 |
+
with torch.no_grad():
|
| 354 |
+
pre_dvf = self.network(
|
| 355 |
+
x=img_rec, y=cond_img, t=t, rec_num=rec_num, text=txt,
|
| 356 |
+
)
|
| 357 |
+
ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| 358 |
+
img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
|
| 359 |
+
if t_save is not None and i in t_save:
|
| 360 |
+
self._intermediate_ddfs.append((i, ddf_comp.clone()))
|
| 361 |
+
|
| 362 |
+
self._predicted_ddf = ddf_comp
|
| 363 |
+
return self
|
| 364 |
+
|
| 365 |
+
def get_def(
|
| 366 |
+
self,
|
| 367 |
+
t_list: Optional[list] = None,
|
| 368 |
+
) -> Union[torch.Tensor, Dict[int, torch.Tensor]]:
|
| 369 |
+
"""Return the final predicted DDF, or intermediate DDFs for given timesteps."""
|
| 370 |
+
if t_list is None:
|
| 371 |
+
if self._predicted_ddf is None:
|
| 372 |
+
raise RuntimeError("predict() must be called before get_def()")
|
| 373 |
+
return self._predicted_ddf
|
| 374 |
+
out = {}
|
| 375 |
+
for t, ddf in self._intermediate_ddfs:
|
| 376 |
+
if t in t_list:
|
| 377 |
+
out[t] = ddf
|
| 378 |
+
return out
|
| 379 |
+
|
| 380 |
+
def apply_def(
|
| 381 |
+
self,
|
| 382 |
+
img=None,
|
| 383 |
+
ddf: Optional[torch.Tensor] = None,
|
| 384 |
+
padding_mode: Optional[str] = None,
|
| 385 |
+
resample_mode: Optional[str] = None,
|
| 386 |
+
) -> torch.Tensor:
|
| 387 |
+
"""Apply a DDF to an image. Auto-upscales DDF when sizes differ.
|
| 388 |
+
|
| 389 |
+
Defaults: init image at full resolution, predicted DDF.
|
| 390 |
+
"""
|
| 391 |
+
if padding_mode is None:
|
| 392 |
+
padding_mode = self.padding_mode
|
| 393 |
+
if resample_mode is None:
|
| 394 |
+
resample_mode = "bilinear"
|
| 395 |
+
|
| 396 |
+
# Default DDF
|
| 397 |
+
if ddf is None:
|
| 398 |
+
if self._predicted_ddf is None:
|
| 399 |
+
raise RuntimeError("predict() must be called before apply_def()")
|
| 400 |
+
ddf = self._predicted_ddf
|
| 401 |
+
|
| 402 |
+
# Default image: full-res init image tensor
|
| 403 |
+
if img is None:
|
| 404 |
+
if self._init_img_raw is not None:
|
| 405 |
+
vol_tensor = self._init_img_raw
|
| 406 |
+
else:
|
| 407 |
+
vol_tensor = self._init_img
|
| 408 |
+
else:
|
| 409 |
+
vol_tensor = self._ensure_tensor(img)
|
| 410 |
+
|
| 411 |
+
# Upscale DDF if sizes differ
|
| 412 |
+
target_sz = list(vol_tensor.shape[2:])
|
| 413 |
+
ddf_sz = list(ddf.shape[2:])
|
| 414 |
+
if target_sz != ddf_sz:
|
| 415 |
+
ddf = F.interpolate(
|
| 416 |
+
ddf, size=target_sz,
|
| 417 |
+
mode="bilinear" if self.ndims == 2 else "trilinear",
|
| 418 |
+
align_corners=False,
|
| 419 |
+
)
|
| 420 |
+
|
| 421 |
+
return self._apply_ddf(vol_tensor, ddf, padding_mode=padding_mode, resample_mode=resample_mode)
|
| 422 |
+
|
| 423 |
+
# ------------------------------------------------------------------
|
| 424 |
+
# Public — Fine-tuning
|
| 425 |
+
# ------------------------------------------------------------------
|
| 426 |
+
|
| 427 |
+
def finetune_setup(
|
| 428 |
+
self,
|
| 429 |
+
lr: float = 1e-4,
|
| 430 |
+
optimizer_cls=None,
|
| 431 |
+
) -> "OMorpher":
|
| 432 |
+
"""Switch to training mode and create an optimizer."""
|
| 433 |
+
self.network.train()
|
| 434 |
+
self.inf_mode = False
|
| 435 |
+
if optimizer_cls is None:
|
| 436 |
+
optimizer_cls = torch.optim.Adam
|
| 437 |
+
self._optimizer = optimizer_cls(self.network.parameters(), lr=lr)
|
| 438 |
+
return self
|
| 439 |
+
|
| 440 |
+
def finetune_step(
|
| 441 |
+
self,
|
| 442 |
+
img_batch,
|
| 443 |
+
cond_batch=None,
|
| 444 |
+
text_batch=None,
|
| 445 |
+
t=None,
|
| 446 |
+
proc_type=None,
|
| 447 |
+
) -> dict:
|
| 448 |
+
"""Single training step. Returns loss dict."""
|
| 449 |
+
if self._optimizer is None:
|
| 450 |
+
raise RuntimeError("finetune_setup() must be called first")
|
| 451 |
+
|
| 452 |
+
img, _, _ = self._standardize_img(img_batch, keep_raw=False)
|
| 453 |
+
cond = self._standardize_img(cond_batch, keep_raw=False)[0] if cond_batch is not None else img.clone()
|
| 454 |
+
text = self._standardize_txt(text_batch)
|
| 455 |
+
|
| 456 |
+
B = img.shape[0]
|
| 457 |
+
if t is None:
|
| 458 |
+
t = torch.randint(0, self.timesteps, (B,), device=self.device)
|
| 459 |
+
else:
|
| 460 |
+
t = torch.tensor(t, device=self.device) if not isinstance(t, torch.Tensor) else t.to(self.device)
|
| 461 |
+
|
| 462 |
+
proc_type = proc_type or self.condition_type
|
| 463 |
+
cond_img, mask, cond_ratio = self._proc_cond_img(cond, proc_type=proc_type)
|
| 464 |
+
noisy_img, dvf_gt, _ = self._get_random_ddf(img, t)
|
| 465 |
+
|
| 466 |
+
# Reshape text for network
|
| 467 |
+
if isinstance(self.network, DefRec_MutAttnNet):
|
| 468 |
+
if text is not None:
|
| 469 |
+
text = text.view(B, -1, *([1] * self.ndims))
|
| 470 |
+
t_input = [t]
|
| 471 |
+
else:
|
| 472 |
+
t_input = t
|
| 473 |
+
|
| 474 |
+
pre_dvf = self.network(x=noisy_img * mask, y=cond_img, t=t_input, rec_num=2, text=text)
|
| 475 |
+
|
| 476 |
+
loss_grad = self._loss_grad(y_pred=pre_dvf, img=img)
|
| 477 |
+
trm_pred = self.stn_full(pre_dvf, dvf_gt)
|
| 478 |
+
loss_dist = self._loss_dist(pred=trm_pred, inv_lab=dvf_gt)
|
| 479 |
+
loss_ang = self._loss_ang(pred=trm_pred, inv_lab=dvf_gt)
|
| 480 |
+
loss_total = 2.0 * loss_ang + 1.0 * loss_dist + 16.0 * loss_grad
|
| 481 |
+
|
| 482 |
+
self._optimizer.zero_grad()
|
| 483 |
+
loss_total.backward()
|
| 484 |
+
self._optimizer.step()
|
| 485 |
+
|
| 486 |
+
return {
|
| 487 |
+
"loss_total": loss_total.item(),
|
| 488 |
+
"loss_grad": loss_grad.item(),
|
| 489 |
+
"loss_dist": loss_dist.item(),
|
| 490 |
+
"loss_ang": loss_ang.item(),
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
def finetune_save(self, path: str, epoch: int = 0):
|
| 494 |
+
"""Save checkpoint in the standard OmniMorph format."""
|
| 495 |
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| 496 |
+
torch.save(
|
| 497 |
+
{
|
| 498 |
+
"model_state_dict": self.network.state_dict(),
|
| 499 |
+
"optimizer_state_dict": self._optimizer.state_dict() if self._optimizer else None,
|
| 500 |
+
"epoch": epoch,
|
| 501 |
+
},
|
| 502 |
+
path,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
def finetune_teardown(self) -> "OMorpher":
|
| 506 |
+
"""Switch back to eval mode."""
|
| 507 |
+
self.network.eval()
|
| 508 |
+
self.inf_mode = True
|
| 509 |
+
self._optimizer = None
|
| 510 |
+
return self
|
| 511 |
+
|
| 512 |
+
# ------------------------------------------------------------------
|
| 513 |
+
# Private — Diffusion logic
|
| 514 |
+
# ------------------------------------------------------------------
|
| 515 |
+
|
| 516 |
+
def _get_ddf_scale(
|
| 517 |
+
self, t: torch.Tensor, divide_num: int = 1, max_ddf_num: int = 200,
|
| 518 |
+
) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
| 519 |
+
"""Timestep-dependent deformation magnitude. Mirrors DeformDDPM._get_ddf_scale()."""
|
| 520 |
+
rec_num = 1
|
| 521 |
+
mul_num_ddf = torch.floor_divide(2 * torch.pow(t.float(), 1.3), 3 * divide_num).int()
|
| 522 |
+
mul_num_dvf = torch.floor_divide(torch.pow(t.float(), 0.6), divide_num).int()
|
| 523 |
+
mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
|
| 524 |
+
mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
|
| 525 |
+
return rec_num, mul_num_ddf, mul_num_dvf
|
| 526 |
+
|
| 527 |
+
def _sample_random_uniform_multi_order(
|
| 528 |
+
self, high=None, low=0.0, order_num=3,
|
| 529 |
+
) -> float:
|
| 530 |
+
sample_value = low
|
| 531 |
+
for _ in range(order_num):
|
| 532 |
+
sample_value = np.random.uniform(low=sample_value, high=high)
|
| 533 |
+
return sample_value
|
| 534 |
+
|
| 535 |
+
def _multiscale_dvf_generate(
|
| 536 |
+
self, v_scale: float, ctl_szs: list = None, rand_v_scale: bool = True,
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
"""Multi-scale Gaussian DVF at control-point sizes."""
|
| 539 |
+
if ctl_szs is None:
|
| 540 |
+
ctl_szs = [4, 8, 16, 32, 64]
|
| 541 |
+
dvf = 0
|
| 542 |
+
for ctl_sz in ctl_szs:
|
| 543 |
+
_v = (
|
| 544 |
+
self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2)
|
| 545 |
+
if rand_v_scale
|
| 546 |
+
else v_scale
|
| 547 |
+
)
|
| 548 |
+
if ctl_sz <= 2:
|
| 549 |
+
_v = _v / 2
|
| 550 |
+
dvf_comp = torch.randn(
|
| 551 |
+
[self.batch_size, self.ndims] + [ctl_sz] * self.ndims
|
| 552 |
+
) * _v
|
| 553 |
+
dvf_comp = F.interpolate(
|
| 554 |
+
dvf_comp * self.ctl_sz / ctl_sz,
|
| 555 |
+
[self.ctl_sz] * self.ndims,
|
| 556 |
+
align_corners=False,
|
| 557 |
+
mode="bilinear" if self.ndims == 2 else "trilinear",
|
| 558 |
+
)
|
| 559 |
+
dvf = dvf + dvf_comp
|
| 560 |
+
return dvf
|
| 561 |
+
|
| 562 |
+
def _random_ddf_generate(
|
| 563 |
+
self,
|
| 564 |
+
rec_num: int = 3,
|
| 565 |
+
mul_num: list = None,
|
| 566 |
+
noise_ratio: float = 0.08,
|
| 567 |
+
select_num: int = 4,
|
| 568 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 569 |
+
"""Compose DVFs to build a DDF. Mirrors DeformDDPM._random_ddf_generate()."""
|
| 570 |
+
if mul_num is None:
|
| 571 |
+
mul_num = [torch.tensor([5]), torch.tensor([5])]
|
| 572 |
+
|
| 573 |
+
crop_rate = 2
|
| 574 |
+
# unsqueeze mul_num for broadcasting
|
| 575 |
+
for _ in range(self.ndims + 1):
|
| 576 |
+
mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
|
| 577 |
+
|
| 578 |
+
ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| 579 |
+
ddf = torch.zeros(ctl_ddf_sz)
|
| 580 |
+
dddf = torch.zeros(ctl_ddf_sz)
|
| 581 |
+
scale_num = min(8, int(math.log2(self.ctl_sz)))
|
| 582 |
+
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
| 583 |
+
|
| 584 |
+
for _i in range(rec_num):
|
| 585 |
+
if len(ctl_szs_all) > select_num:
|
| 586 |
+
ctl_szs = random.sample(ctl_szs_all, select_num)
|
| 587 |
+
else:
|
| 588 |
+
ctl_szs = ctl_szs_all
|
| 589 |
+
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
|
| 590 |
+
if noise_ratio == 0:
|
| 591 |
+
dvf0 = dvf
|
| 592 |
+
else:
|
| 593 |
+
dvf0 = dvf + self.stn_ctl(
|
| 594 |
+
self._multiscale_dvf_generate(
|
| 595 |
+
self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False,
|
| 596 |
+
).to(self.device),
|
| 597 |
+
dvf,
|
| 598 |
+
)
|
| 599 |
+
for j in range(torch.max(mul_num[0]).item()):
|
| 600 |
+
flag = [(n > j).int().to(self.device) for n in mul_num]
|
| 601 |
+
ddf = dvf0 * flag[0] + self.stn_ctl(ddf, dvf0 * flag[0])
|
| 602 |
+
dddf = dvf * flag[1] + self.stn_ctl(dddf, dvf * flag[1])
|
| 603 |
+
|
| 604 |
+
# Upscale and center-crop
|
| 605 |
+
interp_mode = "bilinear" if self.ndims == 2 else "trilinear"
|
| 606 |
+
ddf = F.interpolate(
|
| 607 |
+
ddf * self.img_size / self.ctl_sz,
|
| 608 |
+
self.img_size * crop_rate,
|
| 609 |
+
mode=interp_mode,
|
| 610 |
+
)
|
| 611 |
+
dddf = F.interpolate(
|
| 612 |
+
dddf * self.img_size / self.ctl_sz,
|
| 613 |
+
self.img_size * crop_rate,
|
| 614 |
+
mode=interp_mode,
|
| 615 |
+
)
|
| 616 |
+
half = self.img_size // 2
|
| 617 |
+
three_half = self.img_size * 3 // 2
|
| 618 |
+
if self.ndims == 2:
|
| 619 |
+
ddf = ddf[..., half:three_half, half:three_half]
|
| 620 |
+
dddf = dddf[..., half:three_half, half:three_half]
|
| 621 |
+
else:
|
| 622 |
+
ddf = ddf[..., half:three_half, half:three_half, half:three_half]
|
| 623 |
+
dddf = dddf[..., half:three_half, half:three_half, half:three_half]
|
| 624 |
+
return ddf, dddf
|
| 625 |
+
|
| 626 |
+
def _get_random_ddf(
|
| 627 |
+
self, img: torch.Tensor, t: torch.Tensor,
|
| 628 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 629 |
+
"""Forward-diffuse: generate random DDF and warp image."""
|
| 630 |
+
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| 631 |
+
ddf_forward, dvf_forward = self._random_ddf_generate(
|
| 632 |
+
rec_num=rec_num, mul_num=[mul_num_ddf, mul_num_dvf],
|
| 633 |
+
)
|
| 634 |
+
warped_img = self.img_stn(img, ddf_forward)
|
| 635 |
+
return warped_img, dvf_forward, ddf_forward
|
| 636 |
+
|
| 637 |
+
# ------------------------------------------------------------------
|
| 638 |
+
# Private — Conditioning processing
|
| 639 |
+
# ------------------------------------------------------------------
|
| 640 |
+
|
| 641 |
+
def _proc_cond_img(
|
| 642 |
+
self,
|
| 643 |
+
img: torch.Tensor,
|
| 644 |
+
proc_type: Optional[str] = None,
|
| 645 |
+
noise_scale: float = 0.1,
|
| 646 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 647 |
+
"""Conditioning strategies. Mirrors DeformDDPM.proc_cond_img()."""
|
| 648 |
+
proc_img = img.clone().detach()
|
| 649 |
+
if proc_type is None:
|
| 650 |
+
proc_type = random.choices(
|
| 651 |
+
["adding", "independ", "downsample", "slice", "none", "uncon"],
|
| 652 |
+
weights=[1, 1, 1, 1, 1, 3],
|
| 653 |
+
k=1,
|
| 654 |
+
)[0]
|
| 655 |
+
|
| 656 |
+
mask = torch.tensor(1, device=img.device)
|
| 657 |
+
cond_ratio = torch.tensor(1.0, device=img.device)
|
| 658 |
+
|
| 659 |
+
if proc_type in ["none", None, "", "None"]:
|
| 660 |
+
return proc_img, mask, cond_ratio
|
| 661 |
+
|
| 662 |
+
noise_type = random.choice(["gaussian", "uniform", "none"])
|
| 663 |
+
|
| 664 |
+
if proc_type == "uncon":
|
| 665 |
+
noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
| 666 |
+
return noise_map, torch.tensor(0, device=img.device), torch.tensor(0, device=img.device)
|
| 667 |
+
|
| 668 |
+
noise_map = None
|
| 669 |
+
if proc_type in ["adding", "independ", "slice"]:
|
| 670 |
+
noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
| 671 |
+
|
| 672 |
+
if proc_type == "adding":
|
| 673 |
+
noise_ratio = np.random.uniform(0.0, 1.0)
|
| 674 |
+
proc_img = proc_img * (1 - noise_ratio) + noise_map * noise_ratio
|
| 675 |
+
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| 676 |
+
elif proc_type == "independ":
|
| 677 |
+
mask = self._create_noise_map(img, noise_type="binary")
|
| 678 |
+
proc_img = img * mask
|
| 679 |
+
cond_ratio = mask.float().mean()
|
| 680 |
+
elif proc_type == "downsample":
|
| 681 |
+
down_ratio = list(np.random.uniform(1.0 / 64, 1, [self.ndims]))
|
| 682 |
+
down_img = F.interpolate(
|
| 683 |
+
proc_img, scale_factor=down_ratio,
|
| 684 |
+
mode="bilinear" if self.ndims == 2 else "trilinear",
|
| 685 |
+
)
|
| 686 |
+
proc_img = F.interpolate(
|
| 687 |
+
down_img, size=[self.img_size] * self.ndims,
|
| 688 |
+
mode="bilinear" if self.ndims == 2 else "trilinear",
|
| 689 |
+
align_corners=False,
|
| 690 |
+
)
|
| 691 |
+
cond_ratio = torch.tensor(np.sqrt(np.prod(down_ratio)), device=img.device)
|
| 692 |
+
elif proc_type == "slice":
|
| 693 |
+
slice_num_max = random.randint(1, 64)
|
| 694 |
+
slice_num_max = random.randint(1, slice_num_max)
|
| 695 |
+
mask, sample_ratio = self._get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| 696 |
+
proc_img = img * mask
|
| 697 |
+
cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| 698 |
+
elif proc_type == "project":
|
| 699 |
+
proj_img = torch.zeros_like(img)
|
| 700 |
+
rand_bourn = np.random.randint(0, 2, size=[self.ndims])
|
| 701 |
+
proj_dim_num = np.sum(rand_bourn)
|
| 702 |
+
for i, pflag in zip(range(2, 2 + self.ndims), rand_bourn):
|
| 703 |
+
if pflag:
|
| 704 |
+
proj_img += torch.mean(img, dim=i, keepdim=True)
|
| 705 |
+
proc_img = proj_img / (proj_dim_num + EPS)
|
| 706 |
+
cond_ratio = torch.tensor(proj_dim_num / (128 * self.ndims), device=img.device)
|
| 707 |
+
|
| 708 |
+
return proc_img, mask, cond_ratio
|
| 709 |
+
|
| 710 |
+
def _create_noise_map(
|
| 711 |
+
self,
|
| 712 |
+
img: torch.Tensor,
|
| 713 |
+
noise_type: str = "gaussian",
|
| 714 |
+
noise_scale: float = 0.1,
|
| 715 |
+
) -> torch.Tensor:
|
| 716 |
+
if noise_type == "gaussian":
|
| 717 |
+
return (torch.randn_like(img) * noise_scale).to(img.device)
|
| 718 |
+
elif noise_type == "uniform":
|
| 719 |
+
return (torch.rand_like(img) * noise_scale * 2 - noise_scale).to(img.device)
|
| 720 |
+
elif noise_type == "binary":
|
| 721 |
+
return torch.bernoulli(torch.rand_like(img)).to(img.device)
|
| 722 |
+
return torch.zeros_like(img).to(img.device)
|
| 723 |
+
|
| 724 |
+
def _get_slice_mask(
|
| 725 |
+
self,
|
| 726 |
+
img: torch.Tensor,
|
| 727 |
+
slice_num_range: list = None,
|
| 728 |
+
) -> Tuple[torch.Tensor, float]:
|
| 729 |
+
if slice_num_range is None:
|
| 730 |
+
slice_num_range = [0, 32]
|
| 731 |
+
slice_num_range[1] = min(slice_num_range[1], self.img_size)
|
| 732 |
+
mask = torch.zeros_like(img)
|
| 733 |
+
sample_ratio = 0.0
|
| 734 |
+
for i in range(self.ndims):
|
| 735 |
+
if self.inf_mode:
|
| 736 |
+
slice_num = 1
|
| 737 |
+
slice_idx = [self.img_size // 2]
|
| 738 |
+
else:
|
| 739 |
+
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| 740 |
+
slice_idx = random.sample(range(self.img_size), slice_num)
|
| 741 |
+
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| 742 |
+
for idx in slice_idx:
|
| 743 |
+
mask[..., idx] = 1
|
| 744 |
+
mask = mask.permute(*transpose_list)
|
| 745 |
+
sample_ratio += np.sqrt(slice_num / self.img_size) / self.ndims
|
| 746 |
+
return mask, sample_ratio
|
| 747 |
+
|
| 748 |
+
# ------------------------------------------------------------------
|
| 749 |
+
# Private — Standardization
|
| 750 |
+
# ------------------------------------------------------------------
|
| 751 |
+
|
| 752 |
+
def _standardize_img(
|
| 753 |
+
self,
|
| 754 |
+
img,
|
| 755 |
+
modality: Optional[str] = None,
|
| 756 |
+
keep_raw: bool = False,
|
| 757 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple]]:
|
| 758 |
+
"""Deterministic inference variant of the dataloader pipeline.
|
| 759 |
+
|
| 760 |
+
Returns ``(model_tensor, fullres_tensor_or_None, orig_shape_or_None)``.
|
| 761 |
+
|
| 762 |
+
* *model_tensor*: ``[B, C, S, S, S]`` at model resolution.
|
| 763 |
+
* *fullres_tensor*: ``[B, C, D, H, W]`` at original padded resolution
|
| 764 |
+
(only when *keep_raw=True*).
|
| 765 |
+
* *orig_shape*: spatial dims of padded volume before resize.
|
| 766 |
+
|
| 767 |
+
Accepts numpy arrays, torch tensors (any dimensionality), or a
|
| 768 |
+
file path (loaded via SimpleITK). Torch tensors with >= 4 dims
|
| 769 |
+
are treated as already-batched and are passed through with
|
| 770 |
+
appropriate device/dtype conversion.
|
| 771 |
+
"""
|
| 772 |
+
fullres_tensor = None
|
| 773 |
+
orig_shape = None
|
| 774 |
+
|
| 775 |
+
# 1. Load from path
|
| 776 |
+
if isinstance(img, str):
|
| 777 |
+
sitk_img = sitk.ReadImage(img)
|
| 778 |
+
vol = sitk.GetArrayFromImage(sitk_img)
|
| 779 |
+
vol = self._reverse_axis_order(vol)
|
| 780 |
+
elif isinstance(img, np.ndarray):
|
| 781 |
+
vol = img.copy()
|
| 782 |
+
elif isinstance(img, torch.Tensor):
|
| 783 |
+
# If already a batched tensor [B,C,...], pass through
|
| 784 |
+
if img.ndim >= 4:
|
| 785 |
+
t = img.float().to(self.device)
|
| 786 |
+
if keep_raw:
|
| 787 |
+
fullres_tensor = t.clone()
|
| 788 |
+
return t, fullres_tensor, None
|
| 789 |
+
# 1-3D tensor — treat as spatial-only numpy
|
| 790 |
+
vol = img.numpy()
|
| 791 |
+
else:
|
| 792 |
+
raise TypeError(f"Unsupported image type: {type(img)}")
|
| 793 |
+
|
| 794 |
+
# 2. Extract 3D from 4D
|
| 795 |
+
if vol.ndim == 4:
|
| 796 |
+
vol = vol[:, :, :, 0]
|
| 797 |
+
|
| 798 |
+
# 3. CT clamping
|
| 799 |
+
if modality is not None and modality.upper() == "CT" and self.clamp_range is not None:
|
| 800 |
+
vol = np.clip(vol, self.clamp_range[0], self.clamp_range[1])
|
| 801 |
+
|
| 802 |
+
# 4. Normalize [0, 1]
|
| 803 |
+
vol = vol.astype(np.float64)
|
| 804 |
+
vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-7)
|
| 805 |
+
|
| 806 |
+
# 5. Center-pad to cube
|
| 807 |
+
vol = self._center_pad_to_cube(vol)
|
| 808 |
+
orig_shape = vol.shape[:3]
|
| 809 |
+
|
| 810 |
+
# 6. Full-res tensor (before resize)
|
| 811 |
+
if keep_raw:
|
| 812 |
+
fullres_tensor = torch.tensor(
|
| 813 |
+
vol[None, None, ...], dtype=torch.float32, device=self.device,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
# 7. Resize to model resolution
|
| 817 |
+
target_sz = [self.img_size] * self.ndims
|
| 818 |
+
vol_resized = sk_resize(
|
| 819 |
+
vol, target_sz, anti_aliasing=True, preserve_range=True,
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
# 8. Add batch + channel dims
|
| 823 |
+
model_tensor = torch.tensor(
|
| 824 |
+
vol_resized[None, None, ...], dtype=torch.float32, device=self.device,
|
| 825 |
+
)
|
| 826 |
+
return model_tensor, fullres_tensor, orig_shape
|
| 827 |
+
|
| 828 |
+
def _standardize_label(
|
| 829 |
+
self,
|
| 830 |
+
label,
|
| 831 |
+
fill_value: float = -1,
|
| 832 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 833 |
+
"""Standardize a label volume for inference.
|
| 834 |
+
|
| 835 |
+
Returns ``(model_tensor, fullres_tensor)``.
|
| 836 |
+
|
| 837 |
+
* *model_tensor*: ``[1, C, S, S, S]`` at model resolution
|
| 838 |
+
(nearest-neighbor resize, no anti-aliasing).
|
| 839 |
+
* *fullres_tensor*: ``[1, C, D, H, W]`` at original padded resolution.
|
| 840 |
+
|
| 841 |
+
If *label* is ``None``, returns *fill_value*-filled placeholders
|
| 842 |
+
shaped to match the current init image (model-res and full-res).
|
| 843 |
+
|
| 844 |
+
Accepts numpy arrays or torch tensors. Does NOT apply
|
| 845 |
+
normalization or clamping (labels are discrete indices).
|
| 846 |
+
"""
|
| 847 |
+
# --- Placeholder for missing labels ---
|
| 848 |
+
if label is None:
|
| 849 |
+
model_sz = [self.img_size] * self.ndims
|
| 850 |
+
model_t = torch.full(
|
| 851 |
+
[1, 1] + model_sz, fill_value,
|
| 852 |
+
dtype=torch.float32, device=self.device,
|
| 853 |
+
)
|
| 854 |
+
if self._init_img_raw is not None:
|
| 855 |
+
fullres_sz = list(self._init_img_raw.shape[2:])
|
| 856 |
+
else:
|
| 857 |
+
fullres_sz = model_sz
|
| 858 |
+
fullres_t = torch.full(
|
| 859 |
+
[1, 1] + fullres_sz, fill_value,
|
| 860 |
+
dtype=torch.float32, device=self.device,
|
| 861 |
+
)
|
| 862 |
+
return model_t, fullres_t
|
| 863 |
+
|
| 864 |
+
# --- Convert to numpy if needed ---
|
| 865 |
+
if isinstance(label, torch.Tensor):
|
| 866 |
+
if label.ndim >= 4:
|
| 867 |
+
# Already batched tensor — pass through
|
| 868 |
+
fullres_t = label.float().to(self.device)
|
| 869 |
+
target_sz = [self.img_size] * self.ndims
|
| 870 |
+
model_t = F.interpolate(
|
| 871 |
+
fullres_t, size=target_sz, mode="nearest",
|
| 872 |
+
)
|
| 873 |
+
return model_t, fullres_t
|
| 874 |
+
lab = label.numpy()
|
| 875 |
+
elif isinstance(label, np.ndarray):
|
| 876 |
+
lab = label.copy()
|
| 877 |
+
else:
|
| 878 |
+
raise TypeError(f"Unsupported label type: {type(label)}")
|
| 879 |
+
|
| 880 |
+
# --- Center-pad to cube ---
|
| 881 |
+
lab = self._center_pad_to_cube(lab)
|
| 882 |
+
|
| 883 |
+
# --- Channel dim: 3D→[C=1,...], 4D→channels-first [C,...] ---
|
| 884 |
+
if lab.ndim == 3:
|
| 885 |
+
lab = lab[None, :, :, :] # [1, D, H, W]
|
| 886 |
+
elif lab.ndim > 3:
|
| 887 |
+
lab = np.transpose(lab, (3, 0, 1, 2)) # [C, D, H, W]
|
| 888 |
+
|
| 889 |
+
# --- Full-res tensor ---
|
| 890 |
+
fullres_t = torch.tensor(
|
| 891 |
+
lab[None, ...], dtype=torch.float32, device=self.device,
|
| 892 |
+
) # [1, C, D, H, W]
|
| 893 |
+
|
| 894 |
+
# --- Resize to model resolution (nearest-neighbor) ---
|
| 895 |
+
target_sz = [self.img_size] * self.ndims
|
| 896 |
+
# Resize each channel separately to avoid resizing the channel dim
|
| 897 |
+
channels = []
|
| 898 |
+
for c in range(lab.shape[0]):
|
| 899 |
+
ch = sk_resize(
|
| 900 |
+
lab[c], target_sz,
|
| 901 |
+
anti_aliasing=False, preserve_range=True, order=0,
|
| 902 |
+
)
|
| 903 |
+
channels.append(ch)
|
| 904 |
+
lab_model = np.stack(channels, axis=0) # [C, S, S, S]
|
| 905 |
+
model_t = torch.tensor(
|
| 906 |
+
lab_model[None, ...], dtype=torch.float32, device=self.device,
|
| 907 |
+
) # [1, C, S, S, S]
|
| 908 |
+
|
| 909 |
+
return model_t, fullres_t
|
| 910 |
+
|
| 911 |
+
def _standardize_txt(self, txt) -> Optional[torch.Tensor]:
|
| 912 |
+
"""Convert text input to [B, 1024] tensor."""
|
| 913 |
+
if txt is None:
|
| 914 |
+
return None
|
| 915 |
+
if isinstance(txt, str):
|
| 916 |
+
self._ensure_bert()
|
| 917 |
+
from Dataloader.bert_helper import str2emb
|
| 918 |
+
emb = str2emb(
|
| 919 |
+
txt, max_words_num=100,
|
| 920 |
+
embeder=self._bert_model, tokenizer=self._bert_tokenizer,
|
| 921 |
+
reduce_method="mean",
|
| 922 |
+
)
|
| 923 |
+
return emb.to(self.device) # [1, 1024]
|
| 924 |
+
if isinstance(txt, np.ndarray):
|
| 925 |
+
t = torch.tensor(txt, dtype=torch.float32, device=self.device)
|
| 926 |
+
if t.ndim == 1:
|
| 927 |
+
t = t.unsqueeze(0)
|
| 928 |
+
return t
|
| 929 |
+
if isinstance(txt, torch.Tensor):
|
| 930 |
+
t = txt.float().to(self.device)
|
| 931 |
+
if t.ndim == 1:
|
| 932 |
+
t = t.unsqueeze(0)
|
| 933 |
+
return t
|
| 934 |
+
raise TypeError(f"Unsupported text type: {type(txt)}")
|
| 935 |
+
|
| 936 |
+
def _ensure_bert(self):
|
| 937 |
+
if self._bert_model is None:
|
| 938 |
+
from Dataloader.bert_helper import get_frozen_embeder
|
| 939 |
+
self._bert_model, self._bert_tokenizer = get_frozen_embeder(self.bert_model_path)
|
| 940 |
+
|
| 941 |
+
# ------------------------------------------------------------------
|
| 942 |
+
# Private — Spatial utilities
|
| 943 |
+
# ------------------------------------------------------------------
|
| 944 |
+
|
| 945 |
+
@staticmethod
|
| 946 |
+
def _reverse_axis_order(arr: np.ndarray) -> np.ndarray:
|
| 947 |
+
"""SimpleITK → NumPy axis order."""
|
| 948 |
+
return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
|
| 949 |
+
|
| 950 |
+
@staticmethod
|
| 951 |
+
def _center_pad_to_cube(volume: np.ndarray) -> np.ndarray:
|
| 952 |
+
"""Pad volume to a cube using the max dimension, with symmetric padding."""
|
| 953 |
+
max_dim = max(volume.shape[:3])
|
| 954 |
+
pad_width = []
|
| 955 |
+
for s in volume.shape[:3]:
|
| 956 |
+
total_pad = max_dim - s
|
| 957 |
+
pad_before = total_pad // 2
|
| 958 |
+
pad_after = total_pad - pad_before
|
| 959 |
+
pad_width.append((pad_before, pad_after))
|
| 960 |
+
for _ in range(volume.ndim - 3):
|
| 961 |
+
pad_width.append((0, 0))
|
| 962 |
+
return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
| 963 |
+
|
| 964 |
+
def _apply_ddf(
|
| 965 |
+
self,
|
| 966 |
+
volume_tensor: torch.Tensor,
|
| 967 |
+
ddf: torch.Tensor,
|
| 968 |
+
padding_mode: str = "border",
|
| 969 |
+
resample_mode: str = "bilinear",
|
| 970 |
+
) -> torch.Tensor:
|
| 971 |
+
"""Apply DDF to volume tensor at any resolution via grid_sample."""
|
| 972 |
+
device = ddf.device
|
| 973 |
+
ndims = self.ndims
|
| 974 |
+
img_sz = list(volume_tensor.shape[2:])
|
| 975 |
+
max_sz = torch.reshape(
|
| 976 |
+
torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| 977 |
+
[1, ndims] + [1] * ndims,
|
| 978 |
+
)
|
| 979 |
+
ref_grid = torch.reshape(
|
| 980 |
+
torch.stack(
|
| 981 |
+
torch.meshgrid(
|
| 982 |
+
[torch.arange(s, device=device, dtype=torch.float32) for s in img_sz],
|
| 983 |
+
indexing="ij",
|
| 984 |
+
),
|
| 985 |
+
0,
|
| 986 |
+
),
|
| 987 |
+
[1, ndims] + img_sz,
|
| 988 |
+
)
|
| 989 |
+
img_shape = torch.reshape(
|
| 990 |
+
torch.tensor(
|
| 991 |
+
[(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device,
|
| 992 |
+
),
|
| 993 |
+
[1] + [1] * ndims + [ndims],
|
| 994 |
+
)
|
| 995 |
+
grid = torch.flip(
|
| 996 |
+
(ddf * max_sz + ref_grid).permute(
|
| 997 |
+
[0] + list(range(2, 2 + ndims)) + [1]
|
| 998 |
+
)
|
| 999 |
+
/ img_shape
|
| 1000 |
+
- 1,
|
| 1001 |
+
dims=[-1],
|
| 1002 |
+
)
|
| 1003 |
+
return F.grid_sample(
|
| 1004 |
+
volume_tensor.to(device),
|
| 1005 |
+
grid.float(),
|
| 1006 |
+
mode=resample_mode,
|
| 1007 |
+
padding_mode=padding_mode,
|
| 1008 |
+
align_corners=True,
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
def _ensure_tensor(self, img) -> torch.Tensor:
|
| 1012 |
+
"""Convert numpy/torch input to a [B, C, ...] float tensor on device."""
|
| 1013 |
+
if isinstance(img, np.ndarray):
|
| 1014 |
+
t = torch.tensor(img, dtype=torch.float32, device=self.device)
|
| 1015 |
+
elif isinstance(img, torch.Tensor):
|
| 1016 |
+
t = img.float().to(self.device)
|
| 1017 |
+
else:
|
| 1018 |
+
raise TypeError(f"Unsupported image type: {type(img)}")
|
| 1019 |
+
if t.ndim == self.ndims: # spatial only → [B=1, C=1, ...]
|
| 1020 |
+
t = t[None, None, ...]
|
| 1021 |
+
elif t.ndim == self.ndims + 1: # [C, ...] → [B=1, C, ...]
|
| 1022 |
+
t = t[None, ...]
|
| 1023 |
+
return t
|
| 1024 |
+
|
| 1025 |
+
def _to_ddf_tensor(self, ddf) -> torch.Tensor:
|
| 1026 |
+
"""Convert ddf input to proper tensor on device."""
|
| 1027 |
+
if isinstance(ddf, np.ndarray):
|
| 1028 |
+
ddf = torch.tensor(ddf, dtype=torch.float32)
|
| 1029 |
+
ddf = ddf.float().to(self.device)
|
| 1030 |
+
if ddf.ndim == self.ndims + 1:
|
| 1031 |
+
ddf = ddf.unsqueeze(0)
|
| 1032 |
+
# Resize to model resolution if needed
|
| 1033 |
+
model_sz = [self.img_size] * self.ndims
|
| 1034 |
+
if list(ddf.shape[2:]) != model_sz:
|
| 1035 |
+
ddf = F.interpolate(
|
| 1036 |
+
ddf, size=model_sz,
|
| 1037 |
+
mode="bilinear" if self.ndims == 2 else "trilinear",
|
| 1038 |
+
align_corners=False,
|
| 1039 |
+
)
|
| 1040 |
+
return ddf
|
| 1041 |
+
|
| 1042 |
+
# ------------------------------------------------------------------
|
| 1043 |
+
# Convenience / repr
|
| 1044 |
+
# ------------------------------------------------------------------
|
| 1045 |
+
|
| 1046 |
+
def __repr__(self) -> str:
|
| 1047 |
+
status_parts = []
|
| 1048 |
+
if self._init_img is not None:
|
| 1049 |
+
status_parts.append(f"init_img={list(self._init_img.shape)}")
|
| 1050 |
+
if self._cond_img is not None:
|
| 1051 |
+
status_parts.append(f"cond_img={list(self._cond_img.shape)}")
|
| 1052 |
+
if self._predicted_ddf is not None:
|
| 1053 |
+
status_parts.append(f"predicted_ddf={list(self._predicted_ddf.shape)}")
|
| 1054 |
+
status = ", ".join(status_parts) if status_parts else "empty"
|
| 1055 |
+
return (
|
| 1056 |
+
f"OMorpher(net={self.net_name}, ndims={self.ndims}, "
|
| 1057 |
+
f"img_size={self.img_size}, device={self.device}, {status})"
|
| 1058 |
+
)
|
README.md
CHANGED
|
@@ -1,80 +1,129 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
-
|
| 7 |
-
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
``
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
```
|
| 36 |
-
|
| 37 |
-
```
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
``
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
#
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- medical-imaging
|
| 5 |
+
- registration
|
| 6 |
+
- diffusion
|
| 7 |
+
- 3d
|
| 8 |
+
- image-generation
|
| 9 |
+
- image-restoration
|
| 10 |
+
- pytorch
|
| 11 |
+
library_name: pytorch
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# OmniMorph
|
| 15 |
+
|
| 16 |
+
**Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).**
|
| 17 |
+
|
| 18 |
+
OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:
|
| 19 |
+
|
| 20 |
+
- **Generation** — text-conditioned image synthesis via BERT embeddings.
|
| 21 |
+
- **Restoration** — recover anatomically plausible images from degraded inputs.
|
| 22 |
+
- **Registration** — paired / unpaired / flexible-resolution registration via diffused deformation vector fields.
|
| 23 |
+
|
| 24 |
+
## Repository Contents
|
| 25 |
+
|
| 26 |
+
| Path | Description |
|
| 27 |
+
|---|---|
|
| 28 |
+
| `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) |
|
| 29 |
+
| `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts |
|
| 30 |
+
| `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils |
|
| 31 |
+
| `OMorpher/` | Higher-level model wrapper |
|
| 32 |
+
| `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) |
|
| 33 |
+
| `Config/` | YAML training/inference configs |
|
| 34 |
+
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
|
| 35 |
+
| `tests/` | Pytest suite for `OMorpher` and loss functions |
|
| 36 |
+
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
|
| 37 |
+
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint (epoch 110, multi-modal `recmulmodmutattnnet`) |
|
| 38 |
+
|
| 39 |
+
> **Note** Only the final checkpoint (epoch 110) is shipped here. Earlier epochs and the `bert_large_uncased` weights are not bundled — download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
|
| 40 |
+
|
| 41 |
+
## Setup
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
git clone https://huggingface.co/DRDMsig/Omini3D
|
| 45 |
+
cd Omini3D
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements.
|
| 50 |
+
|
| 51 |
+
## Quick Start
|
| 52 |
+
|
| 53 |
+
### Training
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# Single-mode diffusion
|
| 57 |
+
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
|
| 58 |
+
|
| 59 |
+
# Dual mode (diffusion + registration)
|
| 60 |
+
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
|
| 61 |
+
|
| 62 |
+
# Triple mode (diffusion + contrastive + registration)
|
| 63 |
+
CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml
|
| 64 |
+
|
| 65 |
+
# Intel XPU (single node)
|
| 66 |
+
sbatch bash_train_single_node.sh
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Inference
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
# Augmentation / restoration with a trained model
|
| 73 |
+
python OM_aug.py -C Config/config_om.yaml
|
| 74 |
+
|
| 75 |
+
# Paired registration
|
| 76 |
+
python OM_reg.py -C Config/config_om.yaml
|
| 77 |
+
|
| 78 |
+
# Flexible-resolution registration
|
| 79 |
+
python OM_reg_flexres.py -C Config/config_om.yaml
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
### Loading the checkpoint
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
import torch
|
| 86 |
+
from Diffusion.networks import get_net
|
| 87 |
+
|
| 88 |
+
# Production network (multi-modal recmutattnnet)
|
| 89 |
+
net = get_net("recmulmodmutattnnet")
|
| 90 |
+
state = torch.load("Models/all_om_net/000110_all_om_net.pth", map_location="cpu")
|
| 91 |
+
net.load_state_dict(state["model"] if "model" in state else state)
|
| 92 |
+
net.eval()
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Architecture
|
| 96 |
+
|
| 97 |
+
```
|
| 98 |
+
Config YAML → DataLoader(s) → DeformDDPM(Network, STN) → Loss → Checkpoint
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
- **`DeformDDPM`** (`Diffusion/diffuser.py`) — forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`.
|
| 102 |
+
- **Networks** (`Diffusion/networks.py`) — selectable via `get_net(name)`:
|
| 103 |
+
- `recmulmodmutattnnet` — current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`)
|
| 104 |
+
- `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet`
|
| 105 |
+
- **`STN`** — Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`.
|
| 106 |
+
- **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) — `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`.
|
| 107 |
+
|
| 108 |
+
## Datasets Supported
|
| 109 |
+
|
| 110 |
+
`Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including:
|
| 111 |
+
AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.
|
| 112 |
+
|
| 113 |
+
The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths.
|
| 114 |
+
|
| 115 |
+
## Citation
|
| 116 |
+
|
| 117 |
+
```bibtex
|
| 118 |
+
@article{omnimorph,
|
| 119 |
+
title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
|
| 120 |
+
Restoration and Registration via Conditional Deformation-Recovery
|
| 121 |
+
Diffusion Models},
|
| 122 |
+
author = {Zheng, J. and Mo, M. and others},
|
| 123 |
+
year = {2025}
|
| 124 |
+
}
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
## License
|
| 128 |
+
|
| 129 |
+
MIT — see `LICENSE`.
|
Scripts/OM_aug_om.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OM_aug_om.py — Augmentation using OMorpher.
|
| 3 |
+
|
| 4 |
+
Drop-in replacement for OM_aug.py. Produces identical outputs but uses
|
| 5 |
+
OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python Scripts/OM_aug_om.py -C Config/config_om.yaml
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
# Add project root to path so imports work from Scripts/
|
| 16 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import nibabel as nib
|
| 21 |
+
import yaml
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
|
| 24 |
+
import utils
|
| 25 |
+
from Dataloader.dataLoader import OminiDataset_inference_w_all
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
from OMorpher import OMorpher
|
| 28 |
+
|
| 29 |
+
# ========== CLI ==========
|
| 30 |
+
|
| 31 |
+
parser = argparse.ArgumentParser()
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--config", "-C",
|
| 34 |
+
help="Path for the config file",
|
| 35 |
+
type=str,
|
| 36 |
+
default="Config/config_cmr.yaml",
|
| 37 |
+
required=False,
|
| 38 |
+
)
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
# ========== Config ==========
|
| 42 |
+
|
| 43 |
+
with open(args.config, "r") as file:
|
| 44 |
+
hyp_parameters = yaml.safe_load(file)
|
| 45 |
+
print(hyp_parameters)
|
| 46 |
+
|
| 47 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 48 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 49 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 50 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 51 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 52 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 53 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 54 |
+
|
| 55 |
+
hyp_parameters["batchsize"] = 1
|
| 56 |
+
|
| 57 |
+
# ========== Dataset (identical to OM_aug.py) ==========
|
| 58 |
+
|
| 59 |
+
select_channels_dict = {}
|
| 60 |
+
min_crop_ratio = 0.9
|
| 61 |
+
|
| 62 |
+
label_keys = ["heart"]
|
| 63 |
+
database = ["MnMs"]
|
| 64 |
+
subtype = "es"
|
| 65 |
+
hyp_parameters["aug_img_savepath"] = f"Data/Aug_data/mnms_{subtype}/img/"
|
| 66 |
+
hyp_parameters["aug_msk_savepath"] = f"Data/Aug_data/mnms_{subtype}/msk/"
|
| 67 |
+
hyp_parameters["aug_ddf_savepath"] = f"Data/Aug_data/mnms_{subtype}/ddf/"
|
| 68 |
+
select_channels_dict = {"ImgDict": [subtype]}
|
| 69 |
+
|
| 70 |
+
dataset = OminiDataset_inference_w_all(
|
| 71 |
+
transform=None,
|
| 72 |
+
min_crop_ratio=min_crop_ratio,
|
| 73 |
+
label_key=label_keys,
|
| 74 |
+
database=database,
|
| 75 |
+
select_channels_dict=select_channels_dict,
|
| 76 |
+
)
|
| 77 |
+
Infer_Loader = DataLoader(
|
| 78 |
+
dataset,
|
| 79 |
+
batch_size=hyp_parameters["batchsize"],
|
| 80 |
+
shuffle=False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# ========== OMorpher setup ==========
|
| 84 |
+
|
| 85 |
+
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 86 |
+
model_save_path = os.path.join(
|
| 87 |
+
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| 88 |
+
str(epoch) + ".pth",
|
| 89 |
+
)
|
| 90 |
+
print("Loading model from:", model_save_path)
|
| 91 |
+
|
| 92 |
+
om = OMorpher(
|
| 93 |
+
config=hyp_parameters,
|
| 94 |
+
checkpoint_path=model_save_path,
|
| 95 |
+
device=str(hyp_parameters.get("device", "cpu")),
|
| 96 |
+
)
|
| 97 |
+
print(om)
|
| 98 |
+
|
| 99 |
+
# ========== Output directories ==========
|
| 100 |
+
|
| 101 |
+
os.makedirs(hyp_parameters["aug_img_savepath"], exist_ok=True)
|
| 102 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"], exist_ok=True)
|
| 103 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"], exist_ok=True)
|
| 104 |
+
|
| 105 |
+
# ========== Main inference loop ==========
|
| 106 |
+
|
| 107 |
+
device = om.device
|
| 108 |
+
print("total num of image:", len(Infer_Loader))
|
| 109 |
+
|
| 110 |
+
for e, d in tqdm(enumerate(Infer_Loader)):
|
| 111 |
+
img = d["img"]
|
| 112 |
+
mask = d["labels"]
|
| 113 |
+
label_str = str(d["label_channels"])
|
| 114 |
+
pid = e
|
| 115 |
+
|
| 116 |
+
print("Processing to patient:", pid, " image:", e)
|
| 117 |
+
|
| 118 |
+
img = img.type(torch.float32).to(device)
|
| 119 |
+
image_original = img.cpu().detach().numpy()
|
| 120 |
+
|
| 121 |
+
mask = mask.type(torch.float32).to(device)
|
| 122 |
+
mask_original = mask.cpu().detach().numpy()
|
| 123 |
+
|
| 124 |
+
# Save original image and mask
|
| 125 |
+
nifti_img = utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"])
|
| 126 |
+
nifti_mask = utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"])
|
| 127 |
+
|
| 128 |
+
nib.save(
|
| 129 |
+
nifti_img,
|
| 130 |
+
os.path.join(
|
| 131 |
+
hyp_parameters["aug_img_savepath"],
|
| 132 |
+
utils.get_barcode([pid, e]) + ".nii.gz",
|
| 133 |
+
),
|
| 134 |
+
)
|
| 135 |
+
nib.save(
|
| 136 |
+
nifti_mask,
|
| 137 |
+
os.path.join(
|
| 138 |
+
hyp_parameters["aug_msk_savepath"],
|
| 139 |
+
utils.get_barcode([pid, e]) + "_GT.nii.gz",
|
| 140 |
+
),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Augmentation loop
|
| 144 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
for im in range(hyp_parameters["aug_coe"]):
|
| 147 |
+
print(
|
| 148 |
+
f"Generating -> Subject-{pid}, Scan-{e} "
|
| 149 |
+
f'({im}/{hyp_parameters["aug_coe"]})',
|
| 150 |
+
end="\r",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# 1. Set init image (DataLoader tensor passes through)
|
| 154 |
+
om.set_init_img(img)
|
| 155 |
+
|
| 156 |
+
# 2. Self-conditioning (matches: cond_imgs = img_org.clone().detach())
|
| 157 |
+
om.set_cond_img(img)
|
| 158 |
+
|
| 159 |
+
# 3. Forward diffuse to get noisy image + random DDF
|
| 160 |
+
t_start = torch.tensor(np.array([noise_step]), device=device)
|
| 161 |
+
img_diff, _, ddf_rand = om._get_random_ddf(om._init_img, t_start)
|
| 162 |
+
|
| 163 |
+
# 4. Get noisy mask
|
| 164 |
+
msk_diff = om.apply_def(
|
| 165 |
+
img=mask, ddf=ddf_rand,
|
| 166 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# 5. Set random DDF as initial DDF
|
| 170 |
+
om.set_init_def(ddf=ddf_rand.clone().detach())
|
| 171 |
+
|
| 172 |
+
# 6. Run reverse diffusion
|
| 173 |
+
om.predict(
|
| 174 |
+
T=[noise_step, hyp_parameters["timesteps"]],
|
| 175 |
+
proc_type=hyp_parameters["condition_type"],
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 7. Get recovered outputs
|
| 179 |
+
ddf_comp = om.get_def()
|
| 180 |
+
img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
|
| 181 |
+
msk_rec = om.apply_def(
|
| 182 |
+
img=mask, ddf=ddf_comp,
|
| 183 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# Convert to numpy for saving
|
| 187 |
+
denoise_imgs = img_rec.cpu().detach().numpy()
|
| 188 |
+
denoise_msks = msk_rec.cpu().detach().numpy()
|
| 189 |
+
noisy_imgs_np = img_diff.cpu().detach().numpy()
|
| 190 |
+
noisy_msks_np = msk_diff.cpu().detach().numpy()
|
| 191 |
+
|
| 192 |
+
# Save augmented (recovered) outputs
|
| 193 |
+
nifti_img_aug = utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"])
|
| 194 |
+
nifti_mask_aug = utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"])
|
| 195 |
+
nifti_img = utils.converet_to_nibabel(noisy_imgs_np, ndims=hyp_parameters["ndims"])
|
| 196 |
+
nifti_mask = utils.converet_to_nibabel(noisy_msks_np, ndims=hyp_parameters["ndims"])
|
| 197 |
+
|
| 198 |
+
nib.save(
|
| 199 |
+
nifti_img_aug,
|
| 200 |
+
os.path.join(
|
| 201 |
+
hyp_parameters["aug_img_savepath"],
|
| 202 |
+
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| 203 |
+
),
|
| 204 |
+
)
|
| 205 |
+
nib.save(
|
| 206 |
+
nifti_mask_aug,
|
| 207 |
+
os.path.join(
|
| 208 |
+
hyp_parameters["aug_msk_savepath"],
|
| 209 |
+
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| 210 |
+
),
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Save noisy image/mask
|
| 214 |
+
nib.save(
|
| 215 |
+
nifti_img,
|
| 216 |
+
os.path.join(
|
| 217 |
+
hyp_parameters["aug_img_savepath"],
|
| 218 |
+
utils.get_barcode(
|
| 219 |
+
[pid, e, im, noise_step],
|
| 220 |
+
header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
|
| 221 |
+
) + ".nii.gz",
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
nib.save(
|
| 225 |
+
nifti_mask,
|
| 226 |
+
os.path.join(
|
| 227 |
+
hyp_parameters["aug_msk_savepath"],
|
| 228 |
+
utils.get_barcode(
|
| 229 |
+
[pid, e, im, noise_step],
|
| 230 |
+
header=["Patient", "Slice", "NoiseImg", "NoiseStep"],
|
| 231 |
+
) + "_GT.nii.gz",
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| 236 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 237 |
+
|
| 238 |
+
if e >= 0:
|
| 239 |
+
exit()
|
Scripts/OM_reg_flexres_om.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OM_reg_flexres_om.py — Full-resolution registration using OMorpher.
|
| 3 |
+
|
| 4 |
+
Drop-in replacement for OM_reg_flexres.py. Produces identical outputs but
|
| 5 |
+
uses OMorpher instead of DeformDDPM + STN + standalone apply_ddf().
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python Scripts/OM_reg_flexres_om.py -C Config/config_om.yaml
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import argparse
|
| 14 |
+
|
| 15 |
+
# Add project root to path so imports work from Scripts/
|
| 16 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
import nibabel as nib
|
| 22 |
+
import yaml
|
| 23 |
+
import SimpleITK as sitk
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
import utils
|
| 27 |
+
from Dataloader.dataLoader import OminiDataset_inference_w_all, reverse_axis_order
|
| 28 |
+
from OMorpher import OMorpher
|
| 29 |
+
|
| 30 |
+
# ========== CLI ==========
|
| 31 |
+
|
| 32 |
+
parser = argparse.ArgumentParser()
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--config", "-C",
|
| 35 |
+
help="Path for the config file",
|
| 36 |
+
type=str,
|
| 37 |
+
default="Config/config_om.yaml",
|
| 38 |
+
required=False,
|
| 39 |
+
)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
# ========== Config ==========
|
| 43 |
+
|
| 44 |
+
with open(args.config, "r") as file:
|
| 45 |
+
hyp_parameters = yaml.safe_load(file)
|
| 46 |
+
print(hyp_parameters)
|
| 47 |
+
|
| 48 |
+
if not os.path.exists(hyp_parameters["aug_img_savepath"]):
|
| 49 |
+
os.makedirs(hyp_parameters["aug_img_savepath"])
|
| 50 |
+
if not os.path.exists(hyp_parameters["aug_msk_savepath"]):
|
| 51 |
+
os.makedirs(hyp_parameters["aug_msk_savepath"])
|
| 52 |
+
if not os.path.exists(hyp_parameters["aug_ddf_savepath"]):
|
| 53 |
+
os.makedirs(hyp_parameters["aug_ddf_savepath"])
|
| 54 |
+
print(hyp_parameters["aug_img_savepath"])
|
| 55 |
+
|
| 56 |
+
hyp_parameters["batchsize"] = 1
|
| 57 |
+
model_img_sz = hyp_parameters["img_size"]
|
| 58 |
+
|
| 59 |
+
# ========== Dataset (unchanged — used only for filtering/metadata) ==========
|
| 60 |
+
|
| 61 |
+
label_keys = ["brain"]
|
| 62 |
+
database = ["Brats2019"]
|
| 63 |
+
|
| 64 |
+
dataset = OminiDataset_inference_w_all(
|
| 65 |
+
transform=None, min_crop_ratio=1.0, label_key=label_keys, database=database,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# ========== OMorpher setup ==========
|
| 69 |
+
|
| 70 |
+
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 71 |
+
model_save_path = os.path.join(
|
| 72 |
+
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| 73 |
+
str(epoch) + ".pth",
|
| 74 |
+
)
|
| 75 |
+
print("Loading model from:", model_save_path)
|
| 76 |
+
|
| 77 |
+
om = OMorpher(
|
| 78 |
+
config=hyp_parameters,
|
| 79 |
+
checkpoint_path=model_save_path,
|
| 80 |
+
device=str(hyp_parameters.get("device", "cpu")),
|
| 81 |
+
)
|
| 82 |
+
print(om)
|
| 83 |
+
|
| 84 |
+
# ========== Output directories ==========
|
| 85 |
+
|
| 86 |
+
reg_img_savepath_fullres = hyp_parameters["reg_img_savepath"].rstrip("/") + "_fullres/"
|
| 87 |
+
reg_msk_savepath_fullres = hyp_parameters["reg_msk_savepath"].rstrip("/") + "_fullres/"
|
| 88 |
+
reg_ddf_savepath_fullres = hyp_parameters["reg_ddf_savepath"].rstrip("/") + "_fullres/"
|
| 89 |
+
|
| 90 |
+
for p in [
|
| 91 |
+
hyp_parameters["reg_img_savepath"],
|
| 92 |
+
hyp_parameters["reg_msk_savepath"],
|
| 93 |
+
hyp_parameters["reg_ddf_savepath"],
|
| 94 |
+
reg_img_savepath_fullres,
|
| 95 |
+
reg_msk_savepath_fullres,
|
| 96 |
+
reg_ddf_savepath_fullres,
|
| 97 |
+
]:
|
| 98 |
+
os.makedirs(p, exist_ok=True)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ========== Helper: load full-res data (same as original) ==========
|
| 102 |
+
|
| 103 |
+
def center_pad_to_cube(volume):
|
| 104 |
+
"""Pad volume to a cube using the max dimension, with symmetric (center) padding."""
|
| 105 |
+
max_dim = max(volume.shape[:3])
|
| 106 |
+
pad_width = []
|
| 107 |
+
for s in volume.shape[:3]:
|
| 108 |
+
total_pad = max_dim - s
|
| 109 |
+
pad_before = total_pad // 2
|
| 110 |
+
pad_after = total_pad - pad_before
|
| 111 |
+
pad_width.append((pad_before, pad_after))
|
| 112 |
+
for _ in range(volume.ndim - 3):
|
| 113 |
+
pad_width.append((0, 0))
|
| 114 |
+
return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def load_fullres_volume(key, ds):
|
| 118 |
+
"""Load original-resolution volume: axis reorder, clamp, normalize, center-pad to cube."""
|
| 119 |
+
volume = sitk.ReadImage(key)
|
| 120 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 121 |
+
volume = reverse_axis_order(volume)
|
| 122 |
+
if volume.ndim == 4:
|
| 123 |
+
channel_ids = ds.get_channel_ids(key)
|
| 124 |
+
channel_id = channel_ids[0] if len(channel_ids) > 0 else 0
|
| 125 |
+
volume = volume[:, :, :, channel_id]
|
| 126 |
+
if ds.clamp_range is not None:
|
| 127 |
+
modality = ds.ALLdata_filtered[key].get("Modality", None)
|
| 128 |
+
if modality == "CT":
|
| 129 |
+
volume = np.clip(volume, ds.clamp_range[0], ds.clamp_range[1])
|
| 130 |
+
volume = ds.normalize(volume)
|
| 131 |
+
volume = center_pad_to_cube(volume)
|
| 132 |
+
return volume
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_fullres_label(key, ds, label_key):
|
| 136 |
+
"""Load original-resolution label: axis reorder, center-pad to cube."""
|
| 137 |
+
label_path_dict = ds.ALLdata_filtered[key].get("Label_path", {})
|
| 138 |
+
task_labels = label_path_dict.get("segmentation", {})
|
| 139 |
+
if label_key not in task_labels:
|
| 140 |
+
return None
|
| 141 |
+
label = sitk.ReadImage(task_labels[label_key])
|
| 142 |
+
label = sitk.GetArrayFromImage(label)
|
| 143 |
+
label = reverse_axis_order(label)
|
| 144 |
+
if label.ndim > 3:
|
| 145 |
+
channel_ids = ds.get_channel_ids(key)
|
| 146 |
+
if len(channel_ids) != 0:
|
| 147 |
+
label = label[..., channel_ids]
|
| 148 |
+
label = center_pad_to_cube(label)
|
| 149 |
+
return label
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# ========== Main inference loop ==========
|
| 153 |
+
|
| 154 |
+
keys = list(dataset.ALLdata_filtered.keys())
|
| 155 |
+
print("total num of images:", len(keys))
|
| 156 |
+
device = om.device
|
| 157 |
+
|
| 158 |
+
for e, key in enumerate(tqdm(keys)):
|
| 159 |
+
pid = e
|
| 160 |
+
print(f"Processing patient {pid}, image {e}, key: {key}")
|
| 161 |
+
|
| 162 |
+
# --- Load & standardize volume via OMorpher ---
|
| 163 |
+
fullres_vol = load_fullres_volume(key, dataset)
|
| 164 |
+
om.set_init_img(fullres_vol)
|
| 165 |
+
img = om._init_img # [1, 1, model_sz, model_sz, model_sz]
|
| 166 |
+
fullres_img_tensor = om._init_img_raw # [1, 1, D, H, W] full-res tensor
|
| 167 |
+
orig_sz = list(fullres_img_tensor.shape[2:])
|
| 168 |
+
print(f" Full-res padded shape: {orig_sz}")
|
| 169 |
+
|
| 170 |
+
# --- Load & standardize labels via OMorpher ---
|
| 171 |
+
masks_model = []
|
| 172 |
+
masks_fullres = []
|
| 173 |
+
for lk in label_keys:
|
| 174 |
+
lab = load_fullres_label(key, dataset, lk)
|
| 175 |
+
model_t, fullres_t = om._standardize_label(lab) # None → -1 placeholder
|
| 176 |
+
masks_model.append(model_t)
|
| 177 |
+
masks_fullres.append(fullres_t)
|
| 178 |
+
|
| 179 |
+
if masks_model:
|
| 180 |
+
mask = torch.cat(masks_model, dim=1) # [1, C_total, S, S, S]
|
| 181 |
+
fullres_msk_tensor = torch.cat(masks_fullres, dim=1) # [1, C_total, D, H, W]
|
| 182 |
+
else:
|
| 183 |
+
mask = None
|
| 184 |
+
fullres_msk_tensor = None
|
| 185 |
+
|
| 186 |
+
# --- Save target conditioning image (first subject) ---
|
| 187 |
+
if e <= 0:
|
| 188 |
+
target_img = img.clone().detach()
|
| 189 |
+
|
| 190 |
+
# --- Save original images at model resolution ---
|
| 191 |
+
image_original = img.cpu().numpy()
|
| 192 |
+
nib.save(
|
| 193 |
+
utils.converet_to_nibabel(image_original, ndims=hyp_parameters["ndims"]),
|
| 194 |
+
os.path.join(hyp_parameters["reg_img_savepath"],
|
| 195 |
+
utils.get_barcode([pid, e]) + ".nii.gz"),
|
| 196 |
+
)
|
| 197 |
+
if mask is not None:
|
| 198 |
+
mask_original = mask.cpu().numpy()
|
| 199 |
+
nib.save(
|
| 200 |
+
utils.converet_to_nibabel(mask_original, ndims=hyp_parameters["ndims"]),
|
| 201 |
+
os.path.join(hyp_parameters["reg_msk_savepath"],
|
| 202 |
+
utils.get_barcode([pid, e]) + "_GT.nii.gz"),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# --- Save original at full-res ---
|
| 206 |
+
nib.save(
|
| 207 |
+
utils.converet_to_nibabel(fullres_img_tensor, ndims=hyp_parameters["ndims"]),
|
| 208 |
+
os.path.join(reg_img_savepath_fullres,
|
| 209 |
+
utils.get_barcode([pid, e]) + ".nii.gz"),
|
| 210 |
+
)
|
| 211 |
+
if fullres_msk_tensor is not None:
|
| 212 |
+
nib.save(
|
| 213 |
+
utils.converet_to_nibabel(fullres_msk_tensor, ndims=hyp_parameters["ndims"]),
|
| 214 |
+
os.path.join(reg_msk_savepath_fullres,
|
| 215 |
+
utils.get_barcode([pid, e]) + "_GT.nii.gz"),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# --- Diffusion recovery via OMorpher ---
|
| 219 |
+
noise_step = hyp_parameters["start_noise_step"]
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
for im in range(1):
|
| 222 |
+
print(
|
| 223 |
+
f" Generating -> Subject-{pid}, Scan-{e} "
|
| 224 |
+
f'({im}/{hyp_parameters["aug_coe"]})',
|
| 225 |
+
end="\r",
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Set up OMorpher inputs
|
| 229 |
+
om.set_init_img(img)
|
| 230 |
+
om.set_cond_img(target_img.clone().detach())
|
| 231 |
+
|
| 232 |
+
# Run diffusion recovery
|
| 233 |
+
# T=[None, timesteps] in original means: no initial noise, full reverse diffusion
|
| 234 |
+
om.predict(
|
| 235 |
+
T=[None, hyp_parameters["timesteps"]],
|
| 236 |
+
proc_type=hyp_parameters["condition_type"],
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
ddf_comp = om.get_def()
|
| 240 |
+
|
| 241 |
+
# Reconstruct images at model resolution using OMorpher
|
| 242 |
+
img_rec = om.apply_def(img=img, ddf=ddf_comp, padding_mode="zeros")
|
| 243 |
+
|
| 244 |
+
# --- Save model-resolution results ---
|
| 245 |
+
denoise_imgs = img_rec.cpu().numpy()
|
| 246 |
+
|
| 247 |
+
nib.save(
|
| 248 |
+
utils.converet_to_nibabel(denoise_imgs, ndims=hyp_parameters["ndims"]),
|
| 249 |
+
os.path.join(
|
| 250 |
+
hyp_parameters["reg_img_savepath"],
|
| 251 |
+
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| 252 |
+
),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if mask is not None:
|
| 256 |
+
msk_rec = om.apply_def(
|
| 257 |
+
img=mask, ddf=ddf_comp,
|
| 258 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 259 |
+
)
|
| 260 |
+
denoise_msks = msk_rec.cpu().numpy()
|
| 261 |
+
nib.save(
|
| 262 |
+
utils.converet_to_nibabel(denoise_msks, ndims=hyp_parameters["ndims"]),
|
| 263 |
+
os.path.join(
|
| 264 |
+
hyp_parameters["reg_msk_savepath"],
|
| 265 |
+
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| 266 |
+
),
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# --- Upscale DDF and apply at full resolution via OMorpher ---
|
| 270 |
+
img_rec_fullres = om.apply_def(
|
| 271 |
+
img=fullres_img_tensor, ddf=ddf_comp, padding_mode="border",
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
if fullres_msk_tensor is not None:
|
| 275 |
+
msk_rec_fullres = om.apply_def(
|
| 276 |
+
img=fullres_msk_tensor, ddf=ddf_comp,
|
| 277 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Upscale DDF for saving
|
| 281 |
+
ddf_fullres = F.interpolate(
|
| 282 |
+
ddf_comp, size=orig_sz, mode="trilinear", align_corners=False,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# --- Save full-res results ---
|
| 286 |
+
nib.save(
|
| 287 |
+
utils.converet_to_nibabel(img_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| 288 |
+
os.path.join(
|
| 289 |
+
reg_img_savepath_fullres,
|
| 290 |
+
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| 291 |
+
),
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
if fullres_msk_tensor is not None:
|
| 295 |
+
nib.save(
|
| 296 |
+
utils.converet_to_nibabel(msk_rec_fullres, ndims=hyp_parameters["ndims"]),
|
| 297 |
+
os.path.join(
|
| 298 |
+
reg_msk_savepath_fullres,
|
| 299 |
+
utils.get_barcode([pid, e, im, noise_step]) + "_GT.nii.gz",
|
| 300 |
+
),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
nib.save(
|
| 304 |
+
utils.converet_to_nibabel(ddf_fullres, ndims=hyp_parameters["ndims"]),
|
| 305 |
+
os.path.join(
|
| 306 |
+
reg_ddf_savepath_fullres,
|
| 307 |
+
utils.get_barcode([pid, e, im, noise_step]) + ".nii.gz",
|
| 308 |
+
),
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if (im - hyp_parameters["start_noise_step"]) % 2 == 0:
|
| 312 |
+
noise_step = noise_step + hyp_parameters["noise_step"]
|
| 313 |
+
|
| 314 |
+
if e > 5:
|
| 315 |
+
break
|
Scripts/OM_reg_pair_ext.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OM_reg_pair.py — Paired registration using OMorpher with external dataset.
|
| 3 |
+
|
| 4 |
+
Loads fixed/moving pairs from a Learn2Reg-style JSON dataset file
|
| 5 |
+
(e.g. HippocampusMR_dataset.json) and registers each moving image to its
|
| 6 |
+
paired fixed image. Saves registered images, masks, DDFs, source originals,
|
| 7 |
+
and evaluation metrics (DSC, ASD, HD) per organ label.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
|
| 11 |
+
--dataset-json /path/to/HippocampusMR_dataset.json \
|
| 12 |
+
--split val
|
| 13 |
+
|
| 14 |
+
python Scripts/OM_reg_pair.py -C Config/config_om.yaml \
|
| 15 |
+
--dataset-json /path/to/HippocampusMR_dataset.json \
|
| 16 |
+
--split test -N 10
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
# Add project root to path so imports work from Scripts/
|
| 23 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 24 |
+
|
| 25 |
+
import csv
|
| 26 |
+
import json
|
| 27 |
+
import numpy as np
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
import nibabel as nib
|
| 31 |
+
import yaml
|
| 32 |
+
import SimpleITK as sitk
|
| 33 |
+
from scipy.ndimage import distance_transform_edt, binary_erosion
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
|
| 36 |
+
import utils
|
| 37 |
+
from Dataloader.dataLoader import reverse_axis_order
|
| 38 |
+
from OMorpher import OMorpher
|
| 39 |
+
|
| 40 |
+
# ========== CLI ==========
|
| 41 |
+
|
| 42 |
+
import argparse
|
| 43 |
+
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--config", "-C",
|
| 47 |
+
help="Path for the config file",
|
| 48 |
+
type=str,
|
| 49 |
+
default="Config/config_om.yaml",
|
| 50 |
+
required=False,
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--dataset-json",
|
| 54 |
+
help="Path to the Learn2Reg-style dataset JSON",
|
| 55 |
+
type=str,
|
| 56 |
+
default="~/rds/rds-airr-p51-TWhPgQVLKbA/Code/Registration/Dataset/HippocampusMR/HippocampusMR_dataset.json",
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--split",
|
| 60 |
+
help="Which registration split to use: 'val' or 'test'",
|
| 61 |
+
type=str,
|
| 62 |
+
choices=["val", "test"],
|
| 63 |
+
default="val",
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--max-samples", "-N",
|
| 67 |
+
help="Max number of pairs to register (0 = all)",
|
| 68 |
+
type=int,
|
| 69 |
+
default=0,
|
| 70 |
+
)
|
| 71 |
+
args = parser.parse_args()
|
| 72 |
+
|
| 73 |
+
# ========== Config ==========
|
| 74 |
+
|
| 75 |
+
with open(args.config, "r") as file:
|
| 76 |
+
hyp_parameters = yaml.safe_load(file)
|
| 77 |
+
print(hyp_parameters)
|
| 78 |
+
|
| 79 |
+
hyp_parameters["batchsize"] = 1
|
| 80 |
+
model_img_sz = hyp_parameters["img_size"]
|
| 81 |
+
timesteps = hyp_parameters["timesteps"]
|
| 82 |
+
condition_type = hyp_parameters["condition_type"]
|
| 83 |
+
ndims = hyp_parameters["ndims"]
|
| 84 |
+
|
| 85 |
+
# ========== Load external dataset JSON ==========
|
| 86 |
+
|
| 87 |
+
dataset_json_path = os.path.expanduser(args.dataset_json)
|
| 88 |
+
dataset_root = os.path.dirname(dataset_json_path)
|
| 89 |
+
|
| 90 |
+
with open(dataset_json_path, "r") as f:
|
| 91 |
+
dataset_meta = json.load(f)
|
| 92 |
+
|
| 93 |
+
dataset_name = dataset_meta.get("name", "UnknownDataset")
|
| 94 |
+
print(f"Dataset: {dataset_name}")
|
| 95 |
+
|
| 96 |
+
# Select registration split
|
| 97 |
+
if args.split == "val":
|
| 98 |
+
pairs = dataset_meta.get("registration_val", [])
|
| 99 |
+
elif args.split == "test":
|
| 100 |
+
pairs = dataset_meta.get("registration_test", [])
|
| 101 |
+
else:
|
| 102 |
+
raise ValueError(f"Unknown split: {args.split}")
|
| 103 |
+
|
| 104 |
+
if args.max_samples > 0:
|
| 105 |
+
pairs = pairs[: args.max_samples]
|
| 106 |
+
|
| 107 |
+
print(f"Split: {args.split}, Pairs: {len(pairs)}")
|
| 108 |
+
|
| 109 |
+
# Build label lookup: image basename -> label relative path
|
| 110 |
+
# from the "training" entries in the JSON
|
| 111 |
+
_label_lookup = {}
|
| 112 |
+
for entry in dataset_meta.get("training", []):
|
| 113 |
+
img_base = os.path.basename(entry["image"])
|
| 114 |
+
_label_lookup[img_base] = entry.get("label")
|
| 115 |
+
|
| 116 |
+
# Label class names (from JSON: "0": "background", "1": "head", "2": "tail")
|
| 117 |
+
_label_names = dataset_meta.get("labels", {}).get("0", {})
|
| 118 |
+
# Organ labels are all non-background classes
|
| 119 |
+
organ_label_ids = {int(k): v for k, v in _label_names.items() if int(k) > 0}
|
| 120 |
+
print(f"Organ labels for evaluation: {organ_label_ids}")
|
| 121 |
+
|
| 122 |
+
# ========== OMorpher setup ==========
|
| 123 |
+
|
| 124 |
+
epoch = f'{hyp_parameters["model_id_str"]}_{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}'
|
| 125 |
+
model_save_path = os.path.join(
|
| 126 |
+
f'Models/{hyp_parameters["data_name"]}_{hyp_parameters["net_name"]}/',
|
| 127 |
+
str(epoch) + ".pth",
|
| 128 |
+
)
|
| 129 |
+
print("Loading model from:", model_save_path)
|
| 130 |
+
|
| 131 |
+
om = OMorpher(
|
| 132 |
+
config=hyp_parameters,
|
| 133 |
+
checkpoint_path=model_save_path,
|
| 134 |
+
device=str(hyp_parameters.get("device", "cpu")),
|
| 135 |
+
)
|
| 136 |
+
print(om)
|
| 137 |
+
|
| 138 |
+
# ========== Output directories ==========
|
| 139 |
+
|
| 140 |
+
reg_img_savepath = hyp_parameters["reg_img_savepath"]
|
| 141 |
+
reg_msk_savepath = hyp_parameters["reg_msk_savepath"]
|
| 142 |
+
reg_ddf_savepath = hyp_parameters["reg_ddf_savepath"]
|
| 143 |
+
|
| 144 |
+
reg_img_savepath_fullres = reg_img_savepath.rstrip("/") + "_fullres/"
|
| 145 |
+
reg_msk_savepath_fullres = reg_msk_savepath.rstrip("/") + "_fullres/"
|
| 146 |
+
reg_ddf_savepath_fullres = reg_ddf_savepath.rstrip("/") + "_fullres/"
|
| 147 |
+
|
| 148 |
+
eval_dir = os.path.join(reg_img_savepath, "..", "eval")
|
| 149 |
+
|
| 150 |
+
for p in [
|
| 151 |
+
reg_img_savepath, reg_msk_savepath, reg_ddf_savepath,
|
| 152 |
+
reg_img_savepath_fullres, reg_msk_savepath_fullres, reg_ddf_savepath_fullres,
|
| 153 |
+
eval_dir,
|
| 154 |
+
]:
|
| 155 |
+
os.makedirs(p, exist_ok=True)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ========== Helper functions ==========
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def resolve_path(rel_path):
|
| 162 |
+
"""Resolve a relative path from the dataset JSON to an absolute path."""
|
| 163 |
+
if os.path.isabs(rel_path):
|
| 164 |
+
return rel_path
|
| 165 |
+
return os.path.normpath(os.path.join(dataset_root, rel_path))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def load_volume(nifti_path):
|
| 169 |
+
"""Load a NIfTI volume: axis reorder only.
|
| 170 |
+
|
| 171 |
+
OMorpher._standardize_img handles: normalize → pad-to-cube → resize to model res.
|
| 172 |
+
"""
|
| 173 |
+
volume = sitk.ReadImage(nifti_path)
|
| 174 |
+
volume = sitk.GetArrayFromImage(volume)
|
| 175 |
+
volume = reverse_axis_order(volume)
|
| 176 |
+
if volume.ndim == 4:
|
| 177 |
+
volume = volume[:, :, :, 0]
|
| 178 |
+
return volume
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def load_label(nifti_path):
|
| 182 |
+
"""Load a NIfTI label map: axis reorder only.
|
| 183 |
+
|
| 184 |
+
OMorpher._standardize_label handles: pad-to-cube → resize to model res (nearest).
|
| 185 |
+
"""
|
| 186 |
+
label = sitk.ReadImage(nifti_path)
|
| 187 |
+
label = sitk.GetArrayFromImage(label)
|
| 188 |
+
label = reverse_axis_order(label)
|
| 189 |
+
if label.ndim > 3:
|
| 190 |
+
label = label[:, :, :, 0]
|
| 191 |
+
return label
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def get_label_path_for_image(image_rel_path):
|
| 195 |
+
"""Find the label path for an image by looking up the training entries."""
|
| 196 |
+
img_base = os.path.basename(image_rel_path)
|
| 197 |
+
label_rel = _label_lookup.get(img_base)
|
| 198 |
+
if label_rel is None:
|
| 199 |
+
return None
|
| 200 |
+
return resolve_path(label_rel)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def split_label_classes(label_map, class_ids):
|
| 204 |
+
"""Split a multi-class label map into per-class binary masks.
|
| 205 |
+
|
| 206 |
+
Returns a dict {class_id: binary_numpy_array}.
|
| 207 |
+
"""
|
| 208 |
+
masks = {}
|
| 209 |
+
for cid in class_ids:
|
| 210 |
+
masks[cid] = (label_map == cid).astype(np.float32)
|
| 211 |
+
return masks
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def get_volume_name(path):
|
| 215 |
+
"""Extract a short name from a NIfTI file path."""
|
| 216 |
+
name = os.path.basename(path)
|
| 217 |
+
for ext in [".nii.gz", ".nii"]:
|
| 218 |
+
if name.endswith(ext):
|
| 219 |
+
name = name[: -len(ext)]
|
| 220 |
+
break
|
| 221 |
+
return name
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------- Evaluation metrics ----------
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _surface_distances(pred, gt):
|
| 228 |
+
"""Compute directed surface distances between two binary masks."""
|
| 229 |
+
pred_bool = pred > 0.5
|
| 230 |
+
gt_bool = gt > 0.5
|
| 231 |
+
|
| 232 |
+
if not np.any(pred_bool) or not np.any(gt_bool):
|
| 233 |
+
return None, None
|
| 234 |
+
|
| 235 |
+
struct = None
|
| 236 |
+
pred_surface = pred_bool ^ binary_erosion(pred_bool, structure=struct)
|
| 237 |
+
gt_surface = gt_bool ^ binary_erosion(gt_bool, structure=struct)
|
| 238 |
+
|
| 239 |
+
if not np.any(pred_surface):
|
| 240 |
+
pred_surface = pred_bool
|
| 241 |
+
if not np.any(gt_surface):
|
| 242 |
+
gt_surface = gt_bool
|
| 243 |
+
|
| 244 |
+
dt_gt = distance_transform_edt(~gt_surface)
|
| 245 |
+
dt_pred = distance_transform_edt(~pred_surface)
|
| 246 |
+
|
| 247 |
+
return dt_gt[pred_surface], dt_pred[gt_surface]
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def compute_dsc(pred, gt):
|
| 251 |
+
"""Dice Similarity Coefficient."""
|
| 252 |
+
pred_bool = pred > 0.5
|
| 253 |
+
gt_bool = gt > 0.5
|
| 254 |
+
intersection = np.sum(pred_bool & gt_bool)
|
| 255 |
+
denom = np.sum(pred_bool) + np.sum(gt_bool)
|
| 256 |
+
if denom == 0:
|
| 257 |
+
return 1.0
|
| 258 |
+
return 2.0 * float(intersection) / float(denom)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def compute_asd(pred, gt):
|
| 262 |
+
"""Average (symmetric) Surface Distance."""
|
| 263 |
+
d1, d2 = _surface_distances(pred, gt)
|
| 264 |
+
if d1 is None:
|
| 265 |
+
return float("nan")
|
| 266 |
+
return (np.mean(d1) + np.mean(d2)) / 2.0
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def compute_hd(pred, gt):
|
| 270 |
+
"""Hausdorff Distance (maximum of directed HDs)."""
|
| 271 |
+
d1, d2 = _surface_distances(pred, gt)
|
| 272 |
+
if d1 is None:
|
| 273 |
+
return float("nan")
|
| 274 |
+
return float(max(np.max(d1), np.max(d2)))
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_negdetj_pct(ddf, ndims=3):
|
| 278 |
+
"""Percent of voxels with negative Jacobian determinant.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
ddf: displacement field tensor [1, ndims, ...] or numpy array.
|
| 282 |
+
ndims: 2 or 3.
|
| 283 |
+
Returns:
|
| 284 |
+
Percentage of voxels where det(Jacobian) < 0.
|
| 285 |
+
"""
|
| 286 |
+
if isinstance(ddf, torch.Tensor):
|
| 287 |
+
ddf = ddf.detach().cpu().numpy()
|
| 288 |
+
# ddf shape: [1, C, ...] or [C, ...]
|
| 289 |
+
if ddf.ndim == ndims + 2:
|
| 290 |
+
ddf = ddf[0] # remove batch dim -> [C, ...]
|
| 291 |
+
|
| 292 |
+
# Compute spatial gradients via finite differences (forward diff, clipped)
|
| 293 |
+
if ndims == 3:
|
| 294 |
+
# ddf: [3, D, H, W]
|
| 295 |
+
# Derivatives along each spatial axis
|
| 296 |
+
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :, :])
|
| 297 |
+
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :, :])
|
| 298 |
+
duz_dx = np.diff(ddf[2], axis=0, append=ddf[2, -1:, :, :])
|
| 299 |
+
|
| 300 |
+
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:, :])
|
| 301 |
+
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:, :])
|
| 302 |
+
duz_dy = np.diff(ddf[2], axis=1, append=ddf[2, :, -1:, :])
|
| 303 |
+
|
| 304 |
+
dux_dz = np.diff(ddf[0], axis=2, append=ddf[0, :, :, -1:])
|
| 305 |
+
duy_dz = np.diff(ddf[1], axis=2, append=ddf[1, :, :, -1:])
|
| 306 |
+
duz_dz = np.diff(ddf[2], axis=2, append=ddf[2, :, :, -1:])
|
| 307 |
+
|
| 308 |
+
# Jacobian = I + du/dx
|
| 309 |
+
j11 = 1.0 + dux_dx; j12 = dux_dy; j13 = dux_dz
|
| 310 |
+
j21 = duy_dx; j22 = 1.0 + duy_dy; j23 = duy_dz
|
| 311 |
+
j31 = duz_dx; j32 = duz_dy; j33 = 1.0 + duz_dz
|
| 312 |
+
|
| 313 |
+
detj = (
|
| 314 |
+
j11 * (j22 * j33 - j23 * j32)
|
| 315 |
+
- j12 * (j21 * j33 - j23 * j31)
|
| 316 |
+
+ j13 * (j21 * j32 - j22 * j31)
|
| 317 |
+
)
|
| 318 |
+
elif ndims == 2:
|
| 319 |
+
dux_dx = np.diff(ddf[0], axis=0, append=ddf[0, -1:, :])
|
| 320 |
+
duy_dx = np.diff(ddf[1], axis=0, append=ddf[1, -1:, :])
|
| 321 |
+
|
| 322 |
+
dux_dy = np.diff(ddf[0], axis=1, append=ddf[0, :, -1:])
|
| 323 |
+
duy_dy = np.diff(ddf[1], axis=1, append=ddf[1, :, -1:])
|
| 324 |
+
|
| 325 |
+
detj = (1.0 + dux_dx) * (1.0 + duy_dy) - dux_dy * duy_dx
|
| 326 |
+
else:
|
| 327 |
+
raise ValueError(f"Unsupported ndims={ndims}")
|
| 328 |
+
|
| 329 |
+
n_neg = np.sum(detj < 0)
|
| 330 |
+
n_total = detj.size
|
| 331 |
+
return 100.0 * float(n_neg) / float(n_total)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ========== Prepare evaluation structures ==========
|
| 335 |
+
|
| 336 |
+
# metrics[class_id][metric_name][pair_idx] = value (post-registration)
|
| 337 |
+
metrics = {
|
| 338 |
+
cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| 339 |
+
for cid in organ_label_ids
|
| 340 |
+
}
|
| 341 |
+
# metrics_pre: same structure but for pre-registration (source vs target, no deformation)
|
| 342 |
+
metrics_pre = {
|
| 343 |
+
cid: {"dsc": {}, "asd": {}, "hd": {}}
|
| 344 |
+
for cid in organ_label_ids
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
# Per-pair DDF quality metric (not per-class)
|
| 348 |
+
negdetj_pct = {} # pair_idx -> percentage of negative Jacobian determinant
|
| 349 |
+
|
| 350 |
+
# Also collect per-pair info for the CSV
|
| 351 |
+
pair_info = [] # list of (pair_idx, fixed_name, moving_name)
|
| 352 |
+
|
| 353 |
+
# ========== Paired registration ==========
|
| 354 |
+
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
for pair_idx, pair in enumerate(tqdm(pairs, desc="Pairs")):
|
| 357 |
+
fixed_rel = pair["fixed"]
|
| 358 |
+
moving_rel = pair["moving"]
|
| 359 |
+
|
| 360 |
+
fixed_path = resolve_path(fixed_rel)
|
| 361 |
+
moving_path = resolve_path(moving_rel)
|
| 362 |
+
|
| 363 |
+
fixed_name = get_volume_name(fixed_rel)
|
| 364 |
+
moving_name = get_volume_name(moving_rel)
|
| 365 |
+
pair_tag = f"Tgt{pair_idx:04d}_Src{pair_idx:04d}"
|
| 366 |
+
|
| 367 |
+
pair_info.append((pair_idx, fixed_name, moving_name))
|
| 368 |
+
print(f"\n [{pair_idx}] Fixed: {fixed_name}, Moving: {moving_name}")
|
| 369 |
+
|
| 370 |
+
# --- Load volumes ---
|
| 371 |
+
fixed_vol = load_volume(fixed_path)
|
| 372 |
+
moving_vol = load_volume(moving_path)
|
| 373 |
+
|
| 374 |
+
# --- Load labels (if available) ---
|
| 375 |
+
fixed_label_path = get_label_path_for_image(fixed_rel)
|
| 376 |
+
moving_label_path = get_label_path_for_image(moving_rel)
|
| 377 |
+
|
| 378 |
+
fixed_label_map = None
|
| 379 |
+
moving_label_map = None
|
| 380 |
+
if fixed_label_path is not None and os.path.exists(fixed_label_path):
|
| 381 |
+
fixed_label_map = load_label(fixed_label_path)
|
| 382 |
+
if moving_label_path is not None and os.path.exists(moving_label_path):
|
| 383 |
+
moving_label_map = load_label(moving_label_path)
|
| 384 |
+
|
| 385 |
+
# --- Prepare tensors via OMorpher ---
|
| 386 |
+
# Set moving image as init (source to be deformed)
|
| 387 |
+
om.set_init_img(moving_vol)
|
| 388 |
+
src_img_model = om._init_img.clone()
|
| 389 |
+
src_img_fullres = om._init_img_raw.clone()
|
| 390 |
+
src_orig_sz = list(src_img_fullres.shape[2:])
|
| 391 |
+
|
| 392 |
+
# Set fixed image as conditioning (target)
|
| 393 |
+
om.set_init_img(fixed_vol)
|
| 394 |
+
tgt_img_model = om._init_img.clone()
|
| 395 |
+
tgt_img_fullres = om._init_img_raw.clone()
|
| 396 |
+
|
| 397 |
+
# Standardize labels through OMorpher
|
| 398 |
+
src_mask_model, src_mask_fullres = None, None
|
| 399 |
+
tgt_mask_model, tgt_mask_fullres = None, None
|
| 400 |
+
|
| 401 |
+
if moving_label_map is not None:
|
| 402 |
+
# Split into per-class binary masks, stack as channels
|
| 403 |
+
src_class_masks = split_label_classes(moving_label_map, organ_label_ids.keys())
|
| 404 |
+
src_masks_model = []
|
| 405 |
+
src_masks_fullres = []
|
| 406 |
+
om.set_init_img(moving_vol) # reset so _standardize_label uses correct shape
|
| 407 |
+
for cid in sorted(organ_label_ids.keys()):
|
| 408 |
+
m_model, m_fullres = om._standardize_label(src_class_masks[cid])
|
| 409 |
+
src_masks_model.append(m_model)
|
| 410 |
+
src_masks_fullres.append(m_fullres)
|
| 411 |
+
src_mask_model = torch.cat(src_masks_model, dim=1)
|
| 412 |
+
src_mask_fullres = torch.cat(src_masks_fullres, dim=1)
|
| 413 |
+
|
| 414 |
+
if fixed_label_map is not None:
|
| 415 |
+
tgt_class_masks = split_label_classes(fixed_label_map, organ_label_ids.keys())
|
| 416 |
+
tgt_masks_model = []
|
| 417 |
+
tgt_masks_fullres = []
|
| 418 |
+
om.set_init_img(fixed_vol) # reset so _standardize_label uses correct shape
|
| 419 |
+
for cid in sorted(organ_label_ids.keys()):
|
| 420 |
+
m_model, m_fullres = om._standardize_label(tgt_class_masks[cid])
|
| 421 |
+
tgt_masks_model.append(m_model)
|
| 422 |
+
tgt_masks_fullres.append(m_fullres)
|
| 423 |
+
tgt_mask_model = torch.cat(tgt_masks_model, dim=1)
|
| 424 |
+
tgt_mask_fullres = torch.cat(tgt_masks_fullres, dim=1)
|
| 425 |
+
|
| 426 |
+
# --- Save target (fixed) original at model resolution ---
|
| 427 |
+
nib.save(
|
| 428 |
+
utils.converet_to_nibabel(tgt_img_model, ndims=ndims),
|
| 429 |
+
os.path.join(reg_img_savepath, f"{pair_tag}_TGT_ORG.nii.gz"),
|
| 430 |
+
)
|
| 431 |
+
if tgt_mask_model is not None:
|
| 432 |
+
nib.save(
|
| 433 |
+
utils.converet_to_nibabel(tgt_mask_model, ndims=ndims),
|
| 434 |
+
os.path.join(reg_msk_savepath, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# --- Save source (moving) original at model resolution ---
|
| 438 |
+
nib.save(
|
| 439 |
+
utils.converet_to_nibabel(src_img_model, ndims=ndims),
|
| 440 |
+
os.path.join(reg_img_savepath, f"Src{pair_idx:04d}_ORG.nii.gz"),
|
| 441 |
+
)
|
| 442 |
+
if src_mask_model is not None:
|
| 443 |
+
nib.save(
|
| 444 |
+
utils.converet_to_nibabel(src_mask_model, ndims=ndims),
|
| 445 |
+
os.path.join(reg_msk_savepath, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# --- Save target original at full resolution ---
|
| 449 |
+
nib.save(
|
| 450 |
+
utils.converet_to_nibabel(tgt_img_fullres, ndims=ndims),
|
| 451 |
+
os.path.join(reg_img_savepath_fullres, f"{pair_tag}_TGT_ORG.nii.gz"),
|
| 452 |
+
)
|
| 453 |
+
if tgt_mask_fullres is not None:
|
| 454 |
+
nib.save(
|
| 455 |
+
utils.converet_to_nibabel(tgt_mask_fullres, ndims=ndims),
|
| 456 |
+
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_TGT_ORG_GT.nii.gz"),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# --- Save source original at full resolution ---
|
| 460 |
+
nib.save(
|
| 461 |
+
utils.converet_to_nibabel(src_img_fullres, ndims=ndims),
|
| 462 |
+
os.path.join(reg_img_savepath_fullres, f"Src{pair_idx:04d}_ORG.nii.gz"),
|
| 463 |
+
)
|
| 464 |
+
if src_mask_fullres is not None:
|
| 465 |
+
nib.save(
|
| 466 |
+
utils.converet_to_nibabel(src_mask_fullres, ndims=ndims),
|
| 467 |
+
os.path.join(reg_msk_savepath_fullres, f"Src{pair_idx:04d}_ORG_GT.nii.gz"),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
# --- Register moving to fixed ---
|
| 471 |
+
om.set_init_img(src_img_model)
|
| 472 |
+
om.set_cond_img(tgt_img_model.clone().detach())
|
| 473 |
+
|
| 474 |
+
om.predict(
|
| 475 |
+
T=[None, timesteps],
|
| 476 |
+
proc_type=condition_type,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
ddf_comp = om.get_def()
|
| 480 |
+
|
| 481 |
+
# --- DDF quality: percent negative Jacobian determinant ---
|
| 482 |
+
neg_pct = compute_negdetj_pct(ddf_comp, ndims=ndims)
|
| 483 |
+
negdetj_pct[pair_idx] = neg_pct
|
| 484 |
+
print(f" %|J|<0 = {neg_pct:.4f}%")
|
| 485 |
+
|
| 486 |
+
# --- Model-resolution registered image ---
|
| 487 |
+
img_rec = om.apply_def(
|
| 488 |
+
img=src_img_model, ddf=ddf_comp, padding_mode="zeros",
|
| 489 |
+
)
|
| 490 |
+
nib.save(
|
| 491 |
+
utils.converet_to_nibabel(img_rec, ndims=ndims),
|
| 492 |
+
os.path.join(reg_img_savepath, f"{pair_tag}.nii.gz"),
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# --- Model-resolution registered mask ---
|
| 496 |
+
msk_rec = None
|
| 497 |
+
if src_mask_model is not None:
|
| 498 |
+
msk_rec = om.apply_def(
|
| 499 |
+
img=src_mask_model, ddf=ddf_comp,
|
| 500 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 501 |
+
)
|
| 502 |
+
nib.save(
|
| 503 |
+
utils.converet_to_nibabel(msk_rec, ndims=ndims),
|
| 504 |
+
os.path.join(reg_msk_savepath, f"{pair_tag}_GT.nii.gz"),
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
# --- Model-resolution DDF ---
|
| 508 |
+
nib.save(
|
| 509 |
+
utils.converet_to_nibabel(ddf_comp, ndims=ndims),
|
| 510 |
+
os.path.join(reg_ddf_savepath, f"{pair_tag}.nii.gz"),
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# --- Full-resolution registered image ---
|
| 514 |
+
img_rec_fullres = om.apply_def(
|
| 515 |
+
img=src_img_fullres, ddf=ddf_comp, padding_mode="border",
|
| 516 |
+
)
|
| 517 |
+
nib.save(
|
| 518 |
+
utils.converet_to_nibabel(img_rec_fullres, ndims=ndims),
|
| 519 |
+
os.path.join(reg_img_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# --- Full-resolution registered mask ---
|
| 523 |
+
msk_rec_fullres = None
|
| 524 |
+
if src_mask_fullres is not None:
|
| 525 |
+
msk_rec_fullres = om.apply_def(
|
| 526 |
+
img=src_mask_fullres, ddf=ddf_comp,
|
| 527 |
+
padding_mode="zeros", resample_mode="nearest",
|
| 528 |
+
)
|
| 529 |
+
nib.save(
|
| 530 |
+
utils.converet_to_nibabel(msk_rec_fullres, ndims=ndims),
|
| 531 |
+
os.path.join(reg_msk_savepath_fullres, f"{pair_tag}_GT.nii.gz"),
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# --- Full-resolution DDF ---
|
| 535 |
+
ddf_fullres = F.interpolate(
|
| 536 |
+
ddf_comp, size=src_orig_sz, mode="trilinear", align_corners=False,
|
| 537 |
+
)
|
| 538 |
+
nib.save(
|
| 539 |
+
utils.converet_to_nibabel(ddf_fullres, ndims=ndims),
|
| 540 |
+
os.path.join(reg_ddf_savepath_fullres, f"{pair_tag}.nii.gz"),
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
# --- Evaluation metrics (full-res organ labels) ---
|
| 544 |
+
if (
|
| 545 |
+
organ_label_ids
|
| 546 |
+
and src_mask_fullres is not None
|
| 547 |
+
and tgt_mask_fullres is not None
|
| 548 |
+
):
|
| 549 |
+
for ch_idx, cid in enumerate(sorted(organ_label_ids.keys())):
|
| 550 |
+
lk = organ_label_ids[cid]
|
| 551 |
+
tgt_mask_np = tgt_mask_fullres[0, ch_idx].cpu().numpy()
|
| 552 |
+
src_mask_np = src_mask_fullres[0, ch_idx].cpu().numpy()
|
| 553 |
+
|
| 554 |
+
if np.all(tgt_mask_np < 0) or np.all(src_mask_np < 0):
|
| 555 |
+
continue
|
| 556 |
+
|
| 557 |
+
# Pre-registration: source vs target (no deformation)
|
| 558 |
+
pre_dsc = compute_dsc(src_mask_np, tgt_mask_np)
|
| 559 |
+
pre_asd = compute_asd(src_mask_np, tgt_mask_np)
|
| 560 |
+
pre_hd = compute_hd(src_mask_np, tgt_mask_np)
|
| 561 |
+
|
| 562 |
+
metrics_pre[cid]["dsc"][pair_idx] = pre_dsc
|
| 563 |
+
metrics_pre[cid]["asd"][pair_idx] = pre_asd
|
| 564 |
+
metrics_pre[cid]["hd"][pair_idx] = pre_hd
|
| 565 |
+
|
| 566 |
+
# Post-registration: registered mask vs target
|
| 567 |
+
if msk_rec_fullres is not None:
|
| 568 |
+
reg_mask_np = msk_rec_fullres[0, ch_idx].cpu().numpy()
|
| 569 |
+
post_dsc = compute_dsc(reg_mask_np, tgt_mask_np)
|
| 570 |
+
post_asd = compute_asd(reg_mask_np, tgt_mask_np)
|
| 571 |
+
post_hd = compute_hd(reg_mask_np, tgt_mask_np)
|
| 572 |
+
else:
|
| 573 |
+
post_dsc = float("nan")
|
| 574 |
+
post_asd = float("nan")
|
| 575 |
+
post_hd = float("nan")
|
| 576 |
+
|
| 577 |
+
metrics[cid]["dsc"][pair_idx] = post_dsc
|
| 578 |
+
metrics[cid]["asd"][pair_idx] = post_asd
|
| 579 |
+
metrics[cid]["hd"][pair_idx] = post_hd
|
| 580 |
+
|
| 581 |
+
print(
|
| 582 |
+
f" [{lk}] PRE DSC={pre_dsc:.4f} ASD={pre_asd:.2f} HD={pre_hd:.2f}"
|
| 583 |
+
)
|
| 584 |
+
print(
|
| 585 |
+
f" [{lk}] POST DSC={post_dsc:.4f} ASD={post_asd:.2f} HD={post_hd:.2f}"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
print("\nPaired registration complete.")
|
| 589 |
+
|
| 590 |
+
# ========== Write evaluation CSVs ==========
|
| 591 |
+
|
| 592 |
+
n_pairs = len(pairs)
|
| 593 |
+
|
| 594 |
+
def _fmt(val):
|
| 595 |
+
if val is None:
|
| 596 |
+
return ""
|
| 597 |
+
if np.isnan(val):
|
| 598 |
+
return "NaN"
|
| 599 |
+
return f"{val:.6f}"
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# --- Per-pair %|J|<0 CSV ---
|
| 603 |
+
negdetj_csv_path = os.path.join(eval_dir, "negdetj_pct.csv")
|
| 604 |
+
with open(negdetj_csv_path, "w", newline="") as f:
|
| 605 |
+
writer = csv.writer(f)
|
| 606 |
+
writer.writerow(["pair_idx", "fixed", "moving", "negdetj_pct"])
|
| 607 |
+
for pi, fixed_name, moving_name in pair_info:
|
| 608 |
+
writer.writerow([pi, fixed_name, moving_name, _fmt(negdetj_pct.get(pi))])
|
| 609 |
+
print(f"Saved {negdetj_csv_path}")
|
| 610 |
+
|
| 611 |
+
for cid in sorted(organ_label_ids.keys()):
|
| 612 |
+
lk = organ_label_ids[cid]
|
| 613 |
+
prefix = f"{lk}_" if len(organ_label_ids) > 1 else ""
|
| 614 |
+
|
| 615 |
+
for metric_name in ["dsc", "asd", "hd"]:
|
| 616 |
+
mn_upper = metric_name.upper()
|
| 617 |
+
csv_path = os.path.join(eval_dir, f"{prefix}{metric_name}.csv")
|
| 618 |
+
with open(csv_path, "w", newline="") as f:
|
| 619 |
+
writer = csv.writer(f)
|
| 620 |
+
writer.writerow([
|
| 621 |
+
"pair_idx", "fixed", "moving",
|
| 622 |
+
f"pre_{mn_upper}", f"post_{mn_upper}",
|
| 623 |
+
])
|
| 624 |
+
for pi, fixed_name, moving_name in pair_info:
|
| 625 |
+
pre_val = metrics_pre[cid][metric_name].get(pi)
|
| 626 |
+
post_val = metrics[cid][metric_name].get(pi)
|
| 627 |
+
writer.writerow([
|
| 628 |
+
pi, fixed_name, moving_name,
|
| 629 |
+
_fmt(pre_val), _fmt(post_val),
|
| 630 |
+
])
|
| 631 |
+
print(f"Saved {csv_path}")
|
| 632 |
+
|
| 633 |
+
# --- Overall summary ---
|
| 634 |
+
overall_path = os.path.join(eval_dir, "overall.csv")
|
| 635 |
+
with open(overall_path, "w", newline="") as f:
|
| 636 |
+
writer = csv.writer(f)
|
| 637 |
+
writer.writerow([
|
| 638 |
+
"label", "metric",
|
| 639 |
+
"pre_mean", "pre_std",
|
| 640 |
+
"post_mean", "post_std",
|
| 641 |
+
"n_pairs",
|
| 642 |
+
])
|
| 643 |
+
# %|J|<0 summary (not per-label)
|
| 644 |
+
negdetj_vals = [v for v in negdetj_pct.values() if not np.isnan(v)]
|
| 645 |
+
writer.writerow([
|
| 646 |
+
"ALL",
|
| 647 |
+
"%|J|<0",
|
| 648 |
+
"", "",
|
| 649 |
+
_fmt(np.mean(negdetj_vals) if negdetj_vals else float("nan")),
|
| 650 |
+
_fmt(np.std(negdetj_vals) if negdetj_vals else float("nan")),
|
| 651 |
+
len(negdetj_vals),
|
| 652 |
+
])
|
| 653 |
+
for cid in sorted(organ_label_ids.keys()):
|
| 654 |
+
lk = organ_label_ids[cid]
|
| 655 |
+
for metric_name in ["dsc", "asd", "hd"]:
|
| 656 |
+
pre_vals = [
|
| 657 |
+
v for v in metrics_pre[cid][metric_name].values()
|
| 658 |
+
if not np.isnan(v)
|
| 659 |
+
]
|
| 660 |
+
post_vals = [
|
| 661 |
+
v for v in metrics[cid][metric_name].values()
|
| 662 |
+
if not np.isnan(v)
|
| 663 |
+
]
|
| 664 |
+
pre_mean = np.mean(pre_vals) if pre_vals else float("nan")
|
| 665 |
+
pre_std = np.std(pre_vals) if pre_vals else float("nan")
|
| 666 |
+
post_mean = np.mean(post_vals) if post_vals else float("nan")
|
| 667 |
+
post_std = np.std(post_vals) if post_vals else float("nan")
|
| 668 |
+
n = max(len(pre_vals), len(post_vals))
|
| 669 |
+
writer.writerow([
|
| 670 |
+
lk,
|
| 671 |
+
metric_name.upper(),
|
| 672 |
+
_fmt(pre_mean), _fmt(pre_std),
|
| 673 |
+
_fmt(post_mean), _fmt(post_std),
|
| 674 |
+
n,
|
| 675 |
+
])
|
| 676 |
+
print(f"Saved {overall_path}")
|