Omini3D / tests /test_safe_conv_transpose.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
Tests for SafeConvTranspose3d — verifies mathematical equivalence with nn.ConvTranspose3d.
Tests cover:
1. Forward pass: output correctness (V1: ~5e-7 precision, V2: bit-for-bit)
2. Backward pass: identical gradients w.r.t. input, weight, and bias
3. Checkpoint loading: weight shapes match nn.ConvTranspose3d
4. Various channel configurations matching the codebase usage
5. torch.autograd.gradcheck for numerical Jacobian verification
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from Diffusion.safe_conv_transpose import (
SafeConvTranspose3d,
SafeConvTranspose3d_v2,
replace_conv_transpose3d,
)
def _make_pair(in_c, out_c, kernel_size=4, stride=2, padding=1, bias=True):
"""Create nn.ConvTranspose3d and both Safe variants with identical weights."""
torch.manual_seed(42)
ref = nn.ConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias)
safe1 = SafeConvTranspose3d(in_c, out_c, kernel_size, stride, padding, bias=bias)
safe1.weight.data.copy_(ref.weight.data)
if bias:
safe1.bias.data.copy_(ref.bias.data)
safe2 = SafeConvTranspose3d_v2(in_c, out_c, kernel_size, stride, padding, bias=bias)
safe2.weight.data.copy_(ref.weight.data)
if bias:
safe2.bias.data.copy_(ref.bias.data)
return ref, safe1, safe2
# =============================================================================
# Basic shape tests
# =============================================================================
def test_weight_shape():
"""Weight and bias shapes must match nn.ConvTranspose3d exactly."""
for in_c, out_c in [(16, 16), (32, 32), (64, 64), (128, 128), (256, 256), (16, 32)]:
ref = nn.ConvTranspose3d(in_c, out_c, 4, 2, 1)
s1 = SafeConvTranspose3d(in_c, out_c, 4, 2, 1)
s2 = SafeConvTranspose3d_v2(in_c, out_c, 4, 2, 1)
assert ref.weight.shape == s1.weight.shape == s2.weight.shape, \
f"Weight shape mismatch for {in_c}->{out_c}"
assert ref.bias.shape == s1.bias.shape == s2.bias.shape, \
f"Bias shape mismatch for {in_c}->{out_c}"
print("PASS: test_weight_shape")
def test_output_shape():
"""Output shape must be [B, C_out, 2*D, 2*H, 2*W] for stride=2."""
for in_size in [2, 4, 8, 16]:
ref = nn.ConvTranspose3d(16, 16, 4, 2, 1)
safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1)
safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1)
x = torch.randn(1, 16, in_size, in_size, in_size)
expected = (1, 16, 2 * in_size, 2 * in_size, 2 * in_size)
assert ref(x).shape == expected
assert safe1(x).shape == expected
assert safe2(x).shape == expected
print("PASS: test_output_shape")
# =============================================================================
# Forward precision tests
# =============================================================================
def test_forward_v1():
"""V1 (decomposed) forward must be close to nn.ConvTranspose3d (~5e-7 precision)."""
configs = [
(16, 16, (2, 16, 4, 4, 4)),
(32, 32, (1, 32, 8, 8, 8)),
(64, 64, (1, 64, 4, 4, 4)),
(128, 128, (1, 128, 2, 2, 2)),
(256, 256, (1, 256, 2, 2, 2)),
]
for in_c, out_c, shape in configs:
ref, safe1, _ = _make_pair(in_c, out_c)
x = torch.randn(shape)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe1(x)
max_diff = (y_ref - y_safe).abs().max().item()
assert max_diff < 1e-5, f"V1 forward diff {max_diff} for {in_c}->{out_c}"
print(f" {in_c:3d}->{out_c:3d} input={shape}: max_diff={max_diff:.2e}")
print("PASS: test_forward_v1")
def test_forward_v2():
"""V2 (custom autograd) forward must be bit-for-bit identical."""
configs = [
(16, 16, (2, 16, 4, 4, 4)),
(32, 32, (1, 32, 8, 8, 8)),
(64, 64, (1, 64, 4, 4, 4)),
(128, 128, (1, 128, 2, 2, 2)),
]
for in_c, out_c, shape in configs:
ref, _, safe2 = _make_pair(in_c, out_c)
x = torch.randn(shape)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe2(x)
max_diff = (y_ref - y_safe).abs().max().item()
assert max_diff == 0.0, f"V2 forward should be bit-for-bit, got diff {max_diff}"
print("PASS: test_forward_v2")
def test_forward_v1_precision_analysis():
"""Detailed precision analysis for V1 vs reference."""
ref, safe1, _ = _make_pair(32, 32)
x = torch.randn(2, 32, 8, 8, 8)
with torch.no_grad():
y_ref = ref(x)
y_safe = safe1(x)
diff = (y_ref - y_safe).abs()
print(f" 32->32, [2,32,8,8,8]:")
print(f" max absolute diff: {diff.max().item():.2e}")
print(f" mean absolute diff: {diff.mean().item():.2e}")
print(f" % elements > 1e-6: {(diff > 1e-6).float().mean().item()*100:.2f}%")
assert diff.max().item() < 1e-4
print("PASS: test_forward_v1_precision_analysis")
# =============================================================================
# Backward tests
# =============================================================================
def _test_backward(version, label):
"""Test backward for grad_input, grad_weight, grad_bias with non-trivial upstream gradient."""
for C_in, C_out, D_in, B in [(4, 4, 3, 2), (8, 4, 5, 1), (4, 8, 4, 3),
(16, 16, 4, 2), (32, 32, 4, 1)]:
torch.manual_seed(42)
ct = nn.ConvTranspose3d(C_in, C_out, 4, 2, 1, bias=True)
safe = (SafeConvTranspose3d if version == 1 else SafeConvTranspose3d_v2)(
C_in, C_out, 4, 2, 1, bias=True
)
safe.weight.data.copy_(ct.weight.data)
safe.bias.data.copy_(ct.bias.data)
torch.manual_seed(123)
x_ref = torch.randn(B, C_in, D_in, D_in, D_in, requires_grad=True)
x_safe = x_ref.detach().clone().requires_grad_(True)
torch.manual_seed(456)
grad_y = torch.randn(B, C_out, 2 * D_in, 2 * D_in, 2 * D_in)
ct(x_ref).backward(grad_y)
safe(x_safe).backward(grad_y)
dx = (x_ref.grad - x_safe.grad).abs().max().item()
dw = (ct.weight.grad - safe.weight.grad).abs().max().item()
db = (ct.bias.grad - safe.bias.grad).abs().max().item()
assert dx < 1e-4, f"V{version} grad_input diff {dx} for {C_in}->{C_out}"
assert dw < 1e-3, f"V{version} grad_weight diff {dw} for {C_in}->{C_out}"
assert db < 1e-3, f"V{version} grad_bias diff {db} for {C_in}->{C_out}"
print(f" {C_in:2d}->{C_out:2d} D={D_in} B={B}: dx={dx:.2e} dw={dw:.2e} db={db:.2e}")
print(f"PASS: test_backward_{label}")
def test_backward_v1():
_test_backward(1, "v1")
def test_backward_v2():
_test_backward(2, "v2")
def test_optimization_step():
"""Run 3 SGD steps and verify parameters stay close."""
torch.manual_seed(42)
ref = nn.ConvTranspose3d(16, 16, 4, 2, 1)
safe1 = SafeConvTranspose3d(16, 16, 4, 2, 1)
safe1.weight.data.copy_(ref.weight.data)
safe1.bias.data.copy_(ref.bias.data)
safe2 = SafeConvTranspose3d_v2(16, 16, 4, 2, 1)
safe2.weight.data.copy_(ref.weight.data)
safe2.bias.data.copy_(ref.bias.data)
opt_ref = torch.optim.SGD(ref.parameters(), lr=0.01)
opt_s1 = torch.optim.SGD(safe1.parameters(), lr=0.01)
opt_s2 = torch.optim.SGD(safe2.parameters(), lr=0.01)
for step in range(3):
torch.manual_seed(step * 100)
x = torch.randn(1, 16, 4, 4, 4)
for opt, mod in [(opt_ref, ref), (opt_s1, safe1), (opt_s2, safe2)]:
opt.zero_grad()
mod(x).sum().backward()
opt.step()
w1 = (ref.weight.data - safe1.weight.data).abs().max().item()
w2 = (ref.weight.data - safe2.weight.data).abs().max().item()
print(f" After 3 SGD steps: V1 drift={w1:.2e}, V2 drift={w2:.2e}")
assert w1 < 1e-4
assert w2 < 1e-4
print("PASS: test_optimization_step")
# =============================================================================
# Checkpoint and replacement tests
# =============================================================================
def test_no_bias():
"""bias=False must work correctly."""
ref, safe1, safe2 = _make_pair(16, 16, bias=False)
x = torch.randn(1, 16, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_s1 = safe1(x)
y_s2 = safe2(x)
assert safe1.bias is None and safe2.bias is None
assert (y_ref - y_s1).abs().max().item() < 1e-5
assert (y_ref - y_s2).abs().max().item() == 0.0
print("PASS: test_no_bias")
def test_checkpoint_loading():
"""state_dict from nn.ConvTranspose3d must load into Safe variants."""
ref = nn.ConvTranspose3d(32, 32, 4, 2, 1)
sd = ref.state_dict()
safe1 = SafeConvTranspose3d(32, 32, 4, 2, 1)
safe1.load_state_dict(sd)
safe2 = SafeConvTranspose3d_v2(32, 32, 4, 2, 1)
safe2.load_state_dict(sd)
assert (safe1.weight.data - ref.weight.data).abs().max().item() == 0.0
assert (safe2.weight.data - ref.weight.data).abs().max().item() == 0.0
print("PASS: test_checkpoint_loading")
def test_replace_utility():
"""Test recursive replacement utility."""
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1)
self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1)
self.conv = nn.Conv3d(16, 3, 3, 1, 1) # should NOT be replaced
def forward(self, x):
return self.conv(self.up2(self.up1(x)))
model = Decoder()
x = torch.randn(1, 64, 4, 4, 4)
with torch.no_grad():
y_before = model(x).clone()
replace_conv_transpose3d(model)
assert isinstance(model.up1, SafeConvTranspose3d)
assert isinstance(model.up2, SafeConvTranspose3d)
assert isinstance(model.conv, nn.Conv3d)
with torch.no_grad():
y_after = model(x)
max_diff = (y_before - y_after).abs().max().item()
assert max_diff < 1e-4, f"Replace utility diff {max_diff}"
print(f" Replace utility: max diff = {max_diff:.2e}")
print("PASS: test_replace_utility")
def test_replace_v2():
"""Replacement with V2 should be bit-for-bit in forward."""
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.up1 = nn.ConvTranspose3d(64, 32, 4, 2, 1)
self.up2 = nn.ConvTranspose3d(32, 16, 4, 2, 1)
def forward(self, x):
return self.up2(self.up1(x))
model = Decoder()
x = torch.randn(1, 64, 4, 4, 4)
with torch.no_grad():
y_before = model(x).clone()
replace_conv_transpose3d(model, target_cls=SafeConvTranspose3d_v2)
assert isinstance(model.up1, SafeConvTranspose3d_v2)
assert isinstance(model.up2, SafeConvTranspose3d_v2)
with torch.no_grad():
y_after = model(x)
assert (y_before - y_after).abs().max().item() == 0.0
print("PASS: test_replace_v2")
def test_asymmetric_channels():
"""in_channels != out_channels."""
ref, safe1, safe2 = _make_pair(64, 32)
x = torch.randn(1, 64, 4, 4, 4)
with torch.no_grad():
y_ref = ref(x)
y_s1 = safe1(x)
y_s2 = safe2(x)
assert y_ref.shape == y_s1.shape == y_s2.shape
assert (y_ref - y_s1).abs().max().item() < 1e-5
assert (y_ref - y_s2).abs().max().item() == 0.0
print("PASS: test_asymmetric_channels")
# =============================================================================
# Numerical gradient verification
# =============================================================================
def test_gradcheck_v1():
"""Numerical Jacobian check for V1."""
safe1 = SafeConvTranspose3d(2, 2, 4, 2, 1, bias=True).double()
x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True)
result = torch.autograd.gradcheck(safe1, (x,), eps=1e-6, atol=1e-4, rtol=1e-3)
assert result
print("PASS: test_gradcheck_v1")
def test_gradcheck_v2():
"""Numerical Jacobian check for V2."""
safe2 = SafeConvTranspose3d_v2(2, 2, 4, 2, 1, bias=True).double()
x = torch.randn(1, 2, 3, 3, 3, dtype=torch.float64, requires_grad=True)
result = torch.autograd.gradcheck(safe2, (x,), eps=1e-6, atol=1e-4, rtol=1e-3)
assert result
print("PASS: test_gradcheck_v2")
# =============================================================================
# Training loss equivalence test
# =============================================================================
def test_training_loss_equivalence():
"""Build a small encoder-decoder network with ConvTranspose3d layers,
train for several steps, then replace with SafeConvTranspose3d and verify
the loss values are identical (V2) or near-identical (V1)."""
class MiniUNet(nn.Module):
"""Small UNet-like network with 3 ConvTranspose3d layers."""
def __init__(self):
super().__init__()
self.enc1 = nn.Conv3d(1, 16, 4, 2, 1)
self.enc2 = nn.Conv3d(16, 32, 4, 2, 1)
self.enc3 = nn.Conv3d(32, 64, 4, 2, 1)
self.dec3 = nn.ConvTranspose3d(64, 32, 4, 2, 1)
self.dec2 = nn.ConvTranspose3d(32, 16, 4, 2, 1)
self.dec1 = nn.ConvTranspose3d(16, 1, 4, 2, 1)
self.act = nn.ReLU()
def forward(self, x):
e1 = self.act(self.enc1(x))
e2 = self.act(self.enc2(e1))
e3 = self.act(self.enc3(e2))
d3 = self.act(self.dec3(e3))
d2 = self.act(self.dec2(d3))
d1 = self.dec1(d2)
return d1
import copy
torch.manual_seed(42)
model_ref = MiniUNet()
# Create Safe V1 version by replacing ConvTranspose3d layers
model_v1 = copy.deepcopy(model_ref)
replace_conv_transpose3d(model_v1, target_cls=SafeConvTranspose3d)
# Create Safe V2 version
model_v2 = copy.deepcopy(model_ref)
replace_conv_transpose3d(model_v2, target_cls=SafeConvTranspose3d_v2)
# Verify the replacement happened
assert isinstance(model_v1.dec1, SafeConvTranspose3d)
assert isinstance(model_v2.dec1, SafeConvTranspose3d_v2)
opt_ref = torch.optim.Adam(model_ref.parameters(), lr=1e-3)
opt_v1 = torch.optim.Adam(model_v1.parameters(), lr=1e-3)
opt_v2 = torch.optim.Adam(model_v2.parameters(), lr=1e-3)
criterion = nn.MSELoss()
n_steps = 5
print(f" Training {n_steps} steps, comparing loss at each step:")
for step in range(n_steps):
torch.manual_seed(step * 777)
x = torch.randn(2, 1, 16, 16, 16)
target = torch.randn(2, 1, 16, 16, 16)
# Reference (ConvTranspose3d)
opt_ref.zero_grad()
loss_ref = criterion(model_ref(x), target)
loss_ref.backward()
opt_ref.step()
# V1 (SafeConvTranspose3d)
opt_v1.zero_grad()
loss_v1 = criterion(model_v1(x), target)
loss_v1.backward()
opt_v1.step()
# V2 (SafeConvTranspose3d_v2)
opt_v2.zero_grad()
loss_v2 = criterion(model_v2(x), target)
loss_v2.backward()
opt_v2.step()
diff_v1 = abs(loss_ref.item() - loss_v1.item())
diff_v2 = abs(loss_ref.item() - loss_v2.item())
print(f" step {step}: loss_ref={loss_ref.item():.6f} "
f"loss_v1={loss_v1.item():.6f} (diff={diff_v1:.2e}) "
f"loss_v2={loss_v2.item():.6f} (diff={diff_v2:.2e})")
assert diff_v1 < 1e-4, f"V1 loss diverged at step {step}: diff={diff_v1}"
assert diff_v2 < 1e-6, f"V2 loss diverged at step {step}: diff={diff_v2}"
# Check final weight divergence after training
w_diff_v1 = max(
(p1.data - p2.data).abs().max().item()
for p1, p2 in zip(model_ref.parameters(), model_v1.parameters())
)
w_diff_v2 = max(
(p1.data - p2.data).abs().max().item()
for p1, p2 in zip(model_ref.parameters(), model_v2.parameters())
)
print(f" After {n_steps} steps — max weight drift: V1={w_diff_v1:.2e}, V2={w_diff_v2:.2e}")
assert w_diff_v1 < 1e-3, f"V1 weight drift too large: {w_diff_v1}"
assert w_diff_v2 < 1e-4, f"V2 weight drift too large: {w_diff_v2}"
print("PASS: test_training_loss_equivalence")
# =============================================================================
# Main
# =============================================================================
if __name__ == '__main__':
print("=" * 70)
print("Testing SafeConvTranspose3d implementations")
print("=" * 70)
tests = [
("Weight shapes", test_weight_shape),
("Output shapes", test_output_shape),
("Forward V1 (decomposed)", test_forward_v1),
("Forward V2 (custom autograd)", test_forward_v2),
("Forward V1 precision analysis", test_forward_v1_precision_analysis),
("Backward V1", test_backward_v1),
("Backward V2", test_backward_v2),
("Optimization step", test_optimization_step),
("No bias", test_no_bias),
("Checkpoint loading", test_checkpoint_loading),
("Replace utility (V1)", test_replace_utility),
("Replace utility (V2)", test_replace_v2),
("Asymmetric channels", test_asymmetric_channels),
("Gradcheck V1", test_gradcheck_v1),
("Gradcheck V2", test_gradcheck_v2),
("Training loss equivalence", test_training_loss_equivalence),
]
failed = []
for name, fn in tests:
print(f"\n--- {name} ---")
try:
fn()
except Exception as e:
print(f"FAIL: {name}: {e}")
failed.append(name)
print("\n" + "=" * 70)
if failed:
print(f"FAILED ({len(failed)}/{len(tests)}): {', '.join(failed)}")
else:
print(f"ALL {len(tests)} TESTS PASSED")
print("=" * 70)