| """
|
| Tests for the OMorpher module.
|
|
|
| Split into two groups:
|
| - Basic tests: verify shapes, value ranges, and API behaviour (no checkpoint needed)
|
| - Alignment tests: cross-validate against DeformDDPM / OM_reg_flexres (shared weights)
|
|
|
| Run:
|
| python -m pytest tests/test_omorpher.py -v
|
| # or directly:
|
| python tests/test_omorpher.py
|
| """
|
|
|
| import os
|
| import sys
|
| import math
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT)
|
|
|
| from OMorpher import OMorpher
|
| from Diffusion.diffuser import DeformDDPM
|
| from Diffusion.networks import get_net, STN
|
|
|
|
|
|
|
| NDIMS = 3
|
| IMG_SIZE = 32
|
| TIMESTEPS = 10
|
| NET_NAME = "recmutattnnet"
|
| DEVICE = "cpu"
|
|
|
| BASE_CONFIG = {
|
| "net_name": NET_NAME,
|
| "ndims": NDIMS,
|
| "img_size": IMG_SIZE,
|
| "timesteps": TIMESTEPS,
|
| "v_scale": 5e-5,
|
| "noise_scale": 0.1,
|
| "condition_type": "none",
|
| "num_input_chn": 1,
|
| "img_pad_mode": "zeros",
|
| "ddf_pad_mode": "border",
|
| "padding_mode": "border",
|
| "resample_mode": "bilinear",
|
| "batchsize": 1,
|
| "data_name": "test",
|
| "start_noise_step": 0,
|
| }
|
|
|
|
|
| def _make_omorpher(**overrides):
|
| cfg = {**BASE_CONFIG, **overrides}
|
| return OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
|
|
|
|
| def _rand_vol(B=1, S=None):
|
| S = S or IMG_SIZE
|
| return torch.rand([B, 1] + [S] * NDIMS)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class TestInstantiation:
|
| """Test 1: OMorpher with config dict + no checkpoint."""
|
|
|
| def test_creates_network_and_stns(self):
|
| om = _make_omorpher()
|
| assert om.network is not None
|
| assert om.stn_full is not None
|
| assert om.stn_ctl is not None
|
| assert om.img_stn is not None
|
| assert om.msk_stn is not None
|
|
|
| def test_device(self):
|
| om = _make_omorpher()
|
| assert om.device == torch.device(DEVICE)
|
|
|
| def test_repr(self):
|
| om = _make_omorpher()
|
| r = repr(om)
|
| assert "OMorpher" in r
|
| assert NET_NAME in r
|
|
|
|
|
| class TestStandardization:
|
| """Test 2: _standardize_img produces correct shape and range."""
|
|
|
| def test_numpy_input(self):
|
| om = _make_omorpher()
|
| vol = np.random.rand(40, 50, 60).astype(np.float32) * 1000.0
|
| tensor, fullres, orig_shape = om._standardize_img(vol, keep_raw=True)
|
| assert tensor.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| assert tensor.min() >= 0.0
|
| assert tensor.max() <= 1.0 + 1e-6
|
| assert fullres is not None
|
| assert isinstance(fullres, torch.Tensor)
|
|
|
| assert fullres.ndim == 5
|
| assert fullres.shape[2] == fullres.shape[3] == fullres.shape[4] == 60
|
|
|
| assert orig_shape[0] == orig_shape[1] == orig_shape[2]
|
|
|
| def test_torch_passthrough(self):
|
| om = _make_omorpher()
|
| vol = torch.rand(1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| tensor, raw, _ = om._standardize_img(vol)
|
| assert tensor.shape == vol.shape
|
|
|
|
|
| class TestLabelStandardization:
|
| """Test: _standardize_label produces correct shapes and handles None."""
|
|
|
| def test_3d_label(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| label = (np.random.rand(40, 50, 60) > 0.5).astype(np.float32)
|
| model_t, fullres_t = om._standardize_label(label)
|
| assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| assert fullres_t.ndim == 5
|
|
|
| assert fullres_t.shape[2] == fullres_t.shape[3] == fullres_t.shape[4] == 60
|
| assert isinstance(model_t, torch.Tensor)
|
| assert isinstance(fullres_t, torch.Tensor)
|
|
|
| def test_none_placeholder(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| model_t, fullres_t = om._standardize_label(None)
|
| assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| assert torch.all(model_t == -1)
|
| assert torch.all(fullres_t == -1)
|
|
|
| def test_4d_label(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| label = (np.random.rand(30, 30, 30, 2) > 0.5).astype(np.float32)
|
| model_t, fullres_t = om._standardize_label(label)
|
|
|
| assert model_t.shape == (1, 2, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| assert fullres_t.shape[1] == 2
|
|
|
|
|
| class TestZeroDDFRoundtrip:
|
| """Test 3: apply_def with zero DDF returns approx original."""
|
|
|
| def test_identity(self):
|
| om = _make_omorpher()
|
| vol = _rand_vol()
|
| zero_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| warped = om._apply_ddf(vol, zero_ddf, padding_mode="border")
|
| assert torch.allclose(vol, warped, atol=1e-5)
|
|
|
|
|
| class TestSizeMismatch:
|
| """Test 4: apply_def auto-upscales DDF."""
|
|
|
| def test_upscale(self):
|
| om = _make_omorpher()
|
| big_vol = _rand_vol(S=64)
|
| small_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
| result = om.apply_def(img=big_vol, ddf=small_ddf)
|
| assert list(result.shape[2:]) == [64, 64, 64]
|
|
|
|
|
| class TestPredict:
|
| """Test 6: predict with random weights produces correct DDF shape.
|
|
|
| Uses IMG_SIZE=64 because RecMutAttnNet has 5 hierarchy levels:
|
| 32→16→8→4→2→1 bottleneck breaks InstanceNorm (no running stats).
|
| """
|
|
|
| def test_predict_shape(self):
|
| sz = 64
|
| om = _make_omorpher(img_size=sz)
|
| img = torch.rand([1, 1] + [sz] * NDIMS)
|
| om.set_init_img(img.numpy()[0, 0])
|
| om.predict(T=[0, 2])
|
| ddf = om.get_def()
|
| assert ddf.shape == (1, NDIMS, sz, sz, sz)
|
|
|
| def test_predict_intermediate(self):
|
| sz = 64
|
| om = _make_omorpher(img_size=sz)
|
| img = torch.rand([1, 1] + [sz] * NDIMS)
|
| om.set_init_img(img.numpy()[0, 0])
|
| om.predict(T=[0, 4], t_save=[3, 1])
|
| intermediates = om.get_def(t_list=[3, 1])
|
| assert isinstance(intermediates, dict)
|
|
|
| def test_chaining(self):
|
| sz = 64
|
| om = _make_omorpher(img_size=sz)
|
| img = torch.rand([1, 1] + [sz] * NDIMS)
|
| result = om.set_init_img(img.numpy()[0, 0]).predict(T=[0, 2])
|
| assert result is om
|
|
|
|
|
| class TestFinetune:
|
| """Test 7: finetune_step with dummy data.
|
|
|
| Uses IMG_SIZE=64 because at 32 the bottleneck hits 1x1x1 and
|
| InstanceNorm fails in training mode.
|
| """
|
|
|
| def test_finetune_roundtrip(self):
|
| ft_size = 64
|
| om = _make_omorpher(img_size=ft_size, batchsize=1)
|
| om.finetune_setup(lr=1e-3)
|
| vol = _rand_vol(S=ft_size)
|
| losses = om.finetune_step(vol)
|
| assert "loss_total" in losses
|
| assert "loss_grad" in losses
|
| assert isinstance(losses["loss_total"], float)
|
| om.finetune_teardown()
|
|
|
|
|
| class TestSetters:
|
| """Test input setters."""
|
|
|
| def test_set_init_def_random(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| om.set_init_def(None)
|
| assert om._init_ddf is not None
|
| assert not torch.all(om._init_ddf == 0)
|
|
|
| def test_set_init_def_provided(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| custom_ddf = np.zeros([1, NDIMS] + [IMG_SIZE] * NDIMS)
|
| om.set_init_def(custom_ddf)
|
| assert torch.all(om._init_ddf == 0)
|
|
|
| def test_set_cond_img_default(self):
|
| om = _make_omorpher()
|
| om.set_init_img(_rand_vol().numpy()[0, 0])
|
| om.set_cond_img(None)
|
| assert om._cond_img is not None
|
| assert om._cond_img.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
|
|
|
| def test_set_cond_txt_numpy(self):
|
| om = _make_omorpher()
|
| emb = np.random.randn(1024).astype(np.float32)
|
| om.set_cond_txt(emb)
|
| assert om._cond_txt is not None
|
| assert om._cond_txt.shape == (1, 1024)
|
|
|
| def test_set_init_img_with_ddf(self):
|
| om = _make_omorpher()
|
| vol = np.random.rand(40, 40, 40).astype(np.float32)
|
| ddf = np.zeros([1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE], dtype=np.float32)
|
| om.set_init_img((vol, ddf))
|
| assert om._init_img is not None
|
| assert om._init_ddf is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _build_shared_weights():
|
| """Build matching OMorpher + DeformDDPM with identical random weights."""
|
| cfg = {**BASE_CONFIG, "inf_mode": False}
|
| Net = get_net(NET_NAME)
|
|
|
| network = Net(
|
| n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE,
|
| )
|
| ddpm = DeformDDPM(
|
| network=network,
|
| n_steps=TIMESTEPS,
|
| image_chw=[1] + [IMG_SIZE] * NDIMS,
|
| device=DEVICE,
|
| batch_size=1,
|
| img_pad_mode="zeros",
|
| ddf_pad_mode="border",
|
| padding_mode="border",
|
| v_scale=cfg["v_scale"],
|
| )
|
| ddpm.eval()
|
|
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
|
|
| om.network.load_state_dict(ddpm.network.state_dict())
|
| om.network.eval()
|
|
|
| return om, ddpm
|
|
|
|
|
| class TestDDFScaleAlignment:
|
| """Test 10: _get_ddf_scale matches DeformDDPM._get_ddf_scale."""
|
|
|
| def test_all_timesteps(self):
|
| om, ddpm = _build_shared_weights()
|
| for t_val in [1, 5, 10, 20, 40, 60, 80]:
|
| t = torch.tensor([t_val])
|
| r1, m1, v1 = om._get_ddf_scale(t)
|
| r2, m2, v2 = ddpm._get_ddf_scale(t)
|
| assert r1 == r2, f"rec_num mismatch at t={t_val}"
|
| assert torch.equal(m1, m2), f"mul_num_ddf mismatch at t={t_val}"
|
| assert torch.equal(v1, v2), f"mul_num_dvf mismatch at t={t_val}"
|
|
|
|
|
| class TestRandomDDFAlignment:
|
| """Test 9: _get_random_ddf matches DeformDDPM._get_random_ddf with same seed.
|
|
|
| Uses ndims=2, IMG_SIZE=128 so that ctl_sz=32, scale_num=5,
|
| len(ctl_szs_all)=5 > select_num=4 — avoiding a known unbound-variable
|
| bug in the original DeformDDPM._random_ddf_generate at smaller sizes.
|
| """
|
|
|
| def test_same_seed(self):
|
| align_size = 128
|
| align_ndims = 2
|
| cfg = {**BASE_CONFIG, "img_size": align_size, "ndims": align_ndims}
|
| Net = get_net(NET_NAME)
|
| network = Net(n_steps=TIMESTEPS, ndims=align_ndims, num_input_chn=1, res=align_size)
|
| ddpm = DeformDDPM(
|
| network=network, n_steps=TIMESTEPS,
|
| image_chw=[1] + [align_size] * align_ndims, device=DEVICE,
|
| batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border",
|
| padding_mode="border", v_scale=cfg["v_scale"],
|
| )
|
| ddpm.eval()
|
|
|
| om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
|
| om.network.load_state_dict(ddpm.network.state_dict())
|
| om.network.eval()
|
|
|
| img = torch.rand([1, 1] + [align_size] * align_ndims).to(DEVICE)
|
| t = torch.tensor([5])
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod = __import__("random")
|
| random_mod.seed(42)
|
| warped_om, dvf_om, ddf_om = om._get_random_ddf(img, t)
|
|
|
|
|
| torch.manual_seed(42)
|
| np.random.seed(42)
|
| random_mod.seed(42)
|
| warped_ddpm, dvf_ddpm, ddf_ddpm = ddpm._get_random_ddf(img, t)
|
|
|
| assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), "DDFs do not match"
|
| assert torch.allclose(dvf_om, dvf_ddpm, atol=1e-5), "DVFs do not match"
|
| assert torch.allclose(warped_om, warped_ddpm, atol=1e-5), "Warped images do not match"
|
|
|
|
|
| class TestConditioningAlignment:
|
| """Test 11: _proc_cond_img matches DeformDDPM.proc_cond_img."""
|
|
|
| def _test_proc_type(self, proc_type):
|
| om, ddpm = _build_shared_weights()
|
| img = _rand_vol().to(DEVICE)
|
|
|
| torch.manual_seed(99)
|
| np.random.seed(99)
|
| random_mod = __import__("random")
|
| random_mod.seed(99)
|
| out_om, mask_om, cr_om = om._proc_cond_img(img, proc_type=proc_type)
|
|
|
| torch.manual_seed(99)
|
| np.random.seed(99)
|
| random_mod.seed(99)
|
| out_ddpm, mask_ddpm, cr_ddpm = ddpm.proc_cond_img(img, proc_type=proc_type)
|
|
|
| assert torch.allclose(out_om, out_ddpm, atol=1e-5), f"Proc image mismatch for {proc_type}"
|
|
|
| def test_uncon(self):
|
| self._test_proc_type("uncon")
|
|
|
| def test_none(self):
|
| self._test_proc_type("none")
|
|
|
| def test_adding(self):
|
| self._test_proc_type("adding")
|
|
|
| def test_independ(self):
|
| self._test_proc_type("independ")
|
|
|
| def test_slice(self):
|
| self._test_proc_type("slice")
|
|
|
| def test_downsample(self):
|
| self._test_proc_type("downsample")
|
|
|
|
|
| class TestApplyDDFAlignment:
|
| """Test 8: _apply_ddf matches OM_reg_flexres.apply_ddf."""
|
|
|
| def test_vs_flexres(self):
|
| om = _make_omorpher()
|
| vol = _rand_vol().to(DEVICE)
|
| ddf = torch.randn(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE, device=DEVICE) * 0.01
|
|
|
|
|
| out_om = om._apply_ddf(vol, ddf, padding_mode="border")
|
|
|
|
|
| ndims = 3
|
| img_sz = list(vol.shape[2:])
|
| max_sz = torch.reshape(
|
| torch.tensor(img_sz, dtype=torch.float32, device=DEVICE),
|
| [1, ndims] + [1] * ndims)
|
| ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(s, device=DEVICE) for s in img_sz], indexing='ij'), 0),
|
| [1, ndims] + img_sz)
|
| img_shape = torch.reshape(
|
| torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=DEVICE),
|
| [1] + [1] * ndims + [ndims])
|
| grid = torch.flip(
|
| (ddf * max_sz + ref_grid).permute(
|
| [0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
|
| dims=[-1])
|
| out_ref = F.grid_sample(vol, grid.float(), mode="bilinear",
|
| padding_mode="border", align_corners=True)
|
|
|
| assert torch.allclose(out_om, out_ref, atol=1e-6), "apply_ddf output mismatch"
|
|
|
|
|
|
|
|
|
|
|
|
|
| def run_all():
|
| import traceback
|
| test_classes = [
|
| TestInstantiation,
|
| TestStandardization,
|
| TestLabelStandardization,
|
| TestZeroDDFRoundtrip,
|
| TestSizeMismatch,
|
| TestPredict,
|
| TestFinetune,
|
| TestSetters,
|
| TestDDFScaleAlignment,
|
| TestRandomDDFAlignment,
|
| TestConditioningAlignment,
|
| TestApplyDDFAlignment,
|
| ]
|
| passed = 0
|
| failed = 0
|
| errors = []
|
| for cls in test_classes:
|
| inst = cls()
|
| for name in sorted(dir(inst)):
|
| if not name.startswith("test"):
|
| continue
|
| full_name = f"{cls.__name__}.{name}"
|
| try:
|
| getattr(inst, name)()
|
| passed += 1
|
| print(f" PASS {full_name}")
|
| except Exception as e:
|
| failed += 1
|
| errors.append((full_name, e))
|
| print(f" FAIL {full_name}: {e}")
|
| traceback.print_exc()
|
| print(f"\n{'='*60}")
|
| print(f"Results: {passed} passed, {failed} failed out of {passed+failed}")
|
| if errors:
|
| print("Failures:")
|
| for name, e in errors:
|
| print(f" - {name}: {e}")
|
| return failed == 0
|
|
|
|
|
| if __name__ == "__main__":
|
| success = run_all()
|
| sys.exit(0 if success else 1)
|
|
|