Omini3D / tests /test_3modes_opt_equivalence.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
test_3modes_opt_equivalence.py — Verify that the optimized pipeline
(diffuser_opt.DeformDDPM + losses_opt.LNCC/MSLNCC) produces bit-identical
network outputs, losses, gradients, and weight updates as the original
(diffuser.DeformDDPM + losses.LNCC/MSLNCC).
Tests all three training modes:
1. Diffusion (single-step forward)
2. Contrastive (text-image alignment)
3. Registration (multi-step diff_recover loop)
Uses dummy tensors — no real dataset required.
Usage:
python -m pytest tests/test_3modes_opt_equivalence.py -v
python tests/test_3modes_opt_equivalence.py
"""
import os
import sys
import copy
import random
import numpy as np
import torch
import torch.nn.functional as F
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT_DIR)
# Original
from Diffusion.diffuser import DeformDDPM as OrigDeformDDPM
from Diffusion.losses import LNCC as OrigLNCC, MSLNCC as OrigMSLNCC, LMSE as OrigLMSE
from Diffusion.losses import Grad, MRSE, NCC
# Optimized
from Diffusion.diffuser_opt import DeformDDPM as OptDeformDDPM
from Diffusion.losses_opt import LNCC as OptLNCC, MSLNCC as OptMSLNCC, LMSE as OptLMSE
from Diffusion.networks_opt import get_net_opt, OptSTN
from Diffusion.networks import get_net, STN
# ========================== Test Config ==========================
IMG_SIZE = 64
BATCHSIZE = 2
TIMESTEPS = 10
NDIMS = 3
V_SCALE = 5e-5
NOISE_SCALE = 0.1
NET_NAME = "recmulmodmutattnnet"
LR = 1e-5
DEVICE = "cpu"
LOSS_WEIGHTS_DIFF = [2.0, 2.0, 4.0]
LOSS_WEIGHTS_REGIST = [1.0, 0.05, 128]
LOSS_WEIGHT_CONTRASTIVE = 1.0
MSK_EPS = 0.01
ATOL = 1e-5
RTOL = 1e-4
def make_config():
return {
"data_name": "test",
"net_name": NET_NAME,
"ndims": NDIMS,
"img_size": IMG_SIZE,
"batchsize": BATCHSIZE,
"timesteps": TIMESTEPS,
"v_scale": V_SCALE,
"noise_scale": NOISE_SCALE,
"num_input_chn": 1,
"img_pad_mode": "zeros",
"ddf_pad_mode": "border",
"padding_mode": "border",
"resample_mode": "bilinear",
"lr": LR,
"epoch": 1,
"epoch_per_save": 1,
"condition_type": "slice",
"device": DEVICE,
}
def seed_all(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ========================== Builders ==========================
def _build_ddpm(DDPMClass, config, use_opt_net=False):
"""Build DeformDDPM (original or optimized) + STN + losses + optimizer."""
if use_opt_net:
Net = get_net_opt(config["net_name"])
stn_cls = OptSTN
else:
Net = get_net(config["net_name"])
stn_cls = STN
network = Net(
n_steps=config["timesteps"],
ndims=config["ndims"],
num_input_chn=config["num_input_chn"],
res=config["img_size"],
)
ddpm = DDPMClass(
network=network,
n_steps=config["timesteps"],
image_chw=[1] + [config["img_size"]] * config["ndims"],
device=config["device"],
batch_size=config["batchsize"],
img_pad_mode=config["img_pad_mode"],
v_scale=config["v_scale"],
)
ddf_stn = stn_cls(
img_sz=config["img_size"], ndims=config["ndims"],
padding_mode=config["padding_mode"], device=config["device"],
)
ddpm.to(config["device"])
ddf_stn.to(config["device"])
loss_reg = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
outrange_thresh=0.2, outrange_weight=1e3)
loss_reg1 = Grad(penalty=['l1', 'negdetj', 'range'], ndims=config["ndims"],
outrange_thresh=0.6, outrange_weight=1e3)
loss_dist = MRSE(img_sz=config["img_size"])
loss_ang = NCC(img_sz=config["img_size"])
optimizer = torch.optim.Adam(ddpm.parameters(), lr=config["lr"])
return ddpm, ddf_stn, optimizer, loss_reg, loss_reg1, loss_dist, loss_ang
def params_flat(module):
return torch.cat([p.detach().clone().flatten() for p in module.parameters()])
def grads_flat(module):
gs = []
for p in module.parameters():
gs.append(p.grad.detach().clone().flatten() if p.grad is not None
else torch.zeros_like(p.flatten()))
return torch.cat(gs)
def assert_close(name, a, b, atol=ATOL, rtol=RTOL):
if isinstance(a, (int, float)):
a, b = torch.tensor(a), torch.tensor(b)
if torch.allclose(a, b, atol=atol, rtol=rtol):
print(f" PASS {name}")
return True
diff = (a - b).abs()
print(f" FAIL {name}: max_diff={diff.max().item():.6e}, mean_diff={diff.mean().item():.6e}")
return False
# ========================== Shared Data ==========================
def make_shared_data():
seed_all(123)
S = IMG_SIZE
x0 = torch.rand(BATCHSIZE, 1, S, S, S, dtype=torch.float32)
embd = torch.randn(BATCHSIZE, 1024, dtype=torch.float32)
blind_mask = torch.ones(1, 1, S, S, S, dtype=torch.float32)
t = torch.tensor([3, 7])
B2 = max(1, BATCHSIZE // 2)
x1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
y1 = torch.rand(B2, 1, S, S, S, dtype=torch.float32)
embd_y = torch.randn(B2, 1024, dtype=torch.float32)
t_contra = torch.tensor([2, 5])
return x0, embd, blind_mask, t, x1, y1, embd_y, t_contra
# ========================== Test: Loss Functions ==========================
def test_loss_equivalence():
"""Verify optimized LNCC/MSLNCC produce identical outputs to original."""
print("\n" + "=" * 60)
print("TEST: Loss Function Equivalence (LNCC, MSLNCC)")
print("=" * 60)
S = IMG_SIZE
I = torch.rand(1, 1, S, S, S)
J = torch.rand(1, 1, S, S, S)
label = (J > 0.3).float()
ok = True
# LNCC
orig_lncc = OrigLNCC()
opt_lncc = OptLNCC()
loss_orig = orig_lncc(I, J, label=label)
loss_opt = opt_lncc(I, J, label=label)
ok &= assert_close("LNCC_forward", loss_orig.item(), loss_opt.item())
loss_orig_nolabel = orig_lncc(I, J)
loss_opt_nolabel = opt_lncc(I, J)
ok &= assert_close("LNCC_no_label", loss_orig_nolabel.item(), loss_opt_nolabel.item())
# MSLNCC
orig_mslncc = OrigMSLNCC()
opt_mslncc = OptMSLNCC()
loss_orig_ms = orig_mslncc(I, J, label=label)
loss_opt_ms = opt_mslncc(I, J, label=label)
ok &= assert_close("MSLNCC_forward", loss_orig_ms.item(), loss_opt_ms.item())
# Gradients through LNCC
I_o = I.clone().requires_grad_(True)
I_p = I.clone().requires_grad_(True)
loss_o = orig_lncc(I_o, J)
loss_p = opt_lncc(I_p, J)
loss_o.backward()
loss_p.backward()
ok &= assert_close("LNCC_grad", I_o.grad, I_p.grad)
print(f"\nLoss Equivalence: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: DeformDDPM Equivalence ==========================
def test_ddpm_equivalence():
"""Verify optimized DeformDDPM methods produce identical outputs."""
print("\n" + "=" * 60)
print("TEST: DeformDDPM Method Equivalence")
print("=" * 60)
config = make_config()
seed_all(42)
orig, _, _, _, _, _, _ = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
opt, _, _, _, _, _, _ = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
# Sync weights
opt.load_state_dict(orig.state_dict(), strict=False)
ok = assert_close("init_weights", params_flat(orig.network), params_flat(opt.network))
S = IMG_SIZE
img = torch.rand(BATCHSIZE, 1, S, S, S)
# Test proc_cond_img for each proc_type
for ptype in ['none', 'uncon', 'adding', 'downsample', 'slice', 'slice1', 'independ']:
seed_all(200)
r1, m1, c1 = orig.proc_cond_img(img, proc_type=ptype)
seed_all(200)
r2, m2, c2 = opt.proc_cond_img(img, proc_type=ptype)
ok &= assert_close(f"proc_cond_{ptype}_img", r1, r2)
ok &= assert_close(f"proc_cond_{ptype}_ratio", c1, c2)
# Test _random_ddf_generate
seed_all(300)
ddf1, dddf1 = orig._random_ddf_generate(rec_num=1, mul_num=[torch.tensor([3]), torch.tensor([2])], select_num=2)
seed_all(300)
ddf2, dddf2 = opt._random_ddf_generate(rec_num=1, mul_num=[torch.tensor([3]), torch.tensor([2])], select_num=2)
ok &= assert_close("random_ddf_ddf", ddf1, ddf2)
ok &= assert_close("random_ddf_dddf", dddf1, dddf2)
print(f"\nDeformDDPM Equivalence: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Mode 1 (Diffusion) ==========================
def test_mode1_diffusion():
"""Identical diffusion training step: orig vs opt."""
print("\n" + "=" * 60)
print("TEST: Mode 1 — Diffusion Training Step")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, t, _, _, _, _ = make_shared_data()
seed_all(42)
orig, stn_o, opt_o, lr_o, _, ld_o, la_o = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
optm, stn_p, opt_p, lr_p, _, ld_p, la_p = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
optm.load_state_dict(orig.state_dict(), strict=False)
ok = assert_close("init_weights", params_flat(orig.network), params_flat(optm.network))
# Pre-compute shared tensors
seed_all(200)
noisy_img, dvf_gt, _ = orig._get_random_ddf(x0, t)
cond_img, _, cond_ratio = orig.proc_cond_img(x0, proc_type='none')
# --- Original ---
orig.network.train()
pre_dvf_o = orig.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
loss_ddf_o = lr_o(pre_dvf_o, img=x0)
trm_o = stn_o(pre_dvf_o, dvf_gt)
loss_d_o = ld_o(pred=trm_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_a_o = la_o(pred=trm_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_o = (LOSS_WEIGHTS_DIFF[0] * loss_a_o + LOSS_WEIGHTS_DIFF[1] * loss_d_o + LOSS_WEIGHTS_DIFF[2] * loss_ddf_o)
lt_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * lt_o
opt_o.zero_grad(); lt_o.backward()
grad_o = grads_flat(orig.network)
opt_o.step()
# --- Optimized ---
optm.network.train()
pre_dvf_p = optm.network(x=noisy_img * blind_mask, y=cond_img, t=t, rec_num=2, text=embd)
loss_ddf_p = lr_p(pre_dvf_p, img=x0)
trm_p = stn_p(pre_dvf_p, dvf_gt)
loss_d_p = ld_p(pred=trm_p, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
loss_a_p = la_p(pred=trm_p, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_p = (LOSS_WEIGHTS_DIFF[0] * loss_a_p + LOSS_WEIGHTS_DIFF[1] * loss_d_p + LOSS_WEIGHTS_DIFF[2] * loss_ddf_p)
lt_p = torch.sqrt(torch.tensor(1. + MSK_EPS) - cond_ratio) * lt_p
opt_p.zero_grad(); lt_p.backward()
grad_p = grads_flat(optm.network)
opt_p.step()
ok &= assert_close("pre_dvf", pre_dvf_o.detach(), pre_dvf_p.detach())
ok &= assert_close("loss_tot", lt_o.item(), lt_p.item())
ok &= assert_close("gradients", grad_o, grad_p)
ok &= assert_close("weights_after", params_flat(orig.network), params_flat(optm.network))
print(f"\nMode 1 Diffusion: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Mode 2 (Contrastive) ==========================
def test_mode2_contrastive():
"""Identical contrastive training step: orig vs opt."""
print("\n" + "=" * 60)
print("TEST: Mode 2 — Contrastive Training Step")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, _, _, _, _, t_contra = make_shared_data()
seed_all(42)
orig, _, opt_o, *_ = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
optm, _, opt_p, *_ = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
optm.load_state_dict(orig.state_dict(), strict=False)
ok = assert_close("init_weights", params_flat(orig.network), params_flat(optm.network))
cond_img, _, _ = orig.proc_cond_img(x0, proc_type='none')
x_in = (x0 * blind_mask).detach()
y_in = cond_img.detach()
# --- Original ---
orig.network.train()
_ = orig.network(x=x_in, y=y_in, t=t_contra, text=embd.detach())
if not hasattr(orig.network, 'img_embd') or orig.network.img_embd is None:
print(" SKIP: network has no img_embd")
return True
ie_o = orig.network.img_embd
lc_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_o, embd, dim=-1).mean())
opt_o.zero_grad(); lc_o.backward()
torch.nn.utils.clip_grad_norm_(orig.parameters(), max_norm=0.05)
grad_o = grads_flat(orig.network)
opt_o.step()
# --- Optimized ---
optm.network.train()
_ = optm.network(x=x_in, y=y_in, t=t_contra, text=embd.detach())
ie_p = optm.network.img_embd
lc_p = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_p, embd, dim=-1).mean())
opt_p.zero_grad(); lc_p.backward()
torch.nn.utils.clip_grad_norm_(optm.parameters(), max_norm=0.05)
grad_p = grads_flat(optm.network)
opt_p.step()
ok &= assert_close("img_embd", ie_o.detach(), ie_p.detach())
ok &= assert_close("loss_contrastive", lc_o.item(), lc_p.item())
ok &= assert_close("gradients_clipped", grad_o, grad_p)
ok &= assert_close("weights_after", params_flat(orig.network), params_flat(optm.network))
print(f"\nMode 2 Contrastive: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Mode 3 (Registration) ==========================
def test_mode3_registration():
"""Identical registration step via diff_recover: orig vs opt."""
print("\n" + "=" * 60)
print("TEST: Mode 3 — Registration Training Step (diff_recover)")
print("=" * 60)
config = make_config()
_, _, _, _, x1, y1, embd_y, _ = make_shared_data()
seed_all(42)
orig, _, opt_o, _, lr1_o, _, _ = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
optm, _, opt_p, _, lr1_p, _, _ = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
optm.load_state_dict(orig.state_dict(), strict=False)
ok = assert_close("init_weights", params_flat(orig.network), params_flat(optm.network))
# Shared
y1_proc, _, cond_ratio = orig.proc_cond_img(y1, proc_type='none')
T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
T_batched = [[tv for _ in range(max(1, BATCHSIZE // 2))] for tv in T_regist]
thresh_imgsim = 0.01
orig_lncc = OrigLNCC()
opt_lncc = OptLNCC()
orig_lmse = OrigLMSE()
opt_lmse = OptLMSE()
# --- Original ---
orig.train()
[ddf_o, _], [rec_o, _, _], _ = orig(
img_org=x1, cond_imgs=y1_proc, T=[None, T_batched], proc_type=[], text=embd_y,
)
msk_tgt = torch.tensor(1.0) + MSK_EPS
ls_o = orig_lncc(rec_o, y1, label=msk_tgt * (y1 > thresh_imgsim))
lm_o = orig_lmse(rec_o, y1, label=msk_tgt * (y1 >= 0.0))
ld_o = lr1_o(ddf_o, img=y1)
lr_o = (LOSS_WEIGHTS_REGIST[0] * ls_o + LOSS_WEIGHTS_REGIST[1] * lm_o + LOSS_WEIGHTS_REGIST[2] * ld_o)
lr_o = torch.sqrt(cond_ratio + MSK_EPS) * lr_o
opt_o.zero_grad(); lr_o.backward()
torch.nn.utils.clip_grad_norm_(orig.parameters(), max_norm=0.2)
grad_o = grads_flat(orig.network)
opt_o.step()
# --- Optimized ---
optm.train()
[ddf_p, _], [rec_p, _, _], _ = optm(
img_org=x1, cond_imgs=y1_proc, T=[None, T_batched], proc_type=[], text=embd_y,
)
ls_p = opt_lncc(rec_p, y1, label=msk_tgt * (y1 > thresh_imgsim))
lm_p = opt_lmse(rec_p, y1, label=msk_tgt * (y1 >= 0.0))
ld_p = lr1_p(ddf_p, img=y1)
lr_p = (LOSS_WEIGHTS_REGIST[0] * ls_p + LOSS_WEIGHTS_REGIST[1] * lm_p + LOSS_WEIGHTS_REGIST[2] * ld_p)
lr_p = torch.sqrt(cond_ratio + MSK_EPS) * lr_p
opt_p.zero_grad(); lr_p.backward()
torch.nn.utils.clip_grad_norm_(optm.parameters(), max_norm=0.2)
grad_p = grads_flat(optm.network)
opt_p.step()
ok &= assert_close("ddf_comp", ddf_o.detach(), ddf_p.detach())
ok &= assert_close("img_rec", rec_o.detach(), rec_p.detach())
ok &= assert_close("loss_sim", ls_o.item(), ls_p.item())
ok &= assert_close("loss_mse", lm_o.item(), lm_p.item())
ok &= assert_close("loss_ddf", ld_o.item(), ld_p.item())
ok &= assert_close("loss_regist", lr_o.item(), lr_p.item())
ok &= assert_close("gradients_clipped", grad_o, grad_p)
ok &= assert_close("weights_after", params_flat(orig.network), params_flat(optm.network))
print(f"\nMode 3 Registration: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Full Sequence ==========================
def test_full_sequence():
"""All 3 modes sequentially on both pipelines, compare final state."""
print("\n" + "=" * 60)
print("TEST: Full Step Sequence (Diffusion → Contrastive → Registration)")
print("=" * 60)
config = make_config()
x0, embd, blind_mask, t, x1, y1, embd_y, t_contra = make_shared_data()
seed_all(42)
orig, stn_o, opt_o, lr_o, lr1_o, ld_o, la_o = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
optm, stn_p, opt_p, lr_p, lr1_p, ld_p, la_p = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
optm.load_state_dict(orig.state_dict(), strict=False)
ok = assert_close("init_weights", params_flat(orig.network), params_flat(optm.network))
# Shared tensors
seed_all(200)
noisy_img, dvf_gt, _ = orig._get_random_ddf(x0, t)
cond_diff, _, cr_diff = orig.proc_cond_img(x0, proc_type='none')
y1_proc, _, cr_reg = orig.proc_cond_img(y1, proc_type='none')
T_regist = sorted([9, 7, 5, 3, 2, 1], reverse=True)
T_batched = [[tv for _ in range(max(1, BATCHSIZE // 2))] for tv in T_regist]
# Losses
orig_lncc = OrigLNCC(); opt_lncc = OptLNCC()
orig_lmse = OrigLMSE(); opt_lmse = OptLMSE()
# ===== Step 1: Diffusion =====
orig.network.train(); optm.network.train()
pdvf_o = orig.network(x=noisy_img * blind_mask, y=cond_diff, t=t, rec_num=2, text=embd)
ld_o2 = lr_o(pdvf_o, img=x0)
tp_o = stn_o(pdvf_o, dvf_gt)
lgd_o = ld_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lga_o = la_o(pred=tp_o, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_o = (LOSS_WEIGHTS_DIFF[0] * lga_o + LOSS_WEIGHTS_DIFF[1] * lgd_o + LOSS_WEIGHTS_DIFF[2] * ld_o2)
lt_o = torch.sqrt(torch.tensor(1. + MSK_EPS) - cr_diff) * lt_o
opt_o.zero_grad(); lt_o.backward(); opt_o.step()
pdvf_p = optm.network(x=noisy_img * blind_mask, y=cond_diff, t=t, rec_num=2, text=embd)
ld_p2 = lr_p(pdvf_p, img=x0)
tp_p = stn_p(pdvf_p, dvf_gt)
lgd_p = ld_p(pred=tp_p, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lga_p = la_p(pred=tp_p, inv_lab=dvf_gt, ddf_stn=None, mask=blind_mask)
lt_p = (LOSS_WEIGHTS_DIFF[0] * lga_p + LOSS_WEIGHTS_DIFF[1] * lgd_p + LOSS_WEIGHTS_DIFF[2] * ld_p2)
lt_p = torch.sqrt(torch.tensor(1. + MSK_EPS) - cr_diff) * lt_p
opt_p.zero_grad(); lt_p.backward(); opt_p.step()
ok &= assert_close("after_diffusion", params_flat(orig.network), params_flat(optm.network))
# ===== Step 2: Contrastive =====
x_in = (x0 * blind_mask).detach()
y_in = cond_diff.detach()
text_in = embd.detach()
_ = orig.network(x=x_in, y=y_in, t=t_contra, text=text_in)
has_embd = hasattr(orig.network, 'img_embd') and orig.network.img_embd is not None
if has_embd:
ie_o = orig.network.img_embd
lc_o = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_o, embd, dim=-1).mean())
opt_o.zero_grad(); lc_o.backward()
torch.nn.utils.clip_grad_norm_(orig.parameters(), max_norm=0.05); opt_o.step()
_ = optm.network(x=x_in, y=y_in, t=t_contra, text=text_in)
ie_p = optm.network.img_embd
lc_p = LOSS_WEIGHT_CONTRASTIVE * (1 - F.cosine_similarity(ie_p, embd, dim=-1).mean())
opt_p.zero_grad(); lc_p.backward()
torch.nn.utils.clip_grad_norm_(optm.parameters(), max_norm=0.05); opt_p.step()
ok &= assert_close("after_contrastive", params_flat(orig.network), params_flat(optm.network))
# ===== Step 3: Registration =====
msk_tgt = torch.tensor(1.0) + MSK_EPS
orig.train()
[ddf_o, _], [rec_o, _, _], _ = orig(
img_org=x1, cond_imgs=y1_proc, T=[None, T_batched], proc_type=[], text=embd_y)
ls_o = orig_lncc(rec_o, y1, label=msk_tgt * (y1 > 0.01))
lms_o = orig_lmse(rec_o, y1, label=msk_tgt * (y1 >= 0.0))
ldr_o = lr1_o(ddf_o, img=y1)
lreg_o = (LOSS_WEIGHTS_REGIST[0] * ls_o + LOSS_WEIGHTS_REGIST[1] * lms_o + LOSS_WEIGHTS_REGIST[2] * ldr_o)
lreg_o = torch.sqrt(cr_reg + MSK_EPS) * lreg_o
opt_o.zero_grad(); lreg_o.backward()
torch.nn.utils.clip_grad_norm_(orig.parameters(), max_norm=0.2); opt_o.step()
optm.train()
[ddf_p, _], [rec_p, _, _], _ = optm(
img_org=x1, cond_imgs=y1_proc, T=[None, T_batched], proc_type=[], text=embd_y)
ls_p = opt_lncc(rec_p, y1, label=msk_tgt * (y1 > 0.01))
lms_p = opt_lmse(rec_p, y1, label=msk_tgt * (y1 >= 0.0))
ldr_p = lr1_p(ddf_p, img=y1)
lreg_p = (LOSS_WEIGHTS_REGIST[0] * ls_p + LOSS_WEIGHTS_REGIST[1] * lms_p + LOSS_WEIGHTS_REGIST[2] * ldr_p)
lreg_p = torch.sqrt(cr_reg + MSK_EPS) * lreg_p
opt_p.zero_grad(); lreg_p.backward()
torch.nn.utils.clip_grad_norm_(optm.parameters(), max_norm=0.2); opt_p.step()
ok &= assert_close("after_registration", params_flat(orig.network), params_flat(optm.network))
print(f"\nFull Sequence: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Test: Checkpoint Compatibility ==========================
def test_checkpoint_compat():
"""Original checkpoint loads into optimized and vice versa."""
print("\n" + "=" * 60)
print("TEST: Checkpoint Cross-Compatibility")
print("=" * 60)
import tempfile
config = make_config()
seed_all(42)
orig, *_ = _build_ddpm(OrigDeformDDPM, config)
seed_all(42)
optm, *_ = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
optm.load_state_dict(orig.state_dict(), strict=False)
ok = True
with tempfile.TemporaryDirectory() as tmpdir:
# Save original
path_o = os.path.join(tmpdir, "orig.pth")
torch.save({'model_state_dict': orig.state_dict(), 'epoch': 0}, path_o)
# Load into optimized
seed_all(42)
opt2, *_ = _build_ddpm(OptDeformDDPM, config, use_opt_net=True)
ckpt = torch.load(path_o, map_location='cpu')
opt2.load_state_dict(ckpt['model_state_dict'], strict=False)
ok &= assert_close("orig→opt", params_flat(opt2.network), params_flat(orig.network))
# Save optimized
path_p = os.path.join(tmpdir, "opt.pth")
torch.save({'model_state_dict': optm.state_dict(), 'epoch': 0}, path_p)
# Load into original
seed_all(42)
orig2, *_ = _build_ddpm(OrigDeformDDPM, config)
ckpt2 = torch.load(path_p, map_location='cpu')
orig2.load_state_dict(ckpt2['model_state_dict'], strict=False)
ok &= assert_close("opt→orig", params_flat(orig2.network), params_flat(optm.network))
print(f"\nCheckpoint Compat: {'ALL PASSED' if ok else 'SOME FAILED'}")
return ok
# ========================== Main ==========================
if __name__ == "__main__":
print("=" * 60)
print("3-Modes Optimized vs Original Equivalence Test Suite")
print(f"IMG_SIZE={IMG_SIZE}, BATCHSIZE={BATCHSIZE}, TIMESTEPS={TIMESTEPS}, NET={NET_NAME}")
print("=" * 60)
results = {}
results["Loss Equivalence"] = test_loss_equivalence()
results["DeformDDPM Methods"] = test_ddpm_equivalence()
results["Mode 1: Diffusion"] = test_mode1_diffusion()
results["Mode 2: Contrastive"] = test_mode2_contrastive()
results["Mode 3: Registration"] = test_mode3_registration()
results["Full Sequence"] = test_full_sequence()
results["Checkpoint Compat"] = test_checkpoint_compat()
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
all_ok = True
for name, passed in results.items():
status = "PASS" if passed else "FAIL"
print(f" [{status}] {name}")
all_ok &= passed
print(f"\nOverall: {'ALL TESTS PASSED' if all_ok else 'SOME TESTS FAILED'}")
sys.exit(0 if all_ok else 1)