Omini3D / tests /test_flexres_equivalence.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
Equivalence tests: OM_reg_flexres.py (DeformDDPM) vs OM_reg_flexres_om.py (OMorpher).
Verifies that OMorpher.predict() + OMorpher.apply_def() produce the *exact same*
DDFs and warped images as DeformDDPM.diff_recover() + apply_ddf(), given
identical network weights and inputs.
These tests do NOT need real data or a trained checkpoint — they use random
weights and synthetic volumes.
Run:
source activate ~/rds/rds-airr-p51-TWhPgQVLKbA/Env/pub_env/pytorch-xpu
python tests/test_flexres_equivalence.py
"""
import os
import sys
import traceback
import numpy as np
import torch
import torch.nn.functional as F
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from OMorpher import OMorpher
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
# ---------- shared config ----------
# Use 2D + 128 for speed (3D + 128 works but is slow)
NDIMS = 2
IMG_SIZE = 128
TIMESTEPS = 10
NET_NAME = "recmutattnnet"
DEVICE = "cpu"
V_SCALE = 5e-5
NOISE_SCALE = 0.1
BASE_CONFIG = {
"net_name": NET_NAME,
"ndims": NDIMS,
"img_size": IMG_SIZE,
"timesteps": TIMESTEPS,
"v_scale": V_SCALE,
"noise_scale": NOISE_SCALE,
"condition_type": "none",
"num_input_chn": 1,
"img_pad_mode": "zeros",
"ddf_pad_mode": "border",
"padding_mode": "border",
"resample_mode": "bilinear",
"batchsize": 1,
"data_name": "test",
"start_noise_step": 0,
}
def _build_matching_pair():
"""Build OMorpher and DeformDDPM with identical weights.
Returns (om, ddpm, img_stn, msk_stn).
"""
Net = get_net(NET_NAME)
network = Net(
n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE,
)
ddpm = DeformDDPM(
network=network,
n_steps=TIMESTEPS,
image_chw=[1] + [IMG_SIZE] * NDIMS,
device=DEVICE,
batch_size=1,
img_pad_mode="zeros",
ddf_pad_mode="border",
padding_mode="border",
v_scale=V_SCALE,
inf_mode=True, # matches OM_reg_flexres.py
)
ddpm.eval()
om = OMorpher(config=BASE_CONFIG, checkpoint_path=None, device=DEVICE)
# Copy weights from DeformDDPM's network to OMorpher's network
om.network.load_state_dict(ddpm.network.state_dict())
om.network.eval()
return om, ddpm
# ================================================================
# Test 1: apply_ddf equivalence
#
# OMorpher.apply_def(img, ddf) vs standalone apply_ddf() from
# OM_reg_flexres.py at multiple resolutions.
# ================================================================
class TestApplyDDFEquivalence:
"""OMorpher._apply_ddf matches the standalone apply_ddf from OM_reg_flexres.py."""
@staticmethod
def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"):
"""Exact copy of apply_ddf() from OM_reg_flexres.py."""
device = ddf.device
ndims = 3
img_sz = list(volume_tensor.shape[2:])
max_sz = torch.reshape(
torch.tensor(img_sz, dtype=torch.float32, device=device),
[1, ndims] + [1] * ndims)
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0),
[1, ndims] + img_sz)
img_shape = torch.reshape(
torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device),
[1] + [1] * ndims + [ndims])
grid = torch.flip(
(ddf * max_sz + ref_grid).permute(
[0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
dims=[-1])
return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
padding_mode=padding_mode, align_corners=True)
def test_same_resolution_3d(self):
"""apply_def at model resolution matches reference."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
vol = torch.rand(1, 1, 32, 32, 32, device=DEVICE)
ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
out_om = om._apply_ddf(vol, ddf, padding_mode="border")
out_ref = self._apply_ddf_reference(vol, ddf, padding_mode="border")
assert torch.allclose(out_om, out_ref, atol=1e-6), (
f"Max diff: {(out_om - out_ref).abs().max().item()}"
)
def test_upscaled_ddf_3d(self):
"""apply_def with DDF upscaling matches reference when DDF is manually upscaled."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
vol = torch.rand(1, 1, 64, 64, 64, device=DEVICE)
ddf_small = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
# OMorpher auto-upscales
out_om = om.apply_def(img=vol, ddf=ddf_small, padding_mode="border")
# Reference: manually upscale then apply
ddf_big = F.interpolate(ddf_small, size=[64, 64, 64],
mode="trilinear", align_corners=False)
out_ref = self._apply_ddf_reference(vol, ddf_big, padding_mode="border")
assert torch.allclose(out_om, out_ref, atol=1e-6), (
f"Max diff: {(out_om - out_ref).abs().max().item()}"
)
def test_mask_nearest_3d(self):
"""apply_def with nearest-neighbor resampling matches reference."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
mask = (torch.rand(1, 1, 32, 32, 32, device=DEVICE) > 0.5).float()
ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
out_om = om._apply_ddf(mask, ddf, padding_mode="zeros", resample_mode="nearest")
out_ref = self._apply_ddf_reference(mask, ddf, padding_mode="zeros",
resample_mode="nearest")
assert torch.allclose(out_om, out_ref, atol=1e-6)
# ================================================================
# Test 2: center_pad_to_cube equivalence
# ================================================================
class TestCenterPadEquivalence:
"""OMorpher._center_pad_to_cube matches the standalone version."""
@staticmethod
def _center_pad_reference(volume):
"""Exact copy from OM_reg_flexres.py."""
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode="constant", constant_values=0)
def test_anisotropic(self):
vol = np.random.rand(30, 40, 50).astype(np.float32)
out_om = OMorpher._center_pad_to_cube(vol)
out_ref = self._center_pad_reference(vol)
assert np.array_equal(out_om, out_ref)
def test_isotropic(self):
vol = np.random.rand(40, 40, 40).astype(np.float32)
out_om = OMorpher._center_pad_to_cube(vol)
out_ref = self._center_pad_reference(vol)
assert np.array_equal(out_om, out_ref)
def test_4d(self):
vol = np.random.rand(30, 40, 50, 3).astype(np.float32)
out_om = OMorpher._center_pad_to_cube(vol)
out_ref = self._center_pad_reference(vol)
assert np.array_equal(out_om, out_ref)
# ================================================================
# Test 2b: Label standardization equivalence
#
# OMorpher._standardize_label matches the manual resize + tensor
# creation from OM_reg_flexres.py.
# ================================================================
class TestLabelStandardizationEquivalence:
"""OMorpher._standardize_label matches the manual label pipeline."""
@staticmethod
def _center_pad_reference(volume):
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode="constant", constant_values=0)
def test_3d_label(self):
"""Single-channel label matches manual resize + tensorify."""
from skimage.transform import resize
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
om.set_init_img(torch.rand(1, 1, 32, 32, 32))
lab = (np.random.rand(30, 40, 50) > 0.5).astype(np.float32)
model_t, fullres_t = om._standardize_label(lab)
# Reference: manual pipeline from OM_reg_flexres.py
lab_padded = self._center_pad_reference(lab)
lab_model_ref = resize(lab_padded, [32, 32, 32],
anti_aliasing=False, preserve_range=True, order=0)
lab_model_ref = lab_model_ref[None, :, :, :] # [1, D, H, W]
model_ref = torch.tensor(lab_model_ref[None], dtype=torch.float32)
fullres_ref = torch.tensor(lab_padded[None, None, ...], dtype=torch.float32)
assert torch.allclose(model_t.cpu(), model_ref, atol=1e-6), (
f"Model-res label max diff: {(model_t.cpu() - model_ref).abs().max().item()}"
)
assert torch.allclose(fullres_t.cpu(), fullres_ref, atol=1e-6), (
f"Fullres label max diff: {(fullres_t.cpu() - fullres_ref).abs().max().item()}"
)
def test_none_placeholder(self):
"""None label produces -1 filled tensors matching manual placeholder."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
fullres_shape = [48, 48, 48]
om._init_img_raw = torch.zeros([1, 1] + fullres_shape)
model_t, fullres_t = om._standardize_label(None)
assert model_t.shape == (1, 1, 32, 32, 32)
assert fullres_t.shape == (1, 1, 48, 48, 48)
assert torch.all(model_t == -1)
assert torch.all(fullres_t == -1)
# ================================================================
# Test 3: Full diff_recover loop equivalence
#
# Most critical test: verifies that OMorpher.predict() produces the
# same DDF as DeformDDPM.diff_recover() given identical inputs,
# weights, and deterministic seeding.
# ================================================================
class TestDiffRecoverEquivalence:
"""OMorpher.predict() matches DeformDDPM.diff_recover() for the iterative
reverse-diffusion loop."""
def test_no_initial_noise(self):
"""T=[None, timesteps] — no forward diffusion, full reverse loop.
This is the exact mode used in OM_reg_flexres.py.
"""
om, ddpm = _build_matching_pair()
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
cond = img.clone().detach() # self-conditioning (common in inference)
# --- DeformDDPM path (original) ---
with torch.no_grad():
[ddf_comp_ddpm, ddf_rand_ddpm], \
[img_rec_ddpm, img_diff_ddpm, _], \
[msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover(
img_org=img,
cond_imgs=cond,
msk_org=None,
T=[None, TIMESTEPS],
v_scale=V_SCALE,
t_save=None,
proc_type="none",
)
# --- OMorpher path (new) ---
om.set_init_img(img)
om.set_cond_img(cond)
om.predict(T=[None, TIMESTEPS], proc_type="none")
ddf_comp_om = om.get_def()
# Reconstruct image from DDF the same way the original does:
# img_rec = img_stn(img_org, ddf_comp)
img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om)
# --- Compare ---
assert ddf_comp_om.shape == ddf_comp_ddpm.shape, (
f"DDF shape mismatch: {ddf_comp_om.shape} vs {ddf_comp_ddpm.shape}"
)
assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
)
assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
f"Reconstructed image max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
)
def test_with_initial_noise(self):
"""T=[5, timesteps] — forward diffusion at t=5, then reverse loop.
Tests the augmentation path where the image is first deformed
randomly before recovery.
"""
om, ddpm = _build_matching_pair()
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
cond = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
t_start = 5
# We need the same random DDF for both paths.
# Generate it once and pass it in.
torch.manual_seed(77)
np.random.seed(77)
import random as random_mod
random_mod.seed(77)
_, _, ddf_rand = ddpm._get_random_ddf(img, torch.tensor([t_start]))
# --- DeformDDPM path ---
with torch.no_grad():
[ddf_comp_ddpm, _], [img_rec_ddpm, _, _], _ = ddpm.diff_recover(
img_org=img,
cond_imgs=cond,
msk_org=None,
T=[t_start, TIMESTEPS],
ddf_rand=ddf_rand.clone(),
t_save=None,
proc_type="none",
)
# --- OMorpher path ---
# Set init image and pre-computed initial DDF
om.set_init_img(img)
om._init_ddf = ddf_rand.clone()
om.set_cond_img(cond)
# predict with T that triggers the "init_ddf is not zero" branch
om.predict(T=[t_start, TIMESTEPS], proc_type="none")
ddf_comp_om = om.get_def()
img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om)
# --- Compare ---
assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
f"DDF max diff with initial noise: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
)
assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
f"Image max diff with initial noise: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
)
def test_with_conditioning_types(self):
"""Test equivalence across different proc_types used in OM_reg_flexres.py."""
om, ddpm = _build_matching_pair()
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
cond = img.clone().detach()
for proc_type in ["none", "uncon", "slice"]:
# Use same random seed for both paths
torch.manual_seed(42)
np.random.seed(42)
import random as random_mod
random_mod.seed(42)
with torch.no_grad():
[ddf_comp_ddpm, _], _, _ = ddpm.diff_recover(
img_org=img, cond_imgs=cond, msk_org=None,
T=[None, TIMESTEPS], proc_type=proc_type,
)
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
om.set_init_img(img)
om.set_cond_img(cond)
om.predict(T=[None, TIMESTEPS], proc_type=proc_type)
ddf_comp_om = om.get_def()
assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
f"DDF mismatch for proc_type={proc_type}: "
f"max diff = {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
)
# ================================================================
# Test 4: Full-resolution warping equivalence
#
# Verifies the key operation in OM_reg_flexres.py:
# 1. Run diffusion at model_res → get ddf_comp
# 2. Upscale DDF to full_res
# 3. Apply to full-res image
# ================================================================
class TestFullResWarpEquivalence:
"""OMorpher.apply_def(fullres_img, model_ddf) matches the manual
upscale + apply_ddf from OM_reg_flexres.py."""
@staticmethod
def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"):
device = ddf.device
ndims = 3
img_sz = list(volume_tensor.shape[2:])
max_sz = torch.reshape(
torch.tensor(img_sz, dtype=torch.float32, device=device),
[1, ndims] + [1] * ndims)
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0),
[1, ndims] + img_sz)
img_shape = torch.reshape(
torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device),
[1] + [1] * ndims + [ndims])
grid = torch.flip(
(ddf * max_sz + ref_grid).permute(
[0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
dims=[-1])
return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
padding_mode=padding_mode, align_corners=True)
def test_fullres_warp(self):
"""Simulate the exact OM_reg_flexres.py full-res warping pipeline."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
model_sz = 32
full_sz = 64
# Synthetic model-res DDF (as produced by predict)
ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02
fullres_img = torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE)
# --- OM_reg_flexres.py path ---
ddf_fullres_ref = F.interpolate(
ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False,
)
img_rec_ref = self._apply_ddf_reference(fullres_img, ddf_fullres_ref)
# --- OMorpher path (auto-upscales DDF) ---
img_rec_om = om.apply_def(img=fullres_img, ddf=ddf_model, padding_mode="border")
assert torch.allclose(img_rec_om, img_rec_ref, atol=1e-6), (
f"Full-res warp max diff: {(img_rec_om - img_rec_ref).abs().max().item()}"
)
def test_fullres_mask_nearest(self):
"""Mask warping with nearest-neighbor at full resolution."""
cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
model_sz = 32
full_sz = 48
ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02
fullres_mask = (torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE) > 0.5).float()
# Reference
ddf_fullres = F.interpolate(
ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False,
)
msk_ref = self._apply_ddf_reference(
fullres_mask, ddf_fullres, padding_mode="zeros", resample_mode="nearest",
)
# OMorpher
msk_om = om.apply_def(
img=fullres_mask, ddf=ddf_model,
padding_mode="zeros", resample_mode="nearest",
)
assert torch.allclose(msk_om, msk_ref, atol=1e-6), (
f"Mask full-res max diff: {(msk_om - msk_ref).abs().max().item()}"
)
# ================================================================
# Test 5: Checkpoint loading equivalence
#
# Verifies that OMorpher strips DDP/DeformDDPM prefixes correctly
# and produces the same outputs as a DeformDDPM loaded from the
# same checkpoint.
# ================================================================
class TestCheckpointLoadEquivalence:
"""OMorpher loads from a DeformDDPM-format checkpoint and produces
the same results."""
def test_round_trip(self):
"""Save a DeformDDPM checkpoint, load it in OMorpher, verify outputs match."""
import tempfile
Net = get_net(NET_NAME)
network = Net(n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE)
ddpm = DeformDDPM(
network=network, n_steps=TIMESTEPS,
image_chw=[1] + [IMG_SIZE] * NDIMS, device=DEVICE,
batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border",
padding_mode="border", v_scale=V_SCALE,
)
ddpm.eval()
# Save checkpoint in standard format (with DeformDDPM wrapper keys)
ckpt_path = os.path.join(tempfile.mkdtemp(), "test_ckpt.pth")
torch.save({
"model_state_dict": ddpm.state_dict(),
"optimizer_state_dict": None,
"epoch": 0,
}, ckpt_path)
# Load in OMorpher
om = OMorpher(config=BASE_CONFIG, checkpoint_path=ckpt_path, device=DEVICE)
# Verify weights match
for k, v in om.network.state_dict().items():
ddpm_v = ddpm.network.state_dict()[k]
assert torch.equal(v, ddpm_v), f"Weight mismatch at {k}"
# Verify inference output matches
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
cond = img.clone()
with torch.no_grad():
[ddf_ddpm, _], _, _ = ddpm.diff_recover(
img_org=img, cond_imgs=cond, msk_org=None,
T=[None, TIMESTEPS], proc_type="none",
)
om.set_init_img(img)
om.set_cond_img(cond)
om.predict(T=[None, TIMESTEPS], proc_type="none")
ddf_om = om.get_def()
assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), (
f"Post-checkpoint DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}"
)
# Cleanup
os.unlink(ckpt_path)
# ================================================================
# Test 6: Augmentation equivalence (OM_aug.py)
#
# Verifies that the OMorpher-based augmentation sequence from
# OM_aug_om.py produces the same outputs as the DeformDDPM-based
# diff_recover() used in OM_aug.py.
# ================================================================
class TestAugEquivalence:
"""OMorpher augmentation sequence matches DeformDDPM.diff_recover()."""
def test_aug_roundtrip(self):
"""Full augmentation iteration: same seed + same weights →
OMorpher produces same img_rec, msk_rec, img_diff, msk_diff
as diff_recover().
This mirrors the exact OM_aug.py flow:
1. Self-condition on input image
2. Forward-diffuse at noise_step → get (img_diff, ddf_rand)
3. Warp mask with ddf_rand → get msk_diff
4. Reverse-diffuse from ddf_rand → get ddf_comp
5. Warp image/mask with ddf_comp → get img_rec, msk_rec
"""
import random as random_mod
om, ddpm = _build_matching_pair()
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float()
noise_step = 5
# --- Generate the same random DDF for both paths ---
torch.manual_seed(99)
np.random.seed(99)
random_mod.seed(99)
_, _, ddf_rand = ddpm._get_random_ddf(
img, torch.tensor([noise_step], device=DEVICE),
)
# --- DeformDDPM path (OM_aug.py flow) ---
# diff_recover with pre-computed ddf_rand and self-conditioning
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
with torch.no_grad():
[ddf_comp_ddpm, ddf_rand_ddpm], \
[img_rec_ddpm, img_diff_ddpm, _], \
[msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover(
img_org=img,
cond_imgs=None, # defaults to img_org.clone().detach()
msk_org=mask,
T=[noise_step, TIMESTEPS],
ddf_rand=ddf_rand.clone(),
t_save=None,
proc_type="none",
)
# --- OMorpher path (OM_aug_om.py flow) ---
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
om.set_init_img(img)
om.set_cond_img(img) # self-conditioning
# Set random DDF as initial DDF
om.set_init_def(ddf=ddf_rand.clone().detach())
# Run reverse diffusion
om.predict(
T=[noise_step, TIMESTEPS],
proc_type="none",
)
ddf_comp_om = om.get_def()
img_rec_om = om.apply_def(img=img, ddf=ddf_comp_om, padding_mode="zeros")
msk_rec_om = om.apply_def(
img=mask, ddf=ddf_comp_om,
padding_mode="zeros", resample_mode="nearest",
)
# Forward-diffused image: img_stn(img, ddf_rand) — same for both paths
img_diff_om = om.img_stn(img.clone().detach(), ddf_rand)
msk_diff_om = om.msk_stn(mask.clone().detach(), ddf_rand)
# --- Compare DDFs ---
assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
)
# --- Compare recovered images ---
assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
f"img_rec max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
)
assert torch.allclose(msk_rec_om, msk_rec_ddpm, atol=1e-5), (
f"msk_rec max diff: {(msk_rec_om - msk_rec_ddpm).abs().max().item()}"
)
# --- Compare noisy images ---
assert torch.allclose(img_diff_om, img_diff_ddpm, atol=1e-5), (
f"img_diff max diff: {(img_diff_om - img_diff_ddpm).abs().max().item()}"
)
assert torch.allclose(msk_diff_om, msk_diff_ddpm, atol=1e-5), (
f"msk_diff max diff: {(msk_diff_om - msk_diff_ddpm).abs().max().item()}"
)
def test_noisy_mask(self):
"""om.apply_def(mask, ddf_rand, zeros, nearest) matches msk_stn(mask, ddf_rand)."""
om, ddpm = _build_matching_pair()
mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float()
ddf = torch.randn([1, NDIMS] + [IMG_SIZE] * NDIMS, device=DEVICE) * 0.01
msk_ddpm = ddpm.msk_stn(mask, ddf)
msk_om = om.apply_def(
img=mask, ddf=ddf,
padding_mode="zeros", resample_mode="nearest",
)
assert torch.allclose(msk_om, msk_ddpm, atol=1e-6), (
f"Noisy mask max diff: {(msk_om - msk_ddpm).abs().max().item()}"
)
def test_self_conditioning(self):
"""Self-conditioning: set_cond_img(img) matches diff_recover default cond_imgs=None."""
import random as random_mod
om, ddpm = _build_matching_pair()
img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
# DeformDDPM with cond_imgs=None (self-conditioning)
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
with torch.no_grad():
[ddf_ddpm, _], _, _ = ddpm.diff_recover(
img_org=img, cond_imgs=None, msk_org=None,
T=[None, TIMESTEPS], proc_type="none",
)
# OMorpher with explicit set_cond_img(img)
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
om.set_init_img(img)
om.set_cond_img(img)
om.predict(T=[None, TIMESTEPS], proc_type="none")
ddf_om = om.get_def()
assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), (
f"Self-cond DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}"
)
# ================================================================
# Runner
# ================================================================
def run_all():
test_classes = [
TestApplyDDFEquivalence,
TestCenterPadEquivalence,
TestLabelStandardizationEquivalence,
TestDiffRecoverEquivalence,
TestFullResWarpEquivalence,
TestCheckpointLoadEquivalence,
TestAugEquivalence,
]
passed = 0
failed = 0
errors = []
for cls in test_classes:
inst = cls()
for name in sorted(dir(inst)):
if not name.startswith("test"):
continue
full_name = f"{cls.__name__}.{name}"
try:
getattr(inst, name)()
passed += 1
print(f" PASS {full_name}")
except Exception as e:
failed += 1
errors.append((full_name, e))
print(f" FAIL {full_name}: {e}")
traceback.print_exc()
print(f"\n{'=' * 60}")
print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
if errors:
print("Failures:")
for name, e in errors:
print(f" - {name}: {e}")
return failed == 0
if __name__ == "__main__":
success = run_all()
sys.exit(0 if success else 1)