Add metapruning/README.md
Browse files- metapruning/README.md +135 -0
metapruning/README.md
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MetaPruning: Meta Pruning via Graph Metanetworks
|
| 2 |
+
|
| 3 |
+
Implementation of **"Meta Pruning via Graph Metanetworks: A Universal Meta-Learning Framework for Network Pruning"**
|
| 4 |
+
|
| 5 |
+
**Paper:** https://arxiv.org/abs/2506.12041
|
| 6 |
+
|
| 7 |
+
## Core Idea
|
| 8 |
+
|
| 9 |
+
Unlike prior "learning to prune" methods that train per-model, MetaPruning **trains once, prunes forever**. A **GNN metanetwork** takes an entire neural network as input (converted to a graph where neurons = nodes, weights = edges), and outputs a transformed version that is easier to prune. After one feedforward pass + standard finetuning, you get SOTA pruning β no per-model special training.
|
| 10 |
+
|
| 11 |
+
## Architecture
|
| 12 |
+
|
| 13 |
+
### Network β Graph Bijection
|
| 14 |
+
|
| 15 |
+
```
|
| 16 |
+
Network Graph
|
| 17 |
+
Node = each output channel/neuron in Conv/Linear layers
|
| 18 |
+
Edge = connection between channels (conv, linear, residual skip)
|
| 19 |
+
Node feature = [weight_mean, weight_std, BN_weight, BN_bias, BN_run_mean, BN_run_var]
|
| 20 |
+
Edge feature = flattened conv kernel (padded to uniform size)
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
### GNN Metanetwork (Appendix B.1)
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
Node/Edge Encoder (MLP) β hidden_dim
|
| 27 |
+
β
|
| 28 |
+
N Γ PNAConv Message Passing Layers
|
| 29 |
+
β’ Message: m_ij = MLP^1(v_i) β MLP^2(v_j) β e_ij
|
| 30 |
+
β’ Message': m'_ji = MLP^1(v_j) β MLP^2(v_i) β (e_ij β EdgeInvertor)
|
| 31 |
+
β’ Aggregation: PNA([MEAN, STD, MAX, MIN]) via MLP_Aggr
|
| 32 |
+
β’ Update: v_i += aggr(m_ij) + aggr(m'_ji)
|
| 33 |
+
β’ Edge: e_ij += MLP^1(v_i) β MLP^2(v_j) β e_ij + ...
|
| 34 |
+
β
|
| 35 |
+
Node/Edge Decoder (MLP) β original dims
|
| 36 |
+
β
|
| 37 |
+
Residual: v_out = Ξ±Β·v_pred + v_in, e_out = Ξ²Β·e_pred + e_in
|
| 38 |
+
(Ξ± = Ξ² = 0.01, learns deltas only)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### Meta-Training Loop
|
| 42 |
+
|
| 43 |
+
1. Select a **data model** (pre-trained network)
|
| 44 |
+
2. Convert network β graph
|
| 45 |
+
3. Feed graph through GNN metanetwork β transformed graph
|
| 46 |
+
4. Convert transformed graph β new network
|
| 47 |
+
5. Compute **accuracy loss** (subset of training data) + **sparsity loss** (L1 weight penalty)
|
| 48 |
+
6. Backprop through metanetwork (data model params frozen)
|
| 49 |
+
|
| 50 |
+
### Inference Pipeline
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
Target Model β Graph β [Metanetwork] β Transformed Graph β New Model
|
| 54 |
+
β
|
| 55 |
+
Finetune (standard SGD, 100-200 epochs)
|
| 56 |
+
β
|
| 57 |
+
Prune (DepGraph / magnitude criterion)
|
| 58 |
+
β
|
| 59 |
+
Pruned Model (ready for inference)
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Files
|
| 63 |
+
|
| 64 |
+
| File | Description |
|
| 65 |
+
|------|-------------|
|
| 66 |
+
| `graph.py` | Network β Graph bijection (`resnet_to_graph`, `graph_to_resnet`, `create_transformed_model`) |
|
| 67 |
+
| `gnn.py` | PNAConv GNN (`Metanetwork`, `PNAConvLayer`, `EdgeInvertor`) |
|
| 68 |
+
| `train_metanetwork.py` | Meta-training loop on CIFAR-10 |
|
| 69 |
+
| `inference.py` | Inference: metanetwork β finetune β prune β evaluate |
|
| 70 |
+
|
| 71 |
+
## Usage
|
| 72 |
+
|
| 73 |
+
### 1. Meta-Train the Metanetwork
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
python -m metapruning.train_metanetwork \
|
| 77 |
+
--meta_epochs 100 \
|
| 78 |
+
--hidden_dim 32 \
|
| 79 |
+
--num_layers 3 \
|
| 80 |
+
--alpha 0.01 \
|
| 81 |
+
--beta 0.01 \
|
| 82 |
+
--lr 1e-3 \
|
| 83 |
+
--weight_decay 5e-4 \
|
| 84 |
+
--pruner_reg 10.0 \
|
| 85 |
+
--num_data_models 1 \
|
| 86 |
+
--pretrain_data_models \
|
| 87 |
+
--pretrain_epochs 100
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
This creates `checkpoints_metapruning/metanetwork.pt`.
|
| 91 |
+
|
| 92 |
+
### 2. Prune Any Target Model
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
python -m metapruning.inference \
|
| 96 |
+
--metanetwork_path checkpoints_metapruning/metanetwork.pt \
|
| 97 |
+
--target_model resnet56 \
|
| 98 |
+
--finetune_epochs 100 \
|
| 99 |
+
--prune_sparsity 0.5 \
|
| 100 |
+
--lr 0.01
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
## Paper Results
|
| 104 |
+
|
| 105 |
+
| Task | Base Acc | Pruned Acc (Ξ) | Pruned FLOPs |
|
| 106 |
+
|------|----------|----------------|--------------|
|
| 107 |
+
| ResNet56 / CIFAR-10 | 93.51% | **93.64%** (+0.13%) | 65.6% |
|
| 108 |
+
| VGG19 / CIFAR-100 | 73.65% | **69.75%** (β3.90%) | 88.83% |
|
| 109 |
+
| ResNet50 / ImageNet | 76.14% | **76.13%** (β0.01%) | 57.2% |
|
| 110 |
+
|
| 111 |
+
## Key Properties
|
| 112 |
+
|
| 113 |
+
- **Transferable**: Metanetwork trained on ResNet56/CIFAR-10 β prunes ResNet110, VGG, ViT on different datasets
|
| 114 |
+
- **One-shot pruning**: Single metanetwork feedforward + finetuning, no iterative pruning
|
| 115 |
+
- **Universal**: Applies to any CNN or ViT via graph bijection
|
| 116 |
+
- **Low cost (amortized)**: Expensive meta-training once, cheap pruning forever
|
| 117 |
+
|
| 118 |
+
## Hyperparameters
|
| 119 |
+
|
| 120 |
+
| Param | Default | Description |
|
| 121 |
+
|-------|---------|-------------|
|
| 122 |
+
| `hidden_dim` | 32 | GNN hidden dimension |
|
| 123 |
+
| `num_layers` | 3 | Message passing layers |
|
| 124 |
+
| `alpha` | 0.01 | Node feature residual coefficient |
|
| 125 |
+
| `beta` | 0.01 | Edge feature residual coefficient |
|
| 126 |
+
| `meta_epochs` | 100 | Meta-training epochs |
|
| 127 |
+
| `lr` | 1e-3 | Metanetwork learning rate |
|
| 128 |
+
| `weight_decay` | 5e-4 | AdamW weight decay |
|
| 129 |
+
| `pruner_reg` | 10.0 | Sparsity loss weight |
|
| 130 |
+
|
| 131 |
+
## Notes
|
| 132 |
+
|
| 133 |
+
- **GraphβModel differentiability**: The current `graph_to_resnet` uses in-place data modification (`module.weight.data += delta`). For fully differentiable meta-training, this should construct new `nn.Parameter` objects from GNN outputs instead. The current implementation demonstrates the architecture; production use should use `torch.autograd.Function` or `torch.nn.Parameter` construction for end-to-end differentiability.
|
| 134 |
+
- **For proper pruning**: Use `torch_pruning` (DepGraph) for structural pruning with dependency groups. The inference script includes a simple magnitude-based channel pruner as a placeholder.
|
| 135 |
+
- **Full paper**: Appendix B.1 contains the complete GNN architecture equations, and Appendix D contains per-task hyperparameters.
|