""" Unit tests for PDP implementation to verify correctness against paper formulas. """ import torch import math from pdp import pdp_soft_mask, compute_threshold, PDPPruner import torch.nn as nn def test_soft_mask_boundary_conditions(): """ From the paper (Section 3.1): - m(w=0) should be 0.5 when |w| = t (equal chance) - m(w=0) -> 0 when w=0 (actually z(0)=1 so m(0)=0... wait let me check) - m(w->inf) -> 1 """ tau = 1e-4 t = 0.6 # w = t -> equal chance, m should be 0.5 w_eq = torch.tensor([t]) m_eq = pdp_soft_mask(w_eq, t, tau) assert abs(m_eq.item() - 0.5) < 1e-3, f"m(t) should be ~0.5, got {m_eq.item()}" # w >> t -> m -> 1 w_big = torch.tensor([10.0]) m_big = pdp_soft_mask(w_big, t, tau) assert m_big.item() > 0.99, f"m(>>t) should be ~1, got {m_big.item()}" # w << t -> m -> 0 w_small = torch.tensor([0.001]) m_small = pdp_soft_mask(w_small, t, tau) assert m_small.item() < 0.01, f"m(< larger m(w) (higher chance to keep). """ tau = 1e-4 t = 0.6 weights = torch.linspace(0.0, 2.0, 100) masks = pdp_soft_mask(weights, t, tau) # Check monotonicity for i in range(len(weights) - 1): assert masks[i] <= masks[i + 1] + 1e-6, "m(w) must be monotonically increasing" print("✅ test_soft_mask_monotonicity passed") def test_gradient_flow(): """ The soft mask must allow gradients to flow through. Paper Eq. 2: Δw = m(w)·Δŵ + (2w²/τ)·m(w)·(1-m(w))·Δŵ We verify this with a mild tau so values near the boundary aren't underflowed. """ tau = 1e-1 # larger tau to avoid numerical underflow near boundary t = 0.6 w = torch.tensor([0.59], requires_grad=True) # very close to boundary # Forward: apply soft mask masked = pdp_soft_mask(w, t, tau) * w loss = masked.sum() loss.backward() # Check gradient exists and is non-zero assert w.grad is not None, "Gradient should flow through PDP mask" assert w.grad.abs().item() > 0, f"Gradient should be non-zero, got {w.grad.item()}" # Near boundary (m≈0.5), the extra gradient term should be maximized # dŵ/dw = m(w) + (2w²/τ)·m(w)·(1-m(w)) m_val = pdp_soft_mask(w.detach(), t, tau).item() expected = m_val + 2 * (w.item() ** 2) / tau * m_val * (1 - m_val) actual = w.grad.item() # Allow tolerance for autograd numerical differences assert abs(actual - expected) < 0.1, \ f"Expected grad ~{expected:.4f}, got {actual:.4f}" print("✅ test_gradient_flow passed") def test_threshold_computation(): """ Test that compute_threshold yields correct sparsity. """ torch.manual_seed(42) weights = torch.randn(1000).abs() sparsity = 0.3 t = compute_threshold(weights, sparsity) below = (weights <= t).float().sum().item() actual_sparsity = below / weights.numel() # Should be close to target (within 1 element) assert abs(actual_sparsity - sparsity) < 0.01, \ f"Threshold sparsity {actual_sparsity} far from target {sparsity}" print("✅ test_threshold_computation passed") def test_pdp_pruner_end_to_end(): """ Full end-to-end test: model, prune, hard prune. """ class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3, padding=1) self.conv2 = nn.Conv2d(16, 32, 3, padding=1) self.fc = nn.Linear(32, 10) def forward(self, x): x = torch.relu(self.conv1(x)) x = torch.relu(self.conv2(x)) x = x.mean(dim=(2, 3)) return self.fc(x) model = Net() pruner = PDPPruner(model, target_sparsity=0.5, s=2, epsilon=0.1, tau=1e-4) pruner.attach() # Simulate 4 training steps for epoch in range(4): x = torch.randn(4, 3, 8, 8) y = model(x) loss = y.sum() loss.backward() with torch.no_grad(): for p in model.parameters(): if p.grad is not None: p -= 0.01 * p.grad p.grad.zero_() pruner.step(epoch) # Hard prune pruner.hard_prune() sparsity = pruner.get_sparsity() assert sparsity > 0, f"After hard prune, sparsity should be > 0, got {sparsity}" # Check weights are actually zero for name, param in model.named_parameters(): if "weight" in name and any(k in name for k in ["conv", "fc"]): zeros = (param.data == 0).float().sum().item() assert zeros > 0, f"No weights pruned in {name}" pruner.detach() print(f"✅ test_pdp_pruner_end_to_end passed (final sparsity={sparsity:.4f})") if __name__ == "__main__": test_soft_mask_boundary_conditions() test_soft_mask_monotonicity() test_gradient_flow() test_threshold_computation() test_pdp_pruner_end_to_end() print("\n🎉 All PDP tests passed!")