| """ |
| Unit tests for PDP implementation to verify correctness against paper formulas. |
| """ |
|
|
| import torch |
| import math |
| from pdp import pdp_soft_mask, compute_threshold, PDPPruner |
| import torch.nn as nn |
|
|
|
|
| def test_soft_mask_boundary_conditions(): |
| """ |
| From the paper (Section 3.1): |
| - m(w=0) should be 0.5 when |w| = t (equal chance) |
| - m(w=0) -> 0 when w=0 (actually z(0)=1 so m(0)=0... wait let me check) |
| - m(w->inf) -> 1 |
| """ |
| tau = 1e-4 |
| t = 0.6 |
|
|
| |
| w_eq = torch.tensor([t]) |
| m_eq = pdp_soft_mask(w_eq, t, tau) |
| assert abs(m_eq.item() - 0.5) < 1e-3, f"m(t) should be ~0.5, got {m_eq.item()}" |
|
|
| |
| w_big = torch.tensor([10.0]) |
| m_big = pdp_soft_mask(w_big, t, tau) |
| assert m_big.item() > 0.99, f"m(>>t) should be ~1, got {m_big.item()}" |
|
|
| |
| w_small = torch.tensor([0.001]) |
| m_small = pdp_soft_mask(w_small, t, tau) |
| assert m_small.item() < 0.01, f"m(<<t) should be ~0, got {m_small.item()}" |
|
|
| print("✅ test_soft_mask_boundary_conditions passed") |
|
|
|
|
| def test_soft_mask_monotonicity(): |
| """ |
| Larger |w| -> larger m(w) (higher chance to keep). |
| """ |
| tau = 1e-4 |
| t = 0.6 |
| weights = torch.linspace(0.0, 2.0, 100) |
| masks = pdp_soft_mask(weights, t, tau) |
|
|
| |
| for i in range(len(weights) - 1): |
| assert masks[i] <= masks[i + 1] + 1e-6, "m(w) must be monotonically increasing" |
|
|
| print("✅ test_soft_mask_monotonicity passed") |
|
|
|
|
| def test_gradient_flow(): |
| """ |
| The soft mask must allow gradients to flow through. |
| Paper Eq. 2: Δw = m(w)·Δŵ + (2w²/τ)·m(w)·(1-m(w))·Δŵ |
| We verify this with a mild tau so values near the boundary aren't underflowed. |
| """ |
| tau = 1e-1 |
| t = 0.6 |
| w = torch.tensor([0.59], requires_grad=True) |
|
|
| |
| masked = pdp_soft_mask(w, t, tau) * w |
| loss = masked.sum() |
| loss.backward() |
|
|
| |
| assert w.grad is not None, "Gradient should flow through PDP mask" |
| assert w.grad.abs().item() > 0, f"Gradient should be non-zero, got {w.grad.item()}" |
|
|
| |
| |
| m_val = pdp_soft_mask(w.detach(), t, tau).item() |
| expected = m_val + 2 * (w.item() ** 2) / tau * m_val * (1 - m_val) |
| actual = w.grad.item() |
| |
| assert abs(actual - expected) < 0.1, \ |
| f"Expected grad ~{expected:.4f}, got {actual:.4f}" |
|
|
| print("✅ test_gradient_flow passed") |
|
|
|
|
| def test_threshold_computation(): |
| """ |
| Test that compute_threshold yields correct sparsity. |
| """ |
| torch.manual_seed(42) |
| weights = torch.randn(1000).abs() |
| sparsity = 0.3 |
| t = compute_threshold(weights, sparsity) |
|
|
| below = (weights <= t).float().sum().item() |
| actual_sparsity = below / weights.numel() |
| |
| assert abs(actual_sparsity - sparsity) < 0.01, \ |
| f"Threshold sparsity {actual_sparsity} far from target {sparsity}" |
|
|
| print("✅ test_threshold_computation passed") |
|
|
|
|
| def test_pdp_pruner_end_to_end(): |
| """ |
| Full end-to-end test: model, prune, hard prune. |
| """ |
| class Net(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = nn.Conv2d(3, 16, 3, padding=1) |
| self.conv2 = nn.Conv2d(16, 32, 3, padding=1) |
| self.fc = nn.Linear(32, 10) |
|
|
| def forward(self, x): |
| x = torch.relu(self.conv1(x)) |
| x = torch.relu(self.conv2(x)) |
| x = x.mean(dim=(2, 3)) |
| return self.fc(x) |
|
|
| model = Net() |
| pruner = PDPPruner(model, target_sparsity=0.5, s=2, epsilon=0.1, tau=1e-4) |
| pruner.attach() |
|
|
| |
| for epoch in range(4): |
| x = torch.randn(4, 3, 8, 8) |
| y = model(x) |
| loss = y.sum() |
| loss.backward() |
| with torch.no_grad(): |
| for p in model.parameters(): |
| if p.grad is not None: |
| p -= 0.01 * p.grad |
| p.grad.zero_() |
| pruner.step(epoch) |
|
|
| |
| pruner.hard_prune() |
| sparsity = pruner.get_sparsity() |
| assert sparsity > 0, f"After hard prune, sparsity should be > 0, got {sparsity}" |
|
|
| |
| for name, param in model.named_parameters(): |
| if "weight" in name and any(k in name for k in ["conv", "fc"]): |
| zeros = (param.data == 0).float().sum().item() |
| assert zeros > 0, f"No weights pruned in {name}" |
|
|
| pruner.detach() |
| print(f"✅ test_pdp_pruner_end_to_end passed (final sparsity={sparsity:.4f})") |
|
|
|
|
| if __name__ == "__main__": |
| test_soft_mask_boundary_conditions() |
| test_soft_mask_monotonicity() |
| test_gradient_flow() |
| test_threshold_computation() |
| test_pdp_pruner_end_to_end() |
| print("\n🎉 All PDP tests passed!") |
|
|