ESPR3SS0's picture
Add pdp.py
3a3ad1b verified
"""
PDP: Parameter-free Differentiable Pruning
Implementation based on the NeurIPS 2023 paper:
"PDP: Parameter-free Differentiable Pruning is All You Need"
https://arxiv.org/abs/2305.11203
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Dict, List, Optional
def pdp_soft_mask(weight: torch.Tensor, threshold: float, tau: float) -> torch.Tensor:
"""
Compute the PDP soft pruning mask.
m(w) = exp(w^2 / tau) / (exp(w^2 / tau) + exp(t^2 / tau))
Args:
weight: The weight tensor.
threshold: The threshold t for this layer/entity.
tau: Temperature hyperparameter controlling mask softness.
Returns:
Soft mask tensor with same shape as weight.
"""
w2 = weight ** 2
t2 = threshold ** 2
# Numerically stable softmax-like computation
# compute logits = [w^2/tau, t^2/tau]
logits_w = w2 / tau
logits_t = torch.full_like(w2, t2 / tau)
# softmax over the "keep" dimension
max_logits = torch.maximum(logits_w, logits_t)
exp_w = torch.exp(logits_w - max_logits)
exp_t = torch.exp(logits_t - max_logits)
return exp_w / (exp_w + exp_t)
def compute_threshold(weight: torch.Tensor, sparsity_ratio: float) -> float:
"""
Compute the threshold t for a given sparsity ratio.
t is set halfway between the largest pruned weight and the smallest unpruned weight.
Args:
weight: Absolute weight tensor (flattened).
sparsity_ratio: Target sparsity ratio in [0, 1).
Returns:
Threshold value t >= 0.
"""
if sparsity_ratio <= 0:
return 0.0
if sparsity_ratio >= 1.0:
return (weight.max().item() + 1e-6)
n = weight.numel()
k = max(1, min(n - 1, int(math.floor(sparsity_ratio * n))))
sorted_vals, _ = torch.sort(weight)
pruned_max = sorted_vals[k - 1].item()
unpruned_min = sorted_vals[k].item() if k < n else sorted_vals[-1].item()
t = (pruned_max + unpruned_min) / 2.0
return max(t, 0.0)
def _make_masked_forward(module: nn.Module, pruner: "PDPPruner", param_name: str):
"""
Monkey-patch module.forward to apply the PDP soft mask during forward pass.
This preserves the computation graph for differentiable backpropagation.
"""
if isinstance(module, nn.Conv2d):
orig_forward = module.forward
def forward(x):
t = pruner.thresholds.get(param_name, 0.0)
if t <= 0:
return orig_forward(x)
mask = pdp_soft_mask(module.weight, t, pruner.tau)
masked_weight = mask * module.weight
return F.conv2d(
x, masked_weight, module.bias,
module.stride, module.padding,
module.dilation, module.groups
)
return forward
elif isinstance(module, nn.Conv1d):
orig_forward = module.forward
def forward(x):
t = pruner.thresholds.get(param_name, 0.0)
if t <= 0:
return orig_forward(x)
mask = pdp_soft_mask(module.weight, t, pruner.tau)
masked_weight = mask * module.weight
return F.conv1d(
x, masked_weight, module.bias,
module.stride, module.padding,
module.dilation, module.groups
)
return forward
elif isinstance(module, nn.Conv3d):
orig_forward = module.forward
def forward(x):
t = pruner.thresholds.get(param_name, 0.0)
if t <= 0:
return orig_forward(x)
mask = pdp_soft_mask(module.weight, t, pruner.tau)
masked_weight = mask * module.weight
return F.conv3d(
x, masked_weight, module.bias,
module.stride, module.padding,
module.dilation, module.groups
)
return forward
elif isinstance(module, nn.Linear):
orig_forward = module.forward
def forward(x):
t = pruner.thresholds.get(param_name, 0.0)
if t <= 0:
return orig_forward(x)
mask = pdp_soft_mask(module.weight, t, pruner.tau)
masked_weight = mask * module.weight
return F.linear(x, masked_weight, module.bias)
return forward
else:
return module.forward
class PDPPruner:
"""
Parameter-free Differentiable Pruning (PDP) pruner.
Applies soft pruning masks during training so the task loss directly guides
pruning decisions. After training, call hard_prune() for inference.
Usage:
pruner = PDPPruner(model, target_sparsity=0.855, s=16, epsilon=0.015, tau=1e-4)
pruner.attach()
for epoch in range(num_epochs):
for batch in dataloader:
loss = model(...)
loss.backward()
optimizer.step()
pruner.step(epoch)
pruner.hard_prune()
"""
def __init__(
self,
model: nn.Module,
target_sparsity: float,
s: int = 16,
epsilon: float = 0.015,
tau: float = 1e-4,
excluded_modules: Optional[List[str]] = None,
):
"""
Args:
model: The model to prune.
target_sparsity: Global target sparsity ratio (e.g. 0.855 for 85.5%).
s: Warmup epochs before computing target sparsity (default 16).
epsilon: Gradual pruning rate per epoch (default 0.015 = 1.5%).
tau: Temperature hyperparameter for soft mask (default 1e-4).
excluded_modules: List of module class names to exclude.
"""
self.model = model
self.target_sparsity = target_sparsity
self.s = s
self.epsilon = epsilon
self.tau = tau
self.excluded_modules = excluded_modules or ["BatchNorm2d", "LayerNorm", "BatchNorm1d"]
# Maps param_name -> nn.Parameter
self.prunable_params: Dict[str, nn.Parameter] = {}
# Maps param_name -> float (target sparsity for that layer)
self.layer_sparsity: Dict[str, float] = {}
# Maps param_name -> float (current threshold t)
self.thresholds: Dict[str, float] = {}
# Whether target sparsities have been computed
self.sparsity_computed = False
# Current effective global sparsity (gradual schedule)
self.current_effective_sparsity = 0.0
# Store original forward methods to restore later
self._orig_forwards: Dict[str, Callable] = {}
self._find_prunable_params()
def _find_prunable_params(self):
"""Identify Conv and Linear weight parameters to prune."""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
if hasattr(module, "weight") and module.weight is not None:
param_name = f"{name}.weight"
self.prunable_params[param_name] = module.weight
def _compute_layer_sparsities(self):
"""
Compute per-layer target sparsity by sorting all weights globally by magnitude.
This is the PDP-base strategy from the paper.
"""
all_weights = []
for name, param in self.prunable_params.items():
all_weights.append(param.data.abs().flatten())
if not all_weights:
return
all_weights_cat = torch.cat(all_weights)
n_total = all_weights_cat.numel()
k = int(math.floor(self.target_sparsity * n_total))
k = max(0, min(n_total - 1, k))
# Global threshold: the k-th smallest weight magnitude
sorted_vals, _ = torch.sort(all_weights_cat)
global_threshold = sorted_vals[k].item() if n_total > 0 else 0.0
# Per-layer sparsity = fraction below/equal to global threshold
for name, param in self.prunable_params.items():
w_abs = param.data.abs()
below = (w_abs <= global_threshold).float().sum().item()
ratio = below / w_abs.numel()
self.layer_sparsity[name] = min(ratio, 0.999) # cap at 99.9%
self.sparsity_computed = True
print(f"[PDP] Computed per-layer sparsities at epoch {self.s}. "
f"Global target: {self.target_sparsity:.4f}")
def _compute_thresholds(self):
"""Recompute per-layer thresholds t based on current weight distribution."""
for name, param in self.prunable_params.items():
ratio = self.layer_sparsity.get(name, 0.0)
if ratio <= 0:
self.thresholds[name] = 0.0
continue
w_abs = param.data.abs().flatten()
self.thresholds[name] = compute_threshold(w_abs, ratio)
def attach(self):
"""Monkey-patch forward methods of prunable modules to apply soft masks."""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
param_name = f"{name}.weight"
if param_name in self.prunable_params:
self._orig_forwards[param_name] = module.forward
module.forward = _make_masked_forward(module, self, param_name)
print(f"[PDP] Attached masked forwards to {len(self.prunable_params)} prunable layers.")
def detach(self):
"""Restore original forward methods."""
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Conv1d, nn.Conv3d, nn.Linear)):
param_name = f"{name}.weight"
if param_name in self._orig_forwards:
module.forward = self._orig_forwards[param_name]
self._orig_forwards.clear()
print("[PDP] Detached all masked forwards.")
def step(self, epoch: int):
"""
Call this after each optimizer.step() (or at each epoch boundary).
Recomputes thresholds and updates gradual sparsity schedule.
"""
# Warmup: first s epochs, no pruning
if epoch < self.s:
return
# At epoch s, compute per-layer target sparsities (one-time)
if epoch == self.s and not self.sparsity_computed:
self._compute_layer_sparsities()
# Gradual sparsity increase after warmup
if epoch >= self.s:
steps_since_s = epoch - self.s + 1
# Increase by epsilon (absolute percentage) per epoch
self.current_effective_sparsity = min(
self.target_sparsity,
self.epsilon * steps_since_s
)
# Scale per-layer sparsities proportionally
if self.target_sparsity > 0:
scale = self.current_effective_sparsity / self.target_sparsity
for name in self.layer_sparsity:
self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)
# Recompute thresholds based on current weight distribution
self._compute_thresholds()
def get_sparsity(self) -> float:
"""Return the current actual sparsity (fraction of weights below threshold)."""
total = 0
pruned = 0
for name, param in self.prunable_params.items():
t = self.thresholds.get(name, 0.0)
total += param.numel()
if t > 0:
pruned += (param.data.abs() <= t).sum().item()
return pruned / total if total > 0 else 0.0
def hard_prune(self):
"""
After training, apply hard pruning masks for inference.
Sets pruned weights to exactly zero.
"""
# Restore full target sparsities
if self.target_sparsity > 0:
scale = 1.0 / max(self.current_effective_sparsity / self.target_sparsity, 1e-6)
for name in self.layer_sparsity:
self.layer_sparsity[name] = min(1.0, self.layer_sparsity[name] * scale)
self._compute_thresholds()
for name, param in self.prunable_params.items():
t = self.thresholds.get(name, 0.0)
if t > 0:
mask = (param.data.abs() > t).float()
param.data.mul_(mask)
final_sparsity = self.get_sparsity()
print(f"[PDP] Hard pruning applied. Final sparsity: {final_sparsity:.4f}")
return final_sparsity
def state_dict(self) -> dict:
"""Serialize pruner state."""
return {
"target_sparsity": self.target_sparsity,
"s": self.s,
"epsilon": self.epsilon,
"tau": self.tau,
"sparsity_computed": self.sparsity_computed,
"layer_sparsity": self.layer_sparsity,
"thresholds": self.thresholds,
"current_effective_sparsity": self.current_effective_sparsity,
}
def load_state_dict(self, state: dict):
"""Restore pruner state."""
self.target_sparsity = state["target_sparsity"]
self.s = state["s"]
self.epsilon = state["epsilon"]
self.tau = state["tau"]
self.sparsity_computed = state["sparsity_computed"]
self.layer_sparsity = state["layer_sparsity"]
self.thresholds = state["thresholds"]
self.current_effective_sparsity = state["current_effective_sparsity"]