neural-pruning-impl / test_pdp.py
ESPR3SS0's picture
Add test_pdp.py
7ffed9a verified
"""
Unit tests for PDP implementation to verify correctness against paper formulas.
"""
import torch
import math
from pdp import pdp_soft_mask, compute_threshold, PDPPruner
import torch.nn as nn
def test_soft_mask_boundary_conditions():
"""
From the paper (Section 3.1):
- m(w=0) should be 0.5 when |w| = t (equal chance)
- m(w=0) -> 0 when w=0 (actually z(0)=1 so m(0)=0... wait let me check)
- m(w->inf) -> 1
"""
tau = 1e-4
t = 0.6
# w = t -> equal chance, m should be 0.5
w_eq = torch.tensor([t])
m_eq = pdp_soft_mask(w_eq, t, tau)
assert abs(m_eq.item() - 0.5) < 1e-3, f"m(t) should be ~0.5, got {m_eq.item()}"
# w >> t -> m -> 1
w_big = torch.tensor([10.0])
m_big = pdp_soft_mask(w_big, t, tau)
assert m_big.item() > 0.99, f"m(>>t) should be ~1, got {m_big.item()}"
# w << t -> m -> 0
w_small = torch.tensor([0.001])
m_small = pdp_soft_mask(w_small, t, tau)
assert m_small.item() < 0.01, f"m(<<t) should be ~0, got {m_small.item()}"
print("✅ test_soft_mask_boundary_conditions passed")
def test_soft_mask_monotonicity():
"""
Larger |w| -> larger m(w) (higher chance to keep).
"""
tau = 1e-4
t = 0.6
weights = torch.linspace(0.0, 2.0, 100)
masks = pdp_soft_mask(weights, t, tau)
# Check monotonicity
for i in range(len(weights) - 1):
assert masks[i] <= masks[i + 1] + 1e-6, "m(w) must be monotonically increasing"
print("✅ test_soft_mask_monotonicity passed")
def test_gradient_flow():
"""
The soft mask must allow gradients to flow through.
Paper Eq. 2: Δw = m(w)·Δŵ + (2w²/τ)·m(w)·(1-m(w))·Δŵ
We verify this with a mild tau so values near the boundary aren't underflowed.
"""
tau = 1e-1 # larger tau to avoid numerical underflow near boundary
t = 0.6
w = torch.tensor([0.59], requires_grad=True) # very close to boundary
# Forward: apply soft mask
masked = pdp_soft_mask(w, t, tau) * w
loss = masked.sum()
loss.backward()
# Check gradient exists and is non-zero
assert w.grad is not None, "Gradient should flow through PDP mask"
assert w.grad.abs().item() > 0, f"Gradient should be non-zero, got {w.grad.item()}"
# Near boundary (m≈0.5), the extra gradient term should be maximized
# dŵ/dw = m(w) + (2w²/τ)·m(w)·(1-m(w))
m_val = pdp_soft_mask(w.detach(), t, tau).item()
expected = m_val + 2 * (w.item() ** 2) / tau * m_val * (1 - m_val)
actual = w.grad.item()
# Allow tolerance for autograd numerical differences
assert abs(actual - expected) < 0.1, \
f"Expected grad ~{expected:.4f}, got {actual:.4f}"
print("✅ test_gradient_flow passed")
def test_threshold_computation():
"""
Test that compute_threshold yields correct sparsity.
"""
torch.manual_seed(42)
weights = torch.randn(1000).abs()
sparsity = 0.3
t = compute_threshold(weights, sparsity)
below = (weights <= t).float().sum().item()
actual_sparsity = below / weights.numel()
# Should be close to target (within 1 element)
assert abs(actual_sparsity - sparsity) < 0.01, \
f"Threshold sparsity {actual_sparsity} far from target {sparsity}"
print("✅ test_threshold_computation passed")
def test_pdp_pruner_end_to_end():
"""
Full end-to-end test: model, prune, hard prune.
"""
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc = nn.Linear(32, 10)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.relu(self.conv2(x))
x = x.mean(dim=(2, 3))
return self.fc(x)
model = Net()
pruner = PDPPruner(model, target_sparsity=0.5, s=2, epsilon=0.1, tau=1e-4)
pruner.attach()
# Simulate 4 training steps
for epoch in range(4):
x = torch.randn(4, 3, 8, 8)
y = model(x)
loss = y.sum()
loss.backward()
with torch.no_grad():
for p in model.parameters():
if p.grad is not None:
p -= 0.01 * p.grad
p.grad.zero_()
pruner.step(epoch)
# Hard prune
pruner.hard_prune()
sparsity = pruner.get_sparsity()
assert sparsity > 0, f"After hard prune, sparsity should be > 0, got {sparsity}"
# Check weights are actually zero
for name, param in model.named_parameters():
if "weight" in name and any(k in name for k in ["conv", "fc"]):
zeros = (param.data == 0).float().sum().item()
assert zeros > 0, f"No weights pruned in {name}"
pruner.detach()
print(f"✅ test_pdp_pruner_end_to_end passed (final sparsity={sparsity:.4f})")
if __name__ == "__main__":
test_soft_mask_boundary_conditions()
test_soft_mask_monotonicity()
test_gradient_flow()
test_threshold_computation()
test_pdp_pruner_end_to_end()
print("\n🎉 All PDP tests passed!")