Omini3D / tests /test_safe_conv_transpose_equiv.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
Test that SafeConvTranspose3d (V1, decomposed) is functionally identical to
torch.nn.ConvTranspose3d for the exact channel configurations used in OM_net
(the production network in OM_train_3modes.py).
OM_net.feat_channels = [1, 12, 32, 64, 128, 512]
Up-layers use SafeConvTranspose3d(ch, ch, 4, 2, 1) for ch in [512, 128, 64, 32, 12].
Tests:
1. Forward output matches within float32 tolerance (~5e-7)
2. Backward gradients (input, weight, bias) match
3. state_dict is interchangeable (load ConvTranspose3d weights into Safe and vice versa)
4. Multi-step optimization trajectories stay close
5. Full OM_net up-path simulation with chained layers
6. No-bias variant
7. Numerical gradient check (float64)
8. Batch dimension invariance
9. Determinism
Usage:
python tests/test_safe_conv_transpose_equiv.py
python -m pytest tests/test_safe_conv_transpose_equiv.py -v (if pytest available)
"""
import copy
import torch
import torch.nn as nn
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from Diffusion.safe_conv_transpose import SafeConvTranspose3d
# Exact channel sizes from OM_net.feat_channels (reversed for decoder)
OM_NET_UP_CHANNELS = [512, 128, 64, 32, 12]
# Kernel/stride/padding used in OM_net (networks.py line 1058)
K, S, P = 4, 2, 1
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_pair(in_c, out_c, bias=True):
"""Create nn.ConvTranspose3d and SafeConvTranspose3d with identical weights."""
torch.manual_seed(42)
ref = nn.ConvTranspose3d(in_c, out_c, K, S, P, bias=bias)
safe = SafeConvTranspose3d(in_c, out_c, K, S, P, bias=bias)
safe.weight.data.copy_(ref.weight.data)
if bias and ref.bias is not None:
safe.bias.data.copy_(ref.bias.data)
return ref, safe
# ---------------------------------------------------------------------------
# 1. Forward precision — exact OM_net channel configs
# ---------------------------------------------------------------------------
def test_forward_om_net_channels():
"""Forward output of SafeConvTranspose3d matches nn.ConvTranspose3d
within float32 tolerance for each OM_net up-layer channel size."""
for ch in OM_NET_UP_CHANNELS:
ref, safe = _make_pair(ch, ch)
# Spatial size 4 is the smallest the decoder sees (bottleneck at 128/(2^5)=4)
x = torch.randn(1, ch, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe(x)
assert y_ref.shape == y_safe.shape, f"Shape mismatch: {y_ref.shape} vs {y_safe.shape}"
max_diff = (y_ref - y_safe).abs().max().item()
assert max_diff < 1e-4, f"ch={ch}: forward max diff {max_diff:.2e} exceeds 1e-4"
print(f" ch={ch:3d}: max_diff={max_diff:.2e}")
print("PASS: test_forward_om_net_channels")
def test_forward_larger_spatial():
"""Test at a larger spatial size — more summation, so numerical
differences accumulate more."""
for ch in OM_NET_UP_CHANNELS:
ref, safe = _make_pair(ch, ch)
spatial = min(8, max(2, 512 // ch)) # keep memory reasonable for ch=512
x = torch.randn(1, ch, spatial, spatial, spatial)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe(x)
max_diff = (y_ref - y_safe).abs().max().item()
assert max_diff < 1e-4, f"ch={ch} spatial={spatial}: max diff {max_diff:.2e}"
print(f" ch={ch:3d} spatial={spatial}: max_diff={max_diff:.2e}")
print("PASS: test_forward_larger_spatial")
# ---------------------------------------------------------------------------
# 2. Backward gradients — OM_net channel configs
# ---------------------------------------------------------------------------
def test_backward_om_net_channels():
"""Gradients w.r.t. input, weight, and bias must match between
nn.ConvTranspose3d and SafeConvTranspose3d."""
for ch in OM_NET_UP_CHANNELS:
torch.manual_seed(42)
ref = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True)
safe = SafeConvTranspose3d(ch, ch, K, S, P, bias=True)
safe.weight.data.copy_(ref.weight.data)
safe.bias.data.copy_(ref.bias.data)
spatial = min(4, max(2, 256 // ch))
torch.manual_seed(123)
x_ref = torch.randn(1, ch, spatial, spatial, spatial, requires_grad=True)
x_safe = x_ref.detach().clone().requires_grad_(True)
torch.manual_seed(456)
grad_y = torch.randn(1, ch, 2 * spatial, 2 * spatial, 2 * spatial)
ref(x_ref).backward(grad_y)
safe(x_safe).backward(grad_y)
dx = (x_ref.grad - x_safe.grad).abs().max().item()
dw = (ref.weight.grad - safe.weight.grad).abs().max().item()
db = (ref.bias.grad - safe.bias.grad).abs().max().item()
assert dx < 1e-4, f"ch={ch}: grad_input diff {dx:.2e}"
assert dw < 1e-3, f"ch={ch}: grad_weight diff {dw:.2e}"
assert db < 1e-3, f"ch={ch}: grad_bias diff {db:.2e}"
print(f" ch={ch:3d}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}")
print("PASS: test_backward_om_net_channels")
# ---------------------------------------------------------------------------
# 3. state_dict compatibility
# ---------------------------------------------------------------------------
def test_state_dict_interchangeable():
"""Weights saved from nn.ConvTranspose3d load into SafeConvTranspose3d
and produce the same output (and vice versa)."""
for ch in OM_NET_UP_CHANNELS:
ref, _ = _make_pair(ch, ch)
sd = ref.state_dict()
safe_loaded = SafeConvTranspose3d(ch, ch, K, S, P, bias=True)
safe_loaded.load_state_dict(sd)
x = torch.randn(1, ch, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_loaded = safe_loaded(x)
diff_fwd = (y_ref - y_loaded).abs().max().item()
assert diff_fwd < 1e-4, f"ch={ch}: forward diff after load {diff_fwd:.2e}"
# Reverse direction: Safe -> ConvTranspose3d
_, safe = _make_pair(ch, ch)
sd_safe = safe.state_dict()
ref2 = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True)
ref2.load_state_dict(sd_safe)
with torch.no_grad():
y_safe = safe(x)
y_ref2 = ref2(x)
diff_rev = (y_safe - y_ref2).abs().max().item()
assert diff_rev < 1e-4, f"ch={ch}: reverse load diff {diff_rev:.2e}"
print(f" ch={ch:3d}: fwd_load_diff={diff_fwd:.2e} rev_load_diff={diff_rev:.2e}")
print("PASS: test_state_dict_interchangeable")
# ---------------------------------------------------------------------------
# 4. Multi-step optimization drift
# ---------------------------------------------------------------------------
def test_optimization_drift():
"""After N Adam steps, weights and losses stay close."""
for ch in [32, 128]: # representative subset to save time
torch.manual_seed(42)
ref = nn.ConvTranspose3d(ch, ch, K, S, P)
safe = SafeConvTranspose3d(ch, ch, K, S, P)
safe.weight.data.copy_(ref.weight.data)
safe.bias.data.copy_(ref.bias.data)
opt_ref = torch.optim.Adam(ref.parameters(), lr=1e-3)
opt_safe = torch.optim.Adam(safe.parameters(), lr=1e-3)
spatial = min(4, max(2, 256 // ch))
n_steps = 10
for step in range(n_steps):
torch.manual_seed(step * 100)
x = torch.randn(1, ch, spatial, spatial, spatial)
opt_ref.zero_grad()
loss_ref = ref(x).sum()
loss_ref.backward()
opt_ref.step()
opt_safe.zero_grad()
loss_safe = safe(x).sum()
loss_safe.backward()
opt_safe.step()
w_drift = (ref.weight.data - safe.weight.data).abs().max().item()
b_drift = (ref.bias.data - safe.bias.data).abs().max().item()
assert w_drift < 1e-3, f"ch={ch}: weight drift {w_drift:.2e} after {n_steps} steps"
assert b_drift < 1e-3, f"ch={ch}: bias drift {b_drift:.2e} after {n_steps} steps"
print(f" ch={ch:3d}: weight_drift={w_drift:.2e} bias_drift={b_drift:.2e} ({n_steps} steps)")
print("PASS: test_optimization_drift")
# ---------------------------------------------------------------------------
# 5. Chained up-path (simulates OM_net decoder)
# ---------------------------------------------------------------------------
def test_chained_up_path():
"""Simulate the OM_net decoder path: chain 5 SafeConvTranspose3d layers
matching the actual channel progression and verify outputs match a chain
of nn.ConvTranspose3d layers with the same weights.
OM_net decoder: 512->512, 128->128, 64->64, 32->32, 12->12
with 1x1 conv adaptors between layers to reduce channels.
"""
torch.manual_seed(42)
ref_layers = nn.ModuleList()
safe_layers = nn.ModuleList()
ref_adaptors = nn.ModuleList()
safe_adaptors = nn.ModuleList()
channels = OM_NET_UP_CHANNELS # [512, 128, 64, 32, 12]
for i, ch in enumerate(channels):
ref_up = nn.ConvTranspose3d(ch, ch, K, S, P)
safe_up = SafeConvTranspose3d(ch, ch, K, S, P)
safe_up.weight.data.copy_(ref_up.weight.data)
safe_up.bias.data.copy_(ref_up.bias.data)
ref_layers.append(ref_up)
safe_layers.append(safe_up)
# Channel reduction after up (except last layer)
if i < len(channels) - 1:
next_ch = channels[i + 1]
ref_conv = nn.Conv3d(ch, next_ch, 1, 1, 0)
safe_conv = nn.Conv3d(ch, next_ch, 1, 1, 0)
safe_conv.weight.data.copy_(ref_conv.weight.data)
safe_conv.bias.data.copy_(ref_conv.bias.data)
ref_adaptors.append(ref_conv)
safe_adaptors.append(safe_conv)
# Forward through chain: start at bottleneck spatial=4, ch=512
x = torch.randn(1, 512, 4, 4, 4)
x_ref = x.clone()
x_safe = x.clone()
with torch.no_grad():
for i in range(len(channels)):
x_ref = ref_layers[i](x_ref)
x_safe = safe_layers[i](x_safe)
if i < len(channels) - 1:
x_ref = ref_adaptors[i](x_ref)
x_safe = safe_adaptors[i](x_safe)
# After 5 upsample stages: 4 -> 8 -> 16 -> 32 -> 64 -> 128
assert x_ref.shape == (1, 12, 128, 128, 128), f"Unexpected ref shape: {x_ref.shape}"
assert x_safe.shape == (1, 12, 128, 128, 128), f"Unexpected safe shape: {x_safe.shape}"
max_diff = (x_ref - x_safe).abs().max().item()
mean_diff = (x_ref - x_safe).abs().mean().item()
# Accumulated error over 5 layers — allow more tolerance
assert max_diff < 1e-2, f"Chained path max diff {max_diff:.2e}"
assert mean_diff < 1e-4, f"Chained path mean diff {mean_diff:.2e}"
print(f" 5-layer chain: max_diff={max_diff:.2e} mean_diff={mean_diff:.2e}")
print("PASS: test_chained_up_path")
# ---------------------------------------------------------------------------
# 6. No bias variant
# ---------------------------------------------------------------------------
def test_no_bias():
"""bias=False must work correctly."""
for ch in [64, 128]:
ref, safe = _make_pair(ch, ch, bias=False)
assert safe.bias is None
x = torch.randn(1, ch, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe(x)
diff = (y_ref - y_safe).abs().max().item()
assert diff < 1e-4, f"ch={ch} no-bias: diff {diff:.2e}"
print(f" ch={ch:3d} no-bias: max_diff={diff:.2e}")
print("PASS: test_no_bias")
# ---------------------------------------------------------------------------
# 7. Numerical gradient check (float64)
# ---------------------------------------------------------------------------
def test_gradcheck():
"""Numerical Jacobian verification for SafeConvTranspose3d.
Uses small channels to keep gradcheck tractable."""
safe = SafeConvTranspose3d(2, 2, K, S, P, bias=True).double()
x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True)
result = torch.autograd.gradcheck(safe, (x,), eps=1e-6, atol=1e-4, rtol=1e-3)
assert result
print("PASS: test_gradcheck")
# ---------------------------------------------------------------------------
# 8. Batch dimension invariance
# ---------------------------------------------------------------------------
def test_batch_sizes():
"""Results should be identical regardless of batch size."""
ch = 64
ref, safe = _make_pair(ch, ch)
for batch_size in [1, 2, 4]:
x = torch.randn(batch_size, ch, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe(x)
assert y_ref.shape == y_safe.shape
max_diff = (y_ref - y_safe).abs().max().item()
assert max_diff < 1e-4, f"batch={batch_size}: max diff {max_diff:.2e}"
print(f" batch={batch_size}: max_diff={max_diff:.2e}")
print("PASS: test_batch_sizes")
# ---------------------------------------------------------------------------
# 9. Determinism — same input always gives same output
# ---------------------------------------------------------------------------
def test_determinism():
"""SafeConvTranspose3d must be deterministic (no stochastic ops)."""
ch = 64
safe = SafeConvTranspose3d(ch, ch, K, S, P)
x = torch.randn(1, ch, 4, 4, 4)
with torch.no_grad():
y1 = safe(x).clone()
y2 = safe(x).clone()
assert torch.equal(y1, y2), "SafeConvTranspose3d is not deterministic"
print("PASS: test_determinism")
# ---------------------------------------------------------------------------
# Main runner (no pytest needed)
# ---------------------------------------------------------------------------
if __name__ == '__main__':
print("=" * 70)
print("SafeConvTranspose3d equivalence tests (OM_net channel configs)")
print(f"OM_net up-layer channels: {OM_NET_UP_CHANNELS}")
print(f"ConvTranspose params: kernel={K}, stride={S}, padding={P}")
print("=" * 70)
tests = [
("Forward (OM_net channels)", test_forward_om_net_channels),
("Forward (larger spatial)", test_forward_larger_spatial),
("Backward (OM_net channels)", test_backward_om_net_channels),
("state_dict compatibility", test_state_dict_interchangeable),
("Optimization drift (10 steps)", test_optimization_drift),
("Chained up-path (5 layers)", test_chained_up_path),
("No bias", test_no_bias),
("Numerical gradcheck", test_gradcheck),
("Batch size invariance", test_batch_sizes),
("Determinism", test_determinism),
]
failed = []
for name, fn in tests:
print(f"\n--- {name} ---")
try:
fn()
except Exception as e:
print(f"FAIL: {name}: {e}")
failed.append(name)
print("\n" + "=" * 70)
if failed:
print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}")
else:
print(f"ALL {len(tests)} TESTS PASSED")
print("=" * 70)