""" test_3modes_equivalence.py — Verify that the OMorpher-based Scripts/OM_train_3modes.py produces identical network outputs, losses, gradients, and weight updates as the original DeformDDPM-based OM_train_3modes.py. Runs one training step of each mode (diffusion, contrastive, registration) with identical pre-computed inputs, shared initial weights, and deterministic seeding. Compares every intermediate tensor to catch divergences. Usage: python tests/test_3modes_equivalence.py """ import os import sys import copy import random import numpy as np import torch import torch.nn.functional as F ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, ROOT_DIR) from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN import Diffusion.losses as losses from Diffusion.losses import Grad, LNCC, LMSE from OMorpher import OMorpher import utils # ========================== Test Config ========================== IMG_SIZE = 64 # must be >= 64 for multi-scale DDF generation BATCHSIZE = 2 TIMESTEPS = 10 NDIMS = 3 V_SCALE = 5e-5 NOISE_SCALE = 0.1 NET_NAME = "recmulmodmutattnnet" # supports contrastive (has img_embd) LR = 1e-5 DEVICE = "cpu" # Loss constants (from 3modes) LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16] LOSS_WEIGHTS_REGIST = [1.0, 0.3, 64] LOSS_WEIGHT_CONTRASTIVE = 1.0 MSK_EPS = 0.01 ATOL = 1e-5 RTOL = 1e-4 def make_config(): return { "data_name": "test", "net_name": NET_NAME, "ndims": NDIMS, "img_size": IMG_SIZE, "batchsize": BATCHSIZE, "timesteps": TIMESTEPS, "v_scale": V_SCALE, "noise_scale": NOISE_SCALE, "num_input_chn": 1, "img_pad_mode": "zeros", "ddf_pad_mode": "border", "padding_mode": "border", "resample_mode": "bilinear", "lr": LR, "epoch": 1, "epoch_per_save": 1, "condition_type": "slice", "device": DEVICE, } def seed_all(seed=42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ========================== Builders ========================== def build_original(config): """Build DeformDDPM + STN + losses + optimizer (original pipeline).""" Net = get_net(config["net_name"]) network = Net( n_steps=config["timesteps"], ndims=config["ndims"], num_input_chn=config["num_input_chn"], res=config["img_size"], ) ddpm = DeformDDPM( network=network, n_steps=config["timesteps"], image_chw=[1] + [config["img_size"]] * config["ndims"], device=config["device"], batch_size=config["batchsize"], img_pad_mode=config["img_pad_mode"], v_scale=config["v_scale"], ) ddf_stn = STN( img_sz=config["img_size"], ndims=config["ndims"], padding_mode=config["padding_mode"], device=config["device"], ) ddpm.to(config["device"]) ddf_stn.to(config["device"]) loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"], outrange_thresh=0.2, outrange_weight=1e3) loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"], outrange_thresh=0.6, outrange_weight=1e3) loss_dist = losses.MRSE(img_sz=config["img_size"]) loss_ang = losses.NCC(img_sz=config["img_size"]) loss_imgsim = LNCC() loss_imgmse = LMSE() optimizer = torch.optim.Adam(ddpm.parameters(), lr=config["lr"]) return ddpm, ddf_stn, optimizer, loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse def build_omorpher(config): """Build OMorpher + losses + optimizer (Scripts pipeline).""" om = OMorpher(config=config, device=config["device"]) loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims, outrange_thresh=0.2, outrange_weight=1e3) loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims, outrange_thresh=0.6, outrange_weight=1e3) loss_imgsim = LNCC() loss_imgmse = LMSE() optimizer = torch.optim.Adam(om.network.parameters(), lr=config["lr"]) return om, optimizer, loss_reg, loss_reg1, loss_imgsim, loss_imgmse def sync_weights(ddpm, om): """Copy network weights from DeformDDPM.network → OMorpher.network.""" om.network.load_state_dict(ddpm.network.state_dict()) def params_flat(module): return torch.cat([p.detach().clone().flatten() for p in module.parameters()]) def grads_flat(module): gs = [] for p in module.parameters(): gs.append(p.grad.detach().clone().flatten() if p.grad is not None else torch.zeros_like(p.flatten())) return torch.cat(gs) def assert_close(name, a, b, atol=ATOL, rtol=RTOL): if isinstance(a, (int, float)): a, b = torch.tensor(a), torch.tensor(b) if torch.allclose(a, b, atol=atol, rtol=rtol): print(f" PASS {name}") return True diff = (a - b).abs() print(f" FAIL {name}: max_diff={diff.max().item():.6e}, mean_diff={diff.mean().item():.6e}") return False # ========================== Shared Data Generators ========================== def make_shared_data(): """Create deterministic dummy tensors for all three modes.""" seed_all(123) S = IMG_SIZE x0 = torch.rand(BATCHSIZE, 1, S, S, S, dtype=torch.float32) embd = torch.randn(BATCHSIZE, 1024, dtype=torch.float32) blind_mask = torch.ones(1, 1, S, S, S, dtype=torch.float32) # deterministic t = torch.tensor([3, 7]) # fixed timesteps # Paired data (half batchsize) B2 = max(1, BATCHSIZE // 2) x1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32) y1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32) embd_y = torch.randn(B2, 1024, dtype=torch.float32) t_contra = torch.tensor([2, 5]) return x0, embd, blind_mask, t, x1, y1, embd_y, t_contra # ========================== Test: Mode 1 (Diffusion) ========================== def test_mode1_diffusion(): """Both pipelines: same noisy_img+dvf → network → loss → grad → weight update.""" print("\n" + "=" * 60) print("TEST: Mode 1 — Diffusion Training Step") print("=" * 60) config = make_config() x0, embd, blind_mask, t, _, _, _, _ = make_shared_data() # Build with identical weights seed_all(42) ddpm, ddf_stn, opt_o, loss_reg_o, _, loss_dist_o, loss_ang_o, _, _ = build_original(config) seed_all(42) om, opt_m, loss_reg_m, _, _, _ = build_omorpher(config) sync_weights(ddpm, om) ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network)) # Pre-compute shared tensors using OMorpher (source of truth) seed_all(200) noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t) # Use 'none' proc_type for deterministic cond (identity) cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type='none') # --- Original pipeline --- ddpm.network.train() pre_dvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd) loss_ddf_o = loss_reg_o(pre_dvf_o, img=x0) trm_pred_o = ddf_stn(pre_dvf_o, dvf_gt) loss_gen_d_o = loss_dist_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) loss_gen_a_o = loss_ang_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) loss_tot_o = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_o + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_o + LOSS_WEIGHTS_DIFF[2] * loss_ddf_o) loss_tot_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_o opt_o.zero_grad() loss_tot_o.backward() grad_o = grads_flat(ddpm.network) opt_o.step() # --- OMorpher pipeline --- om.network.train() pre_dvf_m = om.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd) loss_ddf_m = loss_reg_m(pre_dvf_m, img=x0) trm_pred_m = om.stn_full(pre_dvf_m, dvf_gt) loss_gen_d_m = om._loss_dist(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) loss_gen_a_m = om._loss_ang(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) loss_tot_m = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_m + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_m + LOSS_WEIGHTS_DIFF[2] * loss_ddf_m) loss_tot_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_m opt_m.zero_grad() loss_tot_m.backward() grad_m = grads_flat(om.network) opt_m.step() # --- Compare --- ok &= assert_close("pre_dvf", pre_dvf_o.detach(), pre_dvf_m.detach()) ok &= assert_close("trm_pred", trm_pred_o.detach(), trm_pred_m.detach()) ok &= assert_close("loss_gen_a", loss_gen_a_o.item(), loss_gen_a_m.item()) ok &= assert_close("loss_gen_d", loss_gen_d_o.item(), loss_gen_d_m.item()) ok &= assert_close("loss_ddf", loss_ddf_o.item(), loss_ddf_m.item()) ok &= assert_close("loss_tot", loss_tot_o.item(), loss_tot_m.item()) ok &= assert_close("gradients", grad_o, grad_m) ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network)) print(f"\nMode 1 Diffusion: {'ALL PASSED' if ok else 'SOME FAILED'}") return ok # ========================== Test: Mode 2 (Contrastive) ========================== def test_mode2_contrastive(): """Both pipelines: same masked x0 + cond_img → network → img_embd → cosine loss → grad.""" print("\n" + "=" * 60) print("TEST: Mode 2 — Contrastive Training Step") print("=" * 60) config = make_config() x0, embd, blind_mask, _, _, _, _, t_contra = make_shared_data() seed_all(42) ddpm, _, opt_o, *_ = build_original(config) seed_all(42) om, opt_m, *_ = build_omorpher(config) sync_weights(ddpm, om) ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network)) # Shared cond_img cond_img, _, _ = om._proc_cond_img(x0, proc_type='none') x_in = (x0 * blind_mask).detach() y_in = cond_img.detach() text_in = embd.detach() # --- Original --- ddpm.network.train() _ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in) if not hasattr(ddpm.network, 'img_embd') or ddpm.network.img_embd is None: print(" SKIP: network has no img_embd attribute") return True img_embd_o = ddpm.network.img_embd loss_c_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_o, embd, dim=-1).mean()) opt_o.zero_grad() loss_c_o.backward() torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1) grad_o = grads_flat(ddpm.network) opt_o.step() # --- OMorpher --- om.network.train() _ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in) img_embd_m = om.network.img_embd loss_c_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_m, embd, dim=-1).mean()) opt_m.zero_grad() loss_c_m.backward() torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1) grad_m = grads_flat(om.network) opt_m.step() # --- Compare --- ok &= assert_close("img_embd", img_embd_o.detach(), img_embd_m.detach()) ok &= assert_close("loss_contrastive", loss_c_o.item(), loss_c_m.item()) ok &= assert_close("gradients_clipped", grad_o, grad_m) ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network)) print(f"\nMode 2 Contrastive: {'ALL PASSED' if ok else 'SOME FAILED'}") return ok # ========================== Test: Mode 3 (Registration) ========================== def test_mode3_registration(): """Both pipelines: reverse diffusion loop → DDF → rec image → reg losses → grad.""" print("\n" + "=" * 60) print("TEST: Mode 3 — Registration Training Step") print("=" * 60) config = make_config() _, _, _, _, x1, y1, embd_y, _ = make_shared_data() seed_all(42) ddpm, _, opt_o, _, loss_reg1_o, _, _, loss_imgsim_o, loss_imgmse_o = build_original(config) seed_all(42) om, opt_m, _, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config) sync_weights(ddpm, om) ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network)) thresh_imgsim = 0.01 # Shared proc_cond_img y1_proc, _, cond_ratio = om._proc_cond_img(y1, proc_type='none') # Fixed timestep schedule T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True) T_regist_batched = [[t_val for _ in range(max(1, BATCHSIZE // 2))] for t_val in T_regist] # --- Original: call DeformDDPM.diff_recover via forward --- ddpm.train() [ddf_o, _], [img_rec_o, _, _], _ = ddpm( img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y, ) loss_sim_o = loss_imgsim_o(img_rec_o, y1, label=(y1 > thresh_imgsim)) loss_mse_o = loss_imgmse_o(img_rec_o, y1) loss_ddf_o = loss_reg1_o(ddf_o, img=y1) loss_regist_o = (LOSS_WEIGHTS_REGIST[0] * loss_sim_o + LOSS_WEIGHTS_REGIST[1] * loss_mse_o + LOSS_WEIGHTS_REGIST[2] * loss_ddf_o) loss_regist_o = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_o opt_o.zero_grad() loss_regist_o.backward() torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4) grad_o = grads_flat(ddpm.network) opt_o.step() # --- OMorpher: reverse_diffuse_train logic --- om.network.train() B = x1.shape[0] S = om.img_size ddf_comp = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device) img_rec_m = x1.clone().detach() k = 2 trainable_iters = T_regist_batched[-1:-k - 1:-1] for i in T_regist_batched: t = torch.tensor(np.array([i])).to(om.device) if i in trainable_iters: pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y) else: with torch.no_grad(): pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y) ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf img_rec_m = om.img_stn(x1.clone().detach(), ddf_comp) loss_sim_m = loss_imgsim_m(img_rec_m, y1, label=(y1 > thresh_imgsim)) loss_mse_m = loss_imgmse_m(img_rec_m, y1) loss_ddf_m = loss_reg1_m(ddf_comp, img=y1) loss_regist_m = (LOSS_WEIGHTS_REGIST[0] * loss_sim_m + LOSS_WEIGHTS_REGIST[1] * loss_mse_m + LOSS_WEIGHTS_REGIST[2] * loss_ddf_m) loss_regist_m = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_m opt_m.zero_grad() loss_regist_m.backward() torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4) grad_m = grads_flat(om.network) opt_m.step() # --- Compare --- ok &= assert_close("ddf_comp", ddf_o.detach(), ddf_comp.detach()) ok &= assert_close("img_rec", img_rec_o.detach(), img_rec_m.detach()) ok &= assert_close("loss_sim", loss_sim_o.item(), loss_sim_m.item()) ok &= assert_close("loss_mse", loss_mse_o.item(), loss_mse_m.item()) ok &= assert_close("loss_ddf_reg", loss_ddf_o.item(), loss_ddf_m.item()) ok &= assert_close("loss_regist", loss_regist_o.item(), loss_regist_m.item()) ok &= assert_close("gradients_clipped", grad_o, grad_m) ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network)) print(f"\nMode 3 Registration: {'ALL PASSED' if ok else 'SOME FAILED'}") return ok # ========================== Test: Full Sequence ========================== def test_full_sequence(): """Run all 3 modes sequentially on both pipelines, compare final weights.""" print("\n" + "=" * 60) print("TEST: Full Step Sequence (Diffusion → Contrastive → Registration)") print("=" * 60) config = make_config() x0, embd, blind_mask, t, x1, y1, embd_y, t_contra = make_shared_data() seed_all(42) ddpm, ddf_stn, opt_o, loss_reg_o, loss_reg1_o, loss_dist_o, loss_ang_o, loss_imgsim_o, loss_imgmse_o = build_original(config) seed_all(42) om, opt_m, loss_reg_m, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config) sync_weights(ddpm, om) ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network)) # Shared pre-computed tensors seed_all(200) noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t) cond_img_diff, _, cond_ratio_diff = om._proc_cond_img(x0, proc_type='none') y1_proc, _, cond_ratio_reg = om._proc_cond_img(y1, proc_type='none') T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True) T_regist_batched = [[tv for _ in range(max(1, BATCHSIZE // 2))] for tv in T_regist] # ========== Step 1: Diffusion (both) ========== ddpm.network.train() om.network.train() # Original pdvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd) ld_o = loss_reg_o(pdvf_o, img=x0) tp_o = ddf_stn(pdvf_o, dvf_gt) lgd_o = loss_dist_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) lga_o = loss_ang_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) lt_o = (LOSS_WEIGHTS_DIFF[0] * lga_o + LOSS_WEIGHTS_DIFF[1] * lgd_o + LOSS_WEIGHTS_DIFF[2] * ld_o) lt_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_o opt_o.zero_grad(); lt_o.backward(); opt_o.step() # OMorpher pdvf_m = om.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd) ld_m = loss_reg_m(pdvf_m, img=x0) tp_m = om.stn_full(pdvf_m, dvf_gt) lgd_m = om._loss_dist(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) lga_m = om._loss_ang(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask) lt_m = (LOSS_WEIGHTS_DIFF[0] * lga_m + LOSS_WEIGHTS_DIFF[1] * lgd_m + LOSS_WEIGHTS_DIFF[2] * ld_m) lt_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_m opt_m.zero_grad(); lt_m.backward(); opt_m.step() ok &= assert_close("weights_after_diffusion", params_flat(ddpm.network), params_flat(om.network)) # ========== Step 2: Contrastive (both) ========== x_in = (x0 * blind_mask).detach() y_in = cond_img_diff.detach() text_in = embd.detach() _ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in) has_embd = hasattr(ddpm.network, 'img_embd') and ddpm.network.img_embd is not None if has_embd: ie_o = ddpm.network.img_embd lc_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_o, embd, dim=-1).mean()) opt_o.zero_grad(); lc_o.backward() torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1); opt_o.step() _ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in) ie_m = om.network.img_embd lc_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_m, embd, dim=-1).mean()) opt_m.zero_grad(); lc_m.backward() torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1); opt_m.step() ok &= assert_close("loss_contrastive_seq", lc_o.item(), lc_m.item()) ok &= assert_close("weights_after_contrastive", params_flat(ddpm.network), params_flat(om.network)) # ========== Step 3: Registration (both) ========== # Original [ddf_o, _], [rec_o, _, _], _ = ddpm( img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y, ) ls_o = loss_imgsim_o(rec_o, y1, label=(y1 > 0.01)) lms_o = loss_imgmse_o(rec_o, y1) ldr_o = loss_reg1_o(ddf_o, img=y1) lr_o = (LOSS_WEIGHTS_REGIST[0] * ls_o + LOSS_WEIGHTS_REGIST[1] * lms_o + LOSS_WEIGHTS_REGIST[2] * ldr_o) lr_o = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_o opt_o.zero_grad(); lr_o.backward() torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4); opt_o.step() # OMorpher B = x1.shape[0] S = om.img_size ddf_m = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device) rec_m = x1.clone().detach() k = 2 trainable_iters = T_regist_batched[-1:-k - 1:-1] for i in T_regist_batched: tt = torch.tensor(np.array([i])).to(om.device) if i in trainable_iters: pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y) else: with torch.no_grad(): pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y) ddf_m = om.stn_full(ddf_m, pdvf) + pdvf rec_m = om.img_stn(x1.clone().detach(), ddf_m) ls_m = loss_imgsim_m(rec_m, y1, label=(y1 > 0.01)) lms_m = loss_imgmse_m(rec_m, y1) ldr_m = loss_reg1_m(ddf_m, img=y1) lr_m = (LOSS_WEIGHTS_REGIST[0] * ls_m + LOSS_WEIGHTS_REGIST[1] * lms_m + LOSS_WEIGHTS_REGIST[2] * ldr_m) lr_m = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_m opt_m.zero_grad(); lr_m.backward() torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4); opt_m.step() ok &= assert_close("loss_regist_seq", lr_o.item(), lr_m.item()) ok &= assert_close("weights_after_registration", params_flat(ddpm.network), params_flat(om.network)) print(f"\nFull Sequence: {'ALL PASSED' if ok else 'SOME FAILED'}") return ok # ========================== Test: Checkpoint Compatibility ========================== def test_checkpoint_compat(): """Checkpoints saved by one version load correctly into the other.""" print("\n" + "=" * 60) print("TEST: Checkpoint Cross-Compatibility") print("=" * 60) import tempfile config = make_config() seed_all(42) ddpm, *_ = build_original(config) seed_all(42) om, *_ = build_omorpher(config) sync_weights(ddpm, om) ok = True with tempfile.TemporaryDirectory() as tmpdir: # Save DeformDDPM checkpoint (original format: includes network.* + stn keys) path_o = os.path.join(tmpdir, "orig.pth") torch.save({'model_state_dict': ddpm.state_dict(), 'epoch': 0}, path_o) # Save OMorpher checkpoint (network.* prefix only) path_m = os.path.join(tmpdir, "om.pth") sd_m = {f"network.{k}": v for k, v in om.network.state_dict().items()} torch.save({'model_state_dict': sd_m, 'epoch': 0}, path_m) # Load original → OMorpher om2, *_ = build_omorpher(config) ckpt = torch.load(path_o, map_location='cpu') cleaned = {} for k, v in ckpt['model_state_dict'].items(): k = k.replace("module.", "") if k.startswith("network."): k = k[len("network."):] cleaned[k] = v net_keys = set(om2.network.state_dict().keys()) om2.network.load_state_dict({k: v for k, v in cleaned.items() if k in net_keys}, strict=False) ok &= assert_close("orig→OMorpher", params_flat(om2.network), params_flat(ddpm.network)) # Load OMorpher → DeformDDPM seed_all(42) ddpm2, *_ = build_original(config) ddpm2.load_state_dict(torch.load(path_m, map_location='cpu')['model_state_dict'], strict=False) ok &= assert_close("OMorpher→orig", params_flat(ddpm2.network), params_flat(om.network)) print(f"\nCheckpoint Compat: {'ALL PASSED' if ok else 'SOME FAILED'}") return ok # ========================== Main ========================== if __name__ == "__main__": print("=" * 60) print("3-Modes Equivalence Test Suite") print(f"IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, TIMESTEPS={TIMESTEPS}, NET={NET_NAME}") print("=" * 60) results = {} results["Mode 1: Diffusion"] = test_mode1_diffusion() results["Mode 2: Contrastive"] = test_mode2_contrastive() results["Mode 3: Registration"] = test_mode3_registration() results["Full Sequence"] = test_full_sequence() results["Checkpoint Compat"] = test_checkpoint_compat() print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) all_ok = True for name, passed in results.items(): status = "PASS" if passed else "FAIL" print(f" [{status}] {name}") all_ok &= passed print(f"\nOverall: {'ALL TESTS PASSED' if all_ok else 'SOME TESTS FAILED'}") sys.exit(0 if all_ok else 1)