""" Equivalence tests: OM_reg_flexres.py (DeformDDPM) vs OM_reg_flexres_om.py (OMorpher). Verifies that OMorpher.predict() + OMorpher.apply_def() produce the *exact same* DDFs and warped images as DeformDDPM.diff_recover() + apply_ddf(), given identical network weights and inputs. These tests do NOT need real data or a trained checkpoint — they use random weights and synthetic volumes. Run: source activate ~/rds/rds-airr-p51-TWhPgQVLKbA/Env/pub_env/pytorch-xpu python tests/test_flexres_equivalence.py """ import os import sys import traceback import numpy as np import torch import torch.nn.functional as F ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, ROOT) from OMorpher import OMorpher from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN, DefRec_MutAttnNet # ---------- shared config ---------- # Use 2D + 128 for speed (3D + 128 works but is slow) NDIMS = 2 IMG_SIZE = 128 TIMESTEPS = 10 NET_NAME = "recmutattnnet" DEVICE = "cpu" V_SCALE = 5e-5 NOISE_SCALE = 0.1 BASE_CONFIG = { "net_name": NET_NAME, "ndims": NDIMS, "img_size": IMG_SIZE, "timesteps": TIMESTEPS, "v_scale": V_SCALE, "noise_scale": NOISE_SCALE, "condition_type": "none", "num_input_chn": 1, "img_pad_mode": "zeros", "ddf_pad_mode": "border", "padding_mode": "border", "resample_mode": "bilinear", "batchsize": 1, "data_name": "test", "start_noise_step": 0, } def _build_matching_pair(): """Build OMorpher and DeformDDPM with identical weights. Returns (om, ddpm, img_stn, msk_stn). """ Net = get_net(NET_NAME) network = Net( n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE, ) ddpm = DeformDDPM( network=network, n_steps=TIMESTEPS, image_chw=[1] + [IMG_SIZE] * NDIMS, device=DEVICE, batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border", padding_mode="border", v_scale=V_SCALE, inf_mode=True, # matches OM_reg_flexres.py ) ddpm.eval() om = OMorpher(config=BASE_CONFIG, checkpoint_path=None, device=DEVICE) # Copy weights from DeformDDPM's network to OMorpher's network om.network.load_state_dict(ddpm.network.state_dict()) om.network.eval() return om, ddpm # ================================================================ # Test 1: apply_ddf equivalence # # OMorpher.apply_def(img, ddf) vs standalone apply_ddf() from # OM_reg_flexres.py at multiple resolutions. # ================================================================ class TestApplyDDFEquivalence: """OMorpher._apply_ddf matches the standalone apply_ddf from OM_reg_flexres.py.""" @staticmethod def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"): """Exact copy of apply_ddf() from OM_reg_flexres.py.""" device = ddf.device ndims = 3 img_sz = list(volume_tensor.shape[2:]) max_sz = torch.reshape( torch.tensor(img_sz, dtype=torch.float32, device=device), [1, ndims] + [1] * ndims) ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0), [1, ndims] + img_sz) img_shape = torch.reshape( torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device), [1] + [1] * ndims + [ndims]) grid = torch.flip( (ddf * max_sz + ref_grid).permute( [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1, dims=[-1]) return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def test_same_resolution_3d(self): """apply_def at model resolution matches reference.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) vol = torch.rand(1, 1, 32, 32, 32, device=DEVICE) ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01 out_om = om._apply_ddf(vol, ddf, padding_mode="border") out_ref = self._apply_ddf_reference(vol, ddf, padding_mode="border") assert torch.allclose(out_om, out_ref, atol=1e-6), ( f"Max diff: {(out_om - out_ref).abs().max().item()}" ) def test_upscaled_ddf_3d(self): """apply_def with DDF upscaling matches reference when DDF is manually upscaled.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) vol = torch.rand(1, 1, 64, 64, 64, device=DEVICE) ddf_small = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01 # OMorpher auto-upscales out_om = om.apply_def(img=vol, ddf=ddf_small, padding_mode="border") # Reference: manually upscale then apply ddf_big = F.interpolate(ddf_small, size=[64, 64, 64], mode="trilinear", align_corners=False) out_ref = self._apply_ddf_reference(vol, ddf_big, padding_mode="border") assert torch.allclose(out_om, out_ref, atol=1e-6), ( f"Max diff: {(out_om - out_ref).abs().max().item()}" ) def test_mask_nearest_3d(self): """apply_def with nearest-neighbor resampling matches reference.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) mask = (torch.rand(1, 1, 32, 32, 32, device=DEVICE) > 0.5).float() ddf = torch.randn(1, 3, 32, 32, 32, device=DEVICE) * 0.01 out_om = om._apply_ddf(mask, ddf, padding_mode="zeros", resample_mode="nearest") out_ref = self._apply_ddf_reference(mask, ddf, padding_mode="zeros", resample_mode="nearest") assert torch.allclose(out_om, out_ref, atol=1e-6) # ================================================================ # Test 2: center_pad_to_cube equivalence # ================================================================ class TestCenterPadEquivalence: """OMorpher._center_pad_to_cube matches the standalone version.""" @staticmethod def _center_pad_reference(volume): """Exact copy from OM_reg_flexres.py.""" max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode="constant", constant_values=0) def test_anisotropic(self): vol = np.random.rand(30, 40, 50).astype(np.float32) out_om = OMorpher._center_pad_to_cube(vol) out_ref = self._center_pad_reference(vol) assert np.array_equal(out_om, out_ref) def test_isotropic(self): vol = np.random.rand(40, 40, 40).astype(np.float32) out_om = OMorpher._center_pad_to_cube(vol) out_ref = self._center_pad_reference(vol) assert np.array_equal(out_om, out_ref) def test_4d(self): vol = np.random.rand(30, 40, 50, 3).astype(np.float32) out_om = OMorpher._center_pad_to_cube(vol) out_ref = self._center_pad_reference(vol) assert np.array_equal(out_om, out_ref) # ================================================================ # Test 2b: Label standardization equivalence # # OMorpher._standardize_label matches the manual resize + tensor # creation from OM_reg_flexres.py. # ================================================================ class TestLabelStandardizationEquivalence: """OMorpher._standardize_label matches the manual label pipeline.""" @staticmethod def _center_pad_reference(volume): max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode="constant", constant_values=0) def test_3d_label(self): """Single-channel label matches manual resize + tensorify.""" from skimage.transform import resize cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) om.set_init_img(torch.rand(1, 1, 32, 32, 32)) lab = (np.random.rand(30, 40, 50) > 0.5).astype(np.float32) model_t, fullres_t = om._standardize_label(lab) # Reference: manual pipeline from OM_reg_flexres.py lab_padded = self._center_pad_reference(lab) lab_model_ref = resize(lab_padded, [32, 32, 32], anti_aliasing=False, preserve_range=True, order=0) lab_model_ref = lab_model_ref[None, :, :, :] # [1, D, H, W] model_ref = torch.tensor(lab_model_ref[None], dtype=torch.float32) fullres_ref = torch.tensor(lab_padded[None, None, ...], dtype=torch.float32) assert torch.allclose(model_t.cpu(), model_ref, atol=1e-6), ( f"Model-res label max diff: {(model_t.cpu() - model_ref).abs().max().item()}" ) assert torch.allclose(fullres_t.cpu(), fullres_ref, atol=1e-6), ( f"Fullres label max diff: {(fullres_t.cpu() - fullres_ref).abs().max().item()}" ) def test_none_placeholder(self): """None label produces -1 filled tensors matching manual placeholder.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) fullres_shape = [48, 48, 48] om._init_img_raw = torch.zeros([1, 1] + fullres_shape) model_t, fullres_t = om._standardize_label(None) assert model_t.shape == (1, 1, 32, 32, 32) assert fullres_t.shape == (1, 1, 48, 48, 48) assert torch.all(model_t == -1) assert torch.all(fullres_t == -1) # ================================================================ # Test 3: Full diff_recover loop equivalence # # Most critical test: verifies that OMorpher.predict() produces the # same DDF as DeformDDPM.diff_recover() given identical inputs, # weights, and deterministic seeding. # ================================================================ class TestDiffRecoverEquivalence: """OMorpher.predict() matches DeformDDPM.diff_recover() for the iterative reverse-diffusion loop.""" def test_no_initial_noise(self): """T=[None, timesteps] — no forward diffusion, full reverse loop. This is the exact mode used in OM_reg_flexres.py. """ om, ddpm = _build_matching_pair() img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) cond = img.clone().detach() # self-conditioning (common in inference) # --- DeformDDPM path (original) --- with torch.no_grad(): [ddf_comp_ddpm, ddf_rand_ddpm], \ [img_rec_ddpm, img_diff_ddpm, _], \ [msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover( img_org=img, cond_imgs=cond, msk_org=None, T=[None, TIMESTEPS], v_scale=V_SCALE, t_save=None, proc_type="none", ) # --- OMorpher path (new) --- om.set_init_img(img) om.set_cond_img(cond) om.predict(T=[None, TIMESTEPS], proc_type="none") ddf_comp_om = om.get_def() # Reconstruct image from DDF the same way the original does: # img_rec = img_stn(img_org, ddf_comp) img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om) # --- Compare --- assert ddf_comp_om.shape == ddf_comp_ddpm.shape, ( f"DDF shape mismatch: {ddf_comp_om.shape} vs {ddf_comp_ddpm.shape}" ) assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), ( f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}" ) assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), ( f"Reconstructed image max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}" ) def test_with_initial_noise(self): """T=[5, timesteps] — forward diffusion at t=5, then reverse loop. Tests the augmentation path where the image is first deformed randomly before recovery. """ om, ddpm = _build_matching_pair() img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) cond = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) t_start = 5 # We need the same random DDF for both paths. # Generate it once and pass it in. torch.manual_seed(77) np.random.seed(77) import random as random_mod random_mod.seed(77) _, _, ddf_rand = ddpm._get_random_ddf(img, torch.tensor([t_start])) # --- DeformDDPM path --- with torch.no_grad(): [ddf_comp_ddpm, _], [img_rec_ddpm, _, _], _ = ddpm.diff_recover( img_org=img, cond_imgs=cond, msk_org=None, T=[t_start, TIMESTEPS], ddf_rand=ddf_rand.clone(), t_save=None, proc_type="none", ) # --- OMorpher path --- # Set init image and pre-computed initial DDF om.set_init_img(img) om._init_ddf = ddf_rand.clone() om.set_cond_img(cond) # predict with T that triggers the "init_ddf is not zero" branch om.predict(T=[t_start, TIMESTEPS], proc_type="none") ddf_comp_om = om.get_def() img_rec_om = om.img_stn(img.clone().detach(), ddf_comp_om) # --- Compare --- assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), ( f"DDF max diff with initial noise: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}" ) assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), ( f"Image max diff with initial noise: {(img_rec_om - img_rec_ddpm).abs().max().item()}" ) def test_with_conditioning_types(self): """Test equivalence across different proc_types used in OM_reg_flexres.py.""" om, ddpm = _build_matching_pair() img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) cond = img.clone().detach() for proc_type in ["none", "uncon", "slice"]: # Use same random seed for both paths torch.manual_seed(42) np.random.seed(42) import random as random_mod random_mod.seed(42) with torch.no_grad(): [ddf_comp_ddpm, _], _, _ = ddpm.diff_recover( img_org=img, cond_imgs=cond, msk_org=None, T=[None, TIMESTEPS], proc_type=proc_type, ) torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) om.set_init_img(img) om.set_cond_img(cond) om.predict(T=[None, TIMESTEPS], proc_type=proc_type) ddf_comp_om = om.get_def() assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), ( f"DDF mismatch for proc_type={proc_type}: " f"max diff = {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}" ) # ================================================================ # Test 4: Full-resolution warping equivalence # # Verifies the key operation in OM_reg_flexres.py: # 1. Run diffusion at model_res → get ddf_comp # 2. Upscale DDF to full_res # 3. Apply to full-res image # ================================================================ class TestFullResWarpEquivalence: """OMorpher.apply_def(fullres_img, model_ddf) matches the manual upscale + apply_ddf from OM_reg_flexres.py.""" @staticmethod def _apply_ddf_reference(volume_tensor, ddf, padding_mode="border", resample_mode="bilinear"): device = ddf.device ndims = 3 img_sz = list(volume_tensor.shape[2:]) max_sz = torch.reshape( torch.tensor(img_sz, dtype=torch.float32, device=device), [1, ndims] + [1] * ndims) ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(s, device=device) for s in img_sz], indexing="ij"), 0), [1, ndims] + img_sz) img_shape = torch.reshape( torch.tensor([(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device), [1] + [1] * ndims + [ndims]) grid = torch.flip( (ddf * max_sz + ref_grid).permute( [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1, dims=[-1]) return F.grid_sample(volume_tensor, grid.float(), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def test_fullres_warp(self): """Simulate the exact OM_reg_flexres.py full-res warping pipeline.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) model_sz = 32 full_sz = 64 # Synthetic model-res DDF (as produced by predict) ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02 fullres_img = torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE) # --- OM_reg_flexres.py path --- ddf_fullres_ref = F.interpolate( ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False, ) img_rec_ref = self._apply_ddf_reference(fullres_img, ddf_fullres_ref) # --- OMorpher path (auto-upscales DDF) --- img_rec_om = om.apply_def(img=fullres_img, ddf=ddf_model, padding_mode="border") assert torch.allclose(img_rec_om, img_rec_ref, atol=1e-6), ( f"Full-res warp max diff: {(img_rec_om - img_rec_ref).abs().max().item()}" ) def test_fullres_mask_nearest(self): """Mask warping with nearest-neighbor at full resolution.""" cfg = {**BASE_CONFIG, "ndims": 3, "img_size": 32} om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) model_sz = 32 full_sz = 48 ddf_model = torch.randn(1, 3, model_sz, model_sz, model_sz, device=DEVICE) * 0.02 fullres_mask = (torch.rand(1, 1, full_sz, full_sz, full_sz, device=DEVICE) > 0.5).float() # Reference ddf_fullres = F.interpolate( ddf_model, size=[full_sz] * 3, mode="trilinear", align_corners=False, ) msk_ref = self._apply_ddf_reference( fullres_mask, ddf_fullres, padding_mode="zeros", resample_mode="nearest", ) # OMorpher msk_om = om.apply_def( img=fullres_mask, ddf=ddf_model, padding_mode="zeros", resample_mode="nearest", ) assert torch.allclose(msk_om, msk_ref, atol=1e-6), ( f"Mask full-res max diff: {(msk_om - msk_ref).abs().max().item()}" ) # ================================================================ # Test 5: Checkpoint loading equivalence # # Verifies that OMorpher strips DDP/DeformDDPM prefixes correctly # and produces the same outputs as a DeformDDPM loaded from the # same checkpoint. # ================================================================ class TestCheckpointLoadEquivalence: """OMorpher loads from a DeformDDPM-format checkpoint and produces the same results.""" def test_round_trip(self): """Save a DeformDDPM checkpoint, load it in OMorpher, verify outputs match.""" import tempfile Net = get_net(NET_NAME) network = Net(n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE) ddpm = DeformDDPM( network=network, n_steps=TIMESTEPS, image_chw=[1] + [IMG_SIZE] * NDIMS, device=DEVICE, batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border", padding_mode="border", v_scale=V_SCALE, ) ddpm.eval() # Save checkpoint in standard format (with DeformDDPM wrapper keys) ckpt_path = os.path.join(tempfile.mkdtemp(), "test_ckpt.pth") torch.save({ "model_state_dict": ddpm.state_dict(), "optimizer_state_dict": None, "epoch": 0, }, ckpt_path) # Load in OMorpher om = OMorpher(config=BASE_CONFIG, checkpoint_path=ckpt_path, device=DEVICE) # Verify weights match for k, v in om.network.state_dict().items(): ddpm_v = ddpm.network.state_dict()[k] assert torch.equal(v, ddpm_v), f"Weight mismatch at {k}" # Verify inference output matches img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) cond = img.clone() with torch.no_grad(): [ddf_ddpm, _], _, _ = ddpm.diff_recover( img_org=img, cond_imgs=cond, msk_org=None, T=[None, TIMESTEPS], proc_type="none", ) om.set_init_img(img) om.set_cond_img(cond) om.predict(T=[None, TIMESTEPS], proc_type="none") ddf_om = om.get_def() assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), ( f"Post-checkpoint DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}" ) # Cleanup os.unlink(ckpt_path) # ================================================================ # Test 6: Augmentation equivalence (OM_aug.py) # # Verifies that the OMorpher-based augmentation sequence from # OM_aug_om.py produces the same outputs as the DeformDDPM-based # diff_recover() used in OM_aug.py. # ================================================================ class TestAugEquivalence: """OMorpher augmentation sequence matches DeformDDPM.diff_recover().""" def test_aug_roundtrip(self): """Full augmentation iteration: same seed + same weights → OMorpher produces same img_rec, msk_rec, img_diff, msk_diff as diff_recover(). This mirrors the exact OM_aug.py flow: 1. Self-condition on input image 2. Forward-diffuse at noise_step → get (img_diff, ddf_rand) 3. Warp mask with ddf_rand → get msk_diff 4. Reverse-diffuse from ddf_rand → get ddf_comp 5. Warp image/mask with ddf_comp → get img_rec, msk_rec """ import random as random_mod om, ddpm = _build_matching_pair() img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float() noise_step = 5 # --- Generate the same random DDF for both paths --- torch.manual_seed(99) np.random.seed(99) random_mod.seed(99) _, _, ddf_rand = ddpm._get_random_ddf( img, torch.tensor([noise_step], device=DEVICE), ) # --- DeformDDPM path (OM_aug.py flow) --- # diff_recover with pre-computed ddf_rand and self-conditioning torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) with torch.no_grad(): [ddf_comp_ddpm, ddf_rand_ddpm], \ [img_rec_ddpm, img_diff_ddpm, _], \ [msk_rec_ddpm, msk_diff_ddpm, _] = ddpm.diff_recover( img_org=img, cond_imgs=None, # defaults to img_org.clone().detach() msk_org=mask, T=[noise_step, TIMESTEPS], ddf_rand=ddf_rand.clone(), t_save=None, proc_type="none", ) # --- OMorpher path (OM_aug_om.py flow) --- torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) om.set_init_img(img) om.set_cond_img(img) # self-conditioning # Set random DDF as initial DDF om.set_init_def(ddf=ddf_rand.clone().detach()) # Run reverse diffusion om.predict( T=[noise_step, TIMESTEPS], proc_type="none", ) ddf_comp_om = om.get_def() img_rec_om = om.apply_def(img=img, ddf=ddf_comp_om, padding_mode="zeros") msk_rec_om = om.apply_def( img=mask, ddf=ddf_comp_om, padding_mode="zeros", resample_mode="nearest", ) # Forward-diffused image: img_stn(img, ddf_rand) — same for both paths img_diff_om = om.img_stn(img.clone().detach(), ddf_rand) msk_diff_om = om.msk_stn(mask.clone().detach(), ddf_rand) # --- Compare DDFs --- assert torch.allclose(ddf_comp_om, ddf_comp_ddpm, atol=1e-5), ( f"DDF max diff: {(ddf_comp_om - ddf_comp_ddpm).abs().max().item()}" ) # --- Compare recovered images --- assert torch.allclose(img_rec_om, img_rec_ddpm, atol=1e-5), ( f"img_rec max diff: {(img_rec_om - img_rec_ddpm).abs().max().item()}" ) assert torch.allclose(msk_rec_om, msk_rec_ddpm, atol=1e-5), ( f"msk_rec max diff: {(msk_rec_om - msk_rec_ddpm).abs().max().item()}" ) # --- Compare noisy images --- assert torch.allclose(img_diff_om, img_diff_ddpm, atol=1e-5), ( f"img_diff max diff: {(img_diff_om - img_diff_ddpm).abs().max().item()}" ) assert torch.allclose(msk_diff_om, msk_diff_ddpm, atol=1e-5), ( f"msk_diff max diff: {(msk_diff_om - msk_diff_ddpm).abs().max().item()}" ) def test_noisy_mask(self): """om.apply_def(mask, ddf_rand, zeros, nearest) matches msk_stn(mask, ddf_rand).""" om, ddpm = _build_matching_pair() mask = (torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) > 0.5).float() ddf = torch.randn([1, NDIMS] + [IMG_SIZE] * NDIMS, device=DEVICE) * 0.01 msk_ddpm = ddpm.msk_stn(mask, ddf) msk_om = om.apply_def( img=mask, ddf=ddf, padding_mode="zeros", resample_mode="nearest", ) assert torch.allclose(msk_om, msk_ddpm, atol=1e-6), ( f"Noisy mask max diff: {(msk_om - msk_ddpm).abs().max().item()}" ) def test_self_conditioning(self): """Self-conditioning: set_cond_img(img) matches diff_recover default cond_imgs=None.""" import random as random_mod om, ddpm = _build_matching_pair() img = torch.rand([1, 1] + [IMG_SIZE] * NDIMS, device=DEVICE) # DeformDDPM with cond_imgs=None (self-conditioning) torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) with torch.no_grad(): [ddf_ddpm, _], _, _ = ddpm.diff_recover( img_org=img, cond_imgs=None, msk_org=None, T=[None, TIMESTEPS], proc_type="none", ) # OMorpher with explicit set_cond_img(img) torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) om.set_init_img(img) om.set_cond_img(img) om.predict(T=[None, TIMESTEPS], proc_type="none") ddf_om = om.get_def() assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), ( f"Self-cond DDF max diff: {(ddf_om - ddf_ddpm).abs().max().item()}" ) # ================================================================ # Runner # ================================================================ def run_all(): test_classes = [ TestApplyDDFEquivalence, TestCenterPadEquivalence, TestLabelStandardizationEquivalence, TestDiffRecoverEquivalence, TestFullResWarpEquivalence, TestCheckpointLoadEquivalence, TestAugEquivalence, ] passed = 0 failed = 0 errors = [] for cls in test_classes: inst = cls() for name in sorted(dir(inst)): if not name.startswith("test"): continue full_name = f"{cls.__name__}.{name}" try: getattr(inst, name)() passed += 1 print(f" PASS {full_name}") except Exception as e: failed += 1 errors.append((full_name, e)) print(f" FAIL {full_name}: {e}") traceback.print_exc() print(f"\n{'=' * 60}") print(f"Results: {passed} passed, {failed} failed out of {passed + failed}") if errors: print("Failures:") for name, e in errors: print(f" - {name}: {e}") return failed == 0 if __name__ == "__main__": success = run_all() sys.exit(0 if success else 1)