| """ |
| PDP: Parameter-free Differentiable Pruning |
| Implementation based on the NeurIPS 2023 paper: |
| "PDP: Parameter-free Differentiable Pruning is All You Need" |
| https://arxiv.org/abs/2305.11203 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from typing import Dict, List, Optional |
|
|
|
|
| def pdp_soft_mask(weight: torch.Tensor, threshold: float, tau: float) -> torch.Tensor: |
| """ |
| Compute the PDP soft pruning mask. |
| |
| m(w) = exp(w^2 / tau) / (exp(w^2 / tau) + exp(t^2 / tau)) |
| |
| Args: |
| weight: The weight tensor. |
| threshold: The threshold t for this layer/entity. |
| tau: Temperature hyperparameter controlling mask softness. |
| |
| Returns: |
| Soft mask tensor with same shape as weight. |
| """ |
| w2 = weight ** 2 |
| t2 = threshold ** 2 |
| |
| |
| logits_w = w2 / tau |
| logits_t = torch.full_like(w2, t2 / tau) |
| |
| max_logits = torch.maximum(logits_w, logits_t) |
| exp_w = torch.exp(logits_w - max_logits) |
| exp_t = torch.exp(logits_t - max_logits) |
| return exp_w / (exp_w + exp_t) |
|
|
|
|
| def compute_threshold(weight: torch.Tensor, sparsity_ratio: float) -> float: |
| """ |
| Compute the threshold t for a given sparsity ratio. |
| t is set halfway between the largest pruned weight and the smallest unpruned weight. |
| |
| Args: |
| weight: Absolute weight tensor (flattened). |
| sparsity_ratio: Target sparsity ratio in [0, 1). |
| |
| Returns: |
| Threshold value t >= 0. |
| """ |
| if sparsity_ratio <= 0: |
| return 0.0 |
| if sparsity_ratio >= 1.0: |
| return (weight.max().item() + 1e-6) |
|
|
| n = weight.numel() |
| k = max(1, min(n - 1, int(math.floor(sparsity_ratio * n)))) |
| sorted_vals, _ = torch.sort(weight) |
| pruned_max = sorted_vals[k - 1].item() |
| unpruned_min = sorted_vals[k].item() if k < n else sorted_vals[-1].item() |
| t = (pruned_max + unpruned_min) / 2.0 |
| return max(t, 0.0) |
|
|
|
|
| def _make_masked_forward(module: nn.Module, pruner: "PDPPruner", param_name: str): |
| """ |
| Monkey-patch module.forward to apply the PDP soft mask during forward pass. |
| This preserves the computation graph for differentiable backpropagation. |
| """ |
| if isinstance(module, nn.Conv2d): |
| orig_forward = module.forward |
| def forward(x): |
| t = pruner.thresholds.get(param_name, 0.0) |
| if t <= 0: |
| return orig_forward(x) |
| mask = pdp_soft_mask(module.weight, t, pruner.tau) |
| masked_weight = mask * module.weight |
| return F.conv2d( |
| x, masked_weight, module.bias, |
| module.stride, module.padding, |
| module.dilation, module.groups |
| ) |
| return forward |
|
|
| elif isinstance(module, nn.Conv1d): |
| orig_forward = module.forward |
| def forward(x): |
| t = pruner.thresholds.get(param_name, 0.0) |
| if t <= 0: |
| return orig_forward(x) |
| mask = pdp_soft_mask(module.weight, t, pruner.tau) |
| masked_weight = mask * module.weight |
| return F.conv1d( |
| x, masked_weight, module.bias, |
| module.stride, module.padding, |
| module.dilation, module.groups |
| ) |
| return forward |
|
|
| elif isinstance(module, nn.Conv3d): |
| orig_forward = module.forward |
| def forward(x): |
| t = pruner.thresholds.get(param_name, 0.0) |
| if t <= 0: |
| return orig_forward(x) |
| mask = pdp_soft_mask(module.weight, t, pruner.tau) |
| masked_weight = mask * module.weight |
| return F.conv3d( |
| x, masked_weight, module.bias, |
| module.stride, module.padding, |
| module.dilation, module.groups |
| ) |
| return forward |
|
|
| elif isinstance(module, nn.Linear): |
| orig_forward = module.forward |
| def forward(x): |
| t = pruner.thresholds.get(param_name, 0.0) |
| if t <= 0: |
| return orig_forward(x) |
| mask = pdp_soft_mask(module.weight, t, pruner.tau) |
| masked_weight = mask * module.weight |
| return F.linear(x, masked_weight, module.bias) |
| return forward |
|
|
| else: |
| return module.forward |
|
|
|
|
| class PDPPruner: |
| """ |
| Parameter-free Differentiable Pruning (PDP) pruner. |
| |
| Applies soft pruning masks during training so the task loss directly guides |
| pruning decisions. After training, call hard_prune() for inference. |
| |
| Usage: |
| pruner = PDPPruner(model, target_sparsity=0.855, s=16, epsilon=0.015, tau=1e-4) |
| pruner.attach() |
| for epoch in range(num_epochs): |
| for batch in dataloader: |
| loss = model(...) |
| loss.backward() |
| optimizer.step() |
| pruner.step(epoch) |
| pruner.hard_prune() |
| """ |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| target_sparsity: float, |
| s: int = 16, |
| epsilon: float = 0.015, |
| tau: float = 1e-4, |
| excluded_modules: Optional[List[str]] = None, |
| ): |
| """ |
| Args: |
| model: The model to prune. |
| target_sparsity: Global target sparsity ratio (e.g. 0.855 for 85.5%). |
| s: Warmup epochs before computing target sparsity (default 16). |
| epsilon: Gradual pruning rate per epoch (default 0.015 = 1.5%). |
| tau: Temperature hyperparameter for soft mask (default 1e-4). |
| excluded_modules: List of module class names to exclude. |
| """ |
| self.model = model |
| self.target_sparsity = target_sparsity |
| self.s = s |
| self.epsilon = epsilon |
| self.tau = tau |
| self.excluded_modules = excluded_modules or ["BatchNorm2d", "LayerNorm", "BatchNorm1d"] |
|
|
| |
| self.prunable_params: Dict[str, nn.Parameter] = {} |
| |
| self.layer_sparsity: Dict[str, float] = {} |
| |
| self.thresholds: Dict[str, float] = {} |
| |
| self.sparsity_computed = False |
| |
| self.current_effective_sparsity = 0.0 |
| |
| self._orig_forwards: Dict[str, Callable] = {} |
|
|
| self._find_prunable_params() |
|
|
| def _find_prunable_params(self): |
| """Identify Conv and Linear weight parameters to prune.""" |
| for name, module in self.model.named_modules(): |
| if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): |
| if hasattr(module, "weight") and module.weight is not None: |
| param_name = f"{name}.weight" |
| self.prunable_params[param_name] = module.weight |
|
|
| def _compute_layer_sparsities(self): |
| """ |
| Compute per-layer target sparsity by sorting all weights globally by magnitude. |
| This is the PDP-base strategy from the paper. |
| """ |
| all_weights = [] |
| for name, param in self.prunable_params.items(): |
| all_weights.append(param.data.abs().flatten()) |
|
|
| if not all_weights: |
| return |
|
|
| all_weights_cat = torch.cat(all_weights) |
| n_total = all_weights_cat.numel() |
| k = int(math.floor(self.target_sparsity * n_total)) |
| k = max(0, min(n_total - 1, k)) |
|
|
| |
| sorted_vals, _ = torch.sort(all_weights_cat) |
| global_threshold = sorted_vals[k].item() if n_total > 0 else 0.0 |
|
|
| |
| for name, param in self.prunable_params.items(): |
| w_abs = param.data.abs() |
| below = (w_abs <= global_threshold).float().sum().item() |
| ratio = below / w_abs.numel() |
| self.layer_sparsity[name] = min(ratio, 0.999) |
|
|
| self.sparsity_computed = True |
| print(f"[PDP] Computed per-layer sparsities at epoch {self.s}. " |
| f"Global target: {self.target_sparsity:.4f}") |
|
|
| def _compute_thresholds(self): |
| """Recompute per-layer thresholds t based on current weight distribution.""" |
| for name, param in self.prunable_params.items(): |
| ratio = self.layer_sparsity.get(name, 0.0) |
| if ratio <= 0: |
| self.thresholds[name] = 0.0 |
| continue |
| w_abs = param.data.abs().flatten() |
| self.thresholds[name] = compute_threshold(w_abs, ratio) |
|
|
| def attach(self): |
| """Monkey-patch forward methods of prunable modules to apply soft masks.""" |
| for name, module in self.model.named_modules(): |
| if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): |
| param_name = f"{name}.weight" |
| if param_name in self.prunable_params: |
| self._orig_forwards[param_name] = module.forward |
| module.forward = _make_masked_forward(module, self, param_name) |
| print(f"[PDP] Attached masked forwards to {len(self.prunable_params)} prunable layers.") |
|
|
| def detach(self): |
| """Restore original forward methods.""" |
| for name, module in self.model.named_modules(): |
| if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): |
| param_name = f"{name}.weight" |
| if param_name in self._orig_forwards: |
| module.forward = self._orig_forwards[param_name] |
| self._orig_forwards.clear() |
| print("[PDP] Detached all masked forwards.") |
|
|
| def step(self, epoch: int): |
| """ |
| Call this after each optimizer.step() (or at each epoch boundary). |
| Recomputes thresholds and updates gradual sparsity schedule. |
| """ |
| |
| if epoch < self.s: |
| return |
|
|
| |
| if epoch == self.s and not self.sparsity_computed: |
| self._compute_layer_sparsities() |
|
|
| |
| if epoch >= self.s: |
| steps_since_s = epoch - self.s + 1 |
| |
| self.current_effective_sparsity = min( |
| self.target_sparsity, |
| self.epsilon * steps_since_s |
| ) |
| |
| if self.target_sparsity > 0: |
| scale = self.current_effective_sparsity / self.target_sparsity |
| for name in self.layer_sparsity: |
| self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale) |
|
|
| |
| self._compute_thresholds() |
|
|
| def get_sparsity(self) -> float: |
| """Return the current actual sparsity (fraction of weights below threshold).""" |
| total = 0 |
| pruned = 0 |
| for name, param in self.prunable_params.items(): |
| t = self.thresholds.get(name, 0.0) |
| total += param.numel() |
| if t > 0: |
| pruned += (param.data.abs() <= t).sum().item() |
| return pruned / total if total > 0 else 0.0 |
|
|
| def hard_prune(self): |
| """ |
| After training, apply hard pruning masks for inference. |
| Sets pruned weights to exactly zero. |
| """ |
| |
| if self.target_sparsity > 0: |
| scale = 1.0 / max(self.current_effective_sparsity / self.target_sparsity, 1e-6) |
| for name in self.layer_sparsity: |
| self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale) |
|
|
| self._compute_thresholds() |
|
|
| for name, param in self.prunable_params.items(): |
| t = self.thresholds.get(name, 0.0) |
| if t > 0: |
| mask = (param.data.abs() > t).float() |
| param.data.mul_(mask) |
|
|
| final_sparsity = self.get_sparsity() |
| print(f"[PDP] Hard pruning applied. Final sparsity: {final_sparsity:.4f}") |
| return final_sparsity |
|
|
| def state_dict(self) -> dict: |
| """Serialize pruner state.""" |
| return { |
| "target_sparsity": self.target_sparsity, |
| "s": self.s, |
| "epsilon": self.epsilon, |
| "tau": self.tau, |
| "sparsity_computed": self.sparsity_computed, |
| "layer_sparsity": self.layer_sparsity, |
| "thresholds": self.thresholds, |
| "current_effective_sparsity": self.current_effective_sparsity, |
| } |
|
|
| def load_state_dict(self, state: dict): |
| """Restore pruner state.""" |
| self.target_sparsity = state["target_sparsity"] |
| self.s = state["s"] |
| self.epsilon = state["epsilon"] |
| self.tau = state["tau"] |
| self.sparsity_computed = state["sparsity_computed"] |
| self.layer_sparsity = state["layer_sparsity"] |
| self.thresholds = state["thresholds"] |
| self.current_effective_sparsity = state["current_effective_sparsity"] |
|
|