ESPR3SS0 commited on
Commit
3a3ad1b
·
verified ·
1 Parent(s): af39b15

Add pdp.py

Browse files
Files changed (1) hide show
  1. pdp.py +348 -0
pdp.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDP: Parameter-free Differentiable Pruning
3
+ Implementation based on the NeurIPS 2023 paper:
4
+ "PDP: Parameter-free Differentiable Pruning is All You Need"
5
+ https://arxiv.org/abs/2305.11203
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import math
12
+ from typing import Dict, List, Optional
13
+
14
+
15
+ def pdp_soft_mask(weight: torch.Tensor, threshold: float, tau: float) -> torch.Tensor:
16
+ """
17
+ Compute the PDP soft pruning mask.
18
+
19
+ m(w) = exp(w^2 / tau) / (exp(w^2 / tau) + exp(t^2 / tau))
20
+
21
+ Args:
22
+ weight: The weight tensor.
23
+ threshold: The threshold t for this layer/entity.
24
+ tau: Temperature hyperparameter controlling mask softness.
25
+
26
+ Returns:
27
+ Soft mask tensor with same shape as weight.
28
+ """
29
+ w2 = weight ** 2
30
+ t2 = threshold ** 2
31
+ # Numerically stable softmax-like computation
32
+ # compute logits = [w^2/tau, t^2/tau]
33
+ logits_w = w2 / tau
34
+ logits_t = torch.full_like(w2, t2 / tau)
35
+ # softmax over the "keep" dimension
36
+ max_logits = torch.maximum(logits_w, logits_t)
37
+ exp_w = torch.exp(logits_w - max_logits)
38
+ exp_t = torch.exp(logits_t - max_logits)
39
+ return exp_w / (exp_w + exp_t)
40
+
41
+
42
+ def compute_threshold(weight: torch.Tensor, sparsity_ratio: float) -> float:
43
+ """
44
+ Compute the threshold t for a given sparsity ratio.
45
+ t is set halfway between the largest pruned weight and the smallest unpruned weight.
46
+
47
+ Args:
48
+ weight: Absolute weight tensor (flattened).
49
+ sparsity_ratio: Target sparsity ratio in [0, 1).
50
+
51
+ Returns:
52
+ Threshold value t >= 0.
53
+ """
54
+ if sparsity_ratio <= 0:
55
+ return 0.0
56
+ if sparsity_ratio >= 1.0:
57
+ return (weight.max().item() + 1e-6)
58
+
59
+ n = weight.numel()
60
+ k = max(1, min(n - 1, int(math.floor(sparsity_ratio * n))))
61
+ sorted_vals, _ = torch.sort(weight)
62
+ pruned_max = sorted_vals[k - 1].item()
63
+ unpruned_min = sorted_vals[k].item() if k < n else sorted_vals[-1].item()
64
+ t = (pruned_max + unpruned_min) / 2.0
65
+ return max(t, 0.0)
66
+
67
+
68
+ def _make_masked_forward(module: nn.Module, pruner: "PDPPruner", param_name: str):
69
+ """
70
+ Monkey-patch module.forward to apply the PDP soft mask during forward pass.
71
+ This preserves the computation graph for differentiable backpropagation.
72
+ """
73
+ if isinstance(module, nn.Conv2d):
74
+ orig_forward = module.forward
75
+ def forward(x):
76
+ t = pruner.thresholds.get(param_name, 0.0)
77
+ if t <= 0:
78
+ return orig_forward(x)
79
+ mask = pdp_soft_mask(module.weight, t, pruner.tau)
80
+ masked_weight = mask * module.weight
81
+ return F.conv2d(
82
+ x, masked_weight, module.bias,
83
+ module.stride, module.padding,
84
+ module.dilation, module.groups
85
+ )
86
+ return forward
87
+
88
+ elif isinstance(module, nn.Conv1d):
89
+ orig_forward = module.forward
90
+ def forward(x):
91
+ t = pruner.thresholds.get(param_name, 0.0)
92
+ if t <= 0:
93
+ return orig_forward(x)
94
+ mask = pdp_soft_mask(module.weight, t, pruner.tau)
95
+ masked_weight = mask * module.weight
96
+ return F.conv1d(
97
+ x, masked_weight, module.bias,
98
+ module.stride, module.padding,
99
+ module.dilation, module.groups
100
+ )
101
+ return forward
102
+
103
+ elif isinstance(module, nn.Conv3d):
104
+ orig_forward = module.forward
105
+ def forward(x):
106
+ t = pruner.thresholds.get(param_name, 0.0)
107
+ if t <= 0:
108
+ return orig_forward(x)
109
+ mask = pdp_soft_mask(module.weight, t, pruner.tau)
110
+ masked_weight = mask * module.weight
111
+ return F.conv3d(
112
+ x, masked_weight, module.bias,
113
+ module.stride, module.padding,
114
+ module.dilation, module.groups
115
+ )
116
+ return forward
117
+
118
+ elif isinstance(module, nn.Linear):
119
+ orig_forward = module.forward
120
+ def forward(x):
121
+ t = pruner.thresholds.get(param_name, 0.0)
122
+ if t <= 0:
123
+ return orig_forward(x)
124
+ mask = pdp_soft_mask(module.weight, t, pruner.tau)
125
+ masked_weight = mask * module.weight
126
+ return F.linear(x, masked_weight, module.bias)
127
+ return forward
128
+
129
+ else:
130
+ return module.forward
131
+
132
+
133
+ class PDPPruner:
134
+ """
135
+ Parameter-free Differentiable Pruning (PDP) pruner.
136
+
137
+ Applies soft pruning masks during training so the task loss directly guides
138
+ pruning decisions. After training, call hard_prune() for inference.
139
+
140
+ Usage:
141
+ pruner = PDPPruner(model, target_sparsity=0.855, s=16, epsilon=0.015, tau=1e-4)
142
+ pruner.attach()
143
+ for epoch in range(num_epochs):
144
+ for batch in dataloader:
145
+ loss = model(...)
146
+ loss.backward()
147
+ optimizer.step()
148
+ pruner.step(epoch)
149
+ pruner.hard_prune()
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ model: nn.Module,
155
+ target_sparsity: float,
156
+ s: int = 16,
157
+ epsilon: float = 0.015,
158
+ tau: float = 1e-4,
159
+ excluded_modules: Optional[List[str]] = None,
160
+ ):
161
+ """
162
+ Args:
163
+ model: The model to prune.
164
+ target_sparsity: Global target sparsity ratio (e.g. 0.855 for 85.5%).
165
+ s: Warmup epochs before computing target sparsity (default 16).
166
+ epsilon: Gradual pruning rate per epoch (default 0.015 = 1.5%).
167
+ tau: Temperature hyperparameter for soft mask (default 1e-4).
168
+ excluded_modules: List of module class names to exclude.
169
+ """
170
+ self.model = model
171
+ self.target_sparsity = target_sparsity
172
+ self.s = s
173
+ self.epsilon = epsilon
174
+ self.tau = tau
175
+ self.excluded_modules = excluded_modules or ["BatchNorm2d", "LayerNorm", "BatchNorm1d"]
176
+
177
+ # Maps param_name -> nn.Parameter
178
+ self.prunable_params: Dict[str, nn.Parameter] = {}
179
+ # Maps param_name -> float (target sparsity for that layer)
180
+ self.layer_sparsity: Dict[str, float] = {}
181
+ # Maps param_name -> float (current threshold t)
182
+ self.thresholds: Dict[str, float] = {}
183
+ # Whether target sparsities have been computed
184
+ self.sparsity_computed = False
185
+ # Current effective global sparsity (gradual schedule)
186
+ self.current_effective_sparsity = 0.0
187
+ # Store original forward methods to restore later
188
+ self._orig_forwards: Dict[str, Callable] = {}
189
+
190
+ self._find_prunable_params()
191
+
192
+ def _find_prunable_params(self):
193
+ """Identify Conv and Linear weight parameters to prune."""
194
+ for name, module in self.model.named_modules():
195
+ if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
196
+ if hasattr(module, "weight") and module.weight is not None:
197
+ param_name = f"{name}.weight"
198
+ self.prunable_params[param_name] = module.weight
199
+
200
+ def _compute_layer_sparsities(self):
201
+ """
202
+ Compute per-layer target sparsity by sorting all weights globally by magnitude.
203
+ This is the PDP-base strategy from the paper.
204
+ """
205
+ all_weights = []
206
+ for name, param in self.prunable_params.items():
207
+ all_weights.append(param.data.abs().flatten())
208
+
209
+ if not all_weights:
210
+ return
211
+
212
+ all_weights_cat = torch.cat(all_weights)
213
+ n_total = all_weights_cat.numel()
214
+ k = int(math.floor(self.target_sparsity * n_total))
215
+ k = max(0, min(n_total - 1, k))
216
+
217
+ # Global threshold: the k-th smallest weight magnitude
218
+ sorted_vals, _ = torch.sort(all_weights_cat)
219
+ global_threshold = sorted_vals[k].item() if n_total > 0 else 0.0
220
+
221
+ # Per-layer sparsity = fraction below/equal to global threshold
222
+ for name, param in self.prunable_params.items():
223
+ w_abs = param.data.abs()
224
+ below = (w_abs <= global_threshold).float().sum().item()
225
+ ratio = below / w_abs.numel()
226
+ self.layer_sparsity[name] = min(ratio, 0.999) # cap at 99.9%
227
+
228
+ self.sparsity_computed = True
229
+ print(f"[PDP] Computed per-layer sparsities at epoch {self.s}. "
230
+ f"Global target: {self.target_sparsity:.4f}")
231
+
232
+ def _compute_thresholds(self):
233
+ """Recompute per-layer thresholds t based on current weight distribution."""
234
+ for name, param in self.prunable_params.items():
235
+ ratio = self.layer_sparsity.get(name, 0.0)
236
+ if ratio <= 0:
237
+ self.thresholds[name] = 0.0
238
+ continue
239
+ w_abs = param.data.abs().flatten()
240
+ self.thresholds[name] = compute_threshold(w_abs, ratio)
241
+
242
+ def attach(self):
243
+ """Monkey-patch forward methods of prunable modules to apply soft masks."""
244
+ for name, module in self.model.named_modules():
245
+ if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
246
+ param_name = f"{name}.weight"
247
+ if param_name in self.prunable_params:
248
+ self._orig_forwards[param_name] = module.forward
249
+ module.forward = _make_masked_forward(module, self, param_name)
250
+ print(f"[PDP] Attached masked forwards to {len(self.prunable_params)} prunable layers.")
251
+
252
+ def detach(self):
253
+ """Restore original forward methods."""
254
+ for name, module in self.model.named_modules():
255
+ if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
256
+ param_name = f"{name}.weight"
257
+ if param_name in self._orig_forwards:
258
+ module.forward = self._orig_forwards[param_name]
259
+ self._orig_forwards.clear()
260
+ print("[PDP] Detached all masked forwards.")
261
+
262
+ def step(self, epoch: int):
263
+ """
264
+ Call this after each optimizer.step() (or at each epoch boundary).
265
+ Recomputes thresholds and updates gradual sparsity schedule.
266
+ """
267
+ # Warmup: first s epochs, no pruning
268
+ if epoch < self.s:
269
+ return
270
+
271
+ # At epoch s, compute per-layer target sparsities (one-time)
272
+ if epoch == self.s and not self.sparsity_computed:
273
+ self._compute_layer_sparsities()
274
+
275
+ # Gradual sparsity increase after warmup
276
+ if epoch >= self.s:
277
+ steps_since_s = epoch - self.s + 1
278
+ # Increase by epsilon (absolute percentage) per epoch
279
+ self.current_effective_sparsity = min(
280
+ self.target_sparsity,
281
+ self.epsilon * steps_since_s
282
+ )
283
+ # Scale per-layer sparsities proportionally
284
+ if self.target_sparsity > 0:
285
+ scale = self.current_effective_sparsity / self.target_sparsity
286
+ for name in self.layer_sparsity:
287
+ self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)
288
+
289
+ # Recompute thresholds based on current weight distribution
290
+ self._compute_thresholds()
291
+
292
+ def get_sparsity(self) -> float:
293
+ """Return the current actual sparsity (fraction of weights below threshold)."""
294
+ total = 0
295
+ pruned = 0
296
+ for name, param in self.prunable_params.items():
297
+ t = self.thresholds.get(name, 0.0)
298
+ total += param.numel()
299
+ if t > 0:
300
+ pruned += (param.data.abs() <= t).sum().item()
301
+ return pruned / total if total > 0 else 0.0
302
+
303
+ def hard_prune(self):
304
+ """
305
+ After training, apply hard pruning masks for inference.
306
+ Sets pruned weights to exactly zero.
307
+ """
308
+ # Restore full target sparsities
309
+ if self.target_sparsity > 0:
310
+ scale = 1.0 / max(self.current_effective_sparsity / self.target_sparsity, 1e-6)
311
+ for name in self.layer_sparsity:
312
+ self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)
313
+
314
+ self._compute_thresholds()
315
+
316
+ for name, param in self.prunable_params.items():
317
+ t = self.thresholds.get(name, 0.0)
318
+ if t > 0:
319
+ mask = (param.data.abs() > t).float()
320
+ param.data.mul_(mask)
321
+
322
+ final_sparsity = self.get_sparsity()
323
+ print(f"[PDP] Hard pruning applied. Final sparsity: {final_sparsity:.4f}")
324
+ return final_sparsity
325
+
326
+ def state_dict(self) -> dict:
327
+ """Serialize pruner state."""
328
+ return {
329
+ "target_sparsity": self.target_sparsity,
330
+ "s": self.s,
331
+ "epsilon": self.epsilon,
332
+ "tau": self.tau,
333
+ "sparsity_computed": self.sparsity_computed,
334
+ "layer_sparsity": self.layer_sparsity,
335
+ "thresholds": self.thresholds,
336
+ "current_effective_sparsity": self.current_effective_sparsity,
337
+ }
338
+
339
+ def load_state_dict(self, state: dict):
340
+ """Restore pruner state."""
341
+ self.target_sparsity = state["target_sparsity"]
342
+ self.s = state["s"]
343
+ self.epsilon = state["epsilon"]
344
+ self.tau = state["tau"]
345
+ self.sparsity_computed = state["sparsity_computed"]
346
+ self.layer_sparsity = state["layer_sparsity"]
347
+ self.thresholds = state["thresholds"]
348
+ self.current_effective_sparsity = state["current_effective_sparsity"]