File size: 7,974 Bytes
2af0e94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | """
Tests for MSLNCC (Multi-Scale Local Normalized Cross-Correlation) loss.
Verifies:
1. Loss increases monotonically with larger spatial translations.
2. Gradients flow correctly through all scale branches.
3. Consistency with single-scale LNCC when only one scale is used.
4. Label masking works at all scales.
Run:
python -m pytest tests/test_mslncc.py -v
# or directly:
python tests/test_mslncc.py
"""
import os
import sys
import torch
# Ensure project root is importable
ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)
from Diffusion.losses import LNCC, MSLNCC
# ---------- helpers ----------
SIZE = 64 # must be divisible by max downscale factor (4 for ratio=0.25)
torch.manual_seed(42)
# Fixed reference images — reused across tests
REF_IMG = torch.rand(1, 1, SIZE, SIZE, SIZE)
def translate_image(img, shift):
"""Translate image along the last axis by `shift` voxels (zero-fill)."""
out = torch.zeros_like(img)
if shift == 0:
out.copy_(img)
elif shift > 0:
out[..., shift:] = img[..., :-shift]
else:
out[..., :shift] = img[..., -shift:]
return out
# ---------- Test 1: loss increases with translation ----------
def test_loss_increases_with_translation():
"""MSLNCC loss (negative NCC, so higher = worse match) should increase
monotonically as the translation between I and J grows."""
loss_fn = MSLNCC(smooth=True, central=True)
translations = [0, 2, 4, 8, 16]
losses = []
for t in translations:
J = translate_image(REF_IMG, t)
loss = loss_fn(REF_IMG, J).item()
losses.append(loss)
# loss should be monotonically non-decreasing
for i in range(1, len(losses)):
assert losses[i] >= losses[i - 1], (
f"Loss did not increase: translation {translations[i-1]}->{translations[i]}, "
f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
)
# first and last should be clearly different
assert losses[-1] > losses[0] + 1e-4, (
f"Loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
)
print(f" translations: {translations}")
print(f" losses: {[f'{l:.6f}' for l in losses]}")
# ---------- Test 2: gradients are properly computed ----------
def test_gradient_flows():
"""Verify gradients are non-zero and finite for both I and J at all scales."""
loss_fn = MSLNCC(smooth=True, central=True)
I = REF_IMG.clone().requires_grad_(True)
J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
loss = loss_fn(I, J)
loss.backward()
# Check I gradient
assert I.grad is not None, "No gradient for I"
assert torch.isfinite(I.grad).all(), "Non-finite gradient for I"
assert I.grad.abs().sum() > 0, "Zero gradient for I"
# Check J gradient
assert J.grad is not None, "No gradient for J"
assert torch.isfinite(J.grad).all(), "Non-finite gradient for J"
assert J.grad.abs().sum() > 0, "Zero gradient for J"
print(f" I grad norm: {I.grad.norm():.6f}")
print(f" J grad norm: {J.grad.norm():.6f}")
def test_gradient_with_label():
"""Verify gradients flow correctly when a label mask is provided."""
loss_fn = MSLNCC(smooth=True, central=True)
I = REF_IMG.clone().requires_grad_(True)
J = translate_image(REF_IMG, 4).clone().requires_grad_(True)
# Label: central cube
label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
loss = loss_fn(I, J, label=label)
loss.backward()
assert I.grad is not None and torch.isfinite(I.grad).all(), "Bad gradient for I with label"
assert J.grad is not None and torch.isfinite(J.grad).all(), "Bad gradient for J with label"
assert I.grad.abs().sum() > 0, "Zero gradient for I with label"
assert J.grad.abs().sum() > 0, "Zero gradient for J with label"
print(f" I grad norm (masked): {I.grad.norm():.6f}")
print(f" J grad norm (masked): {J.grad.norm():.6f}")
# ---------- Test 3: single-scale consistency with LNCC ----------
def test_single_scale_matches_lncc():
"""MSLNCC with scale_ratios=[1] should produce the same loss as LNCC."""
lncc_fn = LNCC(smooth=True, central=True)
mslncc_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1], scale_weights=[1])
J = translate_image(REF_IMG, 4)
loss_lncc = lncc_fn(REF_IMG, J).item()
loss_mslncc = mslncc_fn(REF_IMG, J).item()
assert abs(loss_lncc - loss_mslncc) < 1e-6, (
f"Single-scale MSLNCC ({loss_mslncc:.8f}) != LNCC ({loss_lncc:.8f})"
)
print(f" LNCC: {loss_lncc:.8f}")
print(f" MSLNCC: {loss_mslncc:.8f}")
# ---------- Test 4: multi-scale produces different loss than single-scale ----------
def test_multiscale_differs_from_single():
"""Multi-scale loss should differ from single-scale (coarser scales see
different structure), confirming downsampled branches contribute."""
single_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1], scale_weights=[1])
multi_fn = MSLNCC(smooth=True, central=True,
scale_ratios=[1, 0.5, 0.25], scale_weights=[1, 0.5, 0.25])
J = translate_image(REF_IMG, 8)
loss_single = single_fn(REF_IMG, J).item()
loss_multi = multi_fn(REF_IMG, J).item()
assert abs(loss_single - loss_multi) > 1e-6, (
f"Multi-scale loss ({loss_multi:.8f}) is identical to single-scale ({loss_single:.8f})"
)
print(f" single-scale: {loss_single:.8f}")
print(f" multi-scale: {loss_multi:.8f}")
# ---------- Test 5: loss increases with translation (with label) ----------
def test_loss_increases_with_translation_labeled():
"""Same as test_loss_increases_with_translation but with a label mask."""
loss_fn = MSLNCC(smooth=True, central=True)
label = torch.zeros(1, 1, SIZE, SIZE, SIZE)
label[:, :, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4, SIZE//4:3*SIZE//4] = 1.0
translations = [0, 2, 4, 8, 16]
losses = []
for t in translations:
J = translate_image(REF_IMG, t)
loss = loss_fn(REF_IMG, J, label=label).item()
losses.append(loss)
for i in range(1, len(losses)):
assert losses[i] >= losses[i - 1], (
f"Labeled loss did not increase: translation {translations[i-1]}->{translations[i]}, "
f"loss {losses[i-1]:.6f}->{losses[i]:.6f}"
)
assert losses[-1] > losses[0] + 1e-4, (
f"Labeled loss range too small: {losses[0]:.6f} to {losses[-1]:.6f}"
)
print(f" translations: {translations}")
print(f" losses: {[f'{l:.6f}' for l in losses]}")
# ---------- runner ----------
if __name__ == "__main__":
tests = [
("Loss increases with translation", test_loss_increases_with_translation),
("Gradient flows (no label)", test_gradient_flows),
("Gradient flows (with label)", test_gradient_with_label),
("Single-scale matches LNCC", test_single_scale_matches_lncc),
("Multi-scale differs from single", test_multiscale_differs_from_single),
("Loss increases with translation (labeled)", test_loss_increases_with_translation_labeled),
]
passed = 0
failed = 0
for name, fn in tests:
print(f"\n[TEST] {name}")
try:
fn()
print(f" PASSED")
passed += 1
except AssertionError as e:
print(f" FAILED: {e}")
failed += 1
except Exception as e:
print(f" ERROR: {type(e).__name__}: {e}")
failed += 1
print(f"\n{'='*40}")
print(f"Results: {passed} passed, {failed} failed out of {passed + failed}")
if failed:
sys.exit(1)
|