| """ |
| Test that SafeConvTranspose3d (V1, decomposed) is functionally identical to |
| torch.nn.ConvTranspose3d for the exact channel configurations used in OM_net |
| (the production network in OM_train_3modes.py). |
| |
| OM_net.feat_channels = [1, 12, 32, 64, 128, 512] |
| Up-layers use SafeConvTranspose3d(ch, ch, 4, 2, 1) for ch in [512, 128, 64, 32, 12]. |
| |
| Tests: |
| 1. Forward output matches within float32 tolerance (~5e-7) |
| 2. Backward gradients (input, weight, bias) match |
| 3. state_dict is interchangeable (load ConvTranspose3d weights into Safe and vice versa) |
| 4. Multi-step optimization trajectories stay close |
| 5. Full OM_net up-path simulation with chained layers |
| 6. No-bias variant |
| 7. Numerical gradient check (float64) |
| 8. Batch dimension invariance |
| 9. Determinism |
| |
| Usage: |
| python tests/test_safe_conv_transpose_equiv.py |
| python -m pytest tests/test_safe_conv_transpose_equiv.py -v (if pytest available) |
| """ |
|
|
| import copy |
| import torch |
| import torch.nn as nn |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
| from Diffusion.safe_conv_transpose import SafeConvTranspose3d |
|
|
|
|
| |
| OM_NET_UP_CHANNELS = [512, 128, 64, 32, 12] |
|
|
| |
| K, S, P = 4, 2, 1 |
|
|
|
|
| |
| |
| |
|
|
| def _make_pair(in_c, out_c, bias=True): |
| """Create nn.ConvTranspose3d and SafeConvTranspose3d with identical weights.""" |
| torch.manual_seed(42) |
| ref = nn.ConvTranspose3d(in_c, out_c, K, S, P, bias=bias) |
| safe = SafeConvTranspose3d(in_c, out_c, K, S, P, bias=bias) |
| safe.weight.data.copy_(ref.weight.data) |
| if bias and ref.bias is not None: |
| safe.bias.data.copy_(ref.bias.data) |
| return ref, safe |
|
|
|
|
| |
| |
| |
|
|
| def test_forward_om_net_channels(): |
| """Forward output of SafeConvTranspose3d matches nn.ConvTranspose3d |
| within float32 tolerance for each OM_net up-layer channel size.""" |
| for ch in OM_NET_UP_CHANNELS: |
| ref, safe = _make_pair(ch, ch) |
| |
| x = torch.randn(1, ch, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe(x) |
| assert y_ref.shape == y_safe.shape, f"Shape mismatch: {y_ref.shape} vs {y_safe.shape}" |
| max_diff = (y_ref - y_safe).abs().max().item() |
| assert max_diff < 1e-4, f"ch={ch}: forward max diff {max_diff:.2e} exceeds 1e-4" |
| print(f" ch={ch:3d}: max_diff={max_diff:.2e}") |
| print("PASS: test_forward_om_net_channels") |
|
|
|
|
| def test_forward_larger_spatial(): |
| """Test at a larger spatial size — more summation, so numerical |
| differences accumulate more.""" |
| for ch in OM_NET_UP_CHANNELS: |
| ref, safe = _make_pair(ch, ch) |
| spatial = min(8, max(2, 512 // ch)) |
| x = torch.randn(1, ch, spatial, spatial, spatial) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe(x) |
| max_diff = (y_ref - y_safe).abs().max().item() |
| assert max_diff < 1e-4, f"ch={ch} spatial={spatial}: max diff {max_diff:.2e}" |
| print(f" ch={ch:3d} spatial={spatial}: max_diff={max_diff:.2e}") |
| print("PASS: test_forward_larger_spatial") |
|
|
|
|
| |
| |
| |
|
|
| def test_backward_om_net_channels(): |
| """Gradients w.r.t. input, weight, and bias must match between |
| nn.ConvTranspose3d and SafeConvTranspose3d.""" |
| for ch in OM_NET_UP_CHANNELS: |
| torch.manual_seed(42) |
| ref = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True) |
| safe = SafeConvTranspose3d(ch, ch, K, S, P, bias=True) |
| safe.weight.data.copy_(ref.weight.data) |
| safe.bias.data.copy_(ref.bias.data) |
|
|
| spatial = min(4, max(2, 256 // ch)) |
| torch.manual_seed(123) |
| x_ref = torch.randn(1, ch, spatial, spatial, spatial, requires_grad=True) |
| x_safe = x_ref.detach().clone().requires_grad_(True) |
|
|
| torch.manual_seed(456) |
| grad_y = torch.randn(1, ch, 2 * spatial, 2 * spatial, 2 * spatial) |
|
|
| ref(x_ref).backward(grad_y) |
| safe(x_safe).backward(grad_y) |
|
|
| dx = (x_ref.grad - x_safe.grad).abs().max().item() |
| dw = (ref.weight.grad - safe.weight.grad).abs().max().item() |
| db = (ref.bias.grad - safe.bias.grad).abs().max().item() |
|
|
| assert dx < 1e-4, f"ch={ch}: grad_input diff {dx:.2e}" |
| assert dw < 1e-3, f"ch={ch}: grad_weight diff {dw:.2e}" |
| assert db < 1e-3, f"ch={ch}: grad_bias diff {db:.2e}" |
| print(f" ch={ch:3d}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}") |
| print("PASS: test_backward_om_net_channels") |
|
|
|
|
| |
| |
| |
|
|
| def test_state_dict_interchangeable(): |
| """Weights saved from nn.ConvTranspose3d load into SafeConvTranspose3d |
| and produce the same output (and vice versa).""" |
| for ch in OM_NET_UP_CHANNELS: |
| ref, _ = _make_pair(ch, ch) |
| sd = ref.state_dict() |
|
|
| safe_loaded = SafeConvTranspose3d(ch, ch, K, S, P, bias=True) |
| safe_loaded.load_state_dict(sd) |
|
|
| x = torch.randn(1, ch, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_loaded = safe_loaded(x) |
| diff_fwd = (y_ref - y_loaded).abs().max().item() |
| assert diff_fwd < 1e-4, f"ch={ch}: forward diff after load {diff_fwd:.2e}" |
|
|
| |
| _, safe = _make_pair(ch, ch) |
| sd_safe = safe.state_dict() |
| ref2 = nn.ConvTranspose3d(ch, ch, K, S, P, bias=True) |
| ref2.load_state_dict(sd_safe) |
| with torch.no_grad(): |
| y_safe = safe(x) |
| y_ref2 = ref2(x) |
| diff_rev = (y_safe - y_ref2).abs().max().item() |
| assert diff_rev < 1e-4, f"ch={ch}: reverse load diff {diff_rev:.2e}" |
| print(f" ch={ch:3d}: fwd_load_diff={diff_fwd:.2e} rev_load_diff={diff_rev:.2e}") |
| print("PASS: test_state_dict_interchangeable") |
|
|
|
|
| |
| |
| |
|
|
| def test_optimization_drift(): |
| """After N Adam steps, weights and losses stay close.""" |
| for ch in [32, 128]: |
| torch.manual_seed(42) |
| ref = nn.ConvTranspose3d(ch, ch, K, S, P) |
| safe = SafeConvTranspose3d(ch, ch, K, S, P) |
| safe.weight.data.copy_(ref.weight.data) |
| safe.bias.data.copy_(ref.bias.data) |
|
|
| opt_ref = torch.optim.Adam(ref.parameters(), lr=1e-3) |
| opt_safe = torch.optim.Adam(safe.parameters(), lr=1e-3) |
|
|
| spatial = min(4, max(2, 256 // ch)) |
| n_steps = 10 |
|
|
| for step in range(n_steps): |
| torch.manual_seed(step * 100) |
| x = torch.randn(1, ch, spatial, spatial, spatial) |
|
|
| opt_ref.zero_grad() |
| loss_ref = ref(x).sum() |
| loss_ref.backward() |
| opt_ref.step() |
|
|
| opt_safe.zero_grad() |
| loss_safe = safe(x).sum() |
| loss_safe.backward() |
| opt_safe.step() |
|
|
| w_drift = (ref.weight.data - safe.weight.data).abs().max().item() |
| b_drift = (ref.bias.data - safe.bias.data).abs().max().item() |
| assert w_drift < 1e-3, f"ch={ch}: weight drift {w_drift:.2e} after {n_steps} steps" |
| assert b_drift < 1e-3, f"ch={ch}: bias drift {b_drift:.2e} after {n_steps} steps" |
| print(f" ch={ch:3d}: weight_drift={w_drift:.2e} bias_drift={b_drift:.2e} ({n_steps} steps)") |
| print("PASS: test_optimization_drift") |
|
|
|
|
| |
| |
| |
|
|
| def test_chained_up_path(): |
| """Simulate the OM_net decoder path: chain 5 SafeConvTranspose3d layers |
| matching the actual channel progression and verify outputs match a chain |
| of nn.ConvTranspose3d layers with the same weights. |
| |
| OM_net decoder: 512->512, 128->128, 64->64, 32->32, 12->12 |
| with 1x1 conv adaptors between layers to reduce channels. |
| """ |
| torch.manual_seed(42) |
|
|
| ref_layers = nn.ModuleList() |
| safe_layers = nn.ModuleList() |
| ref_adaptors = nn.ModuleList() |
| safe_adaptors = nn.ModuleList() |
|
|
| channels = OM_NET_UP_CHANNELS |
|
|
| for i, ch in enumerate(channels): |
| ref_up = nn.ConvTranspose3d(ch, ch, K, S, P) |
| safe_up = SafeConvTranspose3d(ch, ch, K, S, P) |
| safe_up.weight.data.copy_(ref_up.weight.data) |
| safe_up.bias.data.copy_(ref_up.bias.data) |
| ref_layers.append(ref_up) |
| safe_layers.append(safe_up) |
|
|
| |
| if i < len(channels) - 1: |
| next_ch = channels[i + 1] |
| ref_conv = nn.Conv3d(ch, next_ch, 1, 1, 0) |
| safe_conv = nn.Conv3d(ch, next_ch, 1, 1, 0) |
| safe_conv.weight.data.copy_(ref_conv.weight.data) |
| safe_conv.bias.data.copy_(ref_conv.bias.data) |
| ref_adaptors.append(ref_conv) |
| safe_adaptors.append(safe_conv) |
|
|
| |
| x = torch.randn(1, 512, 4, 4, 4) |
| x_ref = x.clone() |
| x_safe = x.clone() |
|
|
| with torch.no_grad(): |
| for i in range(len(channels)): |
| x_ref = ref_layers[i](x_ref) |
| x_safe = safe_layers[i](x_safe) |
| if i < len(channels) - 1: |
| x_ref = ref_adaptors[i](x_ref) |
| x_safe = safe_adaptors[i](x_safe) |
|
|
| |
| assert x_ref.shape == (1, 12, 128, 128, 128), f"Unexpected ref shape: {x_ref.shape}" |
| assert x_safe.shape == (1, 12, 128, 128, 128), f"Unexpected safe shape: {x_safe.shape}" |
|
|
| max_diff = (x_ref - x_safe).abs().max().item() |
| mean_diff = (x_ref - x_safe).abs().mean().item() |
| |
| assert max_diff < 1e-2, f"Chained path max diff {max_diff:.2e}" |
| assert mean_diff < 1e-4, f"Chained path mean diff {mean_diff:.2e}" |
| print(f" 5-layer chain: max_diff={max_diff:.2e} mean_diff={mean_diff:.2e}") |
| print("PASS: test_chained_up_path") |
|
|
|
|
| |
| |
| |
|
|
| def test_no_bias(): |
| """bias=False must work correctly.""" |
| for ch in [64, 128]: |
| ref, safe = _make_pair(ch, ch, bias=False) |
| assert safe.bias is None |
| x = torch.randn(1, ch, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe(x) |
| diff = (y_ref - y_safe).abs().max().item() |
| assert diff < 1e-4, f"ch={ch} no-bias: diff {diff:.2e}" |
| print(f" ch={ch:3d} no-bias: max_diff={diff:.2e}") |
| print("PASS: test_no_bias") |
|
|
|
|
| |
| |
| |
|
|
| def test_gradcheck(): |
| """Numerical Jacobian verification for SafeConvTranspose3d. |
| Uses small channels to keep gradcheck tractable.""" |
| safe = SafeConvTranspose3d(2, 2, K, S, P, bias=True).double() |
| x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) |
| result = torch.autograd.gradcheck(safe, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) |
| assert result |
| print("PASS: test_gradcheck") |
|
|
|
|
| |
| |
| |
|
|
| def test_batch_sizes(): |
| """Results should be identical regardless of batch size.""" |
| ch = 64 |
| ref, safe = _make_pair(ch, ch) |
| for batch_size in [1, 2, 4]: |
| x = torch.randn(batch_size, ch, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe(x) |
| assert y_ref.shape == y_safe.shape |
| max_diff = (y_ref - y_safe).abs().max().item() |
| assert max_diff < 1e-4, f"batch={batch_size}: max diff {max_diff:.2e}" |
| print(f" batch={batch_size}: max_diff={max_diff:.2e}") |
| print("PASS: test_batch_sizes") |
|
|
|
|
| |
| |
| |
|
|
| def test_determinism(): |
| """SafeConvTranspose3d must be deterministic (no stochastic ops).""" |
| ch = 64 |
| safe = SafeConvTranspose3d(ch, ch, K, S, P) |
| x = torch.randn(1, ch, 4, 4, 4) |
| with torch.no_grad(): |
| y1 = safe(x).clone() |
| y2 = safe(x).clone() |
| assert torch.equal(y1, y2), "SafeConvTranspose3d is not deterministic" |
| print("PASS: test_determinism") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("=" * 70) |
| print("SafeConvTranspose3d equivalence tests (OM_net channel configs)") |
| print(f"OM_net up-layer channels: {OM_NET_UP_CHANNELS}") |
| print(f"ConvTranspose params: kernel={K}, stride={S}, padding={P}") |
| print("=" * 70) |
|
|
| tests = [ |
| ("Forward (OM_net channels)", test_forward_om_net_channels), |
| ("Forward (larger spatial)", test_forward_larger_spatial), |
| ("Backward (OM_net channels)", test_backward_om_net_channels), |
| ("state_dict compatibility", test_state_dict_interchangeable), |
| ("Optimization drift (10 steps)", test_optimization_drift), |
| ("Chained up-path (5 layers)", test_chained_up_path), |
| ("No bias", test_no_bias), |
| ("Numerical gradcheck", test_gradcheck), |
| ("Batch size invariance", test_batch_sizes), |
| ("Determinism", test_determinism), |
| ] |
|
|
| failed = [] |
| for name, fn in tests: |
| print(f"\n--- {name} ---") |
| try: |
| fn() |
| except Exception as e: |
| print(f"FAIL: {name}: {e}") |
| failed.append(name) |
|
|
| print("\n" + "=" * 70) |
| if failed: |
| print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}") |
| else: |
| print(f"ALL {len(tests)} TESTS PASSED") |
| print("=" * 70) |
|
|