import os, sys, contextlib ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.append(ROOT_DIR) import gc import torch import torchvision from torch import nn from torchvision.utils import save_image from torch.utils.data import DataLoader from torch.optim import Adam, SGD from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN from torchvision.transforms import Lambda import torch.nn.functional as F import Diffusion.losses as losses import random import glob import numpy as np import utils from tqdm import tqdm from Dataloader.dataloader0 import get_dataloader from Dataloader.dataLoader import * from Dataloader.dataloader_utils import thresh_img import yaml import argparse # XPU support: import Intel Extension for PyTorch and oneCCL bindings if available try: import intel_extension_for_pytorch as ipex except ImportError: ipex = None try: import oneccl_bindings_for_pytorch except (ImportError, Exception) as e: print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}") #################### import torch.multiprocessing as mp from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist # from torch.distributed import init_process_group ############### def _device_available(device_type): if device_type == 'xpu': return hasattr(torch, 'xpu') and torch.xpu.is_available() return torch.cuda.is_available() def _device_count(device_type): if device_type == 'xpu': return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0 return torch.cuda.device_count() def _set_device(rank, device_type): if device_type == 'xpu': torch.xpu.set_device(rank) else: torch.cuda.set_device(rank) def _empty_cache(device_type): if device_type == 'xpu' and hasattr(torch, 'xpu'): torch.xpu.empty_cache() elif torch.cuda.is_available(): torch.cuda.empty_cache() def ddp_setup(rank, world_size): """ Args: rank: Unique identifier of each process (local_rank when launched by torchrun) world_size: Total number of processes """ backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl" if "LOCAL_RANK" in os.environ: # Launched by torchrun: MASTER_ADDR, MASTER_PORT, RANK, WORLD_SIZE already set dist.init_process_group(backend=backend) _set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE) else: # Single-node mp.spawn os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "12355" dist.init_process_group(backend=backend, rank=rank, world_size=world_size) _set_device(rank, DEVICE_TYPE) EPS = 1e-5 MSK_EPS = 0.01 TEXT_EMBED_PROB = 0.5 AUG_RESAMPLE_PROB = 0.5 LOSS_WEIGHTS_DIFF = [2.0, 1.0, 4.0] # [ang, dist, reg] # LOSS_WEIGHTS_REGIST = [9.0, 1.0, 16.0] # [imgsim, imgmse, ddf] LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2] # [imgsim, imgmse, ddf] DIFF_REG_BATCH_RATIO = 2 LOSS_WEIGHT_CONTRASTIVE = 1e-4 REGISTRATION_STEP_RATIO = 1 CONTRASTIVE_STEP_RATIO = 1 MID_EPOCH_SAVE_STEPS = 10 # Save mid-epoch checkpoint every N steps for crash recovery. # XPU autograd leaks ~1.0 GiB/step of device memory (Intel bug). # With gradient checkpointing, training survives ~26 steps from fresh start, # but fewer when carrying leaked memory from previous epoch. # Save every 10 steps to minimize lost work on OOM crash. EXIT_CODE_RESTART = 42 # Exit code signaling proactive restart (not a crash). # AUG_PERMUTE_PROB = 0.35 parser = argparse.ArgumentParser() # config_file_path = 'Config/config_cmr.yaml' parser.add_argument( "--config", "-C", help="Path for the config file", type=str, # default="Config/config_cmr.yaml", # default="Config/config_lct.yaml", default="Config/config_all.yaml", required=False, ) parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)") parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)") parser.add_argument("--max-steps-before-restart", type=int, default=0, help="Proactive restart: exit after N training steps to reset XPU memory leak. " "0=disabled (rely on OOM crash + auto-resubmit). " "Recommended: 20 for XPU (survives ~26 steps max).") parser.add_argument("--no-save", action="store_true", help="Disable all checkpoint saving (for diagnostic/validation runs)") parser.add_argument("--reset-optimizer", action="store_true", help="Skip optimizer state loading from checkpoint (use when architecture changed)") parser.add_argument("--eval-only", action="store_true", help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)") args = parser.parse_args() # Read config early to determine device type for DDP setup with open(args.config, 'r') as _f: _cfg = yaml.safe_load(_f) DEVICE_TYPE = _cfg.get('device', 'cuda') # 'cuda' or 'xpu' # Auto-detect: use DDP only when multiple devices are available use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1 # use_distributed = True # use_distributed = False #======================================================================================================================= class _DummyIndiv(torch.utils.data.Dataset): def __init__(self, n, sz, embd_dim=1024): self.n, self.sz, self.embd_dim = n, sz, embd_dim def __len__(self): return self.n def __getitem__(self, i): return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32) class _DummyPair(torch.utils.data.Dataset): def __init__(self, n, sz, embd_dim=1024): self.n, self.sz, self.embd_dim = n, sz, embd_dim def __len__(self): return self.n def __getitem__(self, i): return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32), np.random.randn(self.embd_dim).astype(np.float32)) def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01): if use_distributed: ddp_setup(rank,world_size) if torch.distributed.is_initialized() and rank == 0: print(f"World size: {torch.distributed.get_world_size()}") print(f"Communication backend: {torch.distributed.get_backend()}") print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}") if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'): props = torch.xpu.get_device_properties(0) print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB") # gpu_id = global rank (for save/print guards); rank = local device index if "RANK" in os.environ: gpu_id = int(os.environ["RANK"]) rank = int(os.environ["LOCAL_RANK"]) else: gpu_id = rank # Load the YAML file into a dictionary with open(args.config, 'r') as file: hyp_parameters = yaml.safe_load(file) if args.batchsize > 0: hyp_parameters['batchsize'] = args.batchsize if gpu_id == 0: print(hyp_parameters) # epoch_per_save=10 epoch_per_save=hyp_parameters['epoch_per_save'] data_name=hyp_parameters['data_name'] net_name = hyp_parameters['net_name'] Net=get_net(net_name) suffix_pth=f'_{data_name}_{net_name}.pth' model_save_path = os.path.join('Models',f'{data_name}_{net_name}/') model_dir=model_save_path transformer=utils.get_transformer(img_sz=hyp_parameters["ndims"]*[hyp_parameters['img_size']]) # Data_Loader=get_dataloader(data_name=hyp_parameters['data_name'], mode='train') # tsfm = torchvision.transforms.Compose([ # torchvision.transforms.ToTensor(), # ]) # dataset = Data_Loader(target_res = [hyp_parameters["img_size"]]*hyp_parameters["ndims"], transforms=None, noise_scale=hyp_parameters['noise_scale']) # train_loader = DataLoader( # dataset, # batch_size=hyp_parameters['batchsize'], # # shuffle=False, # shuffle=True, # drop_last=True, # ) if args.dummy_samples > 0: dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size']) datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size']) else: # dataset = OminiDataset_v1(transform=None) dataset = OMDataset_indiv(transform=None) # datasetp = OminiDataset_paired(transform=None) datasetp = OMDataset_pair(transform=None) if use_distributed: sampler = DistributedSampler(dataset, shuffle=True) sampler_p = DistributedSampler(datasetp, shuffle=True) else: sampler = None sampler_p = None train_loader = DataLoader( dataset, batch_size=hyp_parameters['batchsize'], shuffle=(sampler is None), drop_last=True, sampler=sampler, ) train_loader_p = DataLoader( datasetp, batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO), shuffle=(sampler_p is None), drop_last=True, sampler=sampler_p, ) network = Net( n_steps=hyp_parameters["timesteps"], ndims=hyp_parameters["ndims"], num_input_chn = hyp_parameters["num_input_chn"], res = hyp_parameters['img_size'] ) # Enable gradient checkpointing on XPU to reduce peak activation memory. # XPU autograd leaks ~1.0 GiB/step; lower peak buys more steps before OOM. if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'): network.use_checkpoint = True if gpu_id == 0: print(" [init] Gradient checkpointing enabled for XPU", flush=True) Deformddpm = DeformDDPM( network=network, n_steps=hyp_parameters["timesteps"], image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"], device=hyp_parameters["device"], batch_size=hyp_parameters["batchsize"], img_pad_mode=hyp_parameters["img_pad_mode"], v_scale=hyp_parameters["v_scale"], ) ddf_stn = STN( img_sz=hyp_parameters["img_size"], ndims=hyp_parameters["ndims"], # padding_mode="zeros", padding_mode=hyp_parameters["padding_mode"], device=hyp_parameters["device"], ) if use_distributed: device = f"{DEVICE_TYPE}:{rank}" # NO pre-allocation. CCL/oneDNN accumulate ~1.4 GiB/step of device memory outside # PyTorch's caching allocator. Pre-allocating steals from that budget: # 92% pre-alloc → crash at step 3, 78% → step 10, none (70% cap) → step 14. # Instead, use empty_cache() between training phases to release unused cached memory # back to the device for CCL/oneDNN. if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'): total_mem = torch.xpu.get_device_properties(rank).total_memory print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True) Deformddpm.to(device) Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True) ddf_stn.to(device) else: Deformddpm.to(hyp_parameters["device"]) ddf_stn.to(hyp_parameters["device"]) # ddf_stn = DDP(ddf_stn, device_ids=[rank]) # mse = nn.MSELoss() # loss_reg = losses.Grad(penalty=['l1', 'negdetj'], ndims=hyp_parameters["ndims"]) # loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"]) loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3) loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3) loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"]) # loss_ang = losses.MRSE(img_sz=hyp_parameters["img_size"]) loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"]) loss_imgsim = losses.MSLNCC() loss_imgmse = losses.LMSE() optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]) # hyp_parameters["lr"]=0.00000001 # optimizer_regist = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01) # optimizer_regist = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"]*0.01, momentum=0.98) # optimizer = SGD(Deformddpm.parameters(), lr=hyp_parameters["lr"], momentum=0.9) # # LR scheduler ----- YHM # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, hyp_parameters["lr"], hyp_parameters["lr"]*10, step_size_up=500, step_size_down=500, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=True, base_momentum=0.8, max_momentum=0.9, last_epoch=-1) # Deformddpm.network.load_state_dict(torch.load('/home/data/jzheng/Adaptive_Motion_Generator-master/models/1000.pth')) # check for existing models if not os.path.exists(model_dir): os.makedirs(model_dir, exist_ok=True) # Check for checkpoints: first check tmp/ for mid-epoch, then main dir for epoch-level tmp_dir = os.path.join(model_dir, "tmp") tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth"))) model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth"))) initial_step = 0 # Epoch stats and RNG states to restore when resuming from mid-epoch checkpoint _resume_epoch_stats = None _resume_rng = None if tmp_files and not args.eval_only and args.max_steps_before_restart > 0: # Mid-epoch checkpoint: only use when proactive restart is enabled latest = tmp_files[-1] if gpu_id == 0: print(f" [resume] Found mid-epoch checkpoint: {latest}") initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed) basename = os.path.basename(latest) initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0]) _ckpt = torch.load(latest, map_location='cpu', weights_only=False) _resume_epoch_stats = _ckpt.get('epoch_stats', None) del _ckpt if gpu_id == 0: print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}" f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True) elif model_files: if gpu_id == 0: print(model_files) latest = model_files[-1] initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed) else: initial_epoch = 0 if gpu_id == 0: print('len_train_data: ',len(dataset)) # Proactive restart: track steps since process start to exit before OOM. max_steps_restart = args.max_steps_before_restart steps_since_start = 0 # Training loop for epoch in range(initial_epoch,hyp_parameters["epoch"]): if use_distributed and sampler is not None: sampler.set_epoch(epoch) sampler_p.set_epoch(epoch) epoch_loss_tot = 0.0 epoch_loss_gen_d = 0.0 epoch_loss_gen_a = 0.0 epoch_loss_reg = 0.0 epoch_loss_regist = 0.0 epoch_loss_imgsim = 0.0 epoch_loss_imgmse = 0.0 epoch_loss_ddfreg = 0.0 epoch_loss_contrastive = 0.0 total_contra = 0 total_reg_restored = None total_contra_restored = None # Restore epoch accumulators from mid-epoch checkpoint (only for the resumed epoch) if _resume_epoch_stats is not None and epoch == initial_epoch: epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0) epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0) epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0) epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0) epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0) epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0) epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0) epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0) epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0) total_reg_restored = _resume_epoch_stats.get('total_reg', None) total_contra_restored = _resume_epoch_stats.get('total_contra', None) loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0) # RNG states are restored INSIDE the skip loop (at the last skipped step) # to avoid DataLoader __getitem__ calls corrupting the restored state. _resume_rng = {k: _resume_epoch_stats[k] for k in ('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda') if k in _resume_epoch_stats} if gpu_id == 0: print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True) _resume_epoch_stats = None # Only restore once else: loss_nan_step = 0 # only reset when NOT resuming mid-epoch # Set model inside to train model Deformddpm.train() total = min(len(train_loader), len(train_loader_p)) total_reg = total // REGISTRATION_STEP_RATIO # Restore total_reg and total_contra from checkpoint if available (mid-epoch resume) if total_reg_restored is not None: total_reg = total_reg_restored total_reg_restored = None if total_contra_restored is not None: total_contra = total_contra_restored total_contra_restored = None # for step, batch in tqdm(enumerate(train_loader)): # for step, batch in tqdm(enumerate(train_loader)): # for step, batch in enumerate(train_loader_omni): for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total): # Skip steps already completed (mid-epoch resume). # Checkpoint at step N is saved AFTER step N's training completes, # so step N itself must also be skipped (use <=, not <). if epoch == initial_epoch and initial_step > 0 and step <= initial_step: # Restore RNG at the last skipped step, AFTER DataLoader __getitem__ # has consumed RNG for all skipped batches. This way the first # non-skipped step starts with exactly the saved RNG state. if step == initial_step and _resume_rng is not None: # Restore rank 0's RNG as base state, then re-seed per-rank # so each rank has independent RNG (matching continuous run's # divergent-per-rank behavior). Without this, all ranks would # share rank 0's RNG → correlated augmentation/dropout decisions. if 'rng_torch' in _resume_rng: torch.set_rng_state(_resume_rng['rng_torch']) if 'rng_numpy' in _resume_rng: np.random.set_state(_resume_rng['rng_numpy']) if 'rng_python' in _resume_rng: random.setstate(_resume_rng['rng_python']) if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu': torch.xpu.set_rng_state(_resume_rng['rng_xpu']) elif 'rng_cuda' in _resume_rng and torch.cuda.is_available(): torch.cuda.set_rng_state(_resume_rng['rng_cuda']) # Per-rank re-seed: checkpoint only has rank 0's RNG state. # Advance each rank's RNG by a deterministic offset so they # diverge (as they would in a continuous run). if gpu_id > 0: rank_seed = gpu_id * 100003 + initial_step * 31 torch.manual_seed(torch.initial_seed() + rank_seed) np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31)) random.seed(random.getrandbits(32) + rank_seed) if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'): torch.xpu.manual_seed(torch.initial_seed() + rank_seed) elif torch.cuda.is_available(): torch.cuda.manual_seed(torch.initial_seed() + rank_seed) _resume_rng = None if gpu_id == 0: print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True) continue # Free registration tensors from previous step x1 = y1 = ddf_comp = img_rec = img_diff = None ddf_rand = y1_proc = msk_tgt = img_save = None loss_regist = loss_sim = loss_mse = loss_ddf1 = None # Memory diagnostic (one per node via local rank 0) — only warn when abnormal # Normal at step start: ~16 GiB reserved, ~48 GiB free (of 64 GiB total) if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'): torch.xpu.reset_peak_memory_stats(rank) free_mem, total_mem_dev = torch.xpu.mem_get_info(rank) used_gib = (total_mem_dev - free_mem) / 1024**3 if used_gib > 24: # Normal is ~16 GiB at step start; warn if accumulating alloc = torch.xpu.memory_allocated() / 1024**3 reserved = torch.xpu.memory_reserved() / 1024**3 free_gib = free_mem / 1024**3 print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: " f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), " f"{free_gib:.1f} GiB free", flush=True) # ========================================================================== # diffusion train on single image # x0 = batch # for omni dataset [x0,embd] = batch # for om dataset x0 = x0.to(hyp_parameters["device"]).type(torch.float32) # print('embd:', embd.shape) embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32) if np.random.uniform(0,1) n x0 = x0.to(hyp_parameters["device"]) blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"]) # random deformation + rotation if hyp_parameters["ndims"]>2: if np.random.uniform(0,1)0: if np.random.uniform(0,1)> JZ: print nan in x0 if torch.isnan(x0).any(): print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.") # >> JZ: print loss of ddf if loss_ddf>0.001: print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.") # yu: check if loss_tot==nan or inf # Synchronize NaN skip across all DDP ranks to avoid collective desync # Use broadcast from rank 0 instead of all_reduce to avoid CCL hang on single-node XPU is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot) if use_distributed: nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}") dist.broadcast(nan_flag, src=0) is_nan = nan_flag.item() > 0 if is_nan: if gpu_id == 0: print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.") loss_nan_step += 1 continue if loss_nan_step > 5: print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.") raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.") # ========================================================================== # Diffusion backward (no gradient clipping — diffusion dominates training) if not args.eval_only: optimizer.zero_grad() loss_tot.backward() optimizer.step() epoch_loss_tot += loss_tot.item() / total epoch_loss_gen_d += loss_gen_d.item() / total epoch_loss_gen_a += loss_gen_a.item() / total epoch_loss_reg += loss_ddf.item() / total # Print running average every 20 steps in eval-only mode if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0: n = step + 1 print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} " f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True) # Free diffusion intermediates and aggressively release all memory to device. # XPU runtime leaks ~1.3 GiB/step outside the caching allocator. # gc.collect() + synchronize() + empty_cache() attempts to reclaim deferred/lazy allocations. loss_gen_a_val = loss_gen_a.item() del pre_dvf_I, dvf_I, trm_pred, loss_tot, loss_gen_a, loss_gen_d, loss_ddf gc.collect() if DEVICE_TYPE == 'xpu': torch.xpu.synchronize() _empty_cache(DEVICE_TYPE) # Sync loss_gen_a across DDP ranks for contrastive and registration gating if use_distributed: loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}") dist.broadcast(loss_gen_a_sync, src=0) loss_gen_a_gate = loss_gen_a_sync.item() else: loss_gen_a_gate = loss_gen_a_val # ========================================================================== # Contrastive train on single image (text-image alignment) # Separate backward with gradient clipping to prevent destabilizing diffusion. loss_contra_val = None if step % CONTRASTIVE_STEP_RATIO == 0: n_contra = x0.size()[0] t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"]) # Route through DDP wrapper and return img_embd directly so DDP # traces the correct subgraph (encoder + mid + attn + img2txt). img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None) # [B, 1024] loss_contra = LOSS_WEIGHT_CONTRASTIVE * F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1).mean()-0.25) if not args.eval_only: optimizer.zero_grad() loss_contra.backward() torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=1e-3) optimizer.step() loss_contra_val = loss_contra.item() epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO # Free remaining intermediates and aggressively release memory before registration if cond_img is not None: del cond_img if blind_mask is not None: del blind_mask gc.collect() if DEVICE_TYPE == 'xpu': torch.xpu.synchronize() _empty_cache(DEVICE_TYPE) # ========================================================================== # registration train on paired images # loss_gen_a_gate already synced across DDP ranks above do_regist = step % REGISTRATION_STEP_RATIO == 0 and loss_gen_a_gate < -0.8 if do_regist: [x1, y1, _, embd_y] = batch_p if np.random.uniform(0,1) n [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3]) if hyp_parameters['noise_scale']>0: [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']]) random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1) random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1) x1 = x1 * random_scale + random_shift y1 = y1 * random_scale + random_shift scale_regist = np.random.uniform(0.0,0.5) select_timestep = np.random.randint(12, 32) # select a random number of timesteps to sample, between 8 and 16 T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True) T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist] proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none']) ddpm_inner = Deformddpm.module if use_distributed else Deformddpm y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type) msk_tgt = msk_tgt+MSK_EPS [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y) # forward diffusion process loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim)) # calculate loss for the registration process loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0)) # calculate loss for the registration process loss_ddf1 = loss_reg1(ddf_comp, img=y1) # calculate loss for the registration process loss_regist = 0 loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1 # >> JZ: print nan in x0 if torch.isnan(x0).any(): print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.") # >> JZ: print loss of ddf if loss_ddf1>0.002: print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.") loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist if not args.eval_only: optimizer.zero_grad() loss_regist.backward() # torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.1) torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02) optimizer.step() epoch_loss_regist += loss_regist.item() epoch_loss_imgsim += loss_sim.item() epoch_loss_imgmse += loss_mse.item() epoch_loss_ddfreg += loss_ddf1.item() else: loss_sim = torch.tensor(0.0) loss_mse = torch.tensor(0.0) loss_ddf1 = torch.tensor(0.0) loss_regist = torch.tensor(0.0) if step % REGISTRATION_STEP_RATIO==0: total_reg = total_reg-1 # Mid-epoch checkpoint and proactive restart (only when --max-steps-before-restart > 0) if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save: _epoch_stats = { 'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d, 'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg, 'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim, 'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg, 'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra, 'loss_nan_step': loss_nan_step, 'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(), 'rng_python': random.getstate(), **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})), } tmp_dir = os.path.join(model_save_path, "tmp") os.makedirs(tmp_dir, exist_ok=True) for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")): os.remove(old_f) mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}") state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict() torch.save({ 'model_state_dict': state, 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'step': step, 'epoch_stats': _epoch_stats, }, mid_save) print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True) # Proactive restart: exit cleanly after N steps to reset XPU memory leak. # The bash wrapper will re-launch srun within the same SLURM allocation. steps_since_start += 1 if max_steps_restart > 0 and steps_since_start >= max_steps_restart: # Save checkpoint at current position (if not just saved above) if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save: _epoch_stats = { 'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d, 'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg, 'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim, 'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg, 'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra, 'loss_nan_step': loss_nan_step, 'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(), 'rng_python': random.getstate(), **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})), } tmp_dir = os.path.join(model_save_path, "tmp") os.makedirs(tmp_dir, exist_ok=True) for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")): os.remove(old_f) mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}") state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict() torch.save({ 'model_state_dict': state, 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch, 'step': step, 'epoch_stats': _epoch_stats, }, mid_save) print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True) if gpu_id == 0: print(f" [restart] Proactive restart after {steps_since_start} steps " f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True) # Clean shutdown _empty_cache(DEVICE_TYPE) gc.collect() if use_distributed and dist.is_initialized(): dist.barrier() dist.destroy_process_group() sys.exit(EXIT_CODE_RESTART) if gpu_id == 0: print('==================') print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)') print(f' loss_contrastive: {epoch_loss_contrastive}') total_reg_safe = max(total_reg, 1) print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)') print('==================') if 0 == epoch % epoch_per_save and not args.no_save: save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth os.makedirs(os.path.dirname(model_save_path), exist_ok=True) # break # FOR TESTING if not use_distributed: print(f"saved in {save_dir}") # torch.save(Deformddpm.state_dict(), save_dir) torch.save({ 'model_state_dict': Deformddpm.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, save_dir) elif gpu_id == 0: print(f"saved in {save_dir}") # torch.save(Deformddpm.module.state_dict(), save_dir) torch.save({ 'model_state_dict': Deformddpm.module.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch }, save_dir) # Clean up tmp/ mid-epoch checkpoints after completed epoch if gpu_id == 0 and not args.no_save: tmp_dir = os.path.join(model_dir, "tmp") tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth")) if tmp_pths: for f in tmp_pths: os.remove(f) print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True) # Reset initial_step after first epoch completes (no more skipping) initial_step = 0 # XPU CCL workaround: restart after each epoch to avoid CCL hang on 2nd epoch. # CCL's Level Zero IPC handles accumulate and cause deadlock after ~200+ collectives. # A fresh process resets the L0 context. The bash loop catches exit code 42 and restarts. if DEVICE_TYPE == 'xpu' and use_distributed: if gpu_id == 0: print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True) _empty_cache(DEVICE_TYPE) gc.collect() if dist.is_initialized(): dist.barrier() dist.destroy_process_group() sys.exit(EXIT_CODE_RESTART) # Resource cleanup at the end of training _empty_cache(DEVICE_TYPE) gc.collect() if use_distributed and dist.is_initialized(): dist.destroy_process_group() def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False): # All ranks load checkpoint so optimizer state is consistent across DDP processes. # (Optimizer state includes per-parameter Adam momentum/variance which are NOT # broadcast — only model weights are broadcast. Without this, non-rank-0 processes # would have fresh Adam state after restart.) gc.collect() _empty_cache(DEVICE_TYPE) if gpu_id == 0: utils.print_memory_usage("Before Loading Model") checkpoint = torch.load(model_file, map_location='cpu', weights_only=False) if use_distributed: Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict) else: Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict) # Restore optimizer state when available (needed for mid-epoch resume). # Selective loading: load states for parameters with matching shapes, skip mismatched ones # (e.g., UpsampleConv replaced ConvTranspose3d — different kernel shapes). # After one epoch, the saved checkpoint will have correct state for ALL parameters. if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer: saved_opt = checkpoint['optimizer_state_dict'] saved_state = saved_opt.get('state', {}) param_list = [p for group in optimizer.param_groups for p in group['params']] # Check if all shapes match (fast path: full load) all_match = True skipped = 0 for idx, s in saved_state.items(): if int(idx) < len(param_list): p = param_list[int(idx)] for k, v in s.items(): if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape: all_match = False break if not all_match: break if all_match: optimizer.load_state_dict(saved_opt) else: # Selective load: restore param_groups settings (lr, betas, etc.) for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups): for k, v in saved_g.items(): if k != 'params': group[k] = v # Restore per-parameter state only where shapes match for idx, s in saved_state.items(): idx_int = int(idx) if idx_int < len(param_list): p = param_list[idx_int] shapes_ok = all( v.shape == p.shape for k, v in s.items() if isinstance(v, torch.Tensor) and v.dim() > 0 ) if shapes_ok: # Cast state tensors to match parameter dtype/device new_state = {} for k, v in s.items(): if isinstance(v, torch.Tensor): new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v else: new_state[k] = v optimizer.state[p] = new_state else: skipped += 1 if gpu_id == 0: loaded = len(saved_state) - skipped print(f" [checkpoint] Selective optimizer load: {loaded} params restored, " f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True) elif args.reset_optimizer and gpu_id == 0: print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True) del checkpoint if gpu_id == 0: utils.print_memory_usage("After Loading Checkpoint on GPU") if use_distributed: # Broadcast model weights from rank 0 to ensure exact consistency dist.barrier() for param in Deformddpm.parameters(): dist.broadcast(param.data, src=0) # get the epoch number from the filename basename = os.path.basename(model_file) epoch_from_file = int(basename[:6]) if '_step' in basename: # Mid-epoch checkpoint: resume at same epoch (don't +1) initial_epoch = epoch_from_file else: # End-of-epoch checkpoint: start next epoch initial_epoch = epoch_from_file + 1 return initial_epoch, Deformddpm, optimizer if __name__ == "__main__": if "LOCAL_RANK" in os.environ: # Multi-node: launched by torchrun / srun use_distributed = True local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}") try: main_train(local_rank, world_size) except Exception as e: import traceback print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True) traceback.print_exc() raise elif use_distributed: # Single-node multi-GPU: use mp.spawn world_size = _device_count(DEVICE_TYPE) print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}") mp.spawn(main_train,args = (world_size,),nprocs = world_size) else: main_train(0,1)