File size: 13,293 Bytes
3a3ad1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""
PDP: Parameter-free Differentiable Pruning
Implementation based on the NeurIPS 2023 paper:
"PDP: Parameter-free Differentiable Pruning is All You Need"
https://arxiv.org/abs/2305.11203
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List, Optional


def pdp_soft_mask(weight: torch.Tensor, threshold: float, tau: float) -> torch.Tensor:
    """
    Compute the PDP soft pruning mask.

    m(w) = exp(w^2 / tau) / (exp(w^2 / tau) + exp(t^2 / tau))

    Args:
        weight: The weight tensor.
        threshold: The threshold t for this layer/entity.
        tau: Temperature hyperparameter controlling mask softness.

    Returns:
        Soft mask tensor with same shape as weight.
    """
    w2 = weight ** 2
    t2 = threshold ** 2
    # Numerically stable softmax-like computation
    # compute logits = [w^2/tau, t^2/tau]
    logits_w = w2 / tau
    logits_t = torch.full_like(w2, t2 / tau)
    # softmax over the "keep" dimension
    max_logits = torch.maximum(logits_w, logits_t)
    exp_w = torch.exp(logits_w - max_logits)
    exp_t = torch.exp(logits_t - max_logits)
    return exp_w / (exp_w + exp_t)


def compute_threshold(weight: torch.Tensor, sparsity_ratio: float) -> float:
    """
    Compute the threshold t for a given sparsity ratio.
    t is set halfway between the largest pruned weight and the smallest unpruned weight.

    Args:
        weight: Absolute weight tensor (flattened).
        sparsity_ratio: Target sparsity ratio in [0, 1).

    Returns:
        Threshold value t >= 0.
    """
    if sparsity_ratio <= 0:
        return 0.0
    if sparsity_ratio >= 1.0:
        return (weight.max().item() + 1e-6)

    n = weight.numel()
    k = max(1, min(n - 1, int(math.floor(sparsity_ratio * n))))
    sorted_vals, _ = torch.sort(weight)
    pruned_max = sorted_vals[k - 1].item()
    unpruned_min = sorted_vals[k].item() if k < n else sorted_vals[-1].item()
    t = (pruned_max + unpruned_min) / 2.0
    return max(t, 0.0)


def _make_masked_forward(module: nn.Module, pruner: "PDPPruner", param_name: str):
    """
    Monkey-patch module.forward to apply the PDP soft mask during forward pass.
    This preserves the computation graph for differentiable backpropagation.
    """
    if isinstance(module, nn.Conv2d):
        orig_forward = module.forward
        def forward(x):
            t = pruner.thresholds.get(param_name, 0.0)
            if t <= 0:
                return orig_forward(x)
            mask = pdp_soft_mask(module.weight, t, pruner.tau)
            masked_weight = mask * module.weight
            return F.conv2d(
                x, masked_weight, module.bias,
                module.stride, module.padding,
                module.dilation, module.groups
            )
        return forward

    elif isinstance(module, nn.Conv1d):
        orig_forward = module.forward
        def forward(x):
            t = pruner.thresholds.get(param_name, 0.0)
            if t <= 0:
                return orig_forward(x)
            mask = pdp_soft_mask(module.weight, t, pruner.tau)
            masked_weight = mask * module.weight
            return F.conv1d(
                x, masked_weight, module.bias,
                module.stride, module.padding,
                module.dilation, module.groups
            )
        return forward

    elif isinstance(module, nn.Conv3d):
        orig_forward = module.forward
        def forward(x):
            t = pruner.thresholds.get(param_name, 0.0)
            if t <= 0:
                return orig_forward(x)
            mask = pdp_soft_mask(module.weight, t, pruner.tau)
            masked_weight = mask * module.weight
            return F.conv3d(
                x, masked_weight, module.bias,
                module.stride, module.padding,
                module.dilation, module.groups
            )
        return forward

    elif isinstance(module, nn.Linear):
        orig_forward = module.forward
        def forward(x):
            t = pruner.thresholds.get(param_name, 0.0)
            if t <= 0:
                return orig_forward(x)
            mask = pdp_soft_mask(module.weight, t, pruner.tau)
            masked_weight = mask * module.weight
            return F.linear(x, masked_weight, module.bias)
        return forward

    else:
        return module.forward


