| """
|
| Tests for MSLNCC (Multi-Scale Local Normalized Cross-Correlation) loss.
|
|
|
| Verifies:
|
| 1. Loss increases monotonically with larger spatial translations.
|
| 2. Gradients flow correctly through all scale branches.
|
| 3. Consistency with single-scale LNCC when only one scale is used.
|
| 4. Label masking works at all scales.
|
|
|
| Run:
|
| python -m pytest tests/test_mslncc.py -v
|
| # or directly:
|
| python tests/test_mslncc.py
|
| """
|
|
|
| import os
|
| import sys
|
|
|
| import torch
|
|
|
|
|
| ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| sys.path.insert(0, ROOT)
|
|
|
| from Diffusion.losses import LNCC, MSLNCC
|
|
|
|
|
|
|
| SIZE = 64
|
| torch.manual_seed(42)
|
|
|
|
|
| REF_IMG = torch.rand(1, 1, SIZE, SIZE, SIZE)
|
|
|
|
|
| def translate_image(img, shift):
|
| """Translate image along the last axis by `shift` voxels (zero-fill)."""
|
| out = torch.zeros_like(img)
|
| if shift == 0:
|
| out.copy_(img)
|
| elif shift > 0:
|
| out[..., shift:] = img[..., :-shift]
|
| else:
|
| out[..., :shift] = img[..., -shift:]
|
| return out
|
|
|
|
|
|
|
|
|
| def test_loss_increases_with_translation():
|
| """MSLNCC loss (negative NCC, so higher = worse match) should increase
|
| monotonically as the translation between I and J grows."""
|
| loss_fn = MSLNCC(smooth=True, central=True)
|
| translations = [0, 2, 4, 8, 16]
|
| losses = []
|
|
|
| for t in translations:
|
| J = translate_image(REF_IMG, t)
|
| loss = loss_fn(REF_IMG, J).item()
|
| losses.append(loss)
|
|
|
|
|
| for i in range(1, len(losses)):
|
| assert losses[i] >= losses[i - 1], (
|
| f"Loss did not increase: translation {translations[i-1]}->{translations[i]}, "
|
| f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
|
| )
|
|
|
| assert losses[-1] > losses[0] + 1e-4, (
|
| f"Loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
|
| )
|
| print(f" translations: {translations}")
|
| print(f" losses: {[f'{l:.6f}' for l in losses]}")
|
|
|
|
|
|
|
|
|
| def test_gradient_flows():
|
| """Verify gradients are non-zero and finite for both I and J at all scales."""
|
| loss_fn = MSLNCC(smooth=True, central=True)
|
| I = REF_IMG.clone().requires_grad_(True)
|
| J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
|
|
|
| loss = loss_fn(I, J)
|
| loss.backward()
|
|
|
|
|
| assert I.grad is not None, "No gradient for I"
|
| assert torch.isfinite(I.grad).all(), "Non-finite gradient for I"
|
| assert I.grad.abs().sum() > 0, "Zero gradient for I"
|
|
|
|
|
| assert J.grad is not None, "No gradient for J"
|
| assert torch.isfinite(J.grad).all(), "Non-finite gradient for J"
|
| assert J.grad.abs().sum() > 0, "Zero gradient for J"
|
|
|
| print(f" I grad norm: {I.grad.norm():.6f}")
|
| print(f" J grad norm: {J.grad.norm():.6f}")
|
|
|
|
|
| def test_gradient_with_label():
|
| """Verify gradients flow correctly when a label mask is provided."""
|
| loss_fn = MSLNCC(smooth=True, central=True)
|
| I = REF_IMG.clone().requires_grad_(True)
|
| J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
|
|
|
| label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
|
| label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
|
|
|
| loss = loss_fn(I, J, label=label)
|
| loss.backward()
|
|
|
| assert I.grad is not None and torch.isfinite(I.grad).all(), "Bad gradient for I with label"
|
| assert J.grad is not None and torch.isfinite(J.grad).all(), "Bad gradient for J with label"
|
| assert I.grad.abs().sum() > 0, "Zero gradient for I with label"
|
| assert J.grad.abs().sum() > 0, "Zero gradient for J with label"
|
|
|
| print(f" I grad norm (masked): {I.grad.norm():.6f}")
|
| print(f" J grad norm (masked): {J.grad.norm():.6f}")
|
|
|
|
|
|
|
|
|
| def test_single_scale_matches_lncc():
|
| """MSLNCC with scale_ratios=[1] should produce the same loss as LNCC."""
|
| lncc_fn = LNCC(smooth=True, central=True)
|
| mslncc_fn = MSLNCC(smooth=True, central=True,
|
| scale_ratios=[1], scale_weights=[1])
|
|
|
| J = translate_image(REF_IMG, 4)
|
| loss_lncc = lncc_fn(REF_IMG, J).item()
|
| loss_mslncc = mslncc_fn(REF_IMG, J).item()
|
|
|
| assert abs(loss_lncc - loss_mslncc) < 1e-6, (
|
| f"Single-scale MSLNCC ({loss_mslncc:.8f}) != LNCC ({loss_lncc:.8f})"
|
| )
|
| print(f" LNCC: {loss_lncc:.8f}")
|
| print(f" MSLNCC: {loss_mslncc:.8f}")
|
|
|
|
|
|
|
|
|
| def test_multiscale_differs_from_single():
|
| """Multi-scale loss should differ from single-scale (coarser scales see
|
| different structure), confirming downsampled branches contribute."""
|
| single_fn = MSLNCC(smooth=True, central=True,
|
| scale_ratios=[1], scale_weights=[1])
|
| multi_fn = MSLNCC(smooth=True, central=True,
|
| scale_ratios=[1, 0.5, 0.25], scale_weights=[1, 0.5, 0.25])
|
|
|
| J = translate_image(REF_IMG, 8)
|
| loss_single = single_fn(REF_IMG, J).item()
|
| loss_multi = multi_fn(REF_IMG, J).item()
|
|
|
| assert abs(loss_single - loss_multi) > 1e-6, (
|
| f"Multi-scale loss ({loss_multi:.8f}) is identical to single-scale ({loss_single:.8f})"
|
| )
|
| print(f" single-scale: {loss_single:.8f}")
|
| print(f" multi-scale: {loss_multi:.8f}")
|
|
|
|
|
|
|
|
|
| def test_loss_increases_with_translation_labeled():
|
| """Same as test_loss_increases_with_translation but with a label mask."""
|
| loss_fn = MSLNCC(smooth=True, central=True)
|
| label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
|
| label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
|
|
|
| translations = [0, 2, 4, 8, 16]
|
| losses = []
|
|
|
| for t in translations:
|
| J = translate_image(REF_IMG, t)
|
| loss = loss_fn(REF_IMG, J, label=label).item()
|
| losses.append(loss)
|
|
|
| for i in range(1, len(losses)):
|
| assert losses[i] >= losses[i - 1], (
|
| f"Labeled loss did not increase: translation {translations[i-1]}->{translations[i]}, "
|
| f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
|
| )
|
| assert losses[-1] > losses[0] + 1e-4, (
|
| f"Labeled loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
|
| )
|
| print(f" translations: {translations}")
|
| print(f" losses: {[f'{l:.6f}' for l in losses]}")
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| tests = [
|
| ("Loss increases with translation", test_loss_increases_with_translation),
|
| ("Gradient flows (no label)", test_gradient_flows),
|
| ("Gradient flows (with label)", test_gradient_with_label),
|
| ("Single-scale matches LNCC", test_single_scale_matches_lncc),
|
| ("Multi-scale differs from single", test_multiscale_differs_from_single),
|
| ("Loss increases with translation (labeled)", test_loss_increases_with_translation_labeled),
|
| ]
|
| passed = 0
|
| failed = 0
|
| for name, fn in tests:
|
| print(f"\n[TEST] {name}")
|
| try:
|
| fn()
|
| print(f" PASSED")
|
| passed += 1
|
| except AssertionError as e:
|
| print(f" FAILED: {e}")
|
| failed += 1
|
| except Exception as e:
|
| print(f" ERROR: {type(e).__name__}: {e}")
|
| failed += 1
|
| print(f"\n{'='*40}")
|
| print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
|
| if failed:
|
| sys.exit(1)
|
|
|