""" Tests for MSLNCC (Multi-Scale Local Normalized Cross-Correlation) loss. Verifies: 1. Loss increases monotonically with larger spatial translations. 2. Gradients flow correctly through all scale branches. 3. Consistency with single-scale LNCC when only one scale is used. 4. Label masking works at all scales. Run: python -m pytest tests/test_mslncc.py -v # or directly: python tests/test_mslncc.py """ import os import sys import torch # Ensure project root is importable ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, ROOT) from Diffusion.losses import LNCC, MSLNCC # ---------- helpers ---------- SIZE = 64 # must be divisible by max downscale factor (4 for ratio=0.25) torch.manual_seed(42) # Fixed reference images — reused across tests REF_IMG = torch.rand(1, 1, SIZE, SIZE, SIZE) def translate_image(img, shift): """Translate image along the last axis by `shift` voxels (zero-fill).""" out = torch.zeros_like(img) if shift == 0: out.copy_(img) elif shift > 0: out[..., shift:] = img[..., :-shift] else: out[..., :shift] = img[..., -shift:] return out # ---------- Test 1: loss increases with translation ---------- def test_loss_increases_with_translation(): """MSLNCC loss (negative NCC, so higher = worse match) should increase monotonically as the translation between I and J grows.""" loss_fn = MSLNCC(smooth=True, central=True) translations = [0, 2, 4, 8, 16] losses = [] for t in translations: J = translate_image(REF_IMG, t) loss = loss_fn(REF_IMG, J).item() losses.append(loss) # loss should be monotonically non-decreasing for i in range(1, len(losses)): assert losses[i] >= losses[i - 1], ( f"Loss did not increase: translation {translations[i-1]}->{translations[i]}, " f"loss {losses[i-1]:.6f}->{losses[i]:.6f}" ) # first and last should be clearly different assert losses[-1] > losses[0] + 1e-4, ( f"Loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}" ) print(f" translations: {translations}") print(f" losses: {[f'{l:.6f}' for l in losses]}") # ---------- Test 2: gradients are properly computed ---------- def test_gradient_flows(): """Verify gradients are non-zero and finite for both I and J at all scales.""" loss_fn = MSLNCC(smooth=True, central=True) I = REF_IMG.clone().requires_grad_(True) J = translate_image(REF_IMG, 4).clone().requires_grad_(True) loss = loss_fn(I, J) loss.backward() # Check I gradient assert I.grad is not None, "No gradient for I" assert torch.isfinite(I.grad).all(), "Non-finite gradient for I" assert I.grad.abs().sum() > 0, "Zero gradient for I" # Check J gradient assert J.grad is not None, "No gradient for J" assert torch.isfinite(J.grad).all(), "Non-finite gradient for J" assert J.grad.abs().sum() > 0, "Zero gradient for J" print(f" I grad norm: {I.grad.norm():.6f}") print(f" J grad norm: {J.grad.norm():.6f}") def test_gradient_with_label(): """Verify gradients flow correctly when a label mask is provided.""" loss_fn = MSLNCC(smooth=True, central=True) I = REF_IMG.clone().requires_grad_(True) J = translate_image(REF_IMG, 4).clone().requires_grad_(True) # Label: central cube label = torch.zeros(1, 1, SIZE, SIZE, SIZE) label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0 loss = loss_fn(I, J, label=label) loss.backward() assert I.grad is not None and torch.isfinite(I.grad).all(), "Bad gradient for I with label" assert J.grad is not None and torch.isfinite(J.grad).all(), "Bad gradient for J with label" assert I.grad.abs().sum() > 0, "Zero gradient for I with label" assert J.grad.abs().sum() > 0, "Zero gradient for J with label" print(f" I grad norm (masked): {I.grad.norm():.6f}") print(f" J grad norm (masked): {J.grad.norm():.6f}") # ---------- Test 3: single-scale consistency with LNCC ---------- def test_single_scale_matches_lncc(): """MSLNCC with scale_ratios=[1] should produce the same loss as LNCC.""" lncc_fn = LNCC(smooth=True, central=True) mslncc_fn = MSLNCC(smooth=True, central=True, scale_ratios=[1], scale_weights=[1]) J = translate_image(REF_IMG, 4) loss_lncc = lncc_fn(REF_IMG, J).item() loss_mslncc = mslncc_fn(REF_IMG, J).item() assert abs(loss_lncc - loss_mslncc) < 1e-6, ( f"Single-scale MSLNCC ({loss_mslncc:.8f}) != LNCC ({loss_lncc:.8f})" ) print(f" LNCC: {loss_lncc:.8f}") print(f" MSLNCC: {loss_mslncc:.8f}") # ---------- Test 4: multi-scale produces different loss than single-scale ---------- def test_multiscale_differs_from_single(): """Multi-scale loss should differ from single-scale (coarser scales see different structure), confirming downsampled branches contribute.""" single_fn = MSLNCC(smooth=True, central=True, scale_ratios=[1], scale_weights=[1]) multi_fn = MSLNCC(smooth=True, central=True, scale_ratios=[1, 0.5, 0.25], scale_weights=[1, 0.5, 0.25]) J = translate_image(REF_IMG, 8) loss_single = single_fn(REF_IMG, J).item() loss_multi = multi_fn(REF_IMG, J).item() assert abs(loss_single - loss_multi) > 1e-6, ( f"Multi-scale loss ({loss_multi:.8f}) is identical to single-scale ({loss_single:.8f})" ) print(f" single-scale: {loss_single:.8f}") print(f" multi-scale: {loss_multi:.8f}") # ---------- Test 5: loss increases with translation (with label) ---------- def test_loss_increases_with_translation_labeled(): """Same as test_loss_increases_with_translation but with a label mask.""" loss_fn = MSLNCC(smooth=True, central=True) label = torch.zeros(1, 1, SIZE, SIZE, SIZE) label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0 translations = [0, 2, 4, 8, 16] losses = [] for t in translations: J = translate_image(REF_IMG, t) loss = loss_fn(REF_IMG, J, label=label).item() losses.append(loss) for i in range(1, len(losses)): assert losses[i] >= losses[i - 1], ( f"Labeled loss did not increase: translation {translations[i-1]}->{translations[i]}, " f"loss {losses[i-1]:.6f}->{losses[i]:.6f}" ) assert losses[-1] > losses[0] + 1e-4, ( f"Labeled loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}" ) print(f" translations: {translations}") print(f" losses: {[f'{l:.6f}' for l in losses]}") # ---------- runner ---------- if __name__ == "__main__": tests = [ ("Loss increases with translation", test_loss_increases_with_translation), ("Gradient flows (no label)", test_gradient_flows), ("Gradient flows (with label)", test_gradient_with_label), ("Single-scale matches LNCC", test_single_scale_matches_lncc), ("Multi-scale differs from single", test_multiscale_differs_from_single), ("Loss increases with translation (labeled)", test_loss_increases_with_translation_labeled), ] passed = 0 failed = 0 for name, fn in tests: print(f"\n[TEST] {name}") try: fn() print(f" PASSED") passed += 1 except AssertionError as e: print(f" FAILED: {e}") failed += 1 except Exception as e: print(f" ERROR: {type(e).__name__}: {e}") failed += 1 print(f"\n{'='*40}") print(f"Results: {passed} passed, {failed} failed out of {passed + failed}") if failed: sys.exit(1)