File size: 3,774 Bytes
af4622f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# PDP: Parameter-free Differentiable Pruning

Implementation of **"PDP: Parameter-free Differentiable Pruning is All You Need"** (NeurIPS 2023).

**Paper:** https://arxiv.org/abs/2305.11203

## Core Idea

PDP generates soft pruning masks **without any extra trainable parameters**. The mask is derived directly from weight magnitudes using a dynamic threshold `t` and temperature `τ`:

```
m(w) = exp(w²/τ) / (exp(w²/τ) + exp(t²/τ))
```

The gradient includes an **additional boosting term** for weights near the pruning boundary, accelerating them toward a clear keep/prune decision.

## Key Properties

| Feature | PDP | Other Differentiable Pruning |
|---------|-----|-------------------------------|
| Extra parameters | **None** | Yes (mask params, thresholds, etc.) |
| Differentiable | ✅ Yes | ✅ Yes (most) |
| Training complexity | **Low** | High |
| Inference speedup | ✅ Yes | Varies |

## Files

| File | Description |
|------|-------------|
| `pdp.py` | Core PDP module: `PDPPruner` class + mask/threshold functions |
| `train_pdp.py` | Full training script for CIFAR-10 + ResNet18 |
| `test_pdp.py` | Unit tests verifying boundary conditions, monotonicity, gradient flow |

## Quick Start

### Train on CIFAR-10 with 85% sparsity
```bash
python train_pdp.py \
    --target_sparsity 0.85 \
    --s 16 \
    --epsilon 0.015 \
    --tau 1e-4 \
    --epochs 100 \
    --lr 0.1 \
    --batch_size 128
```

### Use PDP in your own code
```python
from pdp import PDPPruner
import torch.nn as nn

model = MyModel()
pruner = PDPPruner(
    model=model,
    target_sparsity=0.85,   # 85% sparsity
    s=16,                   # Warmup epochs before pruning
    epsilon=0.015,          # Gradual pruning rate per epoch
    tau=1e-4,               # Temperature (default works well)
)
pruner.attach()

for epoch in range(epochs):
    for batch in dataloader:
        loss = model(...)
        loss.backward()
        optimizer.step()
        pruner.step(epoch)   # Recompute thresholds after each optimizer step

# After training, hard-prune for inference
pruner.hard_prune()
```

## Algorithm

From Appendix D of the paper:

1. **Warmup** (epochs 0 to `s-1`): Train normally, no pruning.
2. **At epoch `s`**: Compute per-layer target sparsity by globally sorting all weights by magnitude.
3. **After epoch `s`**: Gradually increase target sparsity by `ε` per epoch.
4. **Forward pass**: Apply soft mask `m(w)` using current threshold `t`.
5. **After optimizer step**: Recompute `t` from current weight distribution.
6. **After training**: Binarize masks for inference (hard prune).

## Hyperparameters

| Param | Default | Description |
|-------|---------|-------------|
| `target_sparsity` | 0.85 | Global target fraction of weights to prune |
| `s` | 16 | Epochs before pruning starts (warmup) |
| `ε` | 0.015 | Gradual pruning rate (1.5% of target per epoch) |
| `τ` | 1e-4 | Temperature controlling mask softness |

Paper-reported results with these settings:
- **ResNet18 / ImageNet**: 69.0% top-1 at 85.5% sparsity
- **ResNet50 / ImageNet**: 75.3% top-1 at 89.8% sparsity
- **MobileNet-v1 / ImageNet**: 68.2% top-1 at 86.6% sparsity

## Implementation Details

- **Monkey-patching**: `PDPPruner.attach()` replaces the forward methods of Conv/Linear layers to apply the soft mask. This preserves the full autograd graph, making pruning differentiable.
- **No extra parameters**: Unlike STR/CS/OptG, PDP adds zero learnable parameters.
- **Memory efficient**: No mask gradients to store (since masks are not parameters).

## Tests

```bash
python test_pdp.py
```

Tests verify:
- `m(t) = 0.5` (equal chance at threshold)
- `m(w)` monotonicity
- Gradient flow through the soft mask
- Threshold computation accuracy
- End-to-end attach/prune/detach cycle