""" Tests for the OMorpher module. Split into two groups: - Basic tests: verify shapes, value ranges, and API behaviour (no checkpoint needed) - Alignment tests: cross-validate against DeformDDPM / OM_reg_flexres (shared weights) Run: python -m pytest tests/test_omorpher.py -v # or directly: python tests/test_omorpher.py """ import os import sys import math import numpy as np import torch import torch.nn.functional as F # Ensure project root is importable ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, ROOT) from OMorpher import OMorpher from Diffusion.diffuser import DeformDDPM from Diffusion.networks import get_net, STN # ---------- helpers ---------- NDIMS = 3 IMG_SIZE = 32 # tiny for speed TIMESTEPS = 10 NET_NAME = "recmutattnnet" DEVICE = "cpu" BASE_CONFIG = { "net_name": NET_NAME, "ndims": NDIMS, "img_size": IMG_SIZE, "timesteps": TIMESTEPS, "v_scale": 5e-5, "noise_scale": 0.1, "condition_type": "none", "num_input_chn": 1, "img_pad_mode": "zeros", "ddf_pad_mode": "border", "padding_mode": "border", "resample_mode": "bilinear", "batchsize": 1, "data_name": "test", "start_noise_step": 0, } def _make_omorpher(**overrides): cfg = {**BASE_CONFIG, **overrides} return OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) def _rand_vol(B=1, S=None): S = S or IMG_SIZE return torch.rand([B, 1] + [S] * NDIMS) # ================================================================ # 1. Basic tests # ================================================================ class TestInstantiation: """Test 1: OMorpher with config dict + no checkpoint.""" def test_creates_network_and_stns(self): om = _make_omorpher() assert om.network is not None assert om.stn_full is not None assert om.stn_ctl is not None assert om.img_stn is not None assert om.msk_stn is not None def test_device(self): om = _make_omorpher() assert om.device == torch.device(DEVICE) def test_repr(self): om = _make_omorpher() r = repr(om) assert "OMorpher" in r assert NET_NAME in r class TestStandardization: """Test 2: _standardize_img produces correct shape and range.""" def test_numpy_input(self): om = _make_omorpher() vol = np.random.rand(40, 50, 60).astype(np.float32) * 1000.0 tensor, fullres, orig_shape = om._standardize_img(vol, keep_raw=True) assert tensor.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE) assert tensor.min() >= 0.0 assert tensor.max() <= 1.0 + 1e-6 assert fullres is not None assert isinstance(fullres, torch.Tensor) # fullres should be [1, 1, 60, 60, 60] (cube-padded from max dim) assert fullres.ndim == 5 assert fullres.shape[2] == fullres.shape[3] == fullres.shape[4] == 60 # orig_shape should be the cube-padded size assert orig_shape[0] == orig_shape[1] == orig_shape[2] def test_torch_passthrough(self): om = _make_omorpher() vol = torch.rand(1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE) tensor, raw, _ = om._standardize_img(vol) assert tensor.shape == vol.shape class TestLabelStandardization: """Test: _standardize_label produces correct shapes and handles None.""" def test_3d_label(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) label = (np.random.rand(40, 50, 60) > 0.5).astype(np.float32) model_t, fullres_t = om._standardize_label(label) assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE) assert fullres_t.ndim == 5 # cube-padded: max(40,50,60)=60 assert fullres_t.shape[2] == fullres_t.shape[3] == fullres_t.shape[4] == 60 assert isinstance(model_t, torch.Tensor) assert isinstance(fullres_t, torch.Tensor) def test_none_placeholder(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) model_t, fullres_t = om._standardize_label(None) assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE) assert torch.all(model_t == -1) assert torch.all(fullres_t == -1) def test_4d_label(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) label = (np.random.rand(30, 30, 30, 2) > 0.5).astype(np.float32) model_t, fullres_t = om._standardize_label(label) # 4D → channels-first: 2 channels assert model_t.shape == (1, 2, IMG_SIZE, IMG_SIZE, IMG_SIZE) assert fullres_t.shape[1] == 2 class TestZeroDDFRoundtrip: """Test 3: apply_def with zero DDF returns approx original.""" def test_identity(self): om = _make_omorpher() vol = _rand_vol() zero_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE) warped = om._apply_ddf(vol, zero_ddf, padding_mode="border") assert torch.allclose(vol, warped, atol=1e-5) class TestSizeMismatch: """Test 4: apply_def auto-upscales DDF.""" def test_upscale(self): om = _make_omorpher() big_vol = _rand_vol(S=64) small_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE) result = om.apply_def(img=big_vol, ddf=small_ddf) assert list(result.shape[2:]) == [64, 64, 64] class TestPredict: """Test 6: predict with random weights produces correct DDF shape. Uses IMG_SIZE=64 because RecMutAttnNet has 5 hierarchy levels: 32→16→8→4→2→1 bottleneck breaks InstanceNorm (no running stats). """ def test_predict_shape(self): sz = 64 om = _make_omorpher(img_size=sz) img = torch.rand([1, 1] + [sz] * NDIMS) om.set_init_img(img.numpy()[0, 0]) om.predict(T=[0, 2]) ddf = om.get_def() assert ddf.shape == (1, NDIMS, sz, sz, sz) def test_predict_intermediate(self): sz = 64 om = _make_omorpher(img_size=sz) img = torch.rand([1, 1] + [sz] * NDIMS) om.set_init_img(img.numpy()[0, 0]) om.predict(T=[0, 4], t_save=[3, 1]) intermediates = om.get_def(t_list=[3, 1]) assert isinstance(intermediates, dict) def test_chaining(self): sz = 64 om = _make_omorpher(img_size=sz) img = torch.rand([1, 1] + [sz] * NDIMS) result = om.set_init_img(img.numpy()[0, 0]).predict(T=[0, 2]) assert result is om class TestFinetune: """Test 7: finetune_step with dummy data. Uses IMG_SIZE=64 because at 32 the bottleneck hits 1x1x1 and InstanceNorm fails in training mode. """ def test_finetune_roundtrip(self): ft_size = 64 om = _make_omorpher(img_size=ft_size, batchsize=1) om.finetune_setup(lr=1e-3) vol = _rand_vol(S=ft_size) losses = om.finetune_step(vol) assert "loss_total" in losses assert "loss_grad" in losses assert isinstance(losses["loss_total"], float) om.finetune_teardown() class TestSetters: """Test input setters.""" def test_set_init_def_random(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) om.set_init_def(None) # should generate random assert om._init_ddf is not None assert not torch.all(om._init_ddf == 0) def test_set_init_def_provided(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) custom_ddf = np.zeros([1, NDIMS] + [IMG_SIZE] * NDIMS) om.set_init_def(custom_ddf) assert torch.all(om._init_ddf == 0) def test_set_cond_img_default(self): om = _make_omorpher() om.set_init_img(_rand_vol().numpy()[0, 0]) om.set_cond_img(None) assert om._cond_img is not None assert om._cond_img.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE) def test_set_cond_txt_numpy(self): om = _make_omorpher() emb = np.random.randn(1024).astype(np.float32) om.set_cond_txt(emb) assert om._cond_txt is not None assert om._cond_txt.shape == (1, 1024) def test_set_init_img_with_ddf(self): om = _make_omorpher() vol = np.random.rand(40, 40, 40).astype(np.float32) ddf = np.zeros([1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE], dtype=np.float32) om.set_init_img((vol, ddf)) assert om._init_img is not None assert om._init_ddf is not None # ================================================================ # 2. Cross-validation / alignment tests # ================================================================ def _build_shared_weights(): """Build matching OMorpher + DeformDDPM with identical random weights.""" cfg = {**BASE_CONFIG, "inf_mode": False} # match DeformDDPM default Net = get_net(NET_NAME) network = Net( n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE, ) ddpm = DeformDDPM( network=network, n_steps=TIMESTEPS, image_chw=[1] + [IMG_SIZE] * NDIMS, device=DEVICE, batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border", padding_mode="border", v_scale=cfg["v_scale"], ) ddpm.eval() om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) # Copy weights from the DeformDDPM's network to OMorpher's network om.network.load_state_dict(ddpm.network.state_dict()) om.network.eval() return om, ddpm class TestDDFScaleAlignment: """Test 10: _get_ddf_scale matches DeformDDPM._get_ddf_scale.""" def test_all_timesteps(self): om, ddpm = _build_shared_weights() for t_val in [1, 5, 10, 20, 40, 60, 80]: t = torch.tensor([t_val]) r1, m1, v1 = om._get_ddf_scale(t) r2, m2, v2 = ddpm._get_ddf_scale(t) assert r1 == r2, f"rec_num mismatch at t={t_val}" assert torch.equal(m1, m2), f"mul_num_ddf mismatch at t={t_val}" assert torch.equal(v1, v2), f"mul_num_dvf mismatch at t={t_val}" class TestRandomDDFAlignment: """Test 9: _get_random_ddf matches DeformDDPM._get_random_ddf with same seed. Uses ndims=2, IMG_SIZE=128 so that ctl_sz=32, scale_num=5, len(ctl_szs_all)=5 > select_num=4 — avoiding a known unbound-variable bug in the original DeformDDPM._random_ddf_generate at smaller sizes. """ def test_same_seed(self): align_size = 128 align_ndims = 2 cfg = {**BASE_CONFIG, "img_size": align_size, "ndims": align_ndims} Net = get_net(NET_NAME) network = Net(n_steps=TIMESTEPS, ndims=align_ndims, num_input_chn=1, res=align_size) ddpm = DeformDDPM( network=network, n_steps=TIMESTEPS, image_chw=[1] + [align_size] * align_ndims, device=DEVICE, batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border", padding_mode="border", v_scale=cfg["v_scale"], ) ddpm.eval() om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE) om.network.load_state_dict(ddpm.network.state_dict()) om.network.eval() img = torch.rand([1, 1] + [align_size] * align_ndims).to(DEVICE) t = torch.tensor([5]) # OMorpher torch.manual_seed(42) np.random.seed(42) random_mod = __import__("random") random_mod.seed(42) warped_om, dvf_om, ddf_om = om._get_random_ddf(img, t) # DeformDDPM torch.manual_seed(42) np.random.seed(42) random_mod.seed(42) warped_ddpm, dvf_ddpm, ddf_ddpm = ddpm._get_random_ddf(img, t) assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), "DDFs do not match" assert torch.allclose(dvf_om, dvf_ddpm, atol=1e-5), "DVFs do not match" assert torch.allclose(warped_om, warped_ddpm, atol=1e-5), "Warped images do not match" class TestConditioningAlignment: """Test 11: _proc_cond_img matches DeformDDPM.proc_cond_img.""" def _test_proc_type(self, proc_type): om, ddpm = _build_shared_weights() img = _rand_vol().to(DEVICE) torch.manual_seed(99) np.random.seed(99) random_mod = __import__("random") random_mod.seed(99) out_om, mask_om, cr_om = om._proc_cond_img(img, proc_type=proc_type) torch.manual_seed(99) np.random.seed(99) random_mod.seed(99) out_ddpm, mask_ddpm, cr_ddpm = ddpm.proc_cond_img(img, proc_type=proc_type) assert torch.allclose(out_om, out_ddpm, atol=1e-5), f"Proc image mismatch for {proc_type}" def test_uncon(self): self._test_proc_type("uncon") def test_none(self): self._test_proc_type("none") def test_adding(self): self._test_proc_type("adding") def test_independ(self): self._test_proc_type("independ") def test_slice(self): self._test_proc_type("slice") def test_downsample(self): self._test_proc_type("downsample") class TestApplyDDFAlignment: """Test 8: _apply_ddf matches OM_reg_flexres.apply_ddf.""" def test_vs_flexres(self): om = _make_omorpher() vol = _rand_vol().to(DEVICE) ddf = torch.randn(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE, device=DEVICE) * 0.01 # OMorpher version out_om = om._apply_ddf(vol, ddf, padding_mode="border") # Inline reimplementation of OM_reg_flexres.apply_ddf for comparison ndims = 3 img_sz = list(vol.shape[2:]) max_sz = torch.reshape( torch.tensor(img_sz, dtype=torch.float32, device=DEVICE), [1, ndims] + [1] * ndims) ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(s, device=DEVICE) for s in img_sz], indexing='ij'), 0), [1, ndims] + img_sz) img_shape = torch.reshape( torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=DEVICE), [1] + [1] * ndims + [ndims]) grid = torch.flip( (ddf * max_sz + ref_grid).permute( [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1, dims=[-1]) out_ref = F.grid_sample(vol, grid.float(), mode="bilinear", padding_mode="border", align_corners=True) assert torch.allclose(out_om, out_ref, atol=1e-6), "apply_ddf output mismatch" # ================================================================ # Runner # ================================================================ def run_all(): import traceback test_classes = [ TestInstantiation, TestStandardization, TestLabelStandardization, TestZeroDDFRoundtrip, TestSizeMismatch, TestPredict, TestFinetune, TestSetters, TestDDFScaleAlignment, TestRandomDDFAlignment, TestConditioningAlignment, TestApplyDDFAlignment, ] passed = 0 failed = 0 errors = [] for cls in test_classes: inst = cls() for name in sorted(dir(inst)): if not name.startswith("test"): continue full_name = f"{cls.__name__}.{name}" try: getattr(inst, name)() passed += 1 print(f" PASS {full_name}") except Exception as e: failed += 1 errors.append((full_name, e)) print(f" FAIL {full_name}: {e}") traceback.print_exc() print(f"\n{'='*60}") print(f"Results: {passed} passed, {failed} failed out of {passed+failed}") if errors: print("Failures:") for name, e in errors: print(f" - {name}: {e}") return failed == 0 if __name__ == "__main__": success = run_all() sys.exit(0 if success else 1)