| import os, sys, contextlib
|
|
|
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| sys.path.append(ROOT_DIR)
|
|
|
| import gc
|
| import torch
|
| import torchvision
|
| from torch import nn
|
| from torchvision.utils import save_image
|
| from torch.utils.data import DataLoader
|
|
|
| from torch.optim import Adam, SGD
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
|
|
| import torch.nn.functional as F
|
| import Diffusion.losses as losses
|
| import random
|
| import glob
|
| import numpy as np
|
| import utils
|
| from tqdm import tqdm
|
|
|
|
|
| from Dataloader.dataLoader import *
|
|
|
| from Dataloader.dataloader_utils import thresh_img
|
| import yaml
|
| import argparse
|
|
|
|
|
| try:
|
| import intel_extension_for_pytorch as ipex
|
| except ImportError:
|
| ipex = None
|
| try:
|
| import oneccl_bindings_for_pytorch
|
| except (ImportError, Exception) as e:
|
| print(f"WARNING: Failed to import oneccl_bindings_for_pytorch: {e}")
|
|
|
|
|
| import torch.multiprocessing as mp
|
| from torch.utils.data.distributed import DistributedSampler
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
| import torch.distributed as dist
|
|
|
|
|
| def _device_available(device_type):
|
| if device_type == 'xpu':
|
| return hasattr(torch, 'xpu') and torch.xpu.is_available()
|
| return torch.cuda.is_available()
|
|
|
| def _device_count(device_type):
|
| if device_type == 'xpu':
|
| return torch.xpu.device_count() if hasattr(torch, 'xpu') else 0
|
| return torch.cuda.device_count()
|
|
|
| def _set_device(rank, device_type):
|
| if device_type == 'xpu':
|
| torch.xpu.set_device(rank)
|
| else:
|
| torch.cuda.set_device(rank)
|
|
|
| def _empty_cache(device_type):
|
| if device_type == 'xpu' and hasattr(torch, 'xpu'):
|
| torch.xpu.empty_cache()
|
| elif torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
| def ddp_setup(rank, world_size):
|
| """
|
| Args:
|
| rank: Unique identifier of each process (local_rank when launched by torchrun)
|
| world_size: Total number of processes
|
| """
|
| backend = "ccl" if DEVICE_TYPE == "xpu" else "nccl"
|
| if "LOCAL_RANK" in os.environ:
|
|
|
| dist.init_process_group(backend=backend)
|
| _set_device(int(os.environ["LOCAL_RANK"]), DEVICE_TYPE)
|
| else:
|
|
|
| os.environ["MASTER_ADDR"] = "localhost"
|
| os.environ["MASTER_PORT"] = "12355"
|
| dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
| _set_device(rank, DEVICE_TYPE)
|
|
|
| EPS = 1e-5
|
| MSK_EPS = 0.01
|
| TEXT_EMBED_PROB = 0.5
|
| AUG_RESAMPLE_PROB = 0.5
|
| LOSS_WEIGHTS_DIFF = [4.0, 2.0, 8.0]
|
|
|
| LOSS_WEIGHTS_REGIST = [1.0, 0.01, 1e2]
|
| DIFF_REG_BATCH_RATIO = 2
|
|
|
| LOSS_WEIGHT_CONTRASTIVE = 1e-1
|
| REGISTRATION_STEP_RATIO = 1
|
| CONTRASTIVE_STEP_RATIO = 1
|
| ACCEPT_THRESH_CONTRASTIVE = 0.1
|
| ACCEPT_THRESH_ANGLE = -0.8
|
| MID_EPOCH_SAVE_STEPS = 1e4
|
|
|
|
|
|
|
|
|
| EXIT_CODE_RESTART = 42
|
|
|
|
|
|
|
| parser = argparse.ArgumentParser()
|
|
|
|
|
| parser.add_argument(
|
| "--config",
|
| "-C",
|
| help="Path for the config file",
|
| type=str,
|
|
|
|
|
| default="Config/config_all.yaml",
|
| required=False,
|
| )
|
| parser.add_argument("--dummy-samples", type=int, default=0, help="Use dummy random data for testing (0=use real data)")
|
| parser.add_argument("--batchsize", type=int, default=0, help="Override batch size from config (0=use config value)")
|
| parser.add_argument("--max-steps-before-restart", type=int, default=0,
|
| help="Proactive restart: exit after N training steps to reset XPU memory leak. "
|
| "0=disabled (rely on OOM crash + auto-resubmit). "
|
| "Recommended: 20 for XPU (survives ~26 steps max).")
|
| parser.add_argument("--no-save", action="store_true", default=False,
|
| help="Disable all checkpoint saving (for diagnostic/validation runs)")
|
| parser.add_argument("--reset-optimizer", action="store_true",
|
| help="Skip optimizer state loading from checkpoint (use when architecture changed)")
|
| parser.add_argument("--eval-only", action="store_true",
|
| help="Forward pass only: compute and print losses without backward/optimizer (no memory leak)")
|
| args = parser.parse_args()
|
|
|
|
|
| with open(args.config, 'r') as _f:
|
| _cfg = yaml.safe_load(_f)
|
| DEVICE_TYPE = _cfg.get('device', 'cuda')
|
|
|
|
|
| use_distributed = _device_available(DEVICE_TYPE) and _device_count(DEVICE_TYPE) > 1
|
|
|
|
|
|
|
|
|
| class _DummyIndiv(torch.utils.data.Dataset):
|
| def __init__(self, n, sz, embd_dim=1024):
|
| self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| def __len__(self): return self.n
|
| def __getitem__(self, i):
|
| return np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64), np.random.randn(self.embd_dim).astype(np.float32)
|
|
|
| class _DummyPair(torch.utils.data.Dataset):
|
| def __init__(self, n, sz, embd_dim=1024):
|
| self.n, self.sz, self.embd_dim = n, sz, embd_dim
|
| def __len__(self): return self.n
|
| def __getitem__(self, i):
|
| return (np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| np.random.rand(1, self.sz, self.sz, self.sz).astype(np.float64),
|
| np.random.randn(self.embd_dim).astype(np.float32),
|
| np.random.randn(self.embd_dim).astype(np.float32))
|
|
|
|
|
| def main_train(rank=0,world_size=1,train_mode_ratio=1,thresh_imgsim=0.01):
|
| if use_distributed:
|
| ddp_setup(rank,world_size)
|
|
|
| if torch.distributed.is_initialized() and rank == 0:
|
| print(f"World size: {torch.distributed.get_world_size()}")
|
| print(f"Communication backend: {torch.distributed.get_backend()}")
|
| print(f"PYTORCH_ALLOC_CONF: {os.environ.get('PYTORCH_ALLOC_CONF', 'not set')}")
|
| if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| props = torch.xpu.get_device_properties(0)
|
| print(f"XPU device: {props.name}, total memory: {props.total_memory / 1024**3:.2f} GiB")
|
|
|
| if "RANK" in os.environ:
|
| gpu_id = int(os.environ["RANK"])
|
| rank = int(os.environ["LOCAL_RANK"])
|
| else:
|
| gpu_id = rank
|
|
|
|
|
| with open(args.config, 'r') as file:
|
| hyp_parameters = yaml.safe_load(file)
|
| if args.batchsize > 0:
|
| hyp_parameters['batchsize'] = args.batchsize
|
| if gpu_id == 0:
|
| print(hyp_parameters)
|
|
|
|
|
| epoch_per_save=hyp_parameters['epoch_per_save']
|
|
|
| data_name=hyp_parameters['data_name']
|
| net_name = hyp_parameters['net_name']
|
|
|
| Net=get_net(net_name)
|
|
|
| suffix_pth=f'_{data_name}_{net_name}.pth'
|
| model_save_path = os.path.join('Models',f'{data_name}_{net_name}/')
|
| model_dir=model_save_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if args.dummy_samples > 0:
|
| dataset = _DummyIndiv(args.dummy_samples, hyp_parameters['img_size'])
|
| datasetp = _DummyPair(args.dummy_samples, hyp_parameters['img_size'])
|
| else:
|
|
|
| dataset = OMDataset_indiv(transform=None)
|
|
|
| datasetp = OMDataset_pair(transform=None)
|
|
|
| if use_distributed:
|
| sampler = DistributedSampler(dataset, shuffle=True)
|
| sampler_p = DistributedSampler(datasetp, shuffle=True)
|
| else:
|
| sampler = None
|
| sampler_p = None
|
|
|
| train_loader = DataLoader(
|
| dataset,
|
| batch_size=hyp_parameters['batchsize'],
|
| shuffle=(sampler is None),
|
| drop_last=True,
|
| sampler=sampler,
|
| )
|
| train_loader_p = DataLoader(
|
| datasetp,
|
| batch_size=max(1, hyp_parameters['batchsize']//DIFF_REG_BATCH_RATIO),
|
| shuffle=(sampler_p is None),
|
| drop_last=True,
|
| sampler=sampler_p,
|
| )
|
|
|
|
|
|
|
| network = Net(
|
| n_steps=hyp_parameters["timesteps"],
|
| ndims=hyp_parameters["ndims"],
|
| num_input_chn = hyp_parameters["num_input_chn"],
|
| res = hyp_parameters['img_size']
|
| )
|
|
|
|
|
| if DEVICE_TYPE == 'xpu' and hasattr(network, 'use_checkpoint'):
|
| network.use_checkpoint = True
|
| if gpu_id == 0:
|
| print(" [init] Gradient checkpointing enabled for XPU", flush=True)
|
|
|
| Deformddpm = DeformDDPM(
|
| network=network,
|
| n_steps=hyp_parameters["timesteps"],
|
| image_chw=[1] + [hyp_parameters["img_size"]]*hyp_parameters["ndims"],
|
| device=hyp_parameters["device"],
|
| batch_size=hyp_parameters["batchsize"],
|
| img_pad_mode=hyp_parameters["img_pad_mode"],
|
| v_scale=hyp_parameters["v_scale"],
|
| )
|
|
|
|
|
| ddf_stn = STN(
|
| img_sz=hyp_parameters["img_size"],
|
| ndims=hyp_parameters["ndims"],
|
|
|
| padding_mode=hyp_parameters["padding_mode"],
|
| device=hyp_parameters["device"],
|
| )
|
|
|
|
|
| if use_distributed:
|
| device = f"{DEVICE_TYPE}:{rank}"
|
|
|
|
|
|
|
|
|
|
|
| if gpu_id == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| total_mem = torch.xpu.get_device_properties(rank).total_memory
|
| print(f" [init] XPU device memory: {total_mem/1024**3:.1f} GiB, no pre-allocation (relying on empty_cache between phases)", flush=True)
|
| Deformddpm.to(device)
|
| Deformddpm = DDP(Deformddpm, device_ids=[rank], find_unused_parameters=True)
|
| ddf_stn.to(device)
|
| else:
|
| Deformddpm.to(hyp_parameters["device"])
|
| ddf_stn.to(hyp_parameters["device"])
|
|
|
|
|
|
|
|
|
|
|
|
|
| loss_reg = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.2,outrange_weight=1e3)
|
| loss_reg1 = losses.Grad(penalty=['l1', 'negdetj', 'range'], ndims=hyp_parameters["ndims"],outrange_thresh=0.6,outrange_weight=1e3)
|
|
|
| loss_dist = losses.MRSE(img_sz=hyp_parameters["img_size"])
|
|
|
| loss_ang = losses.NCC(img_sz=hyp_parameters["img_size"])
|
| loss_imgsim = losses.MSLNCC()
|
| loss_imgmse = losses.LMSE()
|
|
|
| optimizer = Adam(Deformddpm.parameters(), lr=hyp_parameters["lr"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| if not os.path.exists(model_dir):
|
| os.makedirs(model_dir, exist_ok=True)
|
|
|
| tmp_dir = os.path.join(model_dir, "tmp")
|
| tmp_files = sorted(glob.glob(os.path.join(tmp_dir, "*.pth")))
|
| model_files = sorted(glob.glob(os.path.join(model_dir, "*.pth")))
|
| initial_step = 0
|
|
|
|
|
| _resume_epoch_stats = None
|
| _resume_rng = None
|
|
|
| if tmp_files and not args.eval_only and args.max_steps_before_restart > 0:
|
|
|
| latest = tmp_files[-1]
|
| if gpu_id == 0:
|
| print(f" [resume] Found mid-epoch checkpoint: {latest}")
|
| initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| basename = os.path.basename(latest)
|
| initial_step = int(basename.split('_step')[1].split('_')[0].split('.')[0])
|
| _ckpt = torch.load(latest, map_location='cpu', weights_only=False)
|
| _resume_epoch_stats = _ckpt.get('epoch_stats', None)
|
| del _ckpt
|
| if gpu_id == 0:
|
| print(f" [resume] Resuming epoch {initial_epoch} from step {initial_step}"
|
| f"{' (with epoch_stats)' if _resume_epoch_stats else ''}", flush=True)
|
| elif model_files:
|
| if gpu_id == 0:
|
| print(model_files)
|
| latest = model_files[-1]
|
| initial_epoch, Deformddpm, optimizer = ddp_load_dict(gpu_id, Deformddpm, optimizer, latest, use_distributed=use_distributed)
|
| else:
|
| initial_epoch = 0
|
|
|
| if gpu_id == 0:
|
| print('len_train_data: ',len(dataset))
|
|
|
|
|
| max_steps_restart = args.max_steps_before_restart
|
| steps_since_start = 0
|
| loss_contra_gate = 0.0
|
|
|
|
|
| for epoch in range(initial_epoch,hyp_parameters["epoch"]):
|
| if use_distributed and sampler is not None:
|
| sampler.set_epoch(epoch)
|
| sampler_p.set_epoch(epoch)
|
|
|
| epoch_loss_tot = 0.0
|
| epoch_loss_gen_d = 0.0
|
| epoch_loss_gen_a = 0.0
|
| epoch_loss_reg = 0.0
|
| epoch_loss_regist = 0.0
|
| epoch_loss_imgsim = 0.0
|
| epoch_loss_imgmse = 0.0
|
| epoch_loss_ddfreg = 0.0
|
| epoch_loss_contrastive = 0.0
|
| total_contra = 0
|
| total_reg_restored = None
|
| total_contra_restored = None
|
|
|
|
|
| if _resume_epoch_stats is not None and epoch == initial_epoch:
|
| epoch_loss_tot = _resume_epoch_stats.get('epoch_loss_tot', 0.0)
|
| epoch_loss_gen_d = _resume_epoch_stats.get('epoch_loss_gen_d', 0.0)
|
| epoch_loss_gen_a = _resume_epoch_stats.get('epoch_loss_gen_a', 0.0)
|
| epoch_loss_reg = _resume_epoch_stats.get('epoch_loss_reg', 0.0)
|
| epoch_loss_regist = _resume_epoch_stats.get('epoch_loss_regist', 0.0)
|
| epoch_loss_imgsim = _resume_epoch_stats.get('epoch_loss_imgsim', 0.0)
|
| epoch_loss_imgmse = _resume_epoch_stats.get('epoch_loss_imgmse', 0.0)
|
| epoch_loss_ddfreg = _resume_epoch_stats.get('epoch_loss_ddfreg', 0.0)
|
| epoch_loss_contrastive = _resume_epoch_stats.get('epoch_loss_contrastive', 0.0)
|
| total_reg_restored = _resume_epoch_stats.get('total_reg', None)
|
| total_contra_restored = _resume_epoch_stats.get('total_contra', None)
|
| loss_nan_step = _resume_epoch_stats.get('loss_nan_step', 0)
|
|
|
|
|
| _resume_rng = {k: _resume_epoch_stats[k] for k in
|
| ('rng_torch', 'rng_numpy', 'rng_python', 'rng_xpu', 'rng_cuda')
|
| if k in _resume_epoch_stats}
|
| if gpu_id == 0:
|
| print(f" [resume] Restored epoch stats from checkpoint (loss_tot={epoch_loss_tot:.4f})", flush=True)
|
| _resume_epoch_stats = None
|
| else:
|
| loss_nan_step = 0
|
|
|
|
|
| Deformddpm.train()
|
|
|
| total = min(len(train_loader), len(train_loader_p))
|
| total_reg = total // REGISTRATION_STEP_RATIO
|
|
|
| if total_reg_restored is not None:
|
| total_reg = total_reg_restored
|
| total_reg_restored = None
|
| if total_contra_restored is not None:
|
| total_contra = total_contra_restored
|
| total_contra_restored = None
|
|
|
|
|
|
|
| for step, (batch, batch_p) in tqdm(enumerate(zip(train_loader, train_loader_p)), total=total):
|
|
|
|
|
|
|
|
|
| if epoch == initial_epoch and initial_step > 0 and step <= initial_step:
|
|
|
|
|
|
|
| if step == initial_step and _resume_rng is not None:
|
|
|
|
|
|
|
|
|
| if 'rng_torch' in _resume_rng:
|
| torch.set_rng_state(_resume_rng['rng_torch'])
|
| if 'rng_numpy' in _resume_rng:
|
| np.random.set_state(_resume_rng['rng_numpy'])
|
| if 'rng_python' in _resume_rng:
|
| random.setstate(_resume_rng['rng_python'])
|
| if 'rng_xpu' in _resume_rng and DEVICE_TYPE == 'xpu':
|
| torch.xpu.set_rng_state(_resume_rng['rng_xpu'])
|
| elif 'rng_cuda' in _resume_rng and torch.cuda.is_available():
|
| torch.cuda.set_rng_state(_resume_rng['rng_cuda'])
|
|
|
|
|
|
|
| if gpu_id > 0:
|
| rank_seed = gpu_id * 100003 + initial_step * 31
|
| torch.manual_seed(torch.initial_seed() + rank_seed)
|
| np.random.seed((np.random.get_state()[1][0] + rank_seed) % (2**31))
|
| random.seed(random.getrandbits(32) + rank_seed)
|
| if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| torch.xpu.manual_seed(torch.initial_seed() + rank_seed)
|
| elif torch.cuda.is_available():
|
| torch.cuda.manual_seed(torch.initial_seed() + rank_seed)
|
| _resume_rng = None
|
| if gpu_id == 0:
|
| print(f" [resume] RNG states restored at step {step} (per-rank re-seeded)", flush=True)
|
| continue
|
|
|
|
|
| x1 = y1 = ddf_comp = img_rec = img_diff = None
|
| ddf_rand = y1_proc = msk_tgt = img_save = None
|
| loss_regist = loss_sim = loss_mse = loss_ddf1 = None
|
|
|
|
|
|
|
| if rank == 0 and DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu'):
|
| torch.xpu.reset_peak_memory_stats(rank)
|
| free_mem, total_mem_dev = torch.xpu.mem_get_info(rank)
|
| used_gib = (total_mem_dev - free_mem) / 1024**3
|
| if used_gib > 24:
|
| alloc = torch.xpu.memory_allocated() / 1024**3
|
| reserved = torch.xpu.memory_reserved() / 1024**3
|
| free_gib = free_mem / 1024**3
|
| print(f" [mem WARNING] gpu_id={gpu_id} epoch {epoch} step {step}: "
|
| f"{used_gib:.1f} GiB used ({alloc:.1f} alloc / {reserved:.1f} reserved), "
|
| f"{free_gib:.1f} GiB free", flush=True)
|
|
|
|
|
|
|
|
|
|
|
| [x0,embd] = batch
|
| x0 = x0.to(hyp_parameters["device"]).type(torch.float32)
|
|
|
| embd_dev = embd.to(hyp_parameters["device"]).type(torch.float32)
|
| if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| embd_in = embd_dev
|
| else:
|
| embd_in = None
|
|
|
| n = x0.size()[0]
|
| x0 = x0.to(hyp_parameters["device"])
|
|
|
| blind_mask = utils.get_random_deformed_mask(x0.shape[2:],apply_possibility=0.6).to(hyp_parameters["device"])
|
|
|
|
|
| if hyp_parameters["ndims"]>2:
|
| if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| x0 = utils.random_resample(x0, deform_scale=0)
|
|
|
| else:
|
| [x0] = utils.random_permute([x0], select_dims=[-1,-2,-3])
|
|
|
| if hyp_parameters['noise_scale']>0:
|
| if np.random.uniform(0,1)<AUG_RESAMPLE_PROB:
|
| x0 = thresh_img(x0, [0, 2*hyp_parameters['noise_scale']])
|
| x0 = x0 * (np.random.normal(1, hyp_parameters['noise_scale'] * 1)) + np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
|
|
|
|
| t = torch.randint(0, hyp_parameters["timesteps"], (n,)).to(
|
| hyp_parameters["device"]
|
| )
|
|
|
|
|
| proc_type = random.choice(['adding', 'downsample', 'slice', 'slice1', 'none', 'uncon', 'uncon', 'uncon'])
|
|
|
| ddpm = Deformddpm.module if use_distributed else Deformddpm
|
| cond_img, _, cond_ratio = ddpm.proc_cond_img(x0,proc_type=proc_type)
|
|
|
| if loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE:
|
|
|
| pre_dvf_I,dvf_I = Deformddpm(img_org=x0, t=t, cond_imgs=cond_img, mask=blind_mask,proc_type=[],text=embd_in)
|
|
|
| loss_tot=0
|
|
|
| loss_ddf = loss_reg(pre_dvf_I,img=x0)
|
| trm_pred = ddf_stn(pre_dvf_I, dvf_I)
|
| loss_gen_d = loss_dist(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
| loss_gen_a = loss_ang(pred=trm_pred,inv_lab=dvf_I,ddf_stn=None,mask=blind_mask)
|
|
|
| loss_tot += LOSS_WEIGHTS_DIFF[0] * loss_gen_a + LOSS_WEIGHTS_DIFF[1] * loss_gen_d
|
| loss_tot += LOSS_WEIGHTS_DIFF[2] * loss_ddf
|
| loss_tot = torch.sqrt(1.+MSK_EPS-cond_ratio) * loss_tot
|
|
|
|
|
| if torch.isnan(x0).any():
|
| print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
|
|
| if loss_ddf>0.001:
|
| print(f"*** High diffusion DDF loss at epoch {epoch}, step {step}: {loss_ddf.item()}.")
|
|
|
|
|
|
|
| is_nan = torch.isnan(loss_tot) or torch.isinf(loss_tot)
|
| if use_distributed:
|
| nan_flag = torch.tensor([1.0 if is_nan else 0.0], device=f"{DEVICE_TYPE}:{rank}")
|
| dist.broadcast(nan_flag, src=0)
|
| is_nan = nan_flag.item() > 0
|
| if is_nan:
|
| if gpu_id == 0:
|
| print(f"*** Encountered NaN or Inf loss at epoch {epoch}, step {step}. Skipping this batch.")
|
| loss_nan_step += 1
|
| continue
|
| if loss_nan_step > 5:
|
| print(f"*** Too many NaN or Inf losses ({loss_nan_step} times) at epoch {epoch}, step {step}. Stopping training.")
|
| raise ValueError("Too many NaN losses detected in loss_tot. Code terminated.")
|
|
|
|
|
|
|
|
|
| if (not args.eval_only):
|
| optimizer.zero_grad()
|
| loss_tot.backward()
|
| optimizer.step()
|
|
|
| epoch_loss_tot += loss_tot.item() / total
|
| epoch_loss_gen_d += loss_gen_d.item() / total
|
| epoch_loss_gen_a += loss_gen_a.item() / total
|
| epoch_loss_reg += loss_ddf.item() / total
|
|
|
|
|
| if args.eval_only and gpu_id == 0 and (step + 1) % 20 == 0:
|
| n = step + 1
|
| print(f" [eval] step {step}: running_avg ang={epoch_loss_gen_a*total/n:.4f} "
|
| f"dist={epoch_loss_gen_d*total/n:.4f} regul={epoch_loss_reg*total/n:.6f}", flush=True)
|
|
|
|
|
|
|
|
|
| loss_gen_a_val = loss_gen_a.item()
|
|
|
|
|
| gc.collect()
|
| if DEVICE_TYPE == 'xpu':
|
| torch.xpu.synchronize()
|
| _empty_cache(DEVICE_TYPE)
|
|
|
|
|
| if use_distributed:
|
| loss_gen_a_sync = torch.tensor([loss_gen_a_val], device=f"{DEVICE_TYPE}:{rank}")
|
| dist.broadcast(loss_gen_a_sync, src=0)
|
| loss_gen_a_gate = loss_gen_a_sync.item()
|
| else:
|
| loss_gen_a_gate = loss_gen_a_val
|
|
|
| LOSS_WEIGHT_CONTRASTIVE=1e-4
|
| else:
|
| LOSS_WEIGHT_CONTRASTIVE=1e-1
|
| if gpu_id == 0:
|
| print(f" [train] step {step}: Skipping backward (contra_gate={loss_contra_gate:.4f})", flush=True)
|
|
|
|
|
|
|
|
|
|
|
| loss_contra_val = None
|
| if step % CONTRASTIVE_STEP_RATIO == 0:
|
| n_contra = x0.size()[0]
|
| t_contra = torch.randint(0, hyp_parameters["timesteps"], (n_contra,)).to(hyp_parameters["device"])
|
|
|
|
|
| img_embd = Deformddpm(img_org=(x0 * blind_mask).detach(), cond_imgs=cond_img.detach(), T=t_contra, output_embedding=True, text=None)
|
| loss_contra_preweight = F.relu(1 - F.cosine_similarity(img_embd, embd_dev, dim=-1)-0.25).mean()
|
| loss_contra = LOSS_WEIGHT_CONTRASTIVE * loss_contra_preweight
|
|
|
| if not args.eval_only:
|
| optimizer.zero_grad()
|
| loss_contra.backward()
|
| torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=LOSS_WEIGHT_CONTRASTIVE*1)
|
| optimizer.step()
|
| loss_contra_val = loss_contra.item()
|
| epoch_loss_contrastive += loss_contra_val / total * CONTRASTIVE_STEP_RATIO
|
|
|
|
|
|
|
|
|
|
|
|
|
| if cond_img is not None:
|
| del cond_img
|
| if blind_mask is not None:
|
| del blind_mask
|
| gc.collect()
|
| if DEVICE_TYPE == 'xpu':
|
| torch.xpu.synchronize()
|
| _empty_cache(DEVICE_TYPE)
|
|
|
|
|
| if use_distributed:
|
| loss_contra_sync = torch.tensor([loss_contra_preweight], device=f"{DEVICE_TYPE}:{rank}")
|
| dist.broadcast(loss_contra_sync, src=0)
|
| loss_contra_gate = loss_contra_sync.item()
|
| else:
|
| loss_contra_gate = loss_contra_preweight
|
|
|
|
|
|
|
|
|
| do_regist = step % REGISTRATION_STEP_RATIO == 0 and (loss_contra_gate < ACCEPT_THRESH_CONTRASTIVE) and loss_gen_a_gate < ACCEPT_THRESH_ANGLE
|
| if do_regist:
|
| [x1, y1, _, embd_y] = batch_p
|
| if np.random.uniform(0,1)<TEXT_EMBED_PROB:
|
| embd_y = embd_y.to(hyp_parameters["device"]).type(torch.float32)
|
| else:
|
| embd_y = None
|
|
|
| x1 = x1.to(hyp_parameters["device"]).type(torch.float32)
|
| y1 = y1.to(hyp_parameters["device"]).type(torch.float32)
|
| n = x1.size()[0]
|
| [x1, y1] = utils.random_permute([x1, y1], select_dims=[-1,-2,-3])
|
| if hyp_parameters['noise_scale']>0:
|
| [x1, y1] = thresh_img([x1, y1], [0, 2*hyp_parameters['noise_scale']])
|
| random_scale = np.random.normal(1, hyp_parameters['noise_scale'] * 1)
|
| random_shift = np.random.normal(0, hyp_parameters['noise_scale'] * 1)
|
| x1 = x1 * random_scale + random_shift
|
| y1 = y1 * random_scale + random_shift
|
|
|
| scale_regist = np.random.uniform(0.0,0.5)
|
| select_timestep = np.random.randint(12, 32)
|
| T_regist = sorted(random.sample(range(int(hyp_parameters["timesteps"] * scale_regist),hyp_parameters["timesteps"]), select_timestep), reverse=True)
|
|
|
| T_regist = [[t for _ in range(max(1, hyp_parameters["batchsize"]//2))] for t in T_regist]
|
|
|
| proc_type = random.choice(['downsample', 'slice', 'slice1', 'none', 'none'])
|
| ddpm_inner = Deformddpm.module if use_distributed else Deformddpm
|
| y1_proc, msk_tgt, cond_ratio = ddpm_inner.proc_cond_img(y1,proc_type=proc_type)
|
| msk_tgt = msk_tgt+MSK_EPS
|
| [ddf_comp,ddf_rand],[img_rec,img_diff,img_save],_ = Deformddpm(img_org=x1, cond_imgs=y1_proc, T=[None, T_regist], proc_type=[],text=embd_y)
|
| loss_sim = loss_imgsim(img_rec, y1, label=msk_tgt*(y1>thresh_imgsim))
|
| loss_mse = loss_imgmse(img_rec, y1, label=msk_tgt*(y1>=0.0))
|
| loss_ddf1 = loss_reg1(ddf_comp, img=y1)
|
|
|
| loss_regist = 0
|
| loss_regist += LOSS_WEIGHTS_REGIST[0] * loss_sim
|
| loss_regist += LOSS_WEIGHTS_REGIST[1] * loss_mse
|
| loss_regist += LOSS_WEIGHTS_REGIST[2] * loss_ddf1
|
|
|
|
|
| if torch.isnan(x0).any():
|
| print(f"*** Encountered NaN in input image x0 at epoch {epoch}, step {step}.")
|
|
|
| if loss_ddf1>0.002:
|
| print(f"*** High registration DDF loss at epoch {epoch}, step {step}: {loss_ddf1.item()}.")
|
|
|
| loss_regist = torch.sqrt(cond_ratio+MSK_EPS) *loss_regist
|
| if not args.eval_only:
|
| optimizer.zero_grad()
|
| loss_regist.backward()
|
|
|
| torch.nn.utils.clip_grad_norm_(Deformddpm.parameters(), max_norm=0.02)
|
| optimizer.step()
|
|
|
| epoch_loss_regist += loss_regist.item()
|
| epoch_loss_imgsim += loss_sim.item()
|
| epoch_loss_imgmse += loss_mse.item()
|
| epoch_loss_ddfreg += loss_ddf1.item()
|
| else:
|
| loss_sim = torch.tensor(0.0)
|
| loss_mse = torch.tensor(0.0)
|
| loss_ddf1 = torch.tensor(0.0)
|
| loss_regist = torch.tensor(0.0)
|
| if step % REGISTRATION_STEP_RATIO==0:
|
| total_reg = total_reg-1
|
|
|
|
|
| if step % 10 == 0:
|
| print('step:',step,':', loss_tot.item(),'=',loss_gen_a.item(),'+', loss_gen_d.item(),'+',loss_ddf.item())
|
| print(f'- loss_regist: {loss_regist} = {loss_sim} (imgsim) + {loss_mse} (imgmse) + {loss_ddf1} (ddf)')
|
| print(f'- loss_contra: {loss_contra}')
|
|
|
|
|
| if max_steps_restart > 0 and step > 0 and step % MID_EPOCH_SAVE_STEPS == 0 and gpu_id == 0 and not args.no_save:
|
| _epoch_stats = {
|
| 'epoch_loss_tot': epoch_loss_tot,
|
| 'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 'epoch_loss_gen_a': epoch_loss_gen_a,
|
| 'epoch_loss_reg': epoch_loss_reg,
|
| 'epoch_loss_regist': epoch_loss_regist,
|
| 'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 'epoch_loss_imgmse': epoch_loss_imgmse,
|
| 'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 'epoch_loss_contrastive': epoch_loss_contrastive,
|
| 'total_reg': total_reg,
|
| 'total_contra': total_contra,
|
| 'loss_nan_step': loss_nan_step,
|
| 'rng_torch': torch.get_rng_state(),
|
| 'rng_numpy': np.random.get_state(),
|
| 'rng_python': random.getstate(),
|
| **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| }
|
| tmp_dir = os.path.join(model_save_path, "tmp")
|
| os.makedirs(tmp_dir, exist_ok=True)
|
| for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| os.remove(old_f)
|
| mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| torch.save({
|
| 'model_state_dict': state,
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'epoch': epoch,
|
| 'step': step,
|
| 'epoch_stats': _epoch_stats,
|
| }, mid_save)
|
| print(f" [mid-epoch] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
|
|
|
|
|
|
| steps_since_start += 1
|
| if max_steps_restart > 0 and steps_since_start >= max_steps_restart:
|
|
|
| if not (step > 0 and step % MID_EPOCH_SAVE_STEPS == 0) and gpu_id == 0 and not args.no_save:
|
| _epoch_stats = {
|
| 'epoch_loss_tot': epoch_loss_tot, 'epoch_loss_gen_d': epoch_loss_gen_d,
|
| 'epoch_loss_gen_a': epoch_loss_gen_a, 'epoch_loss_reg': epoch_loss_reg,
|
| 'epoch_loss_regist': epoch_loss_regist, 'epoch_loss_imgsim': epoch_loss_imgsim,
|
| 'epoch_loss_imgmse': epoch_loss_imgmse, 'epoch_loss_ddfreg': epoch_loss_ddfreg,
|
| 'epoch_loss_contrastive': epoch_loss_contrastive, 'total_reg': total_reg, 'total_contra': total_contra,
|
| 'loss_nan_step': loss_nan_step,
|
| 'rng_torch': torch.get_rng_state(), 'rng_numpy': np.random.get_state(),
|
| 'rng_python': random.getstate(),
|
| **(({'rng_xpu': torch.xpu.get_rng_state()} if DEVICE_TYPE == 'xpu' and hasattr(torch, 'xpu') else
|
| {'rng_cuda': torch.cuda.get_rng_state()} if torch.cuda.is_available() else {})),
|
| }
|
| tmp_dir = os.path.join(model_save_path, "tmp")
|
| os.makedirs(tmp_dir, exist_ok=True)
|
| for old_f in glob.glob(os.path.join(tmp_dir, "*.pth")):
|
| os.remove(old_f)
|
| mid_save = os.path.join(tmp_dir, f"{epoch:06d}_step{step:04d}{suffix_pth}")
|
| state = Deformddpm.module.state_dict() if use_distributed else Deformddpm.state_dict()
|
| torch.save({
|
| 'model_state_dict': state,
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'epoch': epoch,
|
| 'step': step,
|
| 'epoch_stats': _epoch_stats,
|
| }, mid_save)
|
| print(f" [restart] Saved checkpoint at epoch {epoch} step {step}: {mid_save}", flush=True)
|
| if gpu_id == 0:
|
| print(f" [restart] Proactive restart after {steps_since_start} steps "
|
| f"(limit {max_steps_restart}). Exiting with code {EXIT_CODE_RESTART}.", flush=True)
|
|
|
| _empty_cache(DEVICE_TYPE)
|
| gc.collect()
|
| if use_distributed and dist.is_initialized():
|
| dist.barrier()
|
| dist.destroy_process_group()
|
| sys.exit(EXIT_CODE_RESTART)
|
|
|
| if gpu_id == 0:
|
| print('==================')
|
| print(epoch,':', epoch_loss_tot,'=',epoch_loss_gen_a,'+', epoch_loss_gen_d,'+',epoch_loss_reg, ' (ang+dist+regul)')
|
| print(f' loss_contrastive: {epoch_loss_contrastive}')
|
| total_reg_safe = max(total_reg, 1)
|
| print(f' loss_regist: {epoch_loss_regist/total_reg_safe} = {epoch_loss_imgsim/total_reg_safe} (imgsim) + {epoch_loss_imgmse/total_reg_safe} (imgmse) + {epoch_loss_ddfreg/total_reg_safe} (ddf)')
|
| print('==================')
|
|
|
|
|
| if 0 == epoch % epoch_per_save and not args.no_save:
|
| save_dir=model_save_path + str(epoch).rjust(6, '0') + suffix_pth
|
| os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
|
|
| if not use_distributed:
|
| print(f"saved in {save_dir}")
|
|
|
| torch.save({
|
| 'model_state_dict': Deformddpm.state_dict(),
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'epoch': epoch
|
| }, save_dir)
|
| elif gpu_id == 0:
|
| print(f"saved in {save_dir}")
|
|
|
| torch.save({
|
| 'model_state_dict': Deformddpm.module.state_dict(),
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'epoch': epoch
|
| }, save_dir)
|
|
|
| if gpu_id == 0 and not args.no_save:
|
| tmp_dir = os.path.join(model_dir, "tmp")
|
| tmp_pths = glob.glob(os.path.join(tmp_dir, "*.pth"))
|
| if tmp_pths:
|
| for f in tmp_pths:
|
| os.remove(f)
|
| print(f" [cleanup] Cleared {len(tmp_pths)} tmp/ mid-epoch checkpoints", flush=True)
|
|
|
| initial_step = 0
|
|
|
|
|
|
|
|
|
| if DEVICE_TYPE == 'xpu' and use_distributed:
|
| if gpu_id == 0:
|
| print(f" [xpu-restart] Epoch {epoch} done. Restarting to reset CCL state.", flush=True)
|
| _empty_cache(DEVICE_TYPE)
|
| gc.collect()
|
| if dist.is_initialized():
|
| dist.barrier()
|
| dist.destroy_process_group()
|
| sys.exit(EXIT_CODE_RESTART)
|
|
|
|
|
| _empty_cache(DEVICE_TYPE)
|
| gc.collect()
|
| if use_distributed and dist.is_initialized():
|
| dist.destroy_process_group()
|
|
|
| def ddp_load_dict(gpu_id, Deformddpm, optimizer, model_file,use_distributed=True, load_strict=False):
|
|
|
|
|
|
|
|
|
|
|
| gc.collect()
|
| _empty_cache(DEVICE_TYPE)
|
| if gpu_id == 0:
|
| utils.print_memory_usage("Before Loading Model")
|
|
|
| checkpoint = torch.load(model_file, map_location='cpu')
|
| if use_distributed:
|
| Deformddpm.module.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
| else:
|
| Deformddpm.load_state_dict(checkpoint['model_state_dict'], strict=load_strict)
|
|
|
|
|
|
|
|
|
| if 'optimizer_state_dict' in checkpoint and not args.reset_optimizer:
|
| saved_opt = checkpoint['optimizer_state_dict']
|
| saved_state = saved_opt.get('state', {})
|
| param_list = [p for group in optimizer.param_groups for p in group['params']]
|
|
|
|
|
| all_match = True
|
| skipped = 0
|
| for idx, s in saved_state.items():
|
| if int(idx) < len(param_list):
|
| p = param_list[int(idx)]
|
| for k, v in s.items():
|
| if isinstance(v, torch.Tensor) and v.dim() > 0 and v.shape != p.shape:
|
| all_match = False
|
| break
|
| if not all_match:
|
| break
|
|
|
| if all_match:
|
| optimizer.load_state_dict(saved_opt)
|
| else:
|
|
|
| for saved_g, group in zip(saved_opt['param_groups'], optimizer.param_groups):
|
| for k, v in saved_g.items():
|
| if k != 'params':
|
| group[k] = v
|
|
|
| for idx, s in saved_state.items():
|
| idx_int = int(idx)
|
| if idx_int < len(param_list):
|
| p = param_list[idx_int]
|
| shapes_ok = all(
|
| v.shape == p.shape for k, v in s.items()
|
| if isinstance(v, torch.Tensor) and v.dim() > 0
|
| )
|
| if shapes_ok:
|
|
|
| new_state = {}
|
| for k, v in s.items():
|
| if isinstance(v, torch.Tensor):
|
| new_state[k] = v.to(dtype=p.dtype, device=p.device) if v.dim() > 0 else v
|
| else:
|
| new_state[k] = v
|
| optimizer.state[p] = new_state
|
| else:
|
| skipped += 1
|
| if gpu_id == 0:
|
| loaded = len(saved_state) - skipped
|
| print(f" [checkpoint] Selective optimizer load: {loaded} params restored, "
|
| f"{skipped} skipped (shape mismatch, fresh Adam for those)", flush=True)
|
| elif args.reset_optimizer and gpu_id == 0:
|
| print(" [checkpoint] --reset-optimizer: skipping optimizer state, starting fresh Adam", flush=True)
|
| del checkpoint
|
| if gpu_id == 0:
|
| utils.print_memory_usage("After Loading Checkpoint on GPU")
|
|
|
| if use_distributed:
|
|
|
| dist.barrier()
|
| for param in Deformddpm.parameters():
|
| dist.broadcast(param.data, src=0)
|
|
|
|
|
| basename = os.path.basename(model_file)
|
| epoch_from_file = int(basename[:6])
|
| if '_step' in basename:
|
|
|
| initial_epoch = epoch_from_file
|
| else:
|
|
|
| initial_epoch = epoch_from_file + 1
|
|
|
| return initial_epoch, Deformddpm, optimizer
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| if "LOCAL_RANK" in os.environ:
|
|
|
| use_distributed = True
|
| local_rank = int(os.environ["LOCAL_RANK"])
|
| world_size = int(os.environ["WORLD_SIZE"])
|
| print(f"torchrun launch: LOCAL_RANK={local_rank}, RANK={os.environ.get('RANK')}, WORLD_SIZE={world_size}")
|
| try:
|
| main_train(local_rank, world_size)
|
| except Exception as e:
|
| import traceback
|
| print(f"\n{'='*60}\nRANK {os.environ.get('RANK')} FAILED:\n{'='*60}", flush=True)
|
| traceback.print_exc()
|
| raise
|
| elif use_distributed:
|
|
|
| world_size = _device_count(DEVICE_TYPE)
|
| print(f"Distributed {DEVICE_TYPE.upper()} device number = {world_size}")
|
| mp.spawn(main_train,args = (world_size,),nprocs = world_size)
|
| else:
|
| main_train(0,1) |