Omini3D / tests /test_3modes_equivalence.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
test_3modes_equivalence.py — Verify that the OMorpher-based Scripts/OM_train_3modes.py
produces identical network outputs, losses, gradients, and weight updates as the
original DeformDDPM-based OM_train_3modes.py.
Runs one training step of each mode (diffusion, contrastive, registration) with
identical pre-computed inputs, shared initial weights, and deterministic seeding.
Compares every intermediate tensor to catch divergences.
Usage:
python tests/test_3modes_equivalence.py
"""
import os
import sys
import copy
import random
import numpy as np
import torch
import torch.nn.functional as F
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT_DIR)
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
import Diffusion.losses as losses
from Diffusion.losses import Grad, LNCC, LMSE
from OMorpher import OMorpher
import utils
# ========================== Test Config ==========================
IMG_SIZE = 64 # must be >= 64 for multi-scale DDF generation
BATCHSIZE = 2
TIMESTEPS = 10
NDIMS = 3
V_SCALE = 5e-5
NOISE_SCALE = 0.1
NET_NAME = "recmulmodmutattnnet" # supports contrastive (has img_embd)
LR = 1e-5
DEVICE = "cpu"
# Loss constants (from 3modes)
LOSS_WEIGHTS_DIFF = [2.0, 1.0, 16]
LOSS_WEIGHTS_REGIST = [1.0, 0.3, 64]
LOSS_WEIGHT_CONTRASTIVE = 1.0
MSK_EPS = 0.01
ATOL = 1e-5
RTOL = 1e-4
def make_config():
return {
"data_name": "test",
"net_name": NET_NAME,
"ndims": NDIMS,
"img_size": IMG_SIZE,
"batchsize": BATCHSIZE,
"timesteps": TIMESTEPS,
"v_scale": V_SCALE,
"noise_scale": NOISE_SCALE,
"num_input_chn": 1,
"img_pad_mode": "zeros",
"ddf_pad_mode": "border",
"padding_mode": "border",
"resample_mode": "bilinear",
"lr": LR,
"epoch": 1,
"epoch_per_save": 1,
"condition_type": "slice",
"device": DEVICE,
}
def seed_all(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ========================== Builders ==========================
def build_original(config):
"""Build DeformDDPM + STN + losses + optimizer (original pipeline)."""
Net = get_net(config["net_name"])
network = Net(
n_steps=config["timesteps"],
ndims=config["ndims"],
num_input_chn=config["num_input_chn"],
res=config["img_size"],
)
ddpm = DeformDDPM(
network=network,
n_steps=config["timesteps"],
image_chw=[1] + [config["img_size"]] * config["ndims"],
device=config["device"],
batch_size=config["batchsize"],
img_pad_mode=config["img_pad_mode"],
v_scale=config["v_scale"],
)
ddf_stn = STN(
img_sz=config["img_size"], ndims=config["ndims"],
padding_mode=config["padding_mode"], device=config["device"],
)
ddpm.to(config["device"])
ddf_stn.to(config["device"])
loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
outrange_thresh=0.2, outrange_weight=1e3)
loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
outrange_thresh=0.6, outrange_weight=1e3)
loss_dist = losses.MRSE(img_sz=config["img_size"])
loss_ang = losses.NCC(img_sz=config["img_size"])
loss_imgsim = LNCC()
loss_imgmse = LMSE()
optimizer = torch.optim.Adam(ddpm.parameters(), lr=config["lr"])
return ddpm, ddf_stn, optimizer, loss_reg, loss_reg1, loss_dist, loss_ang, loss_imgsim, loss_imgmse
def build_omorpher(config):
"""Build OMorpher + losses + optimizer (Scripts pipeline)."""
om = OMorpher(config=config, device=config["device"])
loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
outrange_thresh=0.2, outrange_weight=1e3)
loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=om.ndims,
outrange_thresh=0.6, outrange_weight=1e3)
loss_imgsim = LNCC()
loss_imgmse = LMSE()
optimizer = torch.optim.Adam(om.network.parameters(), lr=config["lr"])
return om, optimizer, loss_reg, loss_reg1, loss_imgsim, loss_imgmse
def sync_weights(ddpm, om):
"""Copy network weights from DeformDDPM.network → OMorpher.network."""
om.network.load_state_dict(ddpm.network.state_dict())
def params_flat(module):
return torch.cat([p.detach().clone().flatten() for p in module.parameters()])
def grads_flat(module):
gs = []
for p in module.parameters():
gs.append(p.grad.detach().clone().flatten() if p.grad is not None
else torch.zeros_like(p.flatten()))
return torch.cat(gs)
def assert_close(name, a, b, atol=ATOL, rtol=RTOL):
if isinstance(a, (int, float)):
a, b = torch.tensor(a), torch.tensor(b)
if torch.allclose(a, b, atol=atol, rtol=rtol):
print(f" PASS {name}")
return True
diff = (a - b).abs()
print(f" FAIL {name}: max_diff={diff.max().item():.6e}, mean_diff={diff.mean().item():.6e}")
return False
# ========================== Shared Data Generators ==========================
def make_shared_data():
"""Create deterministic dummy tensors for all three modes."""
seed_all(123)
S = IMG_SIZE
x0 = torch.rand(BATCHSIZE, 1, S, S, S, dtype=torch.float32)
embd = torch.randn(BATCHSIZE, 1024, dtype=torch.float32)
blind_mask = torch.ones(1, 1, S, S, S, dtype=torch.float32) # deterministic
t = torch.tensor([3, 7]) # fixed timesteps
# Paired data (half batchsize)
B2 = max(1, BATCHSIZE // 2)
x1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
y1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
embd_y = torch.randn(B2, 1024, dtype=torch.float32)
t_contra = torch.tensor([2, 5])
return x0, embd, blind_mask, t, x1, y1, embd_y, t_contra
# ========================== Test: Mode 1 (Diffusion) ==========================
def test_mode1_diffusion():
"""Both pipelines: same noisy_img+dvf → network → loss → grad → weight update."""
print("\n" + "=" * 60)
print("TEST: Mode 1 — Diffusion Training Step")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, t, _, _, _, _ = make_shared_data()
# Build with identical weights
seed_all(42)
ddpm, ddf_stn, opt_o, loss_reg_o, _, loss_dist_o, loss_ang_o, _, _ = build_original(config)
seed_all(42)
om, opt_m, loss_reg_m, _, _, _ = build_omorpher(config)
sync_weights(ddpm, om)
ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
# Pre-compute shared tensors using OMorpher (source of truth)
seed_all(200)
noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
# Use 'none' proc_type for deterministic cond (identity)
cond_img, _, cond_ratio = om._proc_cond_img(x0, proc_type='none')
# --- Original pipeline ---
ddpm.network.train()
pre_dvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
loss_ddf_o = loss_reg_o(pre_dvf_o, img=x0)
trm_pred_o = ddf_stn(pre_dvf_o, dvf_gt)
loss_gen_d_o = loss_dist_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_gen_a_o = loss_ang_o(pred=trm_pred_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_tot_o = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_o + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_o
+ LOSS_WEIGHTS_DIFF[2] * loss_ddf_o)
loss_tot_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_o
opt_o.zero_grad()
loss_tot_o.backward()
grad_o = grads_flat(ddpm.network)
opt_o.step()
# --- OMorpher pipeline ---
om.network.train()
pre_dvf_m = om.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
loss_ddf_m = loss_reg_m(pre_dvf_m, img=x0)
trm_pred_m = om.stn_full(pre_dvf_m, dvf_gt)
loss_gen_d_m = om._loss_dist(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_gen_a_m = om._loss_ang(pred=trm_pred_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_tot_m = (LOSS_WEIGHTS_DIFF[0] * loss_gen_a_m + LOSS_WEIGHTS_DIFF[1] * loss_gen_d_m
+ LOSS_WEIGHTS_DIFF[2] * loss_ddf_m)
loss_tot_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * loss_tot_m
opt_m.zero_grad()
loss_tot_m.backward()
grad_m = grads_flat(om.network)
opt_m.step()
# --- Compare ---
ok &= assert_close("pre_dvf", pre_dvf_o.detach(), pre_dvf_m.detach())
ok &= assert_close("trm_pred", trm_pred_o.detach(), trm_pred_m.detach())
ok &= assert_close("loss_gen_a", loss_gen_a_o.item(), loss_gen_a_m.item())
ok &= assert_close("loss_gen_d", loss_gen_d_o.item(), loss_gen_d_m.item())
ok &= assert_close("loss_ddf", loss_ddf_o.item(), loss_ddf_m.item())
ok &= assert_close("loss_tot", loss_tot_o.item(), loss_tot_m.item())
ok &= assert_close("gradients", grad_o, grad_m)
ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
print(f"\nMode 1 Diffusion: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Mode 2 (Contrastive) ==========================
def test_mode2_contrastive():
"""Both pipelines: same masked x0 + cond_img → network → img_embd → cosine loss → grad."""
print("\n" + "=" * 60)
print("TEST: Mode 2 — Contrastive Training Step")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, _, _, _, _, t_contra = make_shared_data()
seed_all(42)
ddpm, _, opt_o, *_ = build_original(config)
seed_all(42)
om, opt_m, *_ = build_omorpher(config)
sync_weights(ddpm, om)
ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
# Shared cond_img
cond_img, _, _ = om._proc_cond_img(x0, proc_type='none')
x_in = (x0 * blind_mask).detach()
y_in = cond_img.detach()
text_in = embd.detach()
# --- Original ---
ddpm.network.train()
_ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in)
if not hasattr(ddpm.network, 'img_embd') or ddpm.network.img_embd is None:
print(" SKIP: network has no img_embd attribute")
return True
img_embd_o = ddpm.network.img_embd
loss_c_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_o, embd, dim=-1).mean())
opt_o.zero_grad()
loss_c_o.backward()
torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1)
grad_o = grads_flat(ddpm.network)
opt_o.step()
# --- OMorpher ---
om.network.train()
_ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in)
img_embd_m = om.network.img_embd
loss_c_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(img_embd_m, embd, dim=-1).mean())
opt_m.zero_grad()
loss_c_m.backward()
torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1)
grad_m = grads_flat(om.network)
opt_m.step()
# --- Compare ---
ok &= assert_close("img_embd", img_embd_o.detach(), img_embd_m.detach())
ok &= assert_close("loss_contrastive", loss_c_o.item(), loss_c_m.item())
ok &= assert_close("gradients_clipped", grad_o, grad_m)
ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
print(f"\nMode 2 Contrastive: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Mode 3 (Registration) ==========================
def test_mode3_registration():
"""Both pipelines: reverse diffusion loop → DDF → rec image → reg losses → grad."""
print("\n" + "=" * 60)
print("TEST: Mode 3 — Registration Training Step")
print("=" * 60)
config = make_config()
_, _, _, _, x1, y1, embd_y, _ = make_shared_data()
seed_all(42)
ddpm, _, opt_o, _, loss_reg1_o, _, _, loss_imgsim_o, loss_imgmse_o = build_original(config)
seed_all(42)
om, opt_m, _, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config)
sync_weights(ddpm, om)
ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
thresh_imgsim = 0.01
# Shared proc_cond_img
y1_proc, _, cond_ratio = om._proc_cond_img(y1, proc_type='none')
# Fixed timestep schedule
T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
T_regist_batched = [[t_val for _ in range(max(1, BATCHSIZE // 2))] for t_val in T_regist]
# --- Original: call DeformDDPM.diff_recover via forward ---
ddpm.train()
[ddf_o, _], [img_rec_o, _, _], _ = ddpm(
img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y,
)
loss_sim_o = loss_imgsim_o(img_rec_o, y1, label=(y1 > thresh_imgsim))
loss_mse_o = loss_imgmse_o(img_rec_o, y1)
loss_ddf_o = loss_reg1_o(ddf_o, img=y1)
loss_regist_o = (LOSS_WEIGHTS_REGIST[0] * loss_sim_o +
LOSS_WEIGHTS_REGIST[1] * loss_mse_o +
LOSS_WEIGHTS_REGIST[2] * loss_ddf_o)
loss_regist_o = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_o
opt_o.zero_grad()
loss_regist_o.backward()
torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4)
grad_o = grads_flat(ddpm.network)
opt_o.step()
# --- OMorpher: reverse_diffuse_train logic ---
om.network.train()
B = x1.shape[0]
S = om.img_size
ddf_comp = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device)
img_rec_m = x1.clone().detach()
k = 2
trainable_iters = T_regist_batched[-1:-k - 1:-1]
for i in T_regist_batched:
t = torch.tensor(np.array([i])).to(om.device)
if i in trainable_iters:
pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y)
else:
with torch.no_grad():
pre_dvf = om.network(x=img_rec_m, y=y1_proc, t=t, rec_num=2, text=embd_y)
ddf_comp = om.stn_full(ddf_comp, pre_dvf) + pre_dvf
img_rec_m = om.img_stn(x1.clone().detach(), ddf_comp)
loss_sim_m = loss_imgsim_m(img_rec_m, y1, label=(y1 > thresh_imgsim))
loss_mse_m = loss_imgmse_m(img_rec_m, y1)
loss_ddf_m = loss_reg1_m(ddf_comp, img=y1)
loss_regist_m = (LOSS_WEIGHTS_REGIST[0] * loss_sim_m +
LOSS_WEIGHTS_REGIST[1] * loss_mse_m +
LOSS_WEIGHTS_REGIST[2] * loss_ddf_m)
loss_regist_m = torch.sqrt(cond_ratio + MSK_EPS) * loss_regist_m
opt_m.zero_grad()
loss_regist_m.backward()
torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4)
grad_m = grads_flat(om.network)
opt_m.step()
# --- Compare ---
ok &= assert_close("ddf_comp", ddf_o.detach(), ddf_comp.detach())
ok &= assert_close("img_rec", img_rec_o.detach(), img_rec_m.detach())
ok &= assert_close("loss_sim", loss_sim_o.item(), loss_sim_m.item())
ok &= assert_close("loss_mse", loss_mse_o.item(), loss_mse_m.item())
ok &= assert_close("loss_ddf_reg", loss_ddf_o.item(), loss_ddf_m.item())
ok &= assert_close("loss_regist", loss_regist_o.item(), loss_regist_m.item())
ok &= assert_close("gradients_clipped", grad_o, grad_m)
ok &= assert_close("weights_after", params_flat(ddpm.network), params_flat(om.network))
print(f"\nMode 3 Registration: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Full Sequence ==========================
def test_full_sequence():
"""Run all 3 modes sequentially on both pipelines, compare final weights."""
print("\n" + "=" * 60)
print("TEST: Full Step Sequence (Diffusion → Contrastive → Registration)")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, t, x1, y1, embd_y, t_contra = make_shared_data()
seed_all(42)
ddpm, ddf_stn, opt_o, loss_reg_o, loss_reg1_o, loss_dist_o, loss_ang_o, loss_imgsim_o, loss_imgmse_o = build_original(config)
seed_all(42)
om, opt_m, loss_reg_m, loss_reg1_m, loss_imgsim_m, loss_imgmse_m = build_omorpher(config)
sync_weights(ddpm, om)
ok = assert_close("init_weights", params_flat(ddpm.network), params_flat(om.network))
# Shared pre-computed tensors
seed_all(200)
noisy_img, dvf_gt, _ = om._get_random_ddf(x0, t)
cond_img_diff, _, cond_ratio_diff = om._proc_cond_img(x0, proc_type='none')
y1_proc, _, cond_ratio_reg = om._proc_cond_img(y1, proc_type='none')
T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
T_regist_batched = [[tv for _ in range(max(1, BATCHSIZE // 2))] for tv in T_regist]
# ========== Step 1: Diffusion (both) ==========
ddpm.network.train()
om.network.train()
# Original
pdvf_o = ddpm.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd)
ld_o = loss_reg_o(pdvf_o, img=x0)
tp_o = ddf_stn(pdvf_o, dvf_gt)
lgd_o = loss_dist_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lga_o = loss_ang_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_o = (LOSS_WEIGHTS_DIFF[0] * lga_o + LOSS_WEIGHTS_DIFF[1] * lgd_o + LOSS_WEIGHTS_DIFF[2] * ld_o)
lt_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_o
opt_o.zero_grad(); lt_o.backward(); opt_o.step()
# OMorpher
pdvf_m = om.network(x=noisy_img * blind_mask, y=cond_img_diff, t=t, rec_num=2, text=embd)
ld_m = loss_reg_m(pdvf_m, img=x0)
tp_m = om.stn_full(pdvf_m, dvf_gt)
lgd_m = om._loss_dist(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lga_m = om._loss_ang(pred=tp_m, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_m = (LOSS_WEIGHTS_DIFF[0] * lga_m + LOSS_WEIGHTS_DIFF[1] * lgd_m + LOSS_WEIGHTS_DIFF[2] * ld_m)
lt_m = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio_diff) * lt_m
opt_m.zero_grad(); lt_m.backward(); opt_m.step()
ok &= assert_close("weights_after_diffusion", params_flat(ddpm.network), params_flat(om.network))
# ========== Step 2: Contrastive (both) ==========
x_in = (x0 * blind_mask).detach()
y_in = cond_img_diff.detach()
text_in = embd.detach()
_ = ddpm.network(x=x_in, y=y_in, t=t_contra, text=text_in)
has_embd = hasattr(ddpm.network, 'img_embd') and ddpm.network.img_embd is not None
if has_embd:
ie_o = ddpm.network.img_embd
lc_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_o, embd, dim=-1).mean())
opt_o.zero_grad(); lc_o.backward()
torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.1); opt_o.step()
_ = om.network(x=x_in, y=y_in, t=t_contra, text=text_in)
ie_m = om.network.img_embd
lc_m = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_m, embd, dim=-1).mean())
opt_m.zero_grad(); lc_m.backward()
torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.1); opt_m.step()
ok &= assert_close("loss_contrastive_seq", lc_o.item(), lc_m.item())
ok &= assert_close("weights_after_contrastive", params_flat(ddpm.network), params_flat(om.network))
# ========== Step 3: Registration (both) ==========
# Original
[ddf_o, _], [rec_o, _, _], _ = ddpm(
img_org=x1, cond_imgs=y1_proc, T=[None, T_regist_batched], proc_type=[], text=embd_y,
)
ls_o = loss_imgsim_o(rec_o, y1, label=(y1 > 0.01))
lms_o = loss_imgmse_o(rec_o, y1)
ldr_o = loss_reg1_o(ddf_o, img=y1)
lr_o = (LOSS_WEIGHTS_REGIST[0] * ls_o + LOSS_WEIGHTS_REGIST[1] * lms_o + LOSS_WEIGHTS_REGIST[2] * ldr_o)
lr_o = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_o
opt_o.zero_grad(); lr_o.backward()
torch.nn.utils.clip_grad_norm_(ddpm.parameters(), max_norm=0.4); opt_o.step()
# OMorpher
B = x1.shape[0]
S = om.img_size
ddf_m = torch.zeros([B, om.ndims] + [S] * om.ndims, dtype=torch.float32, device=om.device)
rec_m = x1.clone().detach()
k = 2
trainable_iters = T_regist_batched[-1:-k - 1:-1]
for i in T_regist_batched:
tt = torch.tensor(np.array([i])).to(om.device)
if i in trainable_iters:
pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y)
else:
with torch.no_grad():
pdvf = om.network(x=rec_m, y=y1_proc, t=tt, rec_num=2, text=embd_y)
ddf_m = om.stn_full(ddf_m, pdvf) + pdvf
rec_m = om.img_stn(x1.clone().detach(), ddf_m)
ls_m = loss_imgsim_m(rec_m, y1, label=(y1 > 0.01))
lms_m = loss_imgmse_m(rec_m, y1)
ldr_m = loss_reg1_m(ddf_m, img=y1)
lr_m = (LOSS_WEIGHTS_REGIST[0] * ls_m + LOSS_WEIGHTS_REGIST[1] * lms_m + LOSS_WEIGHTS_REGIST[2] * ldr_m)
lr_m = torch.sqrt(cond_ratio_reg + MSK_EPS) * lr_m
opt_m.zero_grad(); lr_m.backward()
torch.nn.utils.clip_grad_norm_(om.network.parameters(), max_norm=0.4); opt_m.step()
ok &= assert_close("loss_regist_seq", lr_o.item(), lr_m.item())
ok &= assert_close("weights_after_registration", params_flat(ddpm.network), params_flat(om.network))
print(f"\nFull Sequence: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Checkpoint Compatibility ==========================
def test_checkpoint_compat():
"""Checkpoints saved by one version load correctly into the other."""
print("\n" + "=" * 60)
print("TEST: Checkpoint Cross-Compatibility")
print("=" * 60)
import tempfile
config = make_config()
seed_all(42)
ddpm, *_ = build_original(config)
seed_all(42)
om, *_ = build_omorpher(config)
sync_weights(ddpm, om)
ok = True
with tempfile.TemporaryDirectory() as tmpdir:
# Save DeformDDPM checkpoint (original format: includes network.* + stn keys)
path_o = os.path.join(tmpdir, "orig.pth")
torch.save({'model_state_dict': ddpm.state_dict(), 'epoch': 0}, path_o)
# Save OMorpher checkpoint (network.* prefix only)
path_m = os.path.join(tmpdir, "om.pth")
sd_m = {f"network.{k}": v for k, v in om.network.state_dict().items()}
torch.save({'model_state_dict': sd_m, 'epoch': 0}, path_m)
# Load original → OMorpher
om2, *_ = build_omorpher(config)
ckpt = torch.load(path_o, map_location='cpu')
cleaned = {}
for k, v in ckpt['model_state_dict'].items():
k = k.replace("module.", "")
if k.startswith("network."):
k = k[len("network."):]
cleaned[k] = v
net_keys = set(om2.network.state_dict().keys())
om2.network.load_state_dict({k: v for k, v in cleaned.items() if k in net_keys}, strict=False)
ok &= assert_close("orig→OMorpher", params_flat(om2.network), params_flat(ddpm.network))
# Load OMorpher → DeformDDPM
seed_all(42)
ddpm2, *_ = build_original(config)
ddpm2.load_state_dict(torch.load(path_m, map_location='cpu')['model_state_dict'], strict=False)
ok &= assert_close("OMorpher→orig", params_flat(ddpm2.network), params_flat(om.network))
print(f"\nCheckpoint Compat: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Main ==========================
if __name__ == "__main__":
print("=" * 60)
print("3-Modes Equivalence Test Suite")
print(f"IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, TIMESTEPS={TIMESTEPS}, NET={NET_NAME}")
print("=" * 60)
results = {}
results["Mode 1: Diffusion"] = test_mode1_diffusion()
results["Mode 2: Contrastive"] = test_mode2_contrastive()
results["Mode 3: Registration"] = test_mode3_registration()
results["Full Sequence"] = test_full_sequence()
results["Checkpoint Compat"] = test_checkpoint_compat()
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
all_ok = True
for name, passed in results.items():
status = "PASS" if passed else "FAIL"
print(f" [{status}] {name}")
all_ok &= passed
print(f"\nOverall: {'ALL TESTS PASSED' if all_ok else 'SOME TESTS FAILED'}")
sys.exit(0 if all_ok else 1)