""" Tests for SafeConvTranspose3d — verifies mathematical equivalence with nn.ConvTranspose3d. Tests cover: 1. Forward pass: output correctness (V1: ~5e-7 precision, V2: bit-for-bit) 2. Backward pass: identical gradients w.r.t. input, weight, and bias 3. Checkpoint loading: weight shapes match nn.ConvTranspose3d 4. Various channel configurations matching the codebase usage 5. torch.autograd.gradcheck for numerical Jacobian verification """ import torch import torch.nn as nn import torch.nn.functional as F import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) from Diffusion.safe_conv_transpose import ( SafeConvTranspose3d, SafeConvTranspose3d_v2, replace_conv_transpose3d, ) def _make_pair(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=True): """Create nn.ConvTranspose3d and both Safe variants with identical weights.""" torch.manual_seed(42) ref = nn.ConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias) safe1 = SafeConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias) safe1.weight.data.copy_(ref.weight.data) if bias: safe1.bias.data.copy_(ref.bias.data) safe2 = SafeConvTranspose3d_v2(in_c, out_c, kernel_size, stride, padding, bias=bias) safe2.weight.data.copy_(ref.weight.data) if bias: safe2.bias.data.copy_(ref.bias.data) return ref, safe1, safe2 # ============================================================================= # Basic shape tests # ============================================================================= def test_weight_shape(): """Weight and bias shapes must match nn.ConvTranspose3d exactly.""" for in_c, out_c in [(16, 16), (32, 32), (64, 64), (128, 128), (256, 256), (16, 32)]: ref = nn.ConvTranspose3d(in_c, out_c, 4, 2, 1) s1 = SafeConvTranspose3d(in_c, out_c, 4, 2, 1) s2 = SafeConvTranspose3d_v2(in_c, out_c, 4, 2, 1) assert ref.weight.shape == s1.weight.shape == s2.weight.shape, \ f"Weight shape mismatch for {in_c}->{out_c}" assert ref.bias.shape == s1.bias.shape == s2.bias.shape, \ f"Bias shape mismatch for {in_c}->{out_c}" print("PASS: test_weight_shape") def test_output_shape(): """Output shape must be [B, C_out, 2*D, 2*H, 2*W] for stride=2.""" for in_size in [2, 4, 8, 16]: ref = nn.ConvTranspose3d(16, 16, 4, 2, 1) safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1) safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1) x = torch.randn(1, 16, in_size, in_size, in_size) expected = (1, 16, 2 * in_size, 2 * in_size, 2 * in_size) assert ref(x).shape == expected assert safe1(x).shape == expected assert safe2(x).shape == expected print("PASS: test_output_shape") # ============================================================================= # Forward precision tests # ============================================================================= def test_forward_v1(): """V1 (decomposed) forward must be close to nn.ConvTranspose3d (~5e-7 precision).""" configs = [ (16, 16, (2, 16, 4, 4, 4)), (32, 32, (1, 32, 8, 8, 8)), (64, 64, (1, 64, 4, 4, 4)), (128, 128, (1, 128, 2, 2, 2)), (256, 256, (1, 256, 2, 2, 2)), ] for in_c, out_c, shape in configs: ref, safe1, _ = _make_pair(in_c, out_c) x = torch.randn(shape) with torch.no_grad(): y_ref = ref(x) y_safe = safe1(x) max_diff = (y_ref - y_safe).abs().max().item() assert max_diff < 1e-5, f"V1 forward diff {max_diff} for {in_c}->{out_c}" print(f" {in_c:3d}->{out_c:3d} input={shape}: max_diff={max_diff:.2e}") print("PASS: test_forward_v1") def test_forward_v2(): """V2 (custom autograd) forward must be bit-for-bit identical.""" configs = [ (16, 16, (2, 16, 4, 4, 4)), (32, 32, (1, 32, 8, 8, 8)), (64, 64, (1, 64, 4, 4, 4)), (128, 128, (1, 128, 2, 2, 2)), ] for in_c, out_c, shape in configs: ref, _, safe2 = _make_pair(in_c, out_c) x = torch.randn(shape) with torch.no_grad(): y_ref = ref(x) y_safe = safe2(x) max_diff = (y_ref - y_safe).abs().max().item() assert max_diff == 0.0, f"V2 forward should be bit-for-bit, got diff {max_diff}" print("PASS: test_forward_v2") def test_forward_v1_precision_analysis(): """Detailed precision analysis for V1 vs reference.""" ref, safe1, _ = _make_pair(32, 32) x = torch.randn(2, 32, 8, 8, 8) with torch.no_grad(): y_ref = ref(x) y_safe = safe1(x) diff = (y_ref - y_safe).abs() print(f" 32->32, [2,32,8,8,8]:") print(f" max absolute diff: {diff.max().item():.2e}") print(f" mean absolute diff: {diff.mean().item():.2e}") print(f" % elements > 1e-6: {(diff > 1e-6).float().mean().item()*100:.2f}%") assert diff.max().item() < 1e-4 print("PASS: test_forward_v1_precision_analysis") # ============================================================================= # Backward tests # ============================================================================= def _test_backward(version, label): """Test backward for grad_input, grad_weight, grad_bias with non-trivial upstream gradient.""" for C_in, C_out, D_in, B in [(4, 4, 3, 2), (8, 4, 5, 1), (4, 8, 4, 3), (16, 16, 4, 2), (32, 32, 4, 1)]: torch.manual_seed(42) ct = nn.ConvTranspose3d(C_in, C_out, 4, 2, 1, bias=True) safe = (SafeConvTranspose3d if version == 1 else SafeConvTranspose3d_v2)( C_in, C_out, 4, 2, 1, bias=True ) safe.weight.data.copy_(ct.weight.data) safe.bias.data.copy_(ct.bias.data) torch.manual_seed(123) x_ref = torch.randn(B, C_in, D_in, D_in, D_in, requires_grad=True) x_safe = x_ref.detach().clone().requires_grad_(True) torch.manual_seed(456) grad_y = torch.randn(B, C_out, 2 * D_in, 2 * D_in, 2 * D_in) ct(x_ref).backward(grad_y) safe(x_safe).backward(grad_y) dx = (x_ref.grad - x_safe.grad).abs().max().item() dw = (ct.weight.grad - safe.weight.grad).abs().max().item() db = (ct.bias.grad - safe.bias.grad).abs().max().item() assert dx < 1e-4, f"V{version} grad_input diff {dx} for {C_in}->{C_out}" assert dw < 1e-3, f"V{version} grad_weight diff {dw} for {C_in}->{C_out}" assert db < 1e-3, f"V{version} grad_bias diff {db} for {C_in}->{C_out}" print(f" {C_in:2d}->{C_out:2d} D={D_in} B={B}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}") print(f"PASS: test_backward_{label}") def test_backward_v1(): _test_backward(1, "v1") def test_backward_v2(): _test_backward(2, "v2") def test_optimization_step(): """Run 3 SGD steps and verify parameters stay close.""" torch.manual_seed(42) ref = nn.ConvTranspose3d(16, 16, 4, 2, 1) safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1) safe1.weight.data.copy_(ref.weight.data) safe1.bias.data.copy_(ref.bias.data) safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1) safe2.weight.data.copy_(ref.weight.data) safe2.bias.data.copy_(ref.bias.data) opt_ref = torch.optim.SGD(ref.parameters(), lr=0.01) opt_s1 = torch.optim.SGD(safe1.parameters(), lr=0.01) opt_s2 = torch.optim.SGD(safe2.parameters(), lr=0.01) for step in range(3): torch.manual_seed(step * 100) x = torch.randn(1, 16, 4, 4, 4) for opt, mod in [(opt_ref, ref), (opt_s1, safe1), (opt_s2, safe2)]: opt.zero_grad() mod(x).sum().backward() opt.step() w1 = (ref.weight.data - safe1.weight.data).abs().max().item() w2 = (ref.weight.data - safe2.weight.data).abs().max().item() print(f" After 3 SGD steps: V1 drift={w1:.2e}, V2 drift={w2:.2e}") assert w1 < 1e-4 assert w2 < 1e-4 print("PASS: test_optimization_step") # ============================================================================= # Checkpoint and replacement tests # ============================================================================= def test_no_bias(): """bias=False must work correctly.""" ref, safe1, safe2 = _make_pair(16, 16, bias=False) x = torch.randn(1, 16, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_s1 = safe1(x) y_s2 = safe2(x) assert safe1.bias is None and safe2.bias is None assert (y_ref - y_s1).abs().max().item() < 1e-5 assert (y_ref - y_s2).abs().max().item() == 0.0 print("PASS: test_no_bias") def test_checkpoint_loading(): """state_dict from nn.ConvTranspose3d must load into Safe variants.""" ref = nn.ConvTranspose3d(32, 32, 4, 2, 1) sd = ref.state_dict() safe1 = SafeConvTranspose3d(32, 32, 4, 2, 1) safe1.load_state_dict(sd) safe2 = SafeConvTranspose3d_v2(32, 32, 4, 2, 1) safe2.load_state_dict(sd) assert (safe1.weight.data - ref.weight.data).abs().max().item() == 0.0 assert (safe2.weight.data - ref.weight.data).abs().max().item() == 0.0 print("PASS: test_checkpoint_loading") def test_replace_utility(): """Test recursive replacement utility.""" class Decoder(nn.Module): def __init__(self): super().__init__() self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1) self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) self.conv = nn.Conv3d(16, 3, 3, 1, 1) # should NOT be replaced def forward(self, x): return self.conv(self.up2(self.up1(x))) model = Decoder() x = torch.randn(1, 64, 4, 4, 4) with torch.no_grad(): y_before = model(x).clone() replace_conv_transpose3d(model) assert isinstance(model.up1, SafeConvTranspose3d) assert isinstance(model.up2, SafeConvTranspose3d) assert isinstance(model.conv, nn.Conv3d) with torch.no_grad(): y_after = model(x) max_diff = (y_before - y_after).abs().max().item() assert max_diff < 1e-4, f"Replace utility diff {max_diff}" print(f" Replace utility: max diff = {max_diff:.2e}") print("PASS: test_replace_utility") def test_replace_v2(): """Replacement with V2 should be bit-for-bit in forward.""" class Decoder(nn.Module): def __init__(self): super().__init__() self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1) self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) def forward(self, x): return self.up2(self.up1(x)) model = Decoder() x = torch.randn(1, 64, 4, 4, 4) with torch.no_grad(): y_before = model(x).clone() replace_conv_transpose3d(model, target_cls=SafeConvTranspose3d_v2) assert isinstance(model.up1, SafeConvTranspose3d_v2) assert isinstance(model.up2, SafeConvTranspose3d_v2) with torch.no_grad(): y_after = model(x) assert (y_before - y_after).abs().max().item() == 0.0 print("PASS: test_replace_v2") def test_asymmetric_channels(): """in_channels != out_channels.""" ref, safe1, safe2 = _make_pair(64, 32) x = torch.randn(1, 64, 4, 4, 4) with torch.no_grad(): y_ref = ref(x) y_s1 = safe1(x) y_s2 = safe2(x) assert y_ref.shape == y_s1.shape == y_s2.shape assert (y_ref - y_s1).abs().max().item() < 1e-5 assert (y_ref - y_s2).abs().max().item() == 0.0 print("PASS: test_asymmetric_channels") # ============================================================================= # Numerical gradient verification # ============================================================================= def test_gradcheck_v1(): """Numerical Jacobian check for V1.""" safe1 = SafeConvTranspose3d(2, 2, 4, 2, 1, bias=True).double() x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) result = torch.autograd.gradcheck(safe1, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) assert result print("PASS: test_gradcheck_v1") def test_gradcheck_v2(): """Numerical Jacobian check for V2.""" safe2 = SafeConvTranspose3d_v2(2, 2, 4, 2, 1, bias=True).double() x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) result = torch.autograd.gradcheck(safe2, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) assert result print("PASS: test_gradcheck_v2") # ============================================================================= # Training loss equivalence test # ============================================================================= def test_training_loss_equivalence(): """Build a small encoder-decoder network with ConvTranspose3d layers, train for several steps, then replace with SafeConvTranspose3d and verify the loss values are identical (V2) or near-identical (V1).""" class MiniUNet(nn.Module): """Small UNet-like network with 3 ConvTranspose3d layers.""" def __init__(self): super().__init__() self.enc1 = nn.Conv3d(1, 16, 4, 2, 1) self.enc2 = nn.Conv3d(16, 32, 4, 2, 1) self.enc3 = nn.Conv3d(32, 64, 4, 2, 1) self.dec3 = nn.ConvTranspose3d(64, 32, 4, 2, 1) self.dec2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) self.dec1 = nn.ConvTranspose3d(16, 1, 4, 2, 1) self.act = nn.ReLU() def forward(self, x): e1 = self.act(self.enc1(x)) e2 = self.act(self.enc2(e1)) e3 = self.act(self.enc3(e2)) d3 = self.act(self.dec3(e3)) d2 = self.act(self.dec2(d3)) d1 = self.dec1(d2) return d1 import copy torch.manual_seed(42) model_ref = MiniUNet() # Create Safe V1 version by replacing ConvTranspose3d layers model_v1 = copy.deepcopy(model_ref) replace_conv_transpose3d(model_v1, target_cls=SafeConvTranspose3d) # Create Safe V2 version model_v2 = copy.deepcopy(model_ref) replace_conv_transpose3d(model_v2, target_cls=SafeConvTranspose3d_v2) # Verify the replacement happened assert isinstance(model_v1.dec1, SafeConvTranspose3d) assert isinstance(model_v2.dec1, SafeConvTranspose3d_v2) opt_ref = torch.optim.Adam(model_ref.parameters(), lr=1e-3) opt_v1 = torch.optim.Adam(model_v1.parameters(), lr=1e-3) opt_v2 = torch.optim.Adam(model_v2.parameters(), lr=1e-3) criterion = nn.MSELoss() n_steps = 5 print(f" Training {n_steps} steps, comparing loss at each step:") for step in range(n_steps): torch.manual_seed(step * 777) x = torch.randn(2, 1, 16, 16, 16) target = torch.randn(2, 1, 16, 16, 16) # Reference (ConvTranspose3d) opt_ref.zero_grad() loss_ref = criterion(model_ref(x), target) loss_ref.backward() opt_ref.step() # V1 (SafeConvTranspose3d) opt_v1.zero_grad() loss_v1 = criterion(model_v1(x), target) loss_v1.backward() opt_v1.step() # V2 (SafeConvTranspose3d_v2) opt_v2.zero_grad() loss_v2 = criterion(model_v2(x), target) loss_v2.backward() opt_v2.step() diff_v1 = abs(loss_ref.item() - loss_v1.item()) diff_v2 = abs(loss_ref.item() - loss_v2.item()) print(f" step {step}: loss_ref={loss_ref.item():.6f} " f"loss_v1={loss_v1.item():.6f} (diff={diff_v1:.2e}) " f"loss_v2={loss_v2.item():.6f} (diff={diff_v2:.2e})") assert diff_v1 < 1e-4, f"V1 loss diverged at step {step}: diff={diff_v1}" assert diff_v2 < 1e-6, f"V2 loss diverged at step {step}: diff={diff_v2}" # Check final weight divergence after training w_diff_v1 = max( (p1.data - p2.data).abs().max().item() for p1, p2 in zip(model_ref.parameters(), model_v1.parameters()) ) w_diff_v2 = max( (p1.data - p2.data).abs().max().item() for p1, p2 in zip(model_ref.parameters(), model_v2.parameters()) ) print(f" After {n_steps} steps — max weight drift: V1={w_diff_v1:.2e}, V2={w_diff_v2:.2e}") assert w_diff_v1 < 1e-3, f"V1 weight drift too large: {w_diff_v1}" assert w_diff_v2 < 1e-4, f"V2 weight drift too large: {w_diff_v2}" print("PASS: test_training_loss_equivalence") # ============================================================================= # Main # ============================================================================= if __name__ == '__main__': print("=" * 70) print("Testing SafeConvTranspose3d implementations") print("=" * 70) tests = [ ("Weight shapes", test_weight_shape), ("Output shapes", test_output_shape), ("Forward V1 (decomposed)", test_forward_v1), ("Forward V2 (custom autograd)", test_forward_v2), ("Forward V1 precision analysis", test_forward_v1_precision_analysis), ("Backward V1", test_backward_v1), ("Backward V2", test_backward_v2), ("Optimization step", test_optimization_step), ("No bias", test_no_bias), ("Checkpoint loading", test_checkpoint_loading), ("Replace utility (V1)", test_replace_utility), ("Replace utility (V2)", test_replace_v2), ("Asymmetric channels", test_asymmetric_channels), ("Gradcheck V1", test_gradcheck_v1), ("Gradcheck V2", test_gradcheck_v2), ("Training loss equivalence", test_training_loss_equivalence), ] failed = [] for name, fn in tests: print(f"\n--- {name} ---") try: fn() except Exception as e: print(f"FAIL: {name}: {e}") failed.append(name) print("\n" + "=" * 70) if failed: print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}") else: print(f"ALL {len(tests)} TESTS PASSED") print("=" * 70)