class PDPPruner:
    """
    Parameter-free Differentiable Pruning (PDP) pruner.

    Applies soft pruning masks during training so the task loss directly guides
    pruning decisions. After training, call hard_prune() for inference.

    Usage:
        pruner = PDPPruner(model, target_sparsity=0.855, s=16, epsilon=0.015, tau=1e-4)
        pruner.attach()
        for epoch in range(num_epochs):
            for batch in dataloader:
                loss = model(...)
                loss.backward()
                optimizer.step()
                pruner.step(epoch)
        pruner.hard_prune()
    """

    def __init__(
        self,
        model: nn.Module,
        target_sparsity: float,
        s: int = 16,
        epsilon: float = 0.015,
        tau: float = 1e-4,
        excluded_modules: Optional[List[str]] = None,
    ):
        """
        Args:
            model: The model to prune.
            target_sparsity: Global target sparsity ratio (e.g. 0.855 for 85.5%).
            s: Warmup epochs before computing target sparsity (default 16).
            epsilon: Gradual pruning rate per epoch (default 0.015 = 1.5%).
            tau: Temperature hyperparameter for soft mask (default 1e-4).
            excluded_modules: List of module class names to exclude.
        """
        self.model = model
        self.target_sparsity = target_sparsity
        self.s = s
        self.epsilon = epsilon
        self.tau = tau
        self.excluded_modules = excluded_modules or ["BatchNorm2d", "LayerNorm", "BatchNorm1d"]

        # Maps param_name -> nn.Parameter
        self.prunable_params: Dict[str, nn.Parameter] = {}
        # Maps param_name -> float (target sparsity for that layer)
        self.layer_sparsity: Dict[str, float] = {}
        # Maps param_name -> float (current threshold t)
        self.thresholds: Dict[str, float] = {}
        # Whether target sparsities have been computed
        self.sparsity_computed = False
        # Current effective global sparsity (gradual schedule)
        self.current_effective_sparsity = 0.0
        # Store original forward methods to restore later
        self._orig_forwards: Dict[str, Callable] = {}

        self._find_prunable_params()

    def _find_prunable_params(self):
        """Identify Conv and Linear weight parameters to prune."""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
                if hasattr(module, "weight") and module.weight is not None:
                    param_name = f"{name}.weight"
                    self.prunable_params[param_name] = module.weight

    def _compute_layer_sparsities(self):
        """
        Compute per-layer target sparsity by sorting all weights globally by magnitude.
        This is the PDP-base strategy from the paper.
        """
        all_weights = []
        for name, param in self.prunable_params.items():
            all_weights.append(param.data.abs().flatten())

        if not all_weights:
            return

        all_weights_cat = torch.cat(all_weights)
        n_total = all_weights_cat.numel()
        k = int(math.floor(self.target_sparsity * n_total))
        k = max(0, min(n_total - 1, k))

        # Global threshold: the k-th smallest weight magnitude
        sorted_vals, _ = torch.sort(all_weights_cat)
        global_threshold = sorted_vals[k].item() if n_total > 0 else 0.0

        # Per-layer sparsity = fraction below/equal to global threshold
        for name, param in self.prunable_params.items():
            w_abs = param.data.abs()
            below = (w_abs <= global_threshold).float().sum().item()
            ratio = below / w_abs.numel()
            self.layer_sparsity[name] = min(ratio, 0.999)  # cap at 99.9%

        self.sparsity_computed = True
        print(f"[PDP] Computed per-layer sparsities at epoch {self.s}. "
              f"Global target: {self.target_sparsity:.4f}")

    def _compute_thresholds(self):
        """Recompute per-layer thresholds t based on current weight distribution."""
        for name, param in self.prunable_params.items():
            ratio = self.layer_sparsity.get(name, 0.0)
            if ratio <= 0:
                self.thresholds[name] = 0.0
                continue
            w_abs = param.data.abs().flatten()
            self.thresholds[name] = compute_threshold(w_abs, ratio)

    def attach(self):
        """Monkey-patch forward methods of prunable modules to apply soft masks."""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
                param_name = f"{name}.weight"
                if param_name in self.prunable_params:
                    self._orig_forwards[param_name] = module.forward
                    module.forward = _make_masked_forward(module, self, param_name)
        print(f"[PDP] Attached masked forwards to {len(self.prunable_params)} prunable layers.")

    def detach(self):
        """Restore original forward methods."""
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
                param_name = f"{name}.weight"
                if param_name in self._orig_forwards:
                    module.forward = self._orig_forwards[param_name]
        self._orig_forwards.clear()
        print("[PDP] Detached all masked forwards.")

    def step(self, epoch: int):
        """
        Call this after each optimizer.step() (or at each epoch boundary).
        Recomputes thresholds and updates gradual sparsity schedule.
        """
        # Warmup: first s epochs, no pruning
        if epoch < self.s:
            return

        # At epoch s, compute per-layer target sparsities (one-time)
        if epoch == self.s and not self.sparsity_computed:
            self._compute_layer_sparsities()

        # Gradual sparsity increase after warmup
        if epoch >= self.s:
            steps_since_s = epoch - self.s + 1
            # Increase by epsilon (absolute percentage) per epoch
            self.current_effective_sparsity = min(
                self.target_sparsity,
                self.epsilon * steps_since_s
            )
            # Scale per-layer sparsities proportionally
            if self.target_sparsity > 0:
                scale = self.current_effective_sparsity / self.target_sparsity
                for name in self.layer_sparsity:
                    self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)

        # Recompute thresholds based on current weight distribution
        self._compute_thresholds()

    def get_sparsity(self) -> float:
        """Return the current actual sparsity (fraction of weights below threshold)."""
        total = 0
        pruned = 0
        for name, param in self.prunable_params.items():
            t = self.thresholds.get(name, 0.0)
            total += param.numel()
            if t > 0:
                pruned += (param.data.abs() <= t).sum().item()
        return pruned / total if total > 0 else 0.0

    def hard_prune(self):
        """
        After training, apply hard pruning masks for inference.
        Sets pruned weights to exactly zero.
        """
        # Restore full target sparsities
        if self.target_sparsity > 0:
            scale = 1.0 / max(self.current_effective_sparsity / self.target_sparsity, 1e-6)
            for name in self.layer_sparsity:
                self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)

        self._compute_thresholds()

        for name, param in self.prunable_params.items():
            t = self.thresholds.get(name, 0.0)
            if t > 0:
                mask = (param.data.abs() > t).float()
                param.data.mul_(mask)

        final_sparsity = self.get_sparsity()
        print(f"[PDP] Hard pruning applied. Final sparsity: {final_sparsity:.4f}")
        return final_sparsity

    def state_dict(self) -> dict:
        """Serialize pruner state."""
        return {
            "target_sparsity": self.target_sparsity,
            "s": self.s,
            "epsilon": self.epsilon,
            "tau": self.tau,
            "sparsity_computed": self.sparsity_computed,
            "layer_sparsity": self.layer_sparsity,
            "thresholds": self.thresholds,
            "current_effective_sparsity": self.current_effective_sparsity,
        }

    def load_state_dict(self, state: dict):
        """Restore pruner state."""
        self.target_sparsity = state["target_sparsity"]
        self.s = state["s"]
        self.epsilon = state["epsilon"]
        self.tau = state["tau"]
        self.sparsity_computed = state["sparsity_computed"]
        self.layer_sparsity = state["layer_sparsity"]
        self.thresholds = state["thresholds"]
        self.current_effective_sparsity = state["current_effective_sparsity"]