Add README_PDP.md
Browse files- README_PDP.md +116 -0
README_PDP.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PDP: Parameter-free Differentiable Pruning
|
| 2 |
+
|
| 3 |
+
Implementation of **"PDP: Parameter-free Differentiable Pruning is All You Need"** (NeurIPS 2023).
|
| 4 |
+
|
| 5 |
+
**Paper:** https://arxiv.org/abs/2305.11203
|
| 6 |
+
|
| 7 |
+
## Core Idea
|
| 8 |
+
|
| 9 |
+
PDP generates soft pruning masks **without any extra trainable parameters**. The mask is derived directly from weight magnitudes using a dynamic threshold `t` and temperature `τ`:
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
m(w) = exp(w²/τ) / (exp(w²/τ) + exp(t²/τ))
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
The gradient includes an **additional boosting term** for weights near the pruning boundary, accelerating them toward a clear keep/prune decision.
|
| 16 |
+
|
| 17 |
+
## Key Properties
|
| 18 |
+
|
| 19 |
+
| Feature | PDP | Other Differentiable Pruning |
|
| 20 |
+
|---------|-----|-------------------------------|
|
| 21 |
+
| Extra parameters | **None** | Yes (mask params, thresholds, etc.) |
|
| 22 |
+
| Differentiable | ✅ Yes | ✅ Yes (most) |
|
| 23 |
+
| Training complexity | **Low** | High |
|
| 24 |
+
| Inference speedup | ✅ Yes | Varies |
|
| 25 |
+
|
| 26 |
+
## Files
|
| 27 |
+
|
| 28 |
+
| File | Description |
|
| 29 |
+
|------|-------------|
|
| 30 |
+
| `pdp.py` | Core PDP module: `PDPPruner` class + mask/threshold functions |
|
| 31 |
+
| `train_pdp.py` | Full training script for CIFAR-10 + ResNet18 |
|
| 32 |
+
| `test_pdp.py` | Unit tests verifying boundary conditions, monotonicity, gradient flow |
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
### Train on CIFAR-10 with 85% sparsity
|
| 37 |
+
```bash
|
| 38 |
+
python train_pdp.py \
|
| 39 |
+
--target_sparsity 0.85 \
|
| 40 |
+
--s 16 \
|
| 41 |
+
--epsilon 0.015 \
|
| 42 |
+
--tau 1e-4 \
|
| 43 |
+
--epochs 100 \
|
| 44 |
+
--lr 0.1 \
|
| 45 |
+
--batch_size 128
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
### Use PDP in your own code
|
| 49 |
+
```python
|
| 50 |
+
from pdp import PDPPruner
|
| 51 |
+
import torch.nn as nn
|
| 52 |
+
|
| 53 |
+
model = MyModel()
|
| 54 |
+
pruner = PDPPruner(
|
| 55 |
+
model=model,
|
| 56 |
+
target_sparsity=0.85, # 85% sparsity
|
| 57 |
+
s=16, # Warmup epochs before pruning
|
| 58 |
+
epsilon=0.015, # Gradual pruning rate per epoch
|
| 59 |
+
tau=1e-4, # Temperature (default works well)
|
| 60 |
+
)
|
| 61 |
+
pruner.attach()
|
| 62 |
+
|
| 63 |
+
for epoch in range(epochs):
|
| 64 |
+
for batch in dataloader:
|
| 65 |
+
loss = model(...)
|
| 66 |
+
loss.backward()
|
| 67 |
+
optimizer.step()
|
| 68 |
+
pruner.step(epoch) # Recompute thresholds after each optimizer step
|
| 69 |
+
|
| 70 |
+
# After training, hard-prune for inference
|
| 71 |
+
pruner.hard_prune()
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
## Algorithm
|
| 75 |
+
|
| 76 |
+
From Appendix D of the paper:
|
| 77 |
+
|
| 78 |
+
1. **Warmup** (epochs 0 to `s-1`): Train normally, no pruning.
|
| 79 |
+
2. **At epoch `s`**: Compute per-layer target sparsity by globally sorting all weights by magnitude.
|
| 80 |
+
3. **After epoch `s`**: Gradually increase target sparsity by `ε` per epoch.
|
| 81 |
+
4. **Forward pass**: Apply soft mask `m(w)` using current threshold `t`.
|
| 82 |
+
5. **After optimizer step**: Recompute `t` from current weight distribution.
|
| 83 |
+
6. **After training**: Binarize masks for inference (hard prune).
|
| 84 |
+
|
| 85 |
+
## Hyperparameters
|
| 86 |
+
|
| 87 |
+
| Param | Default | Description |
|
| 88 |
+
|-------|---------|-------------|
|
| 89 |
+
| `target_sparsity` | 0.85 | Global target fraction of weights to prune |
|
| 90 |
+
| `s` | 16 | Epochs before pruning starts (warmup) |
|
| 91 |
+
| `ε` | 0.015 | Gradual pruning rate (1.5% of target per epoch) |
|
| 92 |
+
| `τ` | 1e-4 | Temperature controlling mask softness |
|
| 93 |
+
|
| 94 |
+
Paper-reported results with these settings:
|
| 95 |
+
- **ResNet18 / ImageNet**: 69.0% top-1 at 85.5% sparsity
|
| 96 |
+
- **ResNet50 / ImageNet**: 75.3% top-1 at 89.8% sparsity
|
| 97 |
+
- **MobileNet-v1 / ImageNet**: 68.2% top-1 at 86.6% sparsity
|
| 98 |
+
|
| 99 |
+
## Implementation Details
|
| 100 |
+
|
| 101 |
+
- **Monkey-patching**: `PDPPruner.attach()` replaces the forward methods of Conv/Linear layers to apply the soft mask. This preserves the full autograd graph, making pruning differentiable.
|
| 102 |
+
- **No extra parameters**: Unlike STR/CS/OptG, PDP adds zero learnable parameters.
|
| 103 |
+
- **Memory efficient**: No mask gradients to store (since masks are not parameters).
|
| 104 |
+
|
| 105 |
+
## Tests
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
python test_pdp.py
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Tests verify:
|
| 112 |
+
- `m(t) = 0.5` (equal chance at threshold)
|
| 113 |
+
- `m(w)` monotonicity
|
| 114 |
+
- Gradient flow through the soft mask
|
| 115 |
+
- Threshold computation accuracy
|
| 116 |
+
- End-to-end attach/prune/detach cycle
|