Omini3D / tests /test_mslncc.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
Tests for MSLNCC (Multi-Scale Local Normalized Cross-Correlation) loss.
Verifies:
1. Loss increases monotonically with larger spatial translations.
2. Gradients flow correctly through all scale branches.
3. Consistency with single-scale LNCC when only one scale is used.
4. Label masking works at all scales.
Run:
python -m pytest tests/test_mslncc.py -v
# or directly:
python tests/test_mslncc.py
"""
import os
import sys
import torch
# Ensure project root is importable
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from Diffusion.losses import LNCC, MSLNCC
# ---------- helpers ----------
SIZE = 64 # must be divisible by max downscale factor (4 for ratio=0.25)
torch.manual_seed(42)
# Fixed reference images — reused across tests
REF_IMG = torch.rand(1, 1, SIZE, SIZE, SIZE)
def translate_image(img, shift):
"""Translate image along the last axis by `shift` voxels (zero-fill)."""
out = torch.zeros_like(img)
if shift == 0:
out.copy_(img)
elif shift > 0:
out[..., shift:] = img[..., :-shift]
else:
out[..., :shift] = img[..., -shift:]
return out
# ---------- Test 1: loss increases with translation ----------
def test_loss_increases_with_translation():
"""MSLNCC loss (negative NCC, so higher = worse match) should increase
monotonically as the translation between I and J grows."""
loss_fn = MSLNCC(smooth=True, central=True)
translations = [0, 2, 4, 8, 16]
losses = []
for t in translations:
J = translate_image(REF_IMG, t)
loss = loss_fn(REF_IMG, J).item()
losses.append(loss)
# loss should be monotonically non-decreasing
for i in range(1, len(losses)):
assert losses[i] >= losses[i - 1], (
f"Loss did not increase: translation {translations[i-1]}->{translations[i]}, "
f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
)
# first and last should be clearly different
assert losses[-1] > losses[0] + 1e-4, (
f"Loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
)
print(f" translations: {translations}")
print(f" losses: {[f'{l:.6f}' for l in losses]}")
# ---------- Test 2: gradients are properly computed ----------
def test_gradient_flows():
"""Verify gradients are non-zero and finite for both I and J at all scales."""
loss_fn = MSLNCC(smooth=True, central=True)
I = REF_IMG.clone().requires_grad_(True)
J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
loss = loss_fn(I, J)
loss.backward()
# Check I gradient
assert I.grad is not None, "No gradient for I"
assert torch.isfinite(I.grad).all(), "Non-finite gradient for I"
assert I.grad.abs().sum() > 0, "Zero gradient for I"
# Check J gradient
assert J.grad is not None, "No gradient for J"
assert torch.isfinite(J.grad).all(), "Non-finite gradient for J"
assert J.grad.abs().sum() > 0, "Zero gradient for J"
print(f" I grad norm: {I.grad.norm():.6f}")
print(f" J grad norm: {J.grad.norm():.6f}")
def test_gradient_with_label():
"""Verify gradients flow correctly when a label mask is provided."""
loss_fn = MSLNCC(smooth=True, central=True)
I = REF_IMG.clone().requires_grad_(True)
J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
# Label: central cube
label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
loss = loss_fn(I, J, label=label)
loss.backward()
assert I.grad is not None and torch.isfinite(I.grad).all(), "Bad gradient for I with label"
assert J.grad is not None and torch.isfinite(J.grad).all(), "Bad gradient for J with label"
assert I.grad.abs().sum() > 0, "Zero gradient for I with label"
assert J.grad.abs().sum() > 0, "Zero gradient for J with label"
print(f" I grad norm (masked): {I.grad.norm():.6f}")
print(f" J grad norm (masked): {J.grad.norm():.6f}")
# ---------- Test 3: single-scale consistency with LNCC ----------
def test_single_scale_matches_lncc():
"""MSLNCC with scale_ratios=[1] should produce the same loss as LNCC."""
lncc_fn = LNCC(smooth=True, central=True)
mslncc_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1], scale_weights=[1])
J = translate_image(REF_IMG, 4)
loss_lncc = lncc_fn(REF_IMG, J).item()
loss_mslncc = mslncc_fn(REF_IMG, J).item()
assert abs(loss_lncc - loss_mslncc) < 1e-6, (
f"Single-scale MSLNCC ({loss_mslncc:.8f}) != LNCC ({loss_lncc:.8f})"
)
print(f" LNCC: {loss_lncc:.8f}")
print(f" MSLNCC: {loss_mslncc:.8f}")
# ---------- Test 4: multi-scale produces different loss than single-scale ----------
def test_multiscale_differs_from_single():
"""Multi-scale loss should differ from single-scale (coarser scales see
different structure), confirming downsampled branches contribute."""
single_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1], scale_weights=[1])
multi_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1, 0.5, 0.25], scale_weights=[1, 0.5, 0.25])
J = translate_image(REF_IMG, 8)
loss_single = single_fn(REF_IMG, J).item()
loss_multi = multi_fn(REF_IMG, J).item()
assert abs(loss_single - loss_multi) > 1e-6, (
f"Multi-scale loss ({loss_multi:.8f}) is identical to single-scale ({loss_single:.8f})"
)
print(f" single-scale: {loss_single:.8f}")
print(f" multi-scale: {loss_multi:.8f}")
# ---------- Test 5: loss increases with translation (with label) ----------
def test_loss_increases_with_translation_labeled():
"""Same as test_loss_increases_with_translation but with a label mask."""
loss_fn = MSLNCC(smooth=True, central=True)
label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
translations = [0, 2, 4, 8, 16]
losses = []
for t in translations:
J = translate_image(REF_IMG, t)
loss = loss_fn(REF_IMG, J, label=label).item()
losses.append(loss)
for i in range(1, len(losses)):
assert losses[i] >= losses[i - 1], (
f"Labeled loss did not increase: translation {translations[i-1]}->{translations[i]}, "
f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
)
assert losses[-1] > losses[0] + 1e-4, (
f"Labeled loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
)
print(f" translations: {translations}")
print(f" losses: {[f'{l:.6f}' for l in losses]}")
# ---------- runner ----------
if __name__ == "__main__":
tests = [
("Loss increases with translation", test_loss_increases_with_translation),
("Gradient flows (no label)", test_gradient_flows),
("Gradient flows (with label)", test_gradient_with_label),
("Single-scale matches LNCC", test_single_scale_matches_lncc),
("Multi-scale differs from single", test_multiscale_differs_from_single),
("Loss increases with translation (labeled)", test_loss_increases_with_translation_labeled),
]
passed = 0
failed = 0
for name, fn in tests:
print(f"\n[TEST] {name}")
try:
fn()
print(f" PASSED")
passed += 1
except AssertionError as e:
print(f" FAILED: {e}")
failed += 1
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
failed += 1
print(f"\n{'='*40}")
print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
if failed:
sys.exit(1)