""" PDP: Parameter-free Differentiable Pruning Implementation based on the NeurIPS 2023 paper: "PDP: Parameter-free Differentiable Pruning is All You Need" https://arxiv.org/abs/2305.11203 """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Dict, List, Optional def pdp_soft_mask(weight: torch.Tensor, threshold: float, tau: float) -> torch.Tensor: """ Compute the PDP soft pruning mask. m(w) = exp(w^2 / tau) / (exp(w^2 / tau) + exp(t^2 / tau)) Args: weight: The weight tensor. threshold: The threshold t for this layer/entity. tau: Temperature hyperparameter controlling mask softness. Returns: Soft mask tensor with same shape as weight. """ w2 = weight ** 2 t2 = threshold ** 2 # Numerically stable softmax-like computation # compute logits = [w^2/tau, t^2/tau] logits_w = w2 / tau logits_t = torch.full_like(w2, t2 / tau) # softmax over the "keep" dimension max_logits = torch.maximum(logits_w, logits_t) exp_w = torch.exp(logits_w - max_logits) exp_t = torch.exp(logits_t - max_logits) return exp_w / (exp_w + exp_t) def compute_threshold(weight: torch.Tensor, sparsity_ratio: float) -> float: """ Compute the threshold t for a given sparsity ratio. t is set halfway between the largest pruned weight and the smallest unpruned weight. Args: weight: Absolute weight tensor (flattened). sparsity_ratio: Target sparsity ratio in [0, 1). Returns: Threshold value t >= 0. """ if sparsity_ratio <= 0: return 0.0 if sparsity_ratio >= 1.0: return (weight.max().item() + 1e-6) n = weight.numel() k = max(1, min(n - 1, int(math.floor(sparsity_ratio * n)))) sorted_vals, _ = torch.sort(weight) pruned_max = sorted_vals[k - 1].item() unpruned_min = sorted_vals[k].item() if k < n else sorted_vals[-1].item() t = (pruned_max + unpruned_min) / 2.0 return max(t, 0.0) def _make_masked_forward(module: nn.Module, pruner: "PDPPruner", param_name: str): """ Monkey-patch module.forward to apply the PDP soft mask during forward pass. This preserves the computation graph for differentiable backpropagation. """ if isinstance(module, nn.Conv2d): orig_forward = module.forward def forward(x): t = pruner.thresholds.get(param_name, 0.0) if t <= 0: return orig_forward(x) mask = pdp_soft_mask(module.weight, t, pruner.tau) masked_weight = mask * module.weight return F.conv2d( x, masked_weight, module.bias, module.stride, module.padding, module.dilation, module.groups ) return forward elif isinstance(module, nn.Conv1d): orig_forward = module.forward def forward(x): t = pruner.thresholds.get(param_name, 0.0) if t <= 0: return orig_forward(x) mask = pdp_soft_mask(module.weight, t, pruner.tau) masked_weight = mask * module.weight return F.conv1d( x, masked_weight, module.bias, module.stride, module.padding, module.dilation, module.groups ) return forward elif isinstance(module, nn.Conv3d): orig_forward = module.forward def forward(x): t = pruner.thresholds.get(param_name, 0.0) if t <= 0: return orig_forward(x) mask = pdp_soft_mask(module.weight, t, pruner.tau) masked_weight = mask * module.weight return F.conv3d( x, masked_weight, module.bias, module.stride, module.padding, module.dilation, module.groups ) return forward elif isinstance(module, nn.Linear): orig_forward = module.forward def forward(x): t = pruner.thresholds.get(param_name, 0.0) if t <= 0: return orig_forward(x) mask = pdp_soft_mask(module.weight, t, pruner.tau) masked_weight = mask * module.weight return F.linear(x, masked_weight, module.bias) return forward else: return module.forward class PDPPruner: """ Parameter-free Differentiable Pruning (PDP) pruner. Applies soft pruning masks during training so the task loss directly guides pruning decisions. After training, call hard_prune() for inference. Usage: pruner = PDPPruner(model, target_sparsity=0.855, s=16, epsilon=0.015, tau=1e-4) pruner.attach() for epoch in range(num_epochs): for batch in dataloader: loss = model(...) loss.backward() optimizer.step() pruner.step(epoch) pruner.hard_prune() """ def __init__( self, model: nn.Module, target_sparsity: float, s: int = 16, epsilon: float = 0.015, tau: float = 1e-4, excluded_modules: Optional[List[str]] = None, ): """ Args: model: The model to prune. target_sparsity: Global target sparsity ratio (e.g. 0.855 for 85.5%). s: Warmup epochs before computing target sparsity (default 16). epsilon: Gradual pruning rate per epoch (default 0.015 = 1.5%). tau: Temperature hyperparameter for soft mask (default 1e-4). excluded_modules: List of module class names to exclude. """ self.model = model self.target_sparsity = target_sparsity self.s = s self.epsilon = epsilon self.tau = tau self.excluded_modules = excluded_modules or ["BatchNorm2d", "LayerNorm", "BatchNorm1d"] # Maps param_name -> nn.Parameter self.prunable_params: Dict[str, nn.Parameter] = {} # Maps param_name -> float (target sparsity for that layer) self.layer_sparsity: Dict[str, float] = {} # Maps param_name -> float (current threshold t) self.thresholds: Dict[str, float] = {} # Whether target sparsities have been computed self.sparsity_computed = False # Current effective global sparsity (gradual schedule) self.current_effective_sparsity = 0.0 # Store original forward methods to restore later self._orig_forwards: Dict[str, Callable] = {} self._find_prunable_params() def _find_prunable_params(self): """Identify Conv and Linear weight parameters to prune.""" for name, module in self.model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): if hasattr(module, "weight") and module.weight is not None: param_name = f"{name}.weight" self.prunable_params[param_name] = module.weight def _compute_layer_sparsities(self): """ Compute per-layer target sparsity by sorting all weights globally by magnitude. This is the PDP-base strategy from the paper. """ all_weights = [] for name, param in self.prunable_params.items(): all_weights.append(param.data.abs().flatten()) if not all_weights: return all_weights_cat = torch.cat(all_weights) n_total = all_weights_cat.numel() k = int(math.floor(self.target_sparsity * n_total)) k = max(0, min(n_total - 1, k)) # Global threshold: the k-th smallest weight magnitude sorted_vals, _ = torch.sort(all_weights_cat) global_threshold = sorted_vals[k].item() if n_total > 0 else 0.0 # Per-layer sparsity = fraction below/equal to global threshold for name, param in self.prunable_params.items(): w_abs = param.data.abs() below = (w_abs <= global_threshold).float().sum().item() ratio = below / w_abs.numel() self.layer_sparsity[name] = min(ratio, 0.999) # cap at 99.9% self.sparsity_computed = True print(f"[PDP] Computed per-layer sparsities at epoch {self.s}. " f"Global target: {self.target_sparsity:.4f}") def _compute_thresholds(self): """Recompute per-layer thresholds t based on current weight distribution.""" for name, param in self.prunable_params.items(): ratio = self.layer_sparsity.get(name, 0.0) if ratio <= 0: self.thresholds[name] = 0.0 continue w_abs = param.data.abs().flatten() self.thresholds[name] = compute_threshold(w_abs, ratio) def attach(self): """Monkey-patch forward methods of prunable modules to apply soft masks.""" for name, module in self.model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): param_name = f"{name}.weight" if param_name in self.prunable_params: self._orig_forwards[param_name] = module.forward module.forward = _make_masked_forward(module, self, param_name) print(f"[PDP] Attached masked forwards to {len(self.prunable_params)} prunable layers.") def detach(self): """Restore original forward methods.""" for name, module in self.model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)): param_name = f"{name}.weight" if param_name in self._orig_forwards: module.forward = self._orig_forwards[param_name] self._orig_forwards.clear() print("[PDP] Detached all masked forwards.") def step(self, epoch: int): """ Call this after each optimizer.step() (or at each epoch boundary). Recomputes thresholds and updates gradual sparsity schedule. """ # Warmup: first s epochs, no pruning if epoch < self.s: return # At epoch s, compute per-layer target sparsities (one-time) if epoch == self.s and not self.sparsity_computed: self._compute_layer_sparsities() # Gradual sparsity increase after warmup if epoch >= self.s: steps_since_s = epoch - self.s + 1 # Increase by epsilon (absolute percentage) per epoch self.current_effective_sparsity = min( self.target_sparsity, self.epsilon * steps_since_s ) # Scale per-layer sparsities proportionally if self.target_sparsity > 0: scale = self.current_effective_sparsity / self.target_sparsity for name in self.layer_sparsity: self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale) # Recompute thresholds based on current weight distribution self._compute_thresholds() def get_sparsity(self) -> float: """Return the current actual sparsity (fraction of weights below threshold).""" total = 0 pruned = 0 for name, param in self.prunable_params.items(): t = self.thresholds.get(name, 0.0) total += param.numel() if t > 0: pruned += (param.data.abs() <= t).sum().item() return pruned / total if total > 0 else 0.0 def hard_prune(self): """ After training, apply hard pruning masks for inference. Sets pruned weights to exactly zero. """ # Restore full target sparsities if self.target_sparsity > 0: scale = 1.0 / max(self.current_effective_sparsity / self.target_sparsity, 1e-6) for name in self.layer_sparsity: self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale) self._compute_thresholds() for name, param in self.prunable_params.items(): t = self.thresholds.get(name, 0.0) if t > 0: mask = (param.data.abs() > t).float() param.data.mul_(mask) final_sparsity = self.get_sparsity() print(f"[PDP] Hard pruning applied. Final sparsity: {final_sparsity:.4f}") return final_sparsity def state_dict(self) -> dict: """Serialize pruner state.""" return { "target_sparsity": self.target_sparsity, "s": self.s, "epsilon": self.epsilon, "tau": self.tau, "sparsity_computed": self.sparsity_computed, "layer_sparsity": self.layer_sparsity, "thresholds": self.thresholds, "current_effective_sparsity": self.current_effective_sparsity, } def load_state_dict(self, state: dict): """Restore pruner state.""" self.target_sparsity = state["target_sparsity"] self.s = state["s"] self.epsilon = state["epsilon"] self.tau = state["tau"] self.sparsity_computed = state["sparsity_computed"] self.layer_sparsity = state["layer_sparsity"] self.thresholds = state["thresholds"] self.current_effective_sparsity = state["current_effective_sparsity"]