neural-pruning-impl / README_PDP.md
ESPR3SS0's picture
Add README_PDP.md
af4622f verified

PDP: Parameter-free Differentiable Pruning

Implementation of "PDP: Parameter-free Differentiable Pruning is All You Need" (NeurIPS 2023).

Paper: https://arxiv.org/abs/2305.11203

Core Idea

PDP generates soft pruning masks without any extra trainable parameters. The mask is derived directly from weight magnitudes using a dynamic threshold t and temperature τ:

m(w) = exp(w²/τ) / (exp(w²/τ) + exp(t²/τ))

The gradient includes an additional boosting term for weights near the pruning boundary, accelerating them toward a clear keep/prune decision.

Key Properties

Feature PDP Other Differentiable Pruning
Extra parameters None Yes (mask params, thresholds, etc.)
Differentiable ✅ Yes ✅ Yes (most)
Training complexity Low High
Inference speedup ✅ Yes Varies

Files

File Description
pdp.py Core PDP module: PDPPruner class + mask/threshold functions
train_pdp.py Full training script for CIFAR-10 + ResNet18
test_pdp.py Unit tests verifying boundary conditions, monotonicity, gradient flow

Quick Start

Train on CIFAR-10 with 85% sparsity

python train_pdp.py \
    --target_sparsity 0.85 \
    --s 16 \
    --epsilon 0.015 \
    --tau 1e-4 \
    --epochs 100 \
    --lr 0.1 \
    --batch_size 128

Use PDP in your own code

from pdp import PDPPruner
import torch.nn as nn

model = MyModel()
pruner = PDPPruner(
    model=model,
    target_sparsity=0.85,   # 85% sparsity
    s=16,                   # Warmup epochs before pruning
    epsilon=0.015,          # Gradual pruning rate per epoch
    tau=1e-4,               # Temperature (default works well)
)
pruner.attach()

for epoch in range(epochs):
    for batch in dataloader:
        loss = model(...)
        loss.backward()
        optimizer.step()
        pruner.step(epoch)   # Recompute thresholds after each optimizer step

# After training, hard-prune for inference
pruner.hard_prune()

Algorithm

From Appendix D of the paper:

  1. Warmup (epochs 0 to s-1): Train normally, no pruning.
  2. At epoch s: Compute per-layer target sparsity by globally sorting all weights by magnitude.
  3. After epoch s: Gradually increase target sparsity by ε per epoch.
  4. Forward pass: Apply soft mask m(w) using current threshold t.
  5. After optimizer step: Recompute t from current weight distribution.
  6. After training: Binarize masks for inference (hard prune).

Hyperparameters

Param Default Description
target_sparsity 0.85 Global target fraction of weights to prune
s 16 Epochs before pruning starts (warmup)
ε 0.015 Gradual pruning rate (1.5% of target per epoch)
τ 1e-4 Temperature controlling mask softness

Paper-reported results with these settings:

  • ResNet18 / ImageNet: 69.0% top-1 at 85.5% sparsity
  • ResNet50 / ImageNet: 75.3% top-1 at 89.8% sparsity
  • MobileNet-v1 / ImageNet: 68.2% top-1 at 86.6% sparsity

Implementation Details

  • Monkey-patching: PDPPruner.attach() replaces the forward methods of Conv/Linear layers to apply the soft mask. This preserves the full autograd graph, making pruning differentiable.
  • No extra parameters: Unlike STR/CS/OptG, PDP adds zero learnable parameters.
  • Memory efficient: No mask gradients to store (since masks are not parameters).

Tests

python test_pdp.py

Tests verify:

  • m(t) = 0.5 (equal chance at threshold)
  • m(w) monotonicity
  • Gradient flow through the soft mask
  • Threshold computation accuracy
  • End-to-end attach/prune/detach cycle