| """
|
| test_3modes_equivalence.py — Verify that the OMorpher-based Scripts/OM_train_3modes.py
|
| produces identical network outputs, losses, gradients, and weight updates as the
|
| original DeformDDPM-based OM_train_3modes.py.
|
|
|
| Runs one training step of each mode (diffusion, contrastive, registration) with
|
| identical pre-computed inputs, shared initial weights, and deterministic seeding.
|
| Compares every intermediate tensor to catch divergences.
|
|
|
| Usage:
|
| python tests/test_3modes_equivalence.py
|
| """
|
|
|
| import os
|
| import sys
|
| import copy
|
| import random
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT_DIR)
|
|
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
| import Diffusion.losses as losses
|
| from Diffusion.losses import Grad, LNCC, LMSE
|
| from OMorpher import OMorpher
|
| import utils
|
|
|
|
|
|
|
| IMG_SIZE = 64
|
| BATCHSIZE = 2
|
| TIMESTEPS = 10
|
| NDIMS = 3
|
| V_SCALE = 5e-5
|
| NOISE_SCALE = 0.1
|
| NET_NAME = "recmulmodmutattnnet"
|
| LR = 1e-5
|
| DEVICE = "cpu"
|
|
|
|
|
| LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16]
|
| LOSS_WEIGHTS_REGIST = [1.0, 0.3, 64]
|
| LOSS_WEIGHT_CONTRASTIVE = 1.0
|
| MSK_EPS = 0.01
|
|
|
| ATOL = 1e-5
|
| RTOL = 1e-4
|
|
|
|
|
| def make_config():
|
| return {
|
| "data_name": "test",
|
| "net_name": NET_NAME,
|
| "ndims": NDIMS,
|
| "img_size": IMG_SIZE,
|
| "batchsize": BATCHSIZE,
|
| "timesteps": TIMESTEPS,
|
| "v_scale": V_SCALE,
|
| "noise_scale": NOISE_SCALE,
|
| "num_input_chn": 1,
|
| "img_pad_mode": "zeros",
|
| "ddf_pad_mode": "border",
|
| "padding_mode": "border",
|
| "resample_mode": "bilinear",
|
| "lr": LR,
|
| "epoch": 1,
|
| "epoch_per_save": 1,
|
| "condition_type": "slice",
|
| "device": DEVICE,
|
| }
|
|
|
|
|
| def seed_all(seed=42):
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
| def build_original(config):
|
| """Build DeformDDPM + STN + losses + optimizer (original pipeline)."""
|
| Net = get_net(config["net_name"])
|
| network = Net(
|
| n_steps=config["timesteps"],
|
| ndims=config["ndims"],
|
| num_input_chn=config["num_input_chn"],
|
| res=config["img_size"],
|
| )
|
| ddpm = DeformDDPM(
|
| network=network,
|
| n_steps=config["timesteps"],
|
| image_chw=[1] + [config["img_size"]] * config["ndims"],
|
| device=config["device"],
|
| batch_size=config["batchsize"],
|
| img_pad_mode=config["img_pad_mode"],
|
| v_scale=config["v_scale"],
|
| )
|
| ddf_stn = STN(
|
| img_sz=config["img_size"], ndims=config["ndims"],
|
| padding_mode=config["padding_mode"], device=config["device"],
|
| )
|
| ddpm.to(config["device"])
|
| ddf_stn.to(config["device"])
|
|
|
| loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
|
| outrange_thresh=0.2, outrange_weight=1e3)
|
| loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
|
| outrange_thresh=0.6, outrange_weight=1e3)
|
| loss_dist = losses.MRSE(img_sz=config["img_size"])
|
| loss_ang = losses.NCC(img_sz=config["img_size"])
|
| loss_imgsim = LNCC()
|
| loss_imgmse = LMSE()
|
| optimizer = torch.optim.Adam(ddpm.parameters(), lr=config["lr"])
|
|
|
| return ddpm, ddf_stn, optimizer, loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse
|
|
|
|
|
| def build_omorpher(config):
|
| """Build OMorpher + losses + optimizer (Scripts pipeline)."""
|
| om = OMorpher(config=config, device=config["device"])
|
|
|
| loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
|
| outrange_thresh=0.2, outrange_weight=1e3)
|
| loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
|
| outrange_thresh=0.6, outrange_weight=1e3)
|
| loss_imgsim = LNCC()
|
| loss_imgmse = LMSE()
|
| optimizer = torch.optim.Adam(om.network.parameters(), lr=config["lr"])
|
|
|
| return om, optimizer, loss_reg, loss_reg1, loss_imgsim, loss_imgmse
|
|
|
|
|
| def sync_weights(ddpm, om):
|
| """Copy network weights from DeformDDPM.network → OMorpher.network."""
|
| om.network.load_state_dict(ddpm.network.state_dict())
|
|
|
|
|
| def params_flat(module):
|
| return torch.cat([p.detach().clone().flatten() for p in module.parameters()])
|
|
|
|
|
| def grads_flat(module):
|
| gs = []
|
| for p in module.parameters():
|
| gs.append(p.grad.detach().clone().flatten() if p.grad is not None
|
| else torch.zeros_like(p.flatten()))
|
| return torch.cat(gs)
|
|
|
|
|
| def assert_close(name, a, b, atol=ATOL, rtol=RTOL):
|
| if isinstance(a, (int, float)):
|
| a, b = torch.tensor(a), torch.tensor(b)
|
| if torch.allclose(a, b, atol=atol, rtol=rtol):
|
| print(f" PASS {name}")
|
| return True
|
| diff = (a - b).abs()
|
| print(f" FAIL {name}: max_diff={diff.max().item():.6e}, mean_diff={diff.mean().item():.6e}")
|
| return False
|
|
|
|
|
|
|
|
|
| def make_shared_data():
|
| """Create deterministic dummy tensors for all three modes."""
|
| seed_all(123)
|
| S = IMG_SIZE
|
| x0 = torch.rand(BATCHSIZE, 1, S, S, S, dtype=torch.float32)
|
| embd = torch.randn(BATCHSIZE, 1024, dtype=torch.float32)
|
| blind_mask = torch.ones(1, 1, S, S, S, dtype=torch.float32)
|
| t = torch.tensor([3, 7])
|
|
|
|
|
| B2 = max(1, BATCHSIZE // 2)
|
| x1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
|
| y1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
|
| embd_y = torch.randn(B2, 1024, dtype=torch.float32)
|
|
|
| t_contra = torch.tensor([2, 5])
|
| return x0, embd, blind_mask, t, x1, y1, embd_y, t_contra
|
|
|
|
|
|
|
|
|
| def test_mode1_diffusion():
|
| """Both pipelines: same noisy_img+dvf → network → loss → grad → weight update."""
|
| print("\n" + "=" * 60)
|
| print("TEST: Mode 1 — Diffusion Training Step")
|
| print("=" * 60)
|
|
|
| config = make_config()
|
| x0, embd, blind_mask, t, _, _, _, _ = make_shared_data()
|
|
|
|
|
| seed_all(42)
|
| ddpm, ddf_stn, opt_o, loss_reg_o, _, loss_dist_o, loss_ang_o, _, _ = build_original(config)
|
| seed_all(42)
|
| om, opt_m, loss_reg_m, _, _, _ = build_omorpher(config)
|
| sync_weights(ddpm, om)
|
|
|
| ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
|
|
|
|
|
| seed_all(200)
|
| noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
|
|
|
| cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type='none')
|
|
|
|
|
| ddpm.network.train()
|
| pre_dvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
|
|
|
| loss_ddf_o = loss_reg_o(pre_dvf_o, img=x0)
|
| trm_pred_o = ddf_stn(pre_dvf_o, dvf_gt)
|
| loss_gen_d_o = loss_dist_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| loss_gen_a_o = loss_ang_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
|
|
| loss_tot_o = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_o + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_o
|
| + LOSS_WEIGHTS_DIFF[2] * loss_ddf_o)
|
| loss_tot_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_o
|
|
|
| opt_o.zero_grad()
|
| loss_tot_o.backward()
|
| grad_o = grads_flat(ddpm.network)
|
| opt_o.step()
|
|
|
|
|
| om.network.train()
|
| pre_dvf_m = om.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
|
|
|
| loss_ddf_m = loss_reg_m(pre_dvf_m, img=x0)
|
| trm_pred_m = om.stn_full(pre_dvf_m, dvf_gt)
|
| loss_gen_d_m = om._loss_dist(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| loss_gen_a_m = om._loss_ang(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
|
|
| loss_tot_m = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_m + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_m
|
| + LOSS_WEIGHTS_DIFF[2] * loss_ddf_m)
|
| loss_tot_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_m
|
|
|
| opt_m.zero_grad()
|
| loss_tot_m.backward()
|
| grad_m = grads_flat(om.network)
|
| opt_m.step()
|
|
|
|
|
| ok &= assert_close("pre_dvf", pre_dvf_o.detach(), pre_dvf_m.detach())
|
| ok &= assert_close("trm_pred", trm_pred_o.detach(), trm_pred_m.detach())
|
| ok &= assert_close("loss_gen_a", loss_gen_a_o.item(), loss_gen_a_m.item())
|
| ok &= assert_close("loss_gen_d", loss_gen_d_o.item(), loss_gen_d_m.item())
|
| ok &= assert_close("loss_ddf", loss_ddf_o.item(), loss_ddf_m.item())
|
| ok &= assert_close("loss_tot", loss_tot_o.item(), loss_tot_m.item())
|
| ok &= assert_close("gradients", grad_o, grad_m)
|
| ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
|
|
|
| print(f"\nMode 1 Diffusion: {'ALL PASSED' if ok else 'SOME FAILED'}")
|
| return ok
|
|
|
|
|
|
|
|
|
| def test_mode2_contrastive():
|
| """Both pipelines: same masked x0 + cond_img → network → img_embd → cosine loss → grad."""
|
| print("\n" + "=" * 60)
|
| print("TEST: Mode 2 — Contrastive Training Step")
|
| print("=" * 60)
|
|
|
| config = make_config()
|
| x0, embd, blind_mask, _, _, _, _, t_contra = make_shared_data()
|
|
|
| seed_all(42)
|
| ddpm, _, opt_o, *_ = build_original(config)
|
| seed_all(42)
|
| om, opt_m, *_ = build_omorpher(config)
|
| sync_weights(ddpm, om)
|
|
|
| ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
|
|
|
|
|
| cond_img, _, _ = om._proc_cond_img(x0, proc_type='none')
|
| x_in = (x0 * blind_mask).detach()
|
| y_in = cond_img.detach()
|
| text_in = embd.detach()
|
|
|
|
|
| ddpm.network.train()
|
| _ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in)
|
| if not hasattr(ddpm.network, 'img_embd') or ddpm.network.img_embd is None:
|
| print(" SKIP: network has no img_embd attribute")
|
| return True
|
| img_embd_o = ddpm.network.img_embd
|
| loss_c_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_o, embd, dim=-1).mean())
|
|
|
| opt_o.zero_grad()
|
| loss_c_o.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1)
|
| grad_o = grads_flat(ddpm.network)
|
| opt_o.step()
|
|
|
|
|
| om.network.train()
|
| _ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in)
|
| img_embd_m = om.network.img_embd
|
| loss_c_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_m, embd, dim=-1).mean())
|
|
|
| opt_m.zero_grad()
|
| loss_c_m.backward()
|
| torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1)
|
| grad_m = grads_flat(om.network)
|
| opt_m.step()
|
|
|
|
|
| ok &= assert_close("img_embd", img_embd_o.detach(), img_embd_m.detach())
|
| ok &= assert_close("loss_contrastive", loss_c_o.item(), loss_c_m.item())
|
| ok &= assert_close("gradients_clipped", grad_o, grad_m)
|
| ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
|
|
|
| print(f"\nMode 2 Contrastive: {'ALL PASSED' if ok else 'SOME FAILED'}")
|
| return ok
|
|
|
|
|
|
|
|
|
| def test_mode3_registration():
|
| """Both pipelines: reverse diffusion loop → DDF → rec image → reg losses → grad."""
|
| print("\n" + "=" * 60)
|
| print("TEST: Mode 3 — Registration Training Step")
|
| print("=" * 60)
|
|
|
| config = make_config()
|
| _, _, _, _, x1, y1, embd_y, _ = make_shared_data()
|
|
|
| seed_all(42)
|
| ddpm, _, opt_o, _, loss_reg1_o, _, _, loss_imgsim_o, loss_imgmse_o = build_original(config)
|
| seed_all(42)
|
| om, opt_m, _, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config)
|
| sync_weights(ddpm, om)
|
|
|
| ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
|
| thresh_imgsim = 0.01
|
|
|
|
|
| y1_proc, _, cond_ratio = om._proc_cond_img(y1, proc_type='none')
|
|
|
|
|
| T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
|
| T_regist_batched = [[t_val for _ in range(max(1, BATCHSIZE // 2))] for t_val in T_regist]
|
|
|
|
|
| ddpm.train()
|
| [ddf_o, _], [img_rec_o, _, _], _ = ddpm(
|
| img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y,
|
| )
|
| loss_sim_o = loss_imgsim_o(img_rec_o, y1, label=(y1 > thresh_imgsim))
|
| loss_mse_o = loss_imgmse_o(img_rec_o, y1)
|
| loss_ddf_o = loss_reg1_o(ddf_o, img=y1)
|
| loss_regist_o = (LOSS_WEIGHTS_REGIST[0] * loss_sim_o +
|
| LOSS_WEIGHTS_REGIST[1] * loss_mse_o +
|
| LOSS_WEIGHTS_REGIST[2] * loss_ddf_o)
|
| loss_regist_o = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_o
|
|
|
| opt_o.zero_grad()
|
| loss_regist_o.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4)
|
| grad_o = grads_flat(ddpm.network)
|
| opt_o.step()
|
|
|
|
|
| om.network.train()
|
| B = x1.shape[0]
|
| S = om.img_size
|
| ddf_comp = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device)
|
| img_rec_m = x1.clone().detach()
|
|
|
| k = 2
|
| trainable_iters = T_regist_batched[-1:-k - 1:-1]
|
|
|
| for i in T_regist_batched:
|
| t = torch.tensor(np.array([i])).to(om.device)
|
| if i in trainable_iters:
|
| pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y)
|
| else:
|
| with torch.no_grad():
|
| pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y)
|
| ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| img_rec_m = om.img_stn(x1.clone().detach(), ddf_comp)
|
|
|
| loss_sim_m = loss_imgsim_m(img_rec_m, y1, label=(y1 > thresh_imgsim))
|
| loss_mse_m = loss_imgmse_m(img_rec_m, y1)
|
| loss_ddf_m = loss_reg1_m(ddf_comp, img=y1)
|
| loss_regist_m = (LOSS_WEIGHTS_REGIST[0] * loss_sim_m +
|
| LOSS_WEIGHTS_REGIST[1] * loss_mse_m +
|
| LOSS_WEIGHTS_REGIST[2] * loss_ddf_m)
|
| loss_regist_m = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_m
|
|
|
| opt_m.zero_grad()
|
| loss_regist_m.backward()
|
| torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4)
|
| grad_m = grads_flat(om.network)
|
| opt_m.step()
|
|
|
|
|
| ok &= assert_close("ddf_comp", ddf_o.detach(), ddf_comp.detach())
|
| ok &= assert_close("img_rec", img_rec_o.detach(), img_rec_m.detach())
|
| ok &= assert_close("loss_sim", loss_sim_o.item(), loss_sim_m.item())
|
| ok &= assert_close("loss_mse", loss_mse_o.item(), loss_mse_m.item())
|
| ok &= assert_close("loss_ddf_reg", loss_ddf_o.item(), loss_ddf_m.item())
|
| ok &= assert_close("loss_regist", loss_regist_o.item(), loss_regist_m.item())
|
| ok &= assert_close("gradients_clipped", grad_o, grad_m)
|
| ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
|
|
|
| print(f"\nMode 3 Registration: {'ALL PASSED' if ok else 'SOME FAILED'}")
|
| return ok
|
|
|
|
|
|
|
|
|
| def test_full_sequence():
|
| """Run all 3 modes sequentially on both pipelines, compare final weights."""
|
| print("\n" + "=" * 60)
|
| print("TEST: Full Step Sequence (Diffusion → Contrastive → Registration)")
|
| print("=" * 60)
|
|
|
| config = make_config()
|
| x0, embd, blind_mask, t, x1, y1, embd_y, t_contra = make_shared_data()
|
|
|
| seed_all(42)
|
| ddpm, ddf_stn, opt_o, loss_reg_o, loss_reg1_o, loss_dist_o, loss_ang_o, loss_imgsim_o, loss_imgmse_o = build_original(config)
|
| seed_all(42)
|
| om, opt_m, loss_reg_m, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config)
|
| sync_weights(ddpm, om)
|
|
|
| ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
|
|
|
|
|
| seed_all(200)
|
| noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
|
| cond_img_diff, _, cond_ratio_diff = om._proc_cond_img(x0, proc_type='none')
|
| y1_proc, _, cond_ratio_reg = om._proc_cond_img(y1, proc_type='none')
|
|
|
| T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
|
| T_regist_batched = [[tv for _ in range(max(1, BATCHSIZE // 2))] for tv in T_regist]
|
|
|
|
|
| ddpm.network.train()
|
| om.network.train()
|
|
|
|
|
| pdvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd)
|
| ld_o = loss_reg_o(pdvf_o, img=x0)
|
| tp_o = ddf_stn(pdvf_o, dvf_gt)
|
| lgd_o = loss_dist_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| lga_o = loss_ang_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| lt_o = (LOSS_WEIGHTS_DIFF[0] * lga_o + LOSS_WEIGHTS_DIFF[1] * lgd_o + LOSS_WEIGHTS_DIFF[2] * ld_o)
|
| lt_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_o
|
| opt_o.zero_grad(); lt_o.backward(); opt_o.step()
|
|
|
|
|
| pdvf_m = om.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd)
|
| ld_m = loss_reg_m(pdvf_m, img=x0)
|
| tp_m = om.stn_full(pdvf_m, dvf_gt)
|
| lgd_m = om._loss_dist(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| lga_m = om._loss_ang(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
|
| lt_m = (LOSS_WEIGHTS_DIFF[0] * lga_m + LOSS_WEIGHTS_DIFF[1] * lgd_m + LOSS_WEIGHTS_DIFF[2] * ld_m)
|
| lt_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_m
|
| opt_m.zero_grad(); lt_m.backward(); opt_m.step()
|
|
|
| ok &= assert_close("weights_after_diffusion", params_flat(ddpm.network), params_flat(om.network))
|
|
|
|
|
| x_in = (x0 * blind_mask).detach()
|
| y_in = cond_img_diff.detach()
|
| text_in = embd.detach()
|
|
|
| _ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in)
|
| has_embd = hasattr(ddpm.network, 'img_embd') and ddpm.network.img_embd is not None
|
| if has_embd:
|
| ie_o = ddpm.network.img_embd
|
| lc_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_o, embd, dim=-1).mean())
|
| opt_o.zero_grad(); lc_o.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1); opt_o.step()
|
|
|
| _ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in)
|
| ie_m = om.network.img_embd
|
| lc_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_m, embd, dim=-1).mean())
|
| opt_m.zero_grad(); lc_m.backward()
|
| torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1); opt_m.step()
|
|
|
| ok &= assert_close("loss_contrastive_seq", lc_o.item(), lc_m.item())
|
| ok &= assert_close("weights_after_contrastive", params_flat(ddpm.network), params_flat(om.network))
|
|
|
|
|
|
|
| [ddf_o, _], [rec_o, _, _], _ = ddpm(
|
| img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y,
|
| )
|
| ls_o = loss_imgsim_o(rec_o, y1, label=(y1 > 0.01))
|
| lms_o = loss_imgmse_o(rec_o, y1)
|
| ldr_o = loss_reg1_o(ddf_o, img=y1)
|
| lr_o = (LOSS_WEIGHTS_REGIST[0] * ls_o + LOSS_WEIGHTS_REGIST[1] * lms_o + LOSS_WEIGHTS_REGIST[2] * ldr_o)
|
| lr_o = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_o
|
| opt_o.zero_grad(); lr_o.backward()
|
| torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4); opt_o.step()
|
|
|
|
|
| B = x1.shape[0]
|
| S = om.img_size
|
| ddf_m = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device)
|
| rec_m = x1.clone().detach()
|
| k = 2
|
| trainable_iters = T_regist_batched[-1:-k - 1:-1]
|
| for i in T_regist_batched:
|
| tt = torch.tensor(np.array([i])).to(om.device)
|
| if i in trainable_iters:
|
| pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y)
|
| else:
|
| with torch.no_grad():
|
| pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y)
|
| ddf_m = om.stn_full(ddf_m, pdvf) + pdvf
|
| rec_m = om.img_stn(x1.clone().detach(), ddf_m)
|
|
|
| ls_m = loss_imgsim_m(rec_m, y1, label=(y1 > 0.01))
|
| lms_m = loss_imgmse_m(rec_m, y1)
|
| ldr_m = loss_reg1_m(ddf_m, img=y1)
|
| lr_m = (LOSS_WEIGHTS_REGIST[0] * ls_m + LOSS_WEIGHTS_REGIST[1] * lms_m + LOSS_WEIGHTS_REGIST[2] * ldr_m)
|
| lr_m = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_m
|
| opt_m.zero_grad(); lr_m.backward()
|
| torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4); opt_m.step()
|
|
|
| ok &= assert_close("loss_regist_seq", lr_o.item(), lr_m.item())
|
| ok &= assert_close("weights_after_registration", params_flat(ddpm.network), params_flat(om.network))
|
|
|
| print(f"\nFull Sequence: {'ALL PASSED' if ok else 'SOME FAILED'}")
|
| return ok
|
|
|
|
|
|
|
|
|
| def test_checkpoint_compat():
|
| """Checkpoints saved by one version load correctly into the other."""
|
| print("\n" + "=" * 60)
|
| print("TEST: Checkpoint Cross-Compatibility")
|
| print("=" * 60)
|
| import tempfile
|
|
|
| config = make_config()
|
| seed_all(42)
|
| ddpm, *_ = build_original(config)
|
| seed_all(42)
|
| om, *_ = build_omorpher(config)
|
| sync_weights(ddpm, om)
|
|
|
| ok = True
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
| path_o = os.path.join(tmpdir, "orig.pth")
|
| torch.save({'model_state_dict': ddpm.state_dict(), 'epoch': 0}, path_o)
|
|
|
|
|
| path_m = os.path.join(tmpdir, "om.pth")
|
| sd_m = {f"network.{k}": v for k, v in om.network.state_dict().items()}
|
| torch.save({'model_state_dict': sd_m, 'epoch': 0}, path_m)
|
|
|
|
|
| om2, *_ = build_omorpher(config)
|
| ckpt = torch.load(path_o, map_location='cpu')
|
| cleaned = {}
|
| for k, v in ckpt['model_state_dict'].items():
|
| k = k.replace("module.", "")
|
| if k.startswith("network."):
|
| k = k[len("network."):]
|
| cleaned[k] = v
|
| net_keys = set(om2.network.state_dict().keys())
|
| om2.network.load_state_dict({k: v for k, v in cleaned.items() if k in net_keys}, strict=False)
|
| ok &= assert_close("orig→OMorpher", params_flat(om2.network), params_flat(ddpm.network))
|
|
|
|
|
| seed_all(42)
|
| ddpm2, *_ = build_original(config)
|
| ddpm2.load_state_dict(torch.load(path_m, map_location='cpu')['model_state_dict'], strict=False)
|
| ok &= assert_close("OMorpher→orig", params_flat(ddpm2.network), params_flat(om.network))
|
|
|
| print(f"\nCheckpoint Compat: {'ALL PASSED' if ok else 'SOME FAILED'}")
|
| return ok
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| print("=" * 60)
|
| print("3-Modes Equivalence Test Suite")
|
| print(f"IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, TIMESTEPS={TIMESTEPS}, NET={NET_NAME}")
|
| print("=" * 60)
|
|
|
| results = {}
|
| results["Mode 1: Diffusion"] = test_mode1_diffusion()
|
| results["Mode 2: Contrastive"] = test_mode2_contrastive()
|
| results["Mode 3: Registration"] = test_mode3_registration()
|
| results["Full Sequence"] = test_full_sequence()
|
| results["Checkpoint Compat"] = test_checkpoint_compat()
|
|
|
| print("\n" + "=" * 60)
|
| print("SUMMARY")
|
| print("=" * 60)
|
| all_ok = True
|
| for name, passed in results.items():
|
| status = "PASS" if passed else "FAIL"
|
| print(f" [{status}] {name}")
|
| all_ok &= passed
|
|
|
| print(f"\nOverall: {'ALL TESTS PASSED' if all_ok else 'SOME TESTS FAILED'}")
|
| sys.exit(0 if all_ok else 1)
|
|
|