| """
|
| Equivalence tests: OM_reg_flexres.py (DeformDDPM) vs OM_reg_flexres_om.py (OMorpher).
|
|
|
| Verifies that OMorpher.predict() + OMorpher.apply_def() produce the *exact same*
|
| DDFs and warped images as DeformDDPM.diff_recover() + apply_ddf(), given
|
| identical network weights and inputs.
|
|
|
| These tests do NOT need real data or a trained checkpoint — they use random
|
| weights and synthetic volumes.
|
|
|
| Run:
|
| source activate ~/rds/rds-airr-p51-TWhPgQVLKbA/Env/pub_env/pytorch-xpu
|
| python tests/test_flexres_equivalence.py
|
| """
|
|
|
| import os
|
| import sys
|
| import traceback
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT)
|
|
|
| from OMorpher import OMorpher
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
|
|
|
|
|
|
|
|
|
| NDIMS = 2
|
| IMG_SIZE = 128
|
| TIMESTEPS = 10
|
| NET_NAME = "recmutattnnet"
|
| DEVICE = "cpu"
|
| V_SCALE = 5e-5
|
| NOISE_SCALE = 0.1
|
|
|
| BASE_CONFIG = {
|
| "net_name": NET_NAME,
|
| "ndims": NDIMS,
|
| "img_size": IMG_SIZE,
|
| "timesteps": TIMESTEPS,
|
| "v_scale": V_SCALE,
|
| "noise_scale": NOISE_SCALE,
|
| "condition_type": "none",
|
| "num_input_chn": 1,
|
| "img_pad_mode": "zeros",
|
| "ddf_pad_mode": "border",
|
| "padding_mode": "border",
|
| "resample_mode": "bilinear",
|
| "batchsize": 1,
|
| "data_name": "test",
|
| "start_noise_step": 0,
|
| }
|
|
|
|
|
| def _build_matching_pair():
|
| """Build OMorpher and DeformDDPM with identical weights.
|
|
|
| Returns (om, ddpm, img_stn, msk_stn).
|
| """
|
| Net = get_net(NET_NAME)
|
| network = Net(
|
| n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE,
|
| )
|
| ddpm = DeformDDPM(
|
| network=network,
|
| n_steps=TIMESTEPS,
|
| image_chw=[1] + [IMG_SIZE] * NDIMS,
|
| device=DEVICE,
|
| batch_size=1,
|
| img_pad_mode="zeros",
|
| ddf_pad_mode="border",
|
| padding_mode="border",
|
| v_scale=V_SCALE,
|
| inf_mode=True,
|
| )
|
| ddpm.eval()
|
|
|
| om = OMorpher(config=BASE_CONFIG, checkpoint_path=None, device=DEVICE)
|
|
|
| om.network.load_state_dict(ddpm.network.state_dict())
|
| om.network.eval()
|
|
|
| return om, ddpm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestApplyDDFEquivalence:
|
| """OMorpher._apply_ddf matches the standalone apply_ddf from OM_reg_flexres.py."""
|
|
|
| @staticmethod
|
| def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"):
|
| """Exact copy of apply_ddf() from OM_reg_flexres.py."""
|
| device = ddf.device
|
| ndims = 3
|
| img_sz = list(volume_tensor.shape[2:])
|
| max_sz = torch.reshape(
|
| torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| [1, ndims] + [1] * ndims)
|
| ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0),
|
| [1, ndims] + img_sz)
|
| img_shape = torch.reshape(
|
| torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device),
|
| [1] + [1] * ndims + [ndims])
|
| grid = torch.flip(
|
| (ddf * max_sz + ref_grid).permute(
|
| [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
|
| dims=[-1])
|
| return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
|
| padding_mode=padding_mode, align_corners=True)
|
|
|
| def test_same_resolution_3d(self):
|
| """apply_def at model resolution matches reference."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| vol = torch.rand(1, 1, 32, 32, 32, device=DEVICE)
|
| ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
|
|
|
| out_om = om._apply_ddf(vol, ddf, padding_mode="border")
|
| out_ref = self._apply_ddf_reference(vol, ddf, padding_mode="border")
|
| assert torch.allclose(out_om, out_ref, atol=1e-6), (
|
| f"Max diff: {(out_om - out_ref).abs().max().item()}"
|
| )
|
|
|
| def test_upscaled_ddf_3d(self):
|
| """apply_def with DDF upscaling matches reference when DDF is manually upscaled."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| vol = torch.rand(1, 1, 64, 64, 64, device=DEVICE)
|
| ddf_small = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
|
|
|
|
|
| out_om = om.apply_def(img=vol, ddf=ddf_small, padding_mode="border")
|
|
|
|
|
| ddf_big = F.interpolate(ddf_small, size=[64, 64, 64],
|
| mode="trilinear", align_corners=False)
|
| out_ref = self._apply_ddf_reference(vol, ddf_big, padding_mode="border")
|
|
|
| assert torch.allclose(out_om, out_ref, atol=1e-6), (
|
| f"Max diff: {(out_om - out_ref).abs().max().item()}"
|
| )
|
|
|
| def test_mask_nearest_3d(self):
|
| """apply_def with nearest-neighbor resampling matches reference."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| mask = (torch.rand(1, 1, 32, 32, 32, device=DEVICE) > 0.5).float()
|
| ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01
|
|
|
| out_om = om._apply_ddf(mask, ddf, padding_mode="zeros", resample_mode="nearest")
|
| out_ref = self._apply_ddf_reference(mask, ddf, padding_mode="zeros",
|
| resample_mode="nearest")
|
| assert torch.allclose(out_om, out_ref, atol=1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestCenterPadEquivalence:
|
| """OMorpher._center_pad_to_cube matches the standalone version."""
|
|
|
| @staticmethod
|
| def _center_pad_reference(volume):
|
| """Exact copy from OM_reg_flexres.py."""
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
|
|
| def test_anisotropic(self):
|
| vol = np.random.rand(30, 40, 50).astype(np.float32)
|
| out_om = OMorpher._center_pad_to_cube(vol)
|
| out_ref = self._center_pad_reference(vol)
|
| assert np.array_equal(out_om, out_ref)
|
|
|
| def test_isotropic(self):
|
| vol = np.random.rand(40, 40, 40).astype(np.float32)
|
| out_om = OMorpher._center_pad_to_cube(vol)
|
| out_ref = self._center_pad_reference(vol)
|
| assert np.array_equal(out_om, out_ref)
|
|
|
| def test_4d(self):
|
| vol = np.random.rand(30, 40, 50, 3).astype(np.float32)
|
| out_om = OMorpher._center_pad_to_cube(vol)
|
| out_ref = self._center_pad_reference(vol)
|
| assert np.array_equal(out_om, out_ref)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestLabelStandardizationEquivalence:
|
| """OMorpher._standardize_label matches the manual label pipeline."""
|
|
|
| @staticmethod
|
| def _center_pad_reference(volume):
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
|
|
| def test_3d_label(self):
|
| """Single-channel label matches manual resize + tensorify."""
|
| from skimage.transform import resize
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| om.set_init_img(torch.rand(1, 1, 32, 32, 32))
|
|
|
| lab = (np.random.rand(30, 40, 50) > 0.5).astype(np.float32)
|
| model_t, fullres_t = om._standardize_label(lab)
|
|
|
|
|
| lab_padded = self._center_pad_reference(lab)
|
| lab_model_ref = resize(lab_padded, [32, 32, 32],
|
| anti_aliasing=False, preserve_range=True, order=0)
|
| lab_model_ref = lab_model_ref[None, :, :, :]
|
| model_ref = torch.tensor(lab_model_ref[None], dtype=torch.float32)
|
| fullres_ref = torch.tensor(lab_padded[None, None, ...], dtype=torch.float32)
|
|
|
| assert torch.allclose(model_t.cpu(), model_ref, atol=1e-6), (
|
| f"Model-res label max diff: {(model_t.cpu() - model_ref).abs().max().item()}"
|
| )
|
| assert torch.allclose(fullres_t.cpu(), fullres_ref, atol=1e-6), (
|
| f"Fullres label max diff: {(fullres_t.cpu() - fullres_ref).abs().max().item()}"
|
| )
|
|
|
| def test_none_placeholder(self):
|
| """None label produces -1 filled tensors matching manual placeholder."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| fullres_shape = [48, 48, 48]
|
| om._init_img_raw = torch.zeros([1, 1] + fullres_shape)
|
|
|
| model_t, fullres_t = om._standardize_label(None)
|
|
|
| assert model_t.shape == (1, 1, 32, 32, 32)
|
| assert fullres_t.shape == (1, 1, 48, 48, 48)
|
| assert torch.all(model_t == -1)
|
| assert torch.all(fullres_t == -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestDiffRecoverEquivalence:
|
| """OMorpher.predict() matches DeformDDPM.diff_recover() for the iterative
|
| reverse-diffusion loop."""
|
|
|
| def test_no_initial_noise(self):
|
| """T=[None, timesteps] — no forward diffusion, full reverse loop.
|
|
|
| This is the exact mode used in OM_reg_flexres.py.
|
| """
|
| om, ddpm = _build_matching_pair()
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
| cond = img.clone().detach()
|
|
|
|
|
| with torch.no_grad():
|
| [ddf_comp_ddpm, ddf_rand_ddpm], \
|
| [img_rec_ddpm, img_diff_ddpm, _], \
|
| [msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover(
|
| img_org=img,
|
| cond_imgs=cond,
|
| msk_org=None,
|
| T=[None, TIMESTEPS],
|
| v_scale=V_SCALE,
|
| t_save=None,
|
| proc_type="none",
|
| )
|
|
|
|
|
| om.set_init_img(img)
|
| om.set_cond_img(cond)
|
| om.predict(T=[None, TIMESTEPS], proc_type="none")
|
| ddf_comp_om = om.get_def()
|
|
|
|
|
|
|
| img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om)
|
|
|
|
|
| assert ddf_comp_om.shape == ddf_comp_ddpm.shape, (
|
| f"DDF shape mismatch: {ddf_comp_om.shape} vs {ddf_comp_ddpm.shape}"
|
| )
|
| assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
|
| f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
|
| )
|
| assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
|
| f"Reconstructed image max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
|
| )
|
|
|
| def test_with_initial_noise(self):
|
| """T=[5, timesteps] — forward diffusion at t=5, then reverse loop.
|
|
|
| Tests the augmentation path where the image is first deformed
|
| randomly before recovery.
|
| """
|
| om, ddpm = _build_matching_pair()
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
| cond = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
|
|
| t_start = 5
|
|
|
|
|
|
|
| torch.manual_seed(77)
|
| np.random.seed(77)
|
| import random as random_mod
|
| random_mod.seed(77)
|
| _, _, ddf_rand = ddpm._get_random_ddf(img, torch.tensor([t_start]))
|
|
|
|
|
| with torch.no_grad():
|
| [ddf_comp_ddpm, _], [img_rec_ddpm, _, _], _ = ddpm.diff_recover(
|
| img_org=img,
|
| cond_imgs=cond,
|
| msk_org=None,
|
| T=[t_start, TIMESTEPS],
|
| ddf_rand=ddf_rand.clone(),
|
| t_save=None,
|
| proc_type="none",
|
| )
|
|
|
|
|
|
|
| om.set_init_img(img)
|
| om._init_ddf = ddf_rand.clone()
|
| om.set_cond_img(cond)
|
|
|
|
|
| om.predict(T=[t_start, TIMESTEPS], proc_type="none")
|
| ddf_comp_om = om.get_def()
|
| img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om)
|
|
|
|
|
| assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
|
| f"DDF max diff with initial noise: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
|
| )
|
| assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
|
| f"Image max diff with initial noise: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
|
| )
|
|
|
| def test_with_conditioning_types(self):
|
| """Test equivalence across different proc_types used in OM_reg_flexres.py."""
|
| om, ddpm = _build_matching_pair()
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
| cond = img.clone().detach()
|
|
|
| for proc_type in ["none", "uncon", "slice"]:
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| import random as random_mod
|
| random_mod.seed(42)
|
|
|
| with torch.no_grad():
|
| [ddf_comp_ddpm, _], _, _ = ddpm.diff_recover(
|
| img_org=img, cond_imgs=cond, msk_org=None,
|
| T=[None, TIMESTEPS], proc_type=proc_type,
|
| )
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
|
|
| om.set_init_img(img)
|
| om.set_cond_img(cond)
|
| om.predict(T=[None, TIMESTEPS], proc_type=proc_type)
|
| ddf_comp_om = om.get_def()
|
|
|
| assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
|
| f"DDF mismatch for proc_type={proc_type}: "
|
| f"max diff = {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestFullResWarpEquivalence:
|
| """OMorpher.apply_def(fullres_img, model_ddf) matches the manual
|
| upscale + apply_ddf from OM_reg_flexres.py."""
|
|
|
| @staticmethod
|
| def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"):
|
| device = ddf.device
|
| ndims = 3
|
| img_sz = list(volume_tensor.shape[2:])
|
| max_sz = torch.reshape(
|
| torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| [1, ndims] + [1] * ndims)
|
| ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0),
|
| [1, ndims] + img_sz)
|
| img_shape = torch.reshape(
|
| torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device),
|
| [1] + [1] * ndims + [ndims])
|
| grid = torch.flip(
|
| (ddf * max_sz + ref_grid).permute(
|
| [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
|
| dims=[-1])
|
| return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode,
|
| padding_mode=padding_mode, align_corners=True)
|
|
|
| def test_fullres_warp(self):
|
| """Simulate the exact OM_reg_flexres.py full-res warping pipeline."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
|
|
| model_sz = 32
|
| full_sz = 64
|
|
|
|
|
| ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02
|
| fullres_img = torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE)
|
|
|
|
|
| ddf_fullres_ref = F.interpolate(
|
| ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False,
|
| )
|
| img_rec_ref = self._apply_ddf_reference(fullres_img, ddf_fullres_ref)
|
|
|
|
|
| img_rec_om = om.apply_def(img=fullres_img, ddf=ddf_model, padding_mode="border")
|
|
|
| assert torch.allclose(img_rec_om, img_rec_ref, atol=1e-6), (
|
| f"Full-res warp max diff: {(img_rec_om - img_rec_ref).abs().max().item()}"
|
| )
|
|
|
| def test_fullres_mask_nearest(self):
|
| """Mask warping with nearest-neighbor at full resolution."""
|
| cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32}
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
|
|
| model_sz = 32
|
| full_sz = 48
|
|
|
| ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02
|
| fullres_mask = (torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE) > 0.5).float()
|
|
|
|
|
| ddf_fullres = F.interpolate(
|
| ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False,
|
| )
|
| msk_ref = self._apply_ddf_reference(
|
| fullres_mask, ddf_fullres, padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
|
|
| msk_om = om.apply_def(
|
| img=fullres_mask, ddf=ddf_model,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
| assert torch.allclose(msk_om, msk_ref, atol=1e-6), (
|
| f"Mask full-res max diff: {(msk_om - msk_ref).abs().max().item()}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestCheckpointLoadEquivalence:
|
| """OMorpher loads from a DeformDDPM-format checkpoint and produces
|
| the same results."""
|
|
|
| def test_round_trip(self):
|
| """Save a DeformDDPM checkpoint, load it in OMorpher, verify outputs match."""
|
| import tempfile
|
|
|
| Net = get_net(NET_NAME)
|
| network = Net(n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE)
|
| ddpm = DeformDDPM(
|
| network=network, n_steps=TIMESTEPS,
|
| image_chw=[1] + [IMG_SIZE] * NDIMS, device=DEVICE,
|
| batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border",
|
| padding_mode="border", v_scale=V_SCALE,
|
| )
|
| ddpm.eval()
|
|
|
|
|
| ckpt_path = os.path.join(tempfile.mkdtemp(), "test_ckpt.pth")
|
| torch.save({
|
| "model_state_dict": ddpm.state_dict(),
|
| "optimizer_state_dict": None,
|
| "epoch": 0,
|
| }, ckpt_path)
|
|
|
|
|
| om = OMorpher(config=BASE_CONFIG, checkpoint_path=ckpt_path, device=DEVICE)
|
|
|
|
|
| for k, v in om.network.state_dict().items():
|
| ddpm_v = ddpm.network.state_dict()[k]
|
| assert torch.equal(v, ddpm_v), f"Weight mismatch at {k}"
|
|
|
|
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
| cond = img.clone()
|
|
|
| with torch.no_grad():
|
| [ddf_ddpm, _], _, _ = ddpm.diff_recover(
|
| img_org=img, cond_imgs=cond, msk_org=None,
|
| T=[None, TIMESTEPS], proc_type="none",
|
| )
|
|
|
| om.set_init_img(img)
|
| om.set_cond_img(cond)
|
| om.predict(T=[None, TIMESTEPS], proc_type="none")
|
| ddf_om = om.get_def()
|
|
|
| assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), (
|
| f"Post-checkpoint DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}"
|
| )
|
|
|
|
|
| os.unlink(ckpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestAugEquivalence:
|
| """OMorpher augmentation sequence matches DeformDDPM.diff_recover()."""
|
|
|
| def test_aug_roundtrip(self):
|
| """Full augmentation iteration: same seed + same weights →
|
| OMorpher produces same img_rec, msk_rec, img_diff, msk_diff
|
| as diff_recover().
|
|
|
| This mirrors the exact OM_aug.py flow:
|
| 1. Self-condition on input image
|
| 2. Forward-diffuse at noise_step → get (img_diff, ddf_rand)
|
| 3. Warp mask with ddf_rand → get msk_diff
|
| 4. Reverse-diffuse from ddf_rand → get ddf_comp
|
| 5. Warp image/mask with ddf_comp → get img_rec, msk_rec
|
| """
|
| import random as random_mod
|
|
|
| om, ddpm = _build_matching_pair()
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
| mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float()
|
| noise_step = 5
|
|
|
|
|
| torch.manual_seed(99)
|
| np.random.seed(99)
|
| random_mod.seed(99)
|
| _, _, ddf_rand = ddpm._get_random_ddf(
|
| img, torch.tensor([noise_step], device=DEVICE),
|
| )
|
|
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
|
|
| with torch.no_grad():
|
| [ddf_comp_ddpm, ddf_rand_ddpm], \
|
| [img_rec_ddpm, img_diff_ddpm, _], \
|
| [msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover(
|
| img_org=img,
|
| cond_imgs=None,
|
| msk_org=mask,
|
| T=[noise_step, TIMESTEPS],
|
| ddf_rand=ddf_rand.clone(),
|
| t_save=None,
|
| proc_type="none",
|
| )
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
|
|
| om.set_init_img(img)
|
| om.set_cond_img(img)
|
|
|
|
|
| om.set_init_def(ddf=ddf_rand.clone().detach())
|
|
|
|
|
| om.predict(
|
| T=[noise_step, TIMESTEPS],
|
| proc_type="none",
|
| )
|
|
|
| ddf_comp_om = om.get_def()
|
| img_rec_om = om.apply_def(img=img, ddf=ddf_comp_om, padding_mode="zeros")
|
| msk_rec_om = om.apply_def(
|
| img=mask, ddf=ddf_comp_om,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
|
|
| img_diff_om = om.img_stn(img.clone().detach(), ddf_rand)
|
| msk_diff_om = om.msk_stn(mask.clone().detach(), ddf_rand)
|
|
|
|
|
| assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), (
|
| f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}"
|
| )
|
|
|
|
|
| assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), (
|
| f"img_rec max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}"
|
| )
|
| assert torch.allclose(msk_rec_om, msk_rec_ddpm, atol=1e-5), (
|
| f"msk_rec max diff: {(msk_rec_om - msk_rec_ddpm).abs().max().item()}"
|
| )
|
|
|
|
|
| assert torch.allclose(img_diff_om, img_diff_ddpm, atol=1e-5), (
|
| f"img_diff max diff: {(img_diff_om - img_diff_ddpm).abs().max().item()}"
|
| )
|
| assert torch.allclose(msk_diff_om, msk_diff_ddpm, atol=1e-5), (
|
| f"msk_diff max diff: {(msk_diff_om - msk_diff_ddpm).abs().max().item()}"
|
| )
|
|
|
| def test_noisy_mask(self):
|
| """om.apply_def(mask, ddf_rand, zeros, nearest) matches msk_stn(mask, ddf_rand)."""
|
| om, ddpm = _build_matching_pair()
|
| mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float()
|
| ddf = torch.randn([1, NDIMS] + [IMG_SIZE] * NDIMS, device=DEVICE) * 0.01
|
|
|
| msk_ddpm = ddpm.msk_stn(mask, ddf)
|
| msk_om = om.apply_def(
|
| img=mask, ddf=ddf,
|
| padding_mode="zeros", resample_mode="nearest",
|
| )
|
|
|
| assert torch.allclose(msk_om, msk_ddpm, atol=1e-6), (
|
| f"Noisy mask max diff: {(msk_om - msk_ddpm).abs().max().item()}"
|
| )
|
|
|
| def test_self_conditioning(self):
|
| """Self-conditioning: set_cond_img(img) matches diff_recover default cond_imgs=None."""
|
| import random as random_mod
|
|
|
| om, ddpm = _build_matching_pair()
|
| img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE)
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
| with torch.no_grad():
|
| [ddf_ddpm, _], _, _ = ddpm.diff_recover(
|
| img_org=img, cond_imgs=None, msk_org=None,
|
| T=[None, TIMESTEPS], proc_type="none",
|
| )
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
| om.set_init_img(img)
|
| om.set_cond_img(img)
|
| om.predict(T=[None, TIMESTEPS], proc_type="none")
|
| ddf_om = om.get_def()
|
|
|
| assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), (
|
| f"Self-cond DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}"
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
| def run_all():
|
| test_classes = [
|
| TestApplyDDFEquivalence,
|
| TestCenterPadEquivalence,
|
| TestLabelStandardizationEquivalence,
|
| TestDiffRecoverEquivalence,
|
| TestFullResWarpEquivalence,
|
| TestCheckpointLoadEquivalence,
|
| TestAugEquivalence,
|
| ]
|
| passed = 0
|
| failed = 0
|
| errors = []
|
|
|
| for cls in test_classes:
|
| inst = cls()
|
| for name in sorted(dir(inst)):
|
| if not name.startswith("test"):
|
| continue
|
| full_name = f"{cls.__name__}.{name}"
|
| try:
|
| getattr(inst, name)()
|
| passed += 1
|
| print(f" PASS {full_name}")
|
| except Exception as e:
|
| failed += 1
|
| errors.append((full_name, e))
|
| print(f" FAIL {full_name}: {e}")
|
| traceback.print_exc()
|
|
|
| print(f"\n{'=' * 60}")
|
| print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
|
| if errors:
|
| print("Failures:")
|
| for name, e in errors:
|
| print(f" - {name}: {e}")
|
| return failed == 0
|
|
|
|
|
| if __name__ == "__main__":
|
| success = run_all()
|
| sys.exit(0 if success else 1)
|
|
|