ESPR3SS0 commited on
Commit
7ffed9a
·
verified ·
1 Parent(s): e4e6e3c

Add test_pdp.py

Browse files
Files changed (1) hide show
  1. test_pdp.py +159 -0
test_pdp.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for PDP implementation to verify correctness against paper formulas.
3
+ """
4
+
5
+ import torch
6
+ import math
7
+ from pdp import pdp_soft_mask, compute_threshold, PDPPruner
8
+ import torch.nn as nn
9
+
10
+
11
+ def test_soft_mask_boundary_conditions():
12
+ """
13
+ From the paper (Section 3.1):
14
+ - m(w=0) should be 0.5 when |w| = t (equal chance)
15
+ - m(w=0) -> 0 when w=0 (actually z(0)=1 so m(0)=0... wait let me check)
16
+ - m(w->inf) -> 1
17
+ """
18
+ tau = 1e-4
19
+ t = 0.6
20
+
21
+ # w = t -> equal chance, m should be 0.5
22
+ w_eq = torch.tensor([t])
23
+ m_eq = pdp_soft_mask(w_eq, t, tau)
24
+ assert abs(m_eq.item() - 0.5) < 1e-3, f"m(t) should be ~0.5, got {m_eq.item()}"
25
+
26
+ # w >> t -> m -> 1
27
+ w_big = torch.tensor([10.0])
28
+ m_big = pdp_soft_mask(w_big, t, tau)
29
+ assert m_big.item() > 0.99, f"m(>>t) should be ~1, got {m_big.item()}"
30
+
31
+ # w << t -> m -> 0
32
+ w_small = torch.tensor([0.001])
33
+ m_small = pdp_soft_mask(w_small, t, tau)
34
+ assert m_small.item() < 0.01, f"m(<<t) should be ~0, got {m_small.item()}"
35
+
36
+ print("✅ test_soft_mask_boundary_conditions passed")
37
+
38
+
39
+ def test_soft_mask_monotonicity():
40
+ """
41
+ Larger |w| -> larger m(w) (higher chance to keep).
42
+ """
43
+ tau = 1e-4
44
+ t = 0.6
45
+ weights = torch.linspace(0.0, 2.0, 100)
46
+ masks = pdp_soft_mask(weights, t, tau)
47
+
48
+ # Check monotonicity
49
+ for i in range(len(weights) - 1):
50
+ assert masks[i] <= masks[i + 1] + 1e-6, "m(w) must be monotonically increasing"
51
+
52
+ print("✅ test_soft_mask_monotonicity passed")
53
+
54
+
55
+ def test_gradient_flow():
56
+ """
57
+ The soft mask must allow gradients to flow through.
58
+ Paper Eq. 2: Δw = m(w)·Δŵ + (2w²/τ)·m(w)·(1-m(w))·Δŵ
59
+ We verify this with a mild tau so values near the boundary aren't underflowed.
60
+ """
61
+ tau = 1e-1 # larger tau to avoid numerical underflow near boundary
62
+ t = 0.6
63
+ w = torch.tensor([0.59], requires_grad=True) # very close to boundary
64
+
65
+ # Forward: apply soft mask
66
+ masked = pdp_soft_mask(w, t, tau) * w
67
+ loss = masked.sum()
68
+ loss.backward()
69
+
70
+ # Check gradient exists and is non-zero
71
+ assert w.grad is not None, "Gradient should flow through PDP mask"
72
+ assert w.grad.abs().item() > 0, f"Gradient should be non-zero, got {w.grad.item()}"
73
+
74
+ # Near boundary (m≈0.5), the extra gradient term should be maximized
75
+ # dŵ/dw = m(w) + (2w²/τ)·m(w)·(1-m(w))
76
+ m_val = pdp_soft_mask(w.detach(), t, tau).item()
77
+ expected = m_val + 2 * (w.item() ** 2) / tau * m_val * (1 - m_val)
78
+ actual = w.grad.item()
79
+ # Allow tolerance for autograd numerical differences
80
+ assert abs(actual - expected) < 0.1, \
81
+ f"Expected grad ~{expected:.4f}, got {actual:.4f}"
82
+
83
+ print("✅ test_gradient_flow passed")
84
+
85
+
86
+ def test_threshold_computation():
87
+ """
88
+ Test that compute_threshold yields correct sparsity.
89
+ """
90
+ torch.manual_seed(42)
91
+ weights = torch.randn(1000).abs()
92
+ sparsity = 0.3
93
+ t = compute_threshold(weights, sparsity)
94
+
95
+ below = (weights <= t).float().sum().item()
96
+ actual_sparsity = below / weights.numel()
97
+ # Should be close to target (within 1 element)
98
+ assert abs(actual_sparsity - sparsity) < 0.01, \
99
+ f"Threshold sparsity {actual_sparsity} far from target {sparsity}"
100
+
101
+ print("✅ test_threshold_computation passed")
102
+
103
+
104
+ def test_pdp_pruner_end_to_end():
105
+ """
106
+ Full end-to-end test: model, prune, hard prune.
107
+ """
108
+ class Net(nn.Module):
109
+ def __init__(self):
110
+ super().__init__()
111
+ self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
112
+ self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
113
+ self.fc = nn.Linear(32, 10)
114
+
115
+ def forward(self, x):
116
+ x = torch.relu(self.conv1(x))
117
+ x = torch.relu(self.conv2(x))
118
+ x = x.mean(dim=(2, 3))
119
+ return self.fc(x)
120
+
121
+ model = Net()
122
+ pruner = PDPPruner(model, target_sparsity=0.5, s=2, epsilon=0.1, tau=1e-4)
123
+ pruner.attach()
124
+
125
+ # Simulate 4 training steps
126
+ for epoch in range(4):
127
+ x = torch.randn(4, 3, 8, 8)
128
+ y = model(x)
129
+ loss = y.sum()
130
+ loss.backward()
131
+ with torch.no_grad():
132
+ for p in model.parameters():
133
+ if p.grad is not None:
134
+ p -= 0.01 * p.grad
135
+ p.grad.zero_()
136
+ pruner.step(epoch)
137
+
138
+ # Hard prune
139
+ pruner.hard_prune()
140
+ sparsity = pruner.get_sparsity()
141
+ assert sparsity > 0, f"After hard prune, sparsity should be > 0, got {sparsity}"
142
+
143
+ # Check weights are actually zero
144
+ for name, param in model.named_parameters():
145
+ if "weight" in name and any(k in name for k in ["conv", "fc"]):
146
+ zeros = (param.data == 0).float().sum().item()
147
+ assert zeros > 0, f"No weights pruned in {name}"
148
+
149
+ pruner.detach()
150
+ print(f"✅ test_pdp_pruner_end_to_end passed (final sparsity={sparsity:.4f})")
151
+
152
+
153
+ if __name__ == "__main__":
154
+ test_soft_mask_boundary_conditions()
155
+ test_soft_mask_monotonicity()
156
+ test_gradient_flow()
157
+ test_threshold_computation()
158
+ test_pdp_pruner_end_to_end()
159
+ print("\n🎉 All PDP tests passed!")