""" Test that SafeConvTranspose3d (V1, decomposed) is functionally identical to torch.nn.ConvTranspose3d for the exact channel configurations used in OM_net (the production network in OM_train_3modes.py). OM_net.feat_channels = [1, 12, 32, 64, 128, 512] Up-layers use SafeConvTranspose3d(ch, ch, 4, 2, 1) for ch in [512, 128, 64, 32, 12]. Tests: 1. Forward output matches within float32 tolerance (~5e-7) 2. Backward gradients (input, weight, bias) match 3. state_dict is interchangeable (load ConvTranspose3d weights into Safe and vice versa) 4. Multi-step optimization trajectories stay close 5. Full OM_net up-path simulation with chained layers 6. No-bias variant 7. Numerical gradient check (float64) 8. Batch dimension invariance 9. Determinism Usage: python tests/test_safe_conv_transpose_equiv.py python -m pytest tests/test_safe_conv_transpose_equiv.py -v (if pytest available) """ import copy import torch import torch.nn as nn import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from Diffusion.safe_conv_transpose import SafeConvTranspose3d # Exact channel sizes from OM_net.feat_channels (reversed for decoder) OM_NET_UP_CHANNELS = [512, 128, 64, 32, 12] # Kernel/stride/padding used in OM_net (networks.py line 1058) K, S, P = 4, 2, 1 # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_pair(in_c, out_c, bias=True): """Create nn.ConvTranspose3d and SafeConvTranspose3d with identical weights.""" torch.manual_seed(42) ref = nn.ConvTranspose3d(in_c, out_c, K, S, P, bias=bias) safe = SafeConvTranspose3d(in_c, out_c, K, S, P, bias=bias) safe.weight.data.copy_(ref.weight.data) if bias and ref.bias is not None: safe.bias.data.copy_(ref.bias.data) return ref, safe # --------------------------------------------------------------------------- # 1. Forward precision — exact OM_net channel configs # --------------------------------------------------------------------------- def test_forward_om_net_channels(): """Forward output of SafeConvTranspose3d matches nn.ConvTranspose3d within float32 tolerance for each OM_net up-layer channel size.""" for ch in OM_NET_UP_CHANNELS: ref, safe = _make_pair(ch, ch) # Spatial size 4 is the smallest the decoder sees (bottleneck at 128/(2^5)=4) x = torch.randn(1, ch, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_safe = safe(x) assert y_ref.shape == y_safe.shape, f"Shape mismatch: {y_ref.shape} vs {y_safe.shape}" max_diff = (y_ref - y_safe).abs().max().item() assert max_diff < 1e-4, f"ch={ch}: forward max diff {max_diff:.2e} exceeds 1e-4" print(f" ch={ch:3d}: max_diff={max_diff:.2e}") print("PASS: test_forward_om_net_channels") def test_forward_larger_spatial(): """Test at a larger spatial size — more summation, so numerical differences accumulate more.""" for ch in OM_NET_UP_CHANNELS: ref, safe = _make_pair(ch, ch) spatial = min(8, max(2, 512 // ch)) # keep memory reasonable for ch=512 x = torch.randn(1, ch, spatial, spatial, spatial) with torch.no_grad(): y_ref = ref(x) y_safe = safe(x) max_diff = (y_ref - y_safe).abs().max().item() assert max_diff < 1e-4, f"ch={ch} spatial={spatial}: max diff {max_diff:.2e}" print(f" ch={ch:3d} spatial={spatial}: max_diff={max_diff:.2e}") print("PASS: test_forward_larger_spatial") # --------------------------------------------------------------------------- # 2. Backward gradients — OM_net channel configs # --------------------------------------------------------------------------- def test_backward_om_net_channels(): """Gradients w.r.t. input, weight, and bias must match between nn.ConvTranspose3d and SafeConvTranspose3d.""" for ch in OM_NET_UP_CHANNELS: torch.manual_seed(42) ref = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True) safe = SafeConvTranspose3d(ch, ch, K, S, P, bias=True) safe.weight.data.copy_(ref.weight.data) safe.bias.data.copy_(ref.bias.data) spatial = min(4, max(2, 256 // ch)) torch.manual_seed(123) x_ref = torch.randn(1, ch, spatial, spatial, spatial, requires_grad=True) x_safe = x_ref.detach().clone().requires_grad_(True) torch.manual_seed(456) grad_y = torch.randn(1, ch, 2 * spatial, 2 * spatial, 2 * spatial) ref(x_ref).backward(grad_y) safe(x_safe).backward(grad_y) dx = (x_ref.grad - x_safe.grad).abs().max().item() dw = (ref.weight.grad - safe.weight.grad).abs().max().item() db = (ref.bias.grad - safe.bias.grad).abs().max().item() assert dx < 1e-4, f"ch={ch}: grad_input diff {dx:.2e}" assert dw < 1e-3, f"ch={ch}: grad_weight diff {dw:.2e}" assert db < 1e-3, f"ch={ch}: grad_bias diff {db:.2e}" print(f" ch={ch:3d}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}") print("PASS: test_backward_om_net_channels") # --------------------------------------------------------------------------- # 3. state_dict compatibility # --------------------------------------------------------------------------- def test_state_dict_interchangeable(): """Weights saved from nn.ConvTranspose3d load into SafeConvTranspose3d and produce the same output (and vice versa).""" for ch in OM_NET_UP_CHANNELS: ref, _ = _make_pair(ch, ch) sd = ref.state_dict() safe_loaded = SafeConvTranspose3d(ch, ch, K, S, P, bias=True) safe_loaded.load_state_dict(sd) x = torch.randn(1, ch, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_loaded = safe_loaded(x) diff_fwd = (y_ref - y_loaded).abs().max().item() assert diff_fwd < 1e-4, f"ch={ch}: forward diff after load {diff_fwd:.2e}" # Reverse direction: Safe -> ConvTranspose3d _, safe = _make_pair(ch, ch) sd_safe = safe.state_dict() ref2 = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True) ref2.load_state_dict(sd_safe) with torch.no_grad(): y_safe = safe(x) y_ref2 = ref2(x) diff_rev = (y_safe - y_ref2).abs().max().item() assert diff_rev < 1e-4, f"ch={ch}: reverse load diff {diff_rev:.2e}" print(f" ch={ch:3d}: fwd_load_diff={diff_fwd:.2e} rev_load_diff={diff_rev:.2e}") print("PASS: test_state_dict_interchangeable") # --------------------------------------------------------------------------- # 4. Multi-step optimization drift # --------------------------------------------------------------------------- def test_optimization_drift(): """After N Adam steps, weights and losses stay close.""" for ch in [32, 128]: # representative subset to save time torch.manual_seed(42) ref = nn.ConvTranspose3d(ch, ch, K, S, P) safe = SafeConvTranspose3d(ch, ch, K, S, P) safe.weight.data.copy_(ref.weight.data) safe.bias.data.copy_(ref.bias.data) opt_ref = torch.optim.Adam(ref.parameters(), lr=1e-3) opt_safe = torch.optim.Adam(safe.parameters(), lr=1e-3) spatial = min(4, max(2, 256 // ch)) n_steps = 10 for step in range(n_steps): torch.manual_seed(step * 100) x = torch.randn(1, ch, spatial, spatial, spatial) opt_ref.zero_grad() loss_ref = ref(x).sum() loss_ref.backward() opt_ref.step() opt_safe.zero_grad() loss_safe = safe(x).sum() loss_safe.backward() opt_safe.step() w_drift = (ref.weight.data - safe.weight.data).abs().max().item() b_drift = (ref.bias.data - safe.bias.data).abs().max().item() assert w_drift < 1e-3, f"ch={ch}: weight drift {w_drift:.2e} after {n_steps} steps" assert b_drift < 1e-3, f"ch={ch}: bias drift {b_drift:.2e} after {n_steps} steps" print(f" ch={ch:3d}: weight_drift={w_drift:.2e} bias_drift={b_drift:.2e} ({n_steps} steps)") print("PASS: test_optimization_drift") # --------------------------------------------------------------------------- # 5. Chained up-path (simulates OM_net decoder) # --------------------------------------------------------------------------- def test_chained_up_path(): """Simulate the OM_net decoder path: chain 5 SafeConvTranspose3d layers matching the actual channel progression and verify outputs match a chain of nn.ConvTranspose3d layers with the same weights. OM_net decoder: 512->512, 128->128, 64->64, 32->32, 12->12 with 1x1 conv adaptors between layers to reduce channels. """ torch.manual_seed(42) ref_layers = nn.ModuleList() safe_layers = nn.ModuleList() ref_adaptors = nn.ModuleList() safe_adaptors = nn.ModuleList() channels = OM_NET_UP_CHANNELS # [512, 128, 64, 32, 12] for i, ch in enumerate(channels): ref_up = nn.ConvTranspose3d(ch, ch, K, S, P) safe_up = SafeConvTranspose3d(ch, ch, K, S, P) safe_up.weight.data.copy_(ref_up.weight.data) safe_up.bias.data.copy_(ref_up.bias.data) ref_layers.append(ref_up) safe_layers.append(safe_up) # Channel reduction after up (except last layer) if i < len(channels) - 1: next_ch = channels[i + 1] ref_conv = nn.Conv3d(ch, next_ch, 1, 1, 0) safe_conv = nn.Conv3d(ch, next_ch, 1, 1, 0) safe_conv.weight.data.copy_(ref_conv.weight.data) safe_conv.bias.data.copy_(ref_conv.bias.data) ref_adaptors.append(ref_conv) safe_adaptors.append(safe_conv) # Forward through chain: start at bottleneck spatial=4, ch=512 x = torch.randn(1, 512, 4, 4, 4) x_ref = x.clone() x_safe = x.clone() with torch.no_grad(): for i in range(len(channels)): x_ref = ref_layers[i](x_ref) x_safe = safe_layers[i](x_safe) if i < len(channels) - 1: x_ref = ref_adaptors[i](x_ref) x_safe = safe_adaptors[i](x_safe) # After 5 upsample stages: 4 -> 8 -> 16 -> 32 -> 64 -> 128 assert x_ref.shape == (1, 12, 128, 128, 128), f"Unexpected ref shape: {x_ref.shape}" assert x_safe.shape == (1, 12, 128, 128, 128), f"Unexpected safe shape: {x_safe.shape}" max_diff = (x_ref - x_safe).abs().max().item() mean_diff = (x_ref - x_safe).abs().mean().item() # Accumulated error over 5 layers — allow more tolerance assert max_diff < 1e-2, f"Chained path max diff {max_diff:.2e}" assert mean_diff < 1e-4, f"Chained path mean diff {mean_diff:.2e}" print(f" 5-layer chain: max_diff={max_diff:.2e} mean_diff={mean_diff:.2e}") print("PASS: test_chained_up_path") # --------------------------------------------------------------------------- # 6. No bias variant # --------------------------------------------------------------------------- def test_no_bias(): """bias=False must work correctly.""" for ch in [64, 128]: ref, safe = _make_pair(ch, ch, bias=False) assert safe.bias is None x = torch.randn(1, ch, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_safe = safe(x) diff = (y_ref - y_safe).abs().max().item() assert diff < 1e-4, f"ch={ch} no-bias: diff {diff:.2e}" print(f" ch={ch:3d} no-bias: max_diff={diff:.2e}") print("PASS: test_no_bias") # --------------------------------------------------------------------------- # 7. Numerical gradient check (float64) # --------------------------------------------------------------------------- def test_gradcheck(): """Numerical Jacobian verification for SafeConvTranspose3d. Uses small channels to keep gradcheck tractable.""" safe = SafeConvTranspose3d(2, 2, K, S, P, bias=True).double() x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) result = torch.autograd.gradcheck(safe, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) assert result print("PASS: test_gradcheck") # --------------------------------------------------------------------------- # 8. Batch dimension invariance # --------------------------------------------------------------------------- def test_batch_sizes(): """Results should be identical regardless of batch size.""" ch = 64 ref, safe = _make_pair(ch, ch) for batch_size in [1, 2, 4]: x = torch.randn(batch_size, ch, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_safe = safe(x) assert y_ref.shape == y_safe.shape max_diff = (y_ref - y_safe).abs().max().item() assert max_diff < 1e-4, f"batch={batch_size}: max diff {max_diff:.2e}" print(f" batch={batch_size}: max_diff={max_diff:.2e}") print("PASS: test_batch_sizes") # --------------------------------------------------------------------------- # 9. Determinism — same input always gives same output # --------------------------------------------------------------------------- def test_determinism(): """SafeConvTranspose3d must be deterministic (no stochastic ops).""" ch = 64 safe = SafeConvTranspose3d(ch, ch, K, S, P) x = torch.randn(1, ch, 4, 4, 4) with torch.no_grad(): y1 = safe(x).clone() y2 = safe(x).clone() assert torch.equal(y1, y2), "SafeConvTranspose3d is not deterministic" print("PASS: test_determinism") # --------------------------------------------------------------------------- # Main runner (no pytest needed) # --------------------------------------------------------------------------- if __name__ == '__main__': print("=" * 70) print("SafeConvTranspose3d equivalence tests (OM_net channel configs)") print(f"OM_net up-layer channels: {OM_NET_UP_CHANNELS}") print(f"ConvTranspose params: kernel={K}, stride={S}, padding={P}") print("=" * 70) tests = [ ("Forward (OM_net channels)", test_forward_om_net_channels), ("Forward (larger spatial)", test_forward_larger_spatial), ("Backward (OM_net channels)", test_backward_om_net_channels), ("state_dict compatibility", test_state_dict_interchangeable), ("Optimization drift (10 steps)", test_optimization_drift), ("Chained up-path (5 layers)", test_chained_up_path), ("No bias", test_no_bias), ("Numerical gradcheck", test_gradcheck), ("Batch size invariance", test_batch_sizes), ("Determinism", test_determinism), ] failed = [] for name, fn in tests: print(f"\n--- {name} ---") try: fn() except Exception as e: print(f"FAIL: {name}: {e}") failed.append(name) print("\n" + "=" * 70) if failed: print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}") else: print(f"ALL {len(tests)} TESTS PASSED") print("=" * 70)