Omini3D / OM_train_3modes_opt.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OM_train_3modes_opt.py — Optimized 3-mode training (diffusion + contrastive + registration).
Speed optimizations over OM_train_3modes.py (all mathematically equivalent):
1. DataLoader: num_workers, pin_memory, persistent_workers for I/O overlap
2. optimizer.zero_grad(set_to_none=True) — avoids zero-fill overhead
3. Fixed-length T_regist (16 steps) — avoids XPU dynamic shape recompilation
4. Removed redundant x0.to(device) call
5. Uses diffuser_opt.DeformDDPM (hoisted clone, no *0 redundancy, OptSTN, inference_mode)
6. Uses losses_opt.MSLNCC/LNCC (register_buffer for kernels)
7. Pre-compute proc_type lists to reduce Python overhead in hot loop
8. Uses OptRecMulModMutAttnNet (cached resample tensors, ~300 fewer CPU→GPU transfers)
9. Uses OptSTN for ddf_stn (register_buffer, no per-call .to())
"""
import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
import gc
import torch
import torchvision
from torch import nn
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
from Diffusion.diffuser_opt import DeformDDPM
from Diffusion.networks_opt import get_net_opt, OptSTN
from torchvision.transforms import Lambda
import torch.nn.functional as F
import Diffusion.losses_opt as losses
import random
import glob
import numpy as np
import utils
from tqdm import tqdm
from Dataloader.dataloader0 import get_dataloader
from Dataloader.dataLoader import *
from Dataloader.dataloader_utils import thresh_img
import yaml
import argparse
####################
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
###############
def ddp_setup(rank, world_size):
"""
Args:
rank: Unique identifier of each process
world_size: Total number of processes
"""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Auto-detect: use DDP only when multiple CUDA GPUs are available
use_distributed = torch.cuda.is_available() and torch.cuda.device_count() > 1
# use_distributed = True
# use_distributed = False
EPS = 1e-5
MSK_EPS = 0.01
TEXT_EMBED_PROB = 0.7
AUG_RESAMPLE_PROB = 0.5
LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0] # [ang, dist, reg]
# LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf]
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128] # [imgsim, imgmse, ddf]
DIFF_REG_BATCH_RATIO = 2
LOSS_WEIGHT_CONTRASTIVE = 1.0
CONTRASTIVE_STEP_RATIO = 2
# OPT: Fixed registration timestep count to avoid XPU dynamic shape recompilation
FIXED_T_REGIST_LEN = 16
# OPT: DataLoader workers (set to 0 to disable multiprocessing if needed)
NUM_WORKERS = 4
PIN_MEMORY = True
# AUG_PERMUTE_PROB = 0.35
parser = argparse.ArgumentParser()
# config_file_path = 'Config/config_cmr.yaml'
parser.add_argument(
"--config",
"-C",
help="Path for the config file",
type=str,
# default="Config/config_cmr.yaml",
# default="Config/config_lct.yaml",
default="Config/config_all.yaml",
required=False,
)
parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
parser.add_argument("--num-workers", type=int, default=NUM_WORKERS, help="DataLoader num_workers (default: 4)")
args = parser.parse_args()
#=======================================================================================================================
class _DummyIndiv(torch.utils.data.Dataset):
def __init__(self, n, sz, embd_dim=1024):
self.n, self.sz, self.embd_dim = n, sz, embd_dim
def __len__(self): return self.n
def __getitem__(self, i):
return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
class _DummyPair(torch.utils.data.Dataset):
def __init__(self, n, sz, embd_dim=1024):
self.n, self.sz, self.embd_dim = n, sz, embd_dim
def __len__(self): return self.n
def __getitem__(self, i):
return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
np.random.randn(self.embd_dim).astype(np.float32),
np.random.randn(self.embd_dim).astype(np.float32))
def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
if use_distributed:
ddp_setup(rank,world_size)
if torch.distributed.is_initialized():
print(f"World size: {torch.distributed.get_world_size()}")
print(f"Communication backend: {torch.distributed.get_backend()}")
gpu_id = rank
# Load the YAML file into a dictionary
with open(args.config, 'r') as file:
hyp_parameters = yaml.safe_load(file)
if args.batchsize > 0:
hyp_parameters['batchsize'] = args.batchsize
print(hyp_parameters)
# epoch_per_save=10
epoch_per_save=hyp_parameters['epoch_per_save']
data_name=hyp_parameters['data_name']
net_name = hyp_parameters['net_name']
Net=get_net_opt(net_name)
suffix_pth=f'_{data_name}_{net_name}.pth'
model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
model_dir=model_save_path
transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']])
# OPT: DataLoader with num_workers, pin_memory, persistent_workers
num_workers = args.num_workers
use_pin_memory = PIN_MEMORY and hyp_parameters["device"] != "cpu"
if args.dummy_samples > 0:
dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
else:
dataset = OMDataset_indiv(transform=None)
datasetp = OMDataset_pair(transform=None)
train_loader = DataLoader(
dataset,
batch_size=hyp_parameters['batchsize'],
shuffle=True,
drop_last=True,
num_workers=num_workers, # OPT
pin_memory=use_pin_memory, # OPT
persistent_workers=num_workers > 0, # OPT
)
train_loader_p = DataLoader(
datasetp,
batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
shuffle=True,
drop_last=True,
num_workers=num_workers, # OPT
pin_memory=use_pin_memory, # OPT
persistent_workers=num_workers > 0, # OPT
)
Deformddpm = DeformDDPM(
network=Net(
n_steps=hyp_parameters["timesteps"],
ndims=hyp_parameters["ndims"],
num_input_chn = hyp_parameters["num_input_chn"],
res = hyp_parameters['img_size']
),
n_steps=hyp_parameters["timesteps"],
image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
device=hyp_parameters["device"],
batch_size=hyp_parameters["batchsize"],
img_pad_mode=hyp_parameters["img_pad_mode"],
v_scale=hyp_parameters["v_scale"],
)
ddf_stn = OptSTN(
img_sz=hyp_parameters["img_size"],
ndims=hyp_parameters["ndims"],
# padding_mode="zeros",
padding_mode=hyp_parameters["padding_mode"],
device=hyp_parameters["device"],
)
if use_distributed:
Deformddpm.to(rank)
Deformddpm = DDP(Deformddpm, device_ids=[rank])
ddf_stn.to(rank)
else:
Deformddpm.to(hyp_parameters["device"])
ddf_stn.to(hyp_parameters["device"])
# ddf_stn = DDP(ddf_stn, device_ids=[rank])
# mse = nn.MSELoss()
# loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"])
# loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"])
loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
# loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"])
loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
loss_imgsim = losses.MSLNCC()
loss_imgmse = losses.LMSE()
optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
# check for existing models
if not os.path.exists(model_dir):
os.makedirs(model_dir, exist_ok=True)
model_files = glob.glob(os.path.join(model_dir, "*.pth"))
model_files.sort()
if model_files:
if gpu_id == 0:
print(model_files)
initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, model_files[-1], use_distributed=use_distributed)
else:
initial_epoch = 0
if gpu_id == 0:
print('len_train_data: ',len(dataset))
# Training loop
for epoch in range(initial_epoch,hyp_parameters["epoch"]):
epoch_loss_tot = 0.0
epoch_loss_gen_d = 0.0
epoch_loss_gen_a = 0.0
epoch_loss_reg = 0.0
epoch_loss_regist = 0.0
epoch_loss_imgsim = 0.0
epoch_loss_imgmse = 0.0
epoch_loss_ddfreg = 0.0
epoch_loss_contrastive = 0.0
# Set model inside to train model
Deformddpm.train()
loss_nan_step = 0 # yu: count the number of nan loss steps
total = min(len(train_loader), len(train_loader_p))
for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
# ==========================================================================
# diffusion train on single image
[x0,embd] = batch # for om dataset
x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
embd_in = embd_dev
else:
embd_in = None
n = x0.size()[0] # batch_size -> n
# OPT: removed redundant x0.to(device) — already done above
blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
# random deformation + rotation
if hyp_parameters["ndims"]>2:
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
x0 = utils.random_resample(x0, deform_scale=0)
# elif np.random.uniform(0,1)<AUG_RESAMPLE_PROB+AUG_PERMUTE_PROB:
else:
[x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
# x0 = transformer(x0)
if hyp_parameters['noise_scale']>0:
if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
# Picking some noise for each of the images in the batch, a timestep and the respective alpha_bars
t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
hyp_parameters["device"]
) # pick up a seq of rand number from 0 to 'timestep'
proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
ddpm = Deformddpm.module if use_distributed else Deformddpm
cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in) # forward diffusion process
loss_tot=0
loss_ddf = loss_reg(pre_dvf_I,img=x0)
trm_pred = ddf_stn(pre_dvf_I, dvf_I)
loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
# >> JZ: print nan in x0
if torch.isnan(x0).any():
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
# >> JZ: print loss of ddf
if loss_ddf>0.001:
print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
# yu: check if loss_tot==nan or inf
if torch.isnan(loss_tot) or torch.isinf(loss_tot):
print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
loss_nan_step += 1
continue
if loss_nan_step > 5:
print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
optimizer.zero_grad(set_to_none=True) # OPT: set_to_none faster than zero-fill
loss_tot.backward()
optimizer.step()
epoch_loss_tot += loss_tot.item() / total
epoch_loss_gen_d += loss_gen_d.item() / total
epoch_loss_gen_a += loss_gen_a.item() / total
epoch_loss_reg += loss_ddf.item() / total
# ==========================================================================
# contrastive train on single image (text-image alignment)
loss_contra_val = None
if step % CONTRASTIVE_STEP_RATIO == 0:
raw_network = Deformddpm.module.network if use_distributed else Deformddpm.network
n_contra = x0.size()[0]
t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
_ = raw_network(x=(x0 * blind_mask).detach(), y=cond_img.detach(), t=t_contra, text=None)
if hasattr(raw_network, 'img_embd') and raw_network.img_embd is not None:
img_embd = raw_network.img_embd # [B, 1024]
loss_contra = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean())
optimizer.zero_grad(set_to_none=True) # OPT
loss_contra.backward()
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.05)
optimizer.step()
loss_contra_val = loss_contra.item()
epoch_loss_contrastive += loss_contra_val / total
else:
if gpu_id == 0:
print(f"*** Warning: Network does not have img_embd attribute for contrastive loss at epoch {epoch}, step {step}.")
# ==========================================================================
# registration train on paired images
if step%train_mode_ratio == 0:
[x1, y1, _, embd_y] = batch_p
if np.random.uniform(0,1)<TEXT_EMBED_PROB:
embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
else:
embd_y = None
x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
n = x1.size()[0] # batch_size -> n
[x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
if hyp_parameters['noise_scale']>0:
[x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
x1 = x1 * random_scale + random_shift
y1 = y1 * random_scale + random_shift
scale_regist = np.random.uniform(0.0,0.7)
# OPT: fixed-length T_regist to avoid XPU dynamic shape recompilation
# Sample FIXED_T_REGIST_LEN timesteps (was: random 8-16), always same loop length
t_pool = list(range(int(hyp_parameters["timesteps"] * scale_regist), hyp_parameters["timesteps"]))
select_timestep = min(FIXED_T_REGIST_LEN, len(t_pool))
T_regist = sorted(random.sample(t_pool, select_timestep), reverse=True)
T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
msk_tgt = msk_tgt+MSK_EPS
[ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process
loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process
loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process
loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process
loss_regist = 0
loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
# >> JZ: print nan in x0
if torch.isnan(x0).any():
print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
# >> JZ: print loss of ddf
if loss_ddf1>0.002:
print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
optimizer.zero_grad(set_to_none=True) # OPT
loss_regist.backward()
torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.2)
optimizer.step()
epoch_loss_regist += loss_regist.item() / total
epoch_loss_imgsim += loss_sim.item() / total
epoch_loss_imgmse += loss_mse.item() / total
epoch_loss_ddfreg += loss_ddf1.item() / total
if step % 10 == 0:
print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
if loss_contra_val is not None:
print(f' loss_contrastive: {loss_contra_val:.6f}')
print(f' loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
if 1:
print('==================')
print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
print(f' loss_contrastive: {epoch_loss_contrastive}')
print(f' loss_regist: {epoch_loss_regist} = {epoch_loss_imgsim} (imgsim) + {epoch_loss_imgmse} (imgmse) + {epoch_loss_ddfreg} (ddf)')
print('==================')
if 0 == epoch % epoch_per_save:
save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
# break # FOR TESTING
if not use_distributed:
print(f"saved in {save_dir}")
# torch.save(Deformddpm.state_dict(), save_dir)
torch.save({
'model_state_dict': Deformddpm.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, save_dir)
elif gpu_id == 0:
print(f"saved in {save_dir}")
# torch.save(Deformddpm.module.state_dict(), save_dir)
torch.save({
'model_state_dict': Deformddpm.module.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}, save_dir)
# Resource cleanup at the end of training
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if use_distributed and dist.is_initialized():
dist.destroy_process_group()
def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
if gpu_id == 0:
# if 0:
utils.print_memory_usage("Before Loading Model")
if torch.cuda.is_available():
gc.collect()
torch.cuda.empty_cache()
checkpoint = torch.load(model_file, map_location='cpu')
if use_distributed:
Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
else:
Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
if load_strict:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
utils.print_memory_usage("After Loading Checkpoint on GPU")
if use_distributed:
# Broadcast model weights from rank 0 to all other GPUs
dist.barrier()
for param in Deformddpm.parameters():
dist.broadcast(param.data, src=0) # Synchronize model across ranks
dist.barrier()
for param_group in optimizer.param_groups:
for param in param_group['params']:
if param.grad is not None:
dist.broadcast(param.grad, src=0) # Sync optimizer gradients
# initial_epoch = checkpoint['epoch'] + 1
# get the epoch number from the filename and add 1 to set as initial_epoch
initial_epoch = int(os.path.basename(model_file).split('.')[0][:6]) + 1
return initial_epoch, Deformddpm, optimizer
if __name__ == "__main__":
if use_distributed:
world_size = torch.cuda.device_count()
print(f"Distributed GPU number = {world_size}")
mp.spawn(main_train,args = (world_size,),nprocs = world_size)
else:
main_train(0,1)