| """ |
| Tests for SafeConvTranspose3d — verifies mathematical equivalence with nn.ConvTranspose3d. |
| |
| Tests cover: |
| 1. Forward pass: output correctness (V1: ~5e-7 precision, V2: bit-for-bit) |
| 2. Backward pass: identical gradients w.r.t. input, weight, and bias |
| 3. Checkpoint loading: weight shapes match nn.ConvTranspose3d |
| 4. Various channel configurations matching the codebase usage |
| 5. torch.autograd.gradcheck for numerical Jacobian verification |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import sys |
| import os |
|
|
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
| from Diffusion.safe_conv_transpose import ( |
| SafeConvTranspose3d, |
| SafeConvTranspose3d_v2, |
| replace_conv_transpose3d, |
| ) |
|
|
|
|
| def _make_pair(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=True): |
| """Create nn.ConvTranspose3d and both Safe variants with identical weights.""" |
| torch.manual_seed(42) |
| ref = nn.ConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias) |
|
|
| safe1 = SafeConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias) |
| safe1.weight.data.copy_(ref.weight.data) |
| if bias: |
| safe1.bias.data.copy_(ref.bias.data) |
|
|
| safe2 = SafeConvTranspose3d_v2(in_c, out_c, kernel_size, stride, padding, bias=bias) |
| safe2.weight.data.copy_(ref.weight.data) |
| if bias: |
| safe2.bias.data.copy_(ref.bias.data) |
|
|
| return ref, safe1, safe2 |
|
|
|
|
| |
| |
| |
|
|
| def test_weight_shape(): |
| """Weight and bias shapes must match nn.ConvTranspose3d exactly.""" |
| for in_c, out_c in [(16, 16), (32, 32), (64, 64), (128, 128), (256, 256), (16, 32)]: |
| ref = nn.ConvTranspose3d(in_c, out_c, 4, 2, 1) |
| s1 = SafeConvTranspose3d(in_c, out_c, 4, 2, 1) |
| s2 = SafeConvTranspose3d_v2(in_c, out_c, 4, 2, 1) |
|
|
| assert ref.weight.shape == s1.weight.shape == s2.weight.shape, \ |
| f"Weight shape mismatch for {in_c}->{out_c}" |
| assert ref.bias.shape == s1.bias.shape == s2.bias.shape, \ |
| f"Bias shape mismatch for {in_c}->{out_c}" |
| print("PASS: test_weight_shape") |
|
|
|
|
| def test_output_shape(): |
| """Output shape must be [B, C_out, 2*D, 2*H, 2*W] for stride=2.""" |
| for in_size in [2, 4, 8, 16]: |
| ref = nn.ConvTranspose3d(16, 16, 4, 2, 1) |
| safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1) |
| safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1) |
|
|
| x = torch.randn(1, 16, in_size, in_size, in_size) |
| expected = (1, 16, 2 * in_size, 2 * in_size, 2 * in_size) |
| assert ref(x).shape == expected |
| assert safe1(x).shape == expected |
| assert safe2(x).shape == expected |
| print("PASS: test_output_shape") |
|
|
|
|
| |
| |
| |
|
|
| def test_forward_v1(): |
| """V1 (decomposed) forward must be close to nn.ConvTranspose3d (~5e-7 precision).""" |
| configs = [ |
| (16, 16, (2, 16, 4, 4, 4)), |
| (32, 32, (1, 32, 8, 8, 8)), |
| (64, 64, (1, 64, 4, 4, 4)), |
| (128, 128, (1, 128, 2, 2, 2)), |
| (256, 256, (1, 256, 2, 2, 2)), |
| ] |
| for in_c, out_c, shape in configs: |
| ref, safe1, _ = _make_pair(in_c, out_c) |
| x = torch.randn(shape) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe1(x) |
| max_diff = (y_ref - y_safe).abs().max().item() |
| assert max_diff < 1e-5, f"V1 forward diff {max_diff} for {in_c}->{out_c}" |
| print(f" {in_c:3d}->{out_c:3d} input={shape}: max_diff={max_diff:.2e}") |
| print("PASS: test_forward_v1") |
|
|
|
|
| def test_forward_v2(): |
| """V2 (custom autograd) forward must be bit-for-bit identical.""" |
| configs = [ |
| (16, 16, (2, 16, 4, 4, 4)), |
| (32, 32, (1, 32, 8, 8, 8)), |
| (64, 64, (1, 64, 4, 4, 4)), |
| (128, 128, (1, 128, 2, 2, 2)), |
| ] |
| for in_c, out_c, shape in configs: |
| ref, _, safe2 = _make_pair(in_c, out_c) |
| x = torch.randn(shape) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe2(x) |
| max_diff = (y_ref - y_safe).abs().max().item() |
| assert max_diff == 0.0, f"V2 forward should be bit-for-bit, got diff {max_diff}" |
| print("PASS: test_forward_v2") |
|
|
|
|
| def test_forward_v1_precision_analysis(): |
| """Detailed precision analysis for V1 vs reference.""" |
| ref, safe1, _ = _make_pair(32, 32) |
| x = torch.randn(2, 32, 8, 8, 8) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_safe = safe1(x) |
| diff = (y_ref - y_safe).abs() |
| print(f" 32->32, [2,32,8,8,8]:") |
| print(f" max absolute diff: {diff.max().item():.2e}") |
| print(f" mean absolute diff: {diff.mean().item():.2e}") |
| print(f" % elements > 1e-6: {(diff > 1e-6).float().mean().item()*100:.2f}%") |
| assert diff.max().item() < 1e-4 |
| print("PASS: test_forward_v1_precision_analysis") |
|
|
|
|
| |
| |
| |
|
|
| def _test_backward(version, label): |
| """Test backward for grad_input, grad_weight, grad_bias with non-trivial upstream gradient.""" |
| for C_in, C_out, D_in, B in [(4, 4, 3, 2), (8, 4, 5, 1), (4, 8, 4, 3), |
| (16, 16, 4, 2), (32, 32, 4, 1)]: |
| torch.manual_seed(42) |
| ct = nn.ConvTranspose3d(C_in, C_out, 4, 2, 1, bias=True) |
| safe = (SafeConvTranspose3d if version == 1 else SafeConvTranspose3d_v2)( |
| C_in, C_out, 4, 2, 1, bias=True |
| ) |
| safe.weight.data.copy_(ct.weight.data) |
| safe.bias.data.copy_(ct.bias.data) |
|
|
| torch.manual_seed(123) |
| x_ref = torch.randn(B, C_in, D_in, D_in, D_in, requires_grad=True) |
| x_safe = x_ref.detach().clone().requires_grad_(True) |
|
|
| torch.manual_seed(456) |
| grad_y = torch.randn(B, C_out, 2 * D_in, 2 * D_in, 2 * D_in) |
|
|
| ct(x_ref).backward(grad_y) |
| safe(x_safe).backward(grad_y) |
|
|
| dx = (x_ref.grad - x_safe.grad).abs().max().item() |
| dw = (ct.weight.grad - safe.weight.grad).abs().max().item() |
| db = (ct.bias.grad - safe.bias.grad).abs().max().item() |
|
|
| assert dx < 1e-4, f"V{version} grad_input diff {dx} for {C_in}->{C_out}" |
| assert dw < 1e-3, f"V{version} grad_weight diff {dw} for {C_in}->{C_out}" |
| assert db < 1e-3, f"V{version} grad_bias diff {db} for {C_in}->{C_out}" |
| print(f" {C_in:2d}->{C_out:2d} D={D_in} B={B}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}") |
|
|
| print(f"PASS: test_backward_{label}") |
|
|
|
|
| def test_backward_v1(): |
| _test_backward(1, "v1") |
|
|
| def test_backward_v2(): |
| _test_backward(2, "v2") |
|
|
|
|
| def test_optimization_step(): |
| """Run 3 SGD steps and verify parameters stay close.""" |
| torch.manual_seed(42) |
| ref = nn.ConvTranspose3d(16, 16, 4, 2, 1) |
|
|
| safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1) |
| safe1.weight.data.copy_(ref.weight.data) |
| safe1.bias.data.copy_(ref.bias.data) |
|
|
| safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1) |
| safe2.weight.data.copy_(ref.weight.data) |
| safe2.bias.data.copy_(ref.bias.data) |
|
|
| opt_ref = torch.optim.SGD(ref.parameters(), lr=0.01) |
| opt_s1 = torch.optim.SGD(safe1.parameters(), lr=0.01) |
| opt_s2 = torch.optim.SGD(safe2.parameters(), lr=0.01) |
|
|
| for step in range(3): |
| torch.manual_seed(step * 100) |
| x = torch.randn(1, 16, 4, 4, 4) |
|
|
| for opt, mod in [(opt_ref, ref), (opt_s1, safe1), (opt_s2, safe2)]: |
| opt.zero_grad() |
| mod(x).sum().backward() |
| opt.step() |
|
|
| w1 = (ref.weight.data - safe1.weight.data).abs().max().item() |
| w2 = (ref.weight.data - safe2.weight.data).abs().max().item() |
| print(f" After 3 SGD steps: V1 drift={w1:.2e}, V2 drift={w2:.2e}") |
| assert w1 < 1e-4 |
| assert w2 < 1e-4 |
| print("PASS: test_optimization_step") |
|
|
|
|
| |
| |
| |
|
|
| def test_no_bias(): |
| """bias=False must work correctly.""" |
| ref, safe1, safe2 = _make_pair(16, 16, bias=False) |
| x = torch.randn(1, 16, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_s1 = safe1(x) |
| y_s2 = safe2(x) |
| assert safe1.bias is None and safe2.bias is None |
| assert (y_ref - y_s1).abs().max().item() < 1e-5 |
| assert (y_ref - y_s2).abs().max().item() == 0.0 |
| print("PASS: test_no_bias") |
|
|
|
|
| def test_checkpoint_loading(): |
| """state_dict from nn.ConvTranspose3d must load into Safe variants.""" |
| ref = nn.ConvTranspose3d(32, 32, 4, 2, 1) |
| sd = ref.state_dict() |
|
|
| safe1 = SafeConvTranspose3d(32, 32, 4, 2, 1) |
| safe1.load_state_dict(sd) |
|
|
| safe2 = SafeConvTranspose3d_v2(32, 32, 4, 2, 1) |
| safe2.load_state_dict(sd) |
|
|
| assert (safe1.weight.data - ref.weight.data).abs().max().item() == 0.0 |
| assert (safe2.weight.data - ref.weight.data).abs().max().item() == 0.0 |
| print("PASS: test_checkpoint_loading") |
|
|
|
|
| def test_replace_utility(): |
| """Test recursive replacement utility.""" |
|
|
| class Decoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1) |
| self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) |
| self.conv = nn.Conv3d(16, 3, 3, 1, 1) |
|
|
| def forward(self, x): |
| return self.conv(self.up2(self.up1(x))) |
|
|
| model = Decoder() |
| x = torch.randn(1, 64, 4, 4, 4) |
| with torch.no_grad(): |
| y_before = model(x).clone() |
|
|
| replace_conv_transpose3d(model) |
| assert isinstance(model.up1, SafeConvTranspose3d) |
| assert isinstance(model.up2, SafeConvTranspose3d) |
| assert isinstance(model.conv, nn.Conv3d) |
|
|
| with torch.no_grad(): |
| y_after = model(x) |
| max_diff = (y_before - y_after).abs().max().item() |
| assert max_diff < 1e-4, f"Replace utility diff {max_diff}" |
| print(f" Replace utility: max diff = {max_diff:.2e}") |
| print("PASS: test_replace_utility") |
|
|
|
|
| def test_replace_v2(): |
| """Replacement with V2 should be bit-for-bit in forward.""" |
|
|
| class Decoder(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1) |
| self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) |
|
|
| def forward(self, x): |
| return self.up2(self.up1(x)) |
|
|
| model = Decoder() |
| x = torch.randn(1, 64, 4, 4, 4) |
| with torch.no_grad(): |
| y_before = model(x).clone() |
|
|
| replace_conv_transpose3d(model, target_cls=SafeConvTranspose3d_v2) |
| assert isinstance(model.up1, SafeConvTranspose3d_v2) |
| assert isinstance(model.up2, SafeConvTranspose3d_v2) |
|
|
| with torch.no_grad(): |
| y_after = model(x) |
| assert (y_before - y_after).abs().max().item() == 0.0 |
| print("PASS: test_replace_v2") |
|
|
|
|
| def test_asymmetric_channels(): |
| """in_channels != out_channels.""" |
| ref, safe1, safe2 = _make_pair(64, 32) |
| x = torch.randn(1, 64, 4, 4, 4) |
| with torch.no_grad(): |
| y_ref = ref(x) |
| y_s1 = safe1(x) |
| y_s2 = safe2(x) |
| assert y_ref.shape == y_s1.shape == y_s2.shape |
| assert (y_ref - y_s1).abs().max().item() < 1e-5 |
| assert (y_ref - y_s2).abs().max().item() == 0.0 |
| print("PASS: test_asymmetric_channels") |
|
|
|
|
| |
| |
| |
|
|
| def test_gradcheck_v1(): |
| """Numerical Jacobian check for V1.""" |
| safe1 = SafeConvTranspose3d(2, 2, 4, 2, 1, bias=True).double() |
| x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) |
| result = torch.autograd.gradcheck(safe1, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) |
| assert result |
| print("PASS: test_gradcheck_v1") |
|
|
|
|
| def test_gradcheck_v2(): |
| """Numerical Jacobian check for V2.""" |
| safe2 = SafeConvTranspose3d_v2(2, 2, 4, 2, 1, bias=True).double() |
| x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True) |
| result = torch.autograd.gradcheck(safe2, (x,), eps=1e-6, atol=1e-4, rtol=1e-3) |
| assert result |
| print("PASS: test_gradcheck_v2") |
|
|
|
|
| |
| |
| |
|
|
| def test_training_loss_equivalence(): |
| """Build a small encoder-decoder network with ConvTranspose3d layers, |
| train for several steps, then replace with SafeConvTranspose3d and verify |
| the loss values are identical (V2) or near-identical (V1).""" |
|
|
| class MiniUNet(nn.Module): |
| """Small UNet-like network with 3 ConvTranspose3d layers.""" |
| def __init__(self): |
| super().__init__() |
| self.enc1 = nn.Conv3d(1, 16, 4, 2, 1) |
| self.enc2 = nn.Conv3d(16, 32, 4, 2, 1) |
| self.enc3 = nn.Conv3d(32, 64, 4, 2, 1) |
| self.dec3 = nn.ConvTranspose3d(64, 32, 4, 2, 1) |
| self.dec2 = nn.ConvTranspose3d(32, 16, 4, 2, 1) |
| self.dec1 = nn.ConvTranspose3d(16, 1, 4, 2, 1) |
| self.act = nn.ReLU() |
|
|
| def forward(self, x): |
| e1 = self.act(self.enc1(x)) |
| e2 = self.act(self.enc2(e1)) |
| e3 = self.act(self.enc3(e2)) |
| d3 = self.act(self.dec3(e3)) |
| d2 = self.act(self.dec2(d3)) |
| d1 = self.dec1(d2) |
| return d1 |
|
|
| import copy |
|
|
| torch.manual_seed(42) |
| model_ref = MiniUNet() |
|
|
| |
| model_v1 = copy.deepcopy(model_ref) |
| replace_conv_transpose3d(model_v1, target_cls=SafeConvTranspose3d) |
|
|
| |
| model_v2 = copy.deepcopy(model_ref) |
| replace_conv_transpose3d(model_v2, target_cls=SafeConvTranspose3d_v2) |
|
|
| |
| assert isinstance(model_v1.dec1, SafeConvTranspose3d) |
| assert isinstance(model_v2.dec1, SafeConvTranspose3d_v2) |
|
|
| opt_ref = torch.optim.Adam(model_ref.parameters(), lr=1e-3) |
| opt_v1 = torch.optim.Adam(model_v1.parameters(), lr=1e-3) |
| opt_v2 = torch.optim.Adam(model_v2.parameters(), lr=1e-3) |
|
|
| criterion = nn.MSELoss() |
| n_steps = 5 |
|
|
| print(f" Training {n_steps} steps, comparing loss at each step:") |
| for step in range(n_steps): |
| torch.manual_seed(step * 777) |
| x = torch.randn(2, 1, 16, 16, 16) |
| target = torch.randn(2, 1, 16, 16, 16) |
|
|
| |
| opt_ref.zero_grad() |
| loss_ref = criterion(model_ref(x), target) |
| loss_ref.backward() |
| opt_ref.step() |
|
|
| |
| opt_v1.zero_grad() |
| loss_v1 = criterion(model_v1(x), target) |
| loss_v1.backward() |
| opt_v1.step() |
|
|
| |
| opt_v2.zero_grad() |
| loss_v2 = criterion(model_v2(x), target) |
| loss_v2.backward() |
| opt_v2.step() |
|
|
| diff_v1 = abs(loss_ref.item() - loss_v1.item()) |
| diff_v2 = abs(loss_ref.item() - loss_v2.item()) |
| print(f" step {step}: loss_ref={loss_ref.item():.6f} " |
| f"loss_v1={loss_v1.item():.6f} (diff={diff_v1:.2e}) " |
| f"loss_v2={loss_v2.item():.6f} (diff={diff_v2:.2e})") |
|
|
| assert diff_v1 < 1e-4, f"V1 loss diverged at step {step}: diff={diff_v1}" |
| assert diff_v2 < 1e-6, f"V2 loss diverged at step {step}: diff={diff_v2}" |
|
|
| |
| w_diff_v1 = max( |
| (p1.data - p2.data).abs().max().item() |
| for p1, p2 in zip(model_ref.parameters(), model_v1.parameters()) |
| ) |
| w_diff_v2 = max( |
| (p1.data - p2.data).abs().max().item() |
| for p1, p2 in zip(model_ref.parameters(), model_v2.parameters()) |
| ) |
| print(f" After {n_steps} steps — max weight drift: V1={w_diff_v1:.2e}, V2={w_diff_v2:.2e}") |
| assert w_diff_v1 < 1e-3, f"V1 weight drift too large: {w_diff_v1}" |
| assert w_diff_v2 < 1e-4, f"V2 weight drift too large: {w_diff_v2}" |
| print("PASS: test_training_loss_equivalence") |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == '__main__': |
| print("=" * 70) |
| print("Testing SafeConvTranspose3d implementations") |
| print("=" * 70) |
|
|
| tests = [ |
| ("Weight shapes", test_weight_shape), |
| ("Output shapes", test_output_shape), |
| ("Forward V1 (decomposed)", test_forward_v1), |
| ("Forward V2 (custom autograd)", test_forward_v2), |
| ("Forward V1 precision analysis", test_forward_v1_precision_analysis), |
| ("Backward V1", test_backward_v1), |
| ("Backward V2", test_backward_v2), |
| ("Optimization step", test_optimization_step), |
| ("No bias", test_no_bias), |
| ("Checkpoint loading", test_checkpoint_loading), |
| ("Replace utility (V1)", test_replace_utility), |
| ("Replace utility (V2)", test_replace_v2), |
| ("Asymmetric channels", test_asymmetric_channels), |
| ("Gradcheck V1", test_gradcheck_v1), |
| ("Gradcheck V2", test_gradcheck_v2), |
| ("Training loss equivalence", test_training_loss_equivalence), |
| ] |
|
|
| failed = [] |
| for name, fn in tests: |
| print(f"\n--- {name} ---") |
| try: |
| fn() |
| except Exception as e: |
| print(f"FAIL: {name}: {e}") |
| failed.append(name) |
|
|
| print("\n" + "=" * 70) |
| if failed: |
| print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}") |
| else: |
| print(f"ALL {len(tests)} TESTS PASSED") |
| print("=" * 70) |
|
|