Omini3D / tests /test_omorpher.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
Tests for the OMorpher module.
Split into two groups:
- Basic tests: verify shapes, value ranges, and API behaviour (no checkpoint needed)
- Alignment tests: cross-validate against DeformDDPM / OM_reg_flexres (shared weights)
Run:
python -m pytest tests/test_omorpher.py -v
# or directly:
python tests/test_omorpher.py
"""
import os
import sys
import math
import numpy as np
import torch
import torch.nn.functional as F
# Ensure project root is importable
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from OMorpher import OMorpher
from Diffusion.diffuser import DeformDDPM
from Diffusion.networks import get_net, STN
# ---------- helpers ----------
NDIMS = 3
IMG_SIZE = 32 # tiny for speed
TIMESTEPS = 10
NET_NAME = "recmutattnnet"
DEVICE = "cpu"
BASE_CONFIG = {
"net_name": NET_NAME,
"ndims": NDIMS,
"img_size": IMG_SIZE,
"timesteps": TIMESTEPS,
"v_scale": 5e-5,
"noise_scale": 0.1,
"condition_type": "none",
"num_input_chn": 1,
"img_pad_mode": "zeros",
"ddf_pad_mode": "border",
"padding_mode": "border",
"resample_mode": "bilinear",
"batchsize": 1,
"data_name": "test",
"start_noise_step": 0,
}
def _make_omorpher(**overrides):
cfg = {**BASE_CONFIG, **overrides}
return OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
def _rand_vol(B=1, S=None):
S = S or IMG_SIZE
return torch.rand([B, 1] + [S] * NDIMS)
# ================================================================
# 1. Basic tests
# ================================================================
class TestInstantiation:
"""Test 1: OMorpher with config dict + no checkpoint."""
def test_creates_network_and_stns(self):
om = _make_omorpher()
assert om.network is not None
assert om.stn_full is not None
assert om.stn_ctl is not None
assert om.img_stn is not None
assert om.msk_stn is not None
def test_device(self):
om = _make_omorpher()
assert om.device == torch.device(DEVICE)
def test_repr(self):
om = _make_omorpher()
r = repr(om)
assert "OMorpher" in r
assert NET_NAME in r
class TestStandardization:
"""Test 2: _standardize_img produces correct shape and range."""
def test_numpy_input(self):
om = _make_omorpher()
vol = np.random.rand(40, 50, 60).astype(np.float32) * 1000.0
tensor, fullres, orig_shape = om._standardize_img(vol, keep_raw=True)
assert tensor.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
assert tensor.min() >= 0.0
assert tensor.max() <= 1.0 + 1e-6
assert fullres is not None
assert isinstance(fullres, torch.Tensor)
# fullres should be [1, 1, 60, 60, 60] (cube-padded from max dim)
assert fullres.ndim == 5
assert fullres.shape[2] == fullres.shape[3] == fullres.shape[4] == 60
# orig_shape should be the cube-padded size
assert orig_shape[0] == orig_shape[1] == orig_shape[2]
def test_torch_passthrough(self):
om = _make_omorpher()
vol = torch.rand(1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
tensor, raw, _ = om._standardize_img(vol)
assert tensor.shape == vol.shape
class TestLabelStandardization:
"""Test: _standardize_label produces correct shapes and handles None."""
def test_3d_label(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
label = (np.random.rand(40, 50, 60) > 0.5).astype(np.float32)
model_t, fullres_t = om._standardize_label(label)
assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
assert fullres_t.ndim == 5
# cube-padded: max(40,50,60)=60
assert fullres_t.shape[2] == fullres_t.shape[3] == fullres_t.shape[4] == 60
assert isinstance(model_t, torch.Tensor)
assert isinstance(fullres_t, torch.Tensor)
def test_none_placeholder(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
model_t, fullres_t = om._standardize_label(None)
assert model_t.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
assert torch.all(model_t == -1)
assert torch.all(fullres_t == -1)
def test_4d_label(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
label = (np.random.rand(30, 30, 30, 2) > 0.5).astype(np.float32)
model_t, fullres_t = om._standardize_label(label)
# 4D → channels-first: 2 channels
assert model_t.shape == (1, 2, IMG_SIZE, IMG_SIZE, IMG_SIZE)
assert fullres_t.shape[1] == 2
class TestZeroDDFRoundtrip:
"""Test 3: apply_def with zero DDF returns approx original."""
def test_identity(self):
om = _make_omorpher()
vol = _rand_vol()
zero_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE)
warped = om._apply_ddf(vol, zero_ddf, padding_mode="border")
assert torch.allclose(vol, warped, atol=1e-5)
class TestSizeMismatch:
"""Test 4: apply_def auto-upscales DDF."""
def test_upscale(self):
om = _make_omorpher()
big_vol = _rand_vol(S=64)
small_ddf = torch.zeros(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE)
result = om.apply_def(img=big_vol, ddf=small_ddf)
assert list(result.shape[2:]) == [64, 64, 64]
class TestPredict:
"""Test 6: predict with random weights produces correct DDF shape.
Uses IMG_SIZE=64 because RecMutAttnNet has 5 hierarchy levels:
32→16→8→4→2→1 bottleneck breaks InstanceNorm (no running stats).
"""
def test_predict_shape(self):
sz = 64
om = _make_omorpher(img_size=sz)
img = torch.rand([1, 1] + [sz] * NDIMS)
om.set_init_img(img.numpy()[0, 0])
om.predict(T=[0, 2])
ddf = om.get_def()
assert ddf.shape == (1, NDIMS, sz, sz, sz)
def test_predict_intermediate(self):
sz = 64
om = _make_omorpher(img_size=sz)
img = torch.rand([1, 1] + [sz] * NDIMS)
om.set_init_img(img.numpy()[0, 0])
om.predict(T=[0, 4], t_save=[3, 1])
intermediates = om.get_def(t_list=[3, 1])
assert isinstance(intermediates, dict)
def test_chaining(self):
sz = 64
om = _make_omorpher(img_size=sz)
img = torch.rand([1, 1] + [sz] * NDIMS)
result = om.set_init_img(img.numpy()[0, 0]).predict(T=[0, 2])
assert result is om
class TestFinetune:
"""Test 7: finetune_step with dummy data.
Uses IMG_SIZE=64 because at 32 the bottleneck hits 1x1x1 and
InstanceNorm fails in training mode.
"""
def test_finetune_roundtrip(self):
ft_size = 64
om = _make_omorpher(img_size=ft_size, batchsize=1)
om.finetune_setup(lr=1e-3)
vol = _rand_vol(S=ft_size)
losses = om.finetune_step(vol)
assert "loss_total" in losses
assert "loss_grad" in losses
assert isinstance(losses["loss_total"], float)
om.finetune_teardown()
class TestSetters:
"""Test input setters."""
def test_set_init_def_random(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
om.set_init_def(None) # should generate random
assert om._init_ddf is not None
assert not torch.all(om._init_ddf == 0)
def test_set_init_def_provided(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
custom_ddf = np.zeros([1, NDIMS] + [IMG_SIZE] * NDIMS)
om.set_init_def(custom_ddf)
assert torch.all(om._init_ddf == 0)
def test_set_cond_img_default(self):
om = _make_omorpher()
om.set_init_img(_rand_vol().numpy()[0, 0])
om.set_cond_img(None)
assert om._cond_img is not None
assert om._cond_img.shape == (1, 1, IMG_SIZE, IMG_SIZE, IMG_SIZE)
def test_set_cond_txt_numpy(self):
om = _make_omorpher()
emb = np.random.randn(1024).astype(np.float32)
om.set_cond_txt(emb)
assert om._cond_txt is not None
assert om._cond_txt.shape == (1, 1024)
def test_set_init_img_with_ddf(self):
om = _make_omorpher()
vol = np.random.rand(40, 40, 40).astype(np.float32)
ddf = np.zeros([1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE], dtype=np.float32)
om.set_init_img((vol, ddf))
assert om._init_img is not None
assert om._init_ddf is not None
# ================================================================
# 2. Cross-validation / alignment tests
# ================================================================
def _build_shared_weights():
"""Build matching OMorpher + DeformDDPM with identical random weights."""
cfg = {**BASE_CONFIG, "inf_mode": False} # match DeformDDPM default
Net = get_net(NET_NAME)
network = Net(
n_steps=TIMESTEPS, ndims=NDIMS, num_input_chn=1, res=IMG_SIZE,
)
ddpm = DeformDDPM(
network=network,
n_steps=TIMESTEPS,
image_chw=[1] + [IMG_SIZE] * NDIMS,
device=DEVICE,
batch_size=1,
img_pad_mode="zeros",
ddf_pad_mode="border",
padding_mode="border",
v_scale=cfg["v_scale"],
)
ddpm.eval()
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
# Copy weights from the DeformDDPM's network to OMorpher's network
om.network.load_state_dict(ddpm.network.state_dict())
om.network.eval()
return om, ddpm
class TestDDFScaleAlignment:
"""Test 10: _get_ddf_scale matches DeformDDPM._get_ddf_scale."""
def test_all_timesteps(self):
om, ddpm = _build_shared_weights()
for t_val in [1, 5, 10, 20, 40, 60, 80]:
t = torch.tensor([t_val])
r1, m1, v1 = om._get_ddf_scale(t)
r2, m2, v2 = ddpm._get_ddf_scale(t)
assert r1 == r2, f"rec_num mismatch at t={t_val}"
assert torch.equal(m1, m2), f"mul_num_ddf mismatch at t={t_val}"
assert torch.equal(v1, v2), f"mul_num_dvf mismatch at t={t_val}"
class TestRandomDDFAlignment:
"""Test 9: _get_random_ddf matches DeformDDPM._get_random_ddf with same seed.
Uses ndims=2, IMG_SIZE=128 so that ctl_sz=32, scale_num=5,
len(ctl_szs_all)=5 > select_num=4 — avoiding a known unbound-variable
bug in the original DeformDDPM._random_ddf_generate at smaller sizes.
"""
def test_same_seed(self):
align_size = 128
align_ndims = 2
cfg = {**BASE_CONFIG, "img_size": align_size, "ndims": align_ndims}
Net = get_net(NET_NAME)
network = Net(n_steps=TIMESTEPS, ndims=align_ndims, num_input_chn=1, res=align_size)
ddpm = DeformDDPM(
network=network, n_steps=TIMESTEPS,
image_chw=[1] + [align_size] * align_ndims, device=DEVICE,
batch_size=1, img_pad_mode="zeros", ddf_pad_mode="border",
padding_mode="border", v_scale=cfg["v_scale"],
)
ddpm.eval()
om = OMorpher(config=cfg, checkpoint_path=None, device=DEVICE)
om.network.load_state_dict(ddpm.network.state_dict())
om.network.eval()
img = torch.rand([1, 1] + [align_size] * align_ndims).to(DEVICE)
t = torch.tensor([5])
# OMorpher
torch.manual_seed(42)
np.random.seed(42)
random_mod = __import__("random")
random_mod.seed(42)
warped_om, dvf_om, ddf_om = om._get_random_ddf(img, t)
# DeformDDPM
torch.manual_seed(42)
np.random.seed(42)
random_mod.seed(42)
warped_ddpm, dvf_ddpm, ddf_ddpm = ddpm._get_random_ddf(img, t)
assert torch.allclose(ddf_om, ddf_ddpm, atol=1e-5), "DDFs do not match"
assert torch.allclose(dvf_om, dvf_ddpm, atol=1e-5), "DVFs do not match"
assert torch.allclose(warped_om, warped_ddpm, atol=1e-5), "Warped images do not match"
class TestConditioningAlignment:
"""Test 11: _proc_cond_img matches DeformDDPM.proc_cond_img."""
def _test_proc_type(self, proc_type):
om, ddpm = _build_shared_weights()
img = _rand_vol().to(DEVICE)
torch.manual_seed(99)
np.random.seed(99)
random_mod = __import__("random")
random_mod.seed(99)
out_om, mask_om, cr_om = om._proc_cond_img(img, proc_type=proc_type)
torch.manual_seed(99)
np.random.seed(99)
random_mod.seed(99)
out_ddpm, mask_ddpm, cr_ddpm = ddpm.proc_cond_img(img, proc_type=proc_type)
assert torch.allclose(out_om, out_ddpm, atol=1e-5), f"Proc image mismatch for {proc_type}"
def test_uncon(self):
self._test_proc_type("uncon")
def test_none(self):
self._test_proc_type("none")
def test_adding(self):
self._test_proc_type("adding")
def test_independ(self):
self._test_proc_type("independ")
def test_slice(self):
self._test_proc_type("slice")
def test_downsample(self):
self._test_proc_type("downsample")
class TestApplyDDFAlignment:
"""Test 8: _apply_ddf matches OM_reg_flexres.apply_ddf."""
def test_vs_flexres(self):
om = _make_omorpher()
vol = _rand_vol().to(DEVICE)
ddf = torch.randn(1, NDIMS, IMG_SIZE, IMG_SIZE, IMG_SIZE, device=DEVICE) * 0.01
# OMorpher version
out_om = om._apply_ddf(vol, ddf, padding_mode="border")
# Inline reimplementation of OM_reg_flexres.apply_ddf for comparison
ndims = 3
img_sz = list(vol.shape[2:])
max_sz = torch.reshape(
torch.tensor(img_sz, dtype=torch.float32, device=DEVICE),
[1, ndims] + [1] * ndims)
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(s, device=DEVICE) for s in img_sz], indexing='ij'), 0),
[1, ndims] + img_sz)
img_shape = torch.reshape(
torch.tensor([(s - 1) / 2. for s in img_sz], dtype=torch.float32, device=DEVICE),
[1] + [1] * ndims + [ndims])
grid = torch.flip(
(ddf * max_sz + ref_grid).permute(
[0] + list(range(2, 2 + ndims)) + [1]) / img_shape - 1,
dims=[-1])
out_ref = F.grid_sample(vol, grid.float(), mode="bilinear",
padding_mode="border", align_corners=True)
assert torch.allclose(out_om, out_ref, atol=1e-6), "apply_ddf output mismatch"
# ================================================================
# Runner
# ================================================================
def run_all():
import traceback
test_classes = [
TestInstantiation,
TestStandardization,
TestLabelStandardization,
TestZeroDDFRoundtrip,
TestSizeMismatch,
TestPredict,
TestFinetune,
TestSetters,
TestDDFScaleAlignment,
TestRandomDDFAlignment,
TestConditioningAlignment,
TestApplyDDFAlignment,
]
passed = 0
failed = 0
errors = []
for cls in test_classes:
inst = cls()
for name in sorted(dir(inst)):
if not name.startswith("test"):
continue
full_name = f"{cls.__name__}.{name}"
try:
getattr(inst, name)()
passed += 1
print(f" PASS {full_name}")
except Exception as e:
failed += 1
errors.append((full_name, e))
print(f" FAIL {full_name}: {e}")
traceback.print_exc()
print(f"\n{'='*60}")
print(f"Results: {passed} passed, {failed} failed out of {passed+failed}")
if errors:
print("Failures:")
for name, e in errors:
print(f" - {name}: {e}")
return failed == 0
if __name__ == "__main__":
success = run_all()
sys.exit(0 if success else 1)