PDP: Parameter-free Differentiable Pruning
Implementation of "PDP: Parameter-free Differentiable Pruning is All You Need" (NeurIPS 2023).
Paper: https://arxiv.org/abs/2305.11203
Core Idea
PDP generates soft pruning masks without any extra trainable parameters. The mask is derived directly from weight magnitudes using a dynamic threshold t and temperature τ:
m(w) = exp(w²/τ) / (exp(w²/τ) + exp(t²/τ))
The gradient includes an additional boosting term for weights near the pruning boundary, accelerating them toward a clear keep/prune decision.
Key Properties
| Feature | PDP | Other Differentiable Pruning |
|---|---|---|
| Extra parameters | None | Yes (mask params, thresholds, etc.) |
| Differentiable | ✅ Yes | ✅ Yes (most) |
| Training complexity | Low | High |
| Inference speedup | ✅ Yes | Varies |
Files
| File | Description |
|---|---|
pdp.py |
Core PDP module: PDPPruner class + mask/threshold functions |
train_pdp.py |
Full training script for CIFAR-10 + ResNet18 |
test_pdp.py |
Unit tests verifying boundary conditions, monotonicity, gradient flow |
Quick Start
Train on CIFAR-10 with 85% sparsity
python train_pdp.py \
--target_sparsity 0.85 \
--s 16 \
--epsilon 0.015 \
--tau 1e-4 \
--epochs 100 \
--lr 0.1 \
--batch_size 128
Use PDP in your own code
from pdp import PDPPruner
import torch.nn as nn
model = MyModel()
pruner = PDPPruner(
model=model,
target_sparsity=0.85, # 85% sparsity
s=16, # Warmup epochs before pruning
epsilon=0.015, # Gradual pruning rate per epoch
tau=1e-4, # Temperature (default works well)
)
pruner.attach()
for epoch in range(epochs):
for batch in dataloader:
loss = model(...)
loss.backward()
optimizer.step()
pruner.step(epoch) # Recompute thresholds after each optimizer step
# After training, hard-prune for inference
pruner.hard_prune()
Algorithm
From Appendix D of the paper:
- Warmup (epochs 0 to
s-1): Train normally, no pruning. - At epoch
s: Compute per-layer target sparsity by globally sorting all weights by magnitude. - After epoch
s: Gradually increase target sparsity byεper epoch. - Forward pass: Apply soft mask
m(w)using current thresholdt. - After optimizer step: Recompute
tfrom current weight distribution. - After training: Binarize masks for inference (hard prune).
Hyperparameters
| Param | Default | Description |
|---|---|---|
target_sparsity |
0.85 | Global target fraction of weights to prune |
s |
16 | Epochs before pruning starts (warmup) |
ε |
0.015 | Gradual pruning rate (1.5% of target per epoch) |
τ |
1e-4 | Temperature controlling mask softness |
Paper-reported results with these settings:
- ResNet18 / ImageNet: 69.0% top-1 at 85.5% sparsity
- ResNet50 / ImageNet: 75.3% top-1 at 89.8% sparsity
- MobileNet-v1 / ImageNet: 68.2% top-1 at 86.6% sparsity
Implementation Details
- Monkey-patching:
PDPPruner.attach()replaces the forward methods of Conv/Linear layers to apply the soft mask. This preserves the full autograd graph, making pruning differentiable. - No extra parameters: Unlike STR/CS/OptG, PDP adds zero learnable parameters.
- Memory efficient: No mask gradients to store (since masks are not parameters).
Tests
python test_pdp.py
Tests verify:
m(t) = 0.5(equal chance at threshold)m(w)monotonicity- Gradient flow through the soft mask
- Threshold computation accuracy
- End-to-end attach/prune/detach cycle