ESPR3SS0 commited on
Commit
ca31dfe
Β·
verified Β·
1 Parent(s): 4d76378

Add metapruning/README.md

Browse files
Files changed (1) hide show
  1. metapruning/README.md +135 -0
metapruning/README.md ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MetaPruning: Meta Pruning via Graph Metanetworks
2
+
3
+ Implementation of **"Meta Pruning via Graph Metanetworks: A Universal Meta-Learning Framework for Network Pruning"**
4
+
5
+ **Paper:** https://arxiv.org/abs/2506.12041
6
+
7
+ ## Core Idea
8
+
9
+ Unlike prior "learning to prune" methods that train per-model, MetaPruning **trains once, prunes forever**. A **GNN metanetwork** takes an entire neural network as input (converted to a graph where neurons = nodes, weights = edges), and outputs a transformed version that is easier to prune. After one feedforward pass + standard finetuning, you get SOTA pruning β€” no per-model special training.
10
+
11
+ ## Architecture
12
+
13
+ ### Network ↔ Graph Bijection
14
+
15
+ ```
16
+ Network Graph
17
+ Node = each output channel/neuron in Conv/Linear layers
18
+ Edge = connection between channels (conv, linear, residual skip)
19
+ Node feature = [weight_mean, weight_std, BN_weight, BN_bias, BN_run_mean, BN_run_var]
20
+ Edge feature = flattened conv kernel (padded to uniform size)
21
+ ```
22
+
23
+ ### GNN Metanetwork (Appendix B.1)
24
+
25
+ ```
26
+ Node/Edge Encoder (MLP) β†’ hidden_dim
27
+ ↓
28
+ N Γ— PNAConv Message Passing Layers
29
+ β€’ Message: m_ij = MLP^1(v_i) βŠ™ MLP^2(v_j) βŠ™ e_ij
30
+ β€’ Message': m'_ji = MLP^1(v_j) βŠ™ MLP^2(v_i) βŠ™ (e_ij βŠ™ EdgeInvertor)
31
+ β€’ Aggregation: PNA([MEAN, STD, MAX, MIN]) via MLP_Aggr
32
+ β€’ Update: v_i += aggr(m_ij) + aggr(m'_ji)
33
+ β€’ Edge: e_ij += MLP^1(v_i) βŠ™ MLP^2(v_j) βŠ™ e_ij + ...
34
+ ↓
35
+ Node/Edge Decoder (MLP) β†’ original dims
36
+ ↓
37
+ Residual: v_out = Ξ±Β·v_pred + v_in, e_out = Ξ²Β·e_pred + e_in
38
+ (Ξ± = Ξ² = 0.01, learns deltas only)
39
+ ```
40
+
41
+ ### Meta-Training Loop
42
+
43
+ 1. Select a **data model** (pre-trained network)
44
+ 2. Convert network β†’ graph
45
+ 3. Feed graph through GNN metanetwork β†’ transformed graph
46
+ 4. Convert transformed graph β†’ new network
47
+ 5. Compute **accuracy loss** (subset of training data) + **sparsity loss** (L1 weight penalty)
48
+ 6. Backprop through metanetwork (data model params frozen)
49
+
50
+ ### Inference Pipeline
51
+
52
+ ```
53
+ Target Model β†’ Graph β†’ [Metanetwork] β†’ Transformed Graph β†’ New Model
54
+ ↓
55
+ Finetune (standard SGD, 100-200 epochs)
56
+ ↓
57
+ Prune (DepGraph / magnitude criterion)
58
+ ↓
59
+ Pruned Model (ready for inference)
60
+ ```
61
+
62
+ ## Files
63
+
64
+ | File | Description |
65
+ |------|-------------|
66
+ | `graph.py` | Network ↔ Graph bijection (`resnet_to_graph`, `graph_to_resnet`, `create_transformed_model`) |
67
+ | `gnn.py` | PNAConv GNN (`Metanetwork`, `PNAConvLayer`, `EdgeInvertor`) |
68
+ | `train_metanetwork.py` | Meta-training loop on CIFAR-10 |
69
+ | `inference.py` | Inference: metanetwork β†’ finetune β†’ prune β†’ evaluate |
70
+
71
+ ## Usage
72
+
73
+ ### 1. Meta-Train the Metanetwork
74
+
75
+ ```bash
76
+ python -m metapruning.train_metanetwork \
77
+ --meta_epochs 100 \
78
+ --hidden_dim 32 \
79
+ --num_layers 3 \
80
+ --alpha 0.01 \
81
+ --beta 0.01 \
82
+ --lr 1e-3 \
83
+ --weight_decay 5e-4 \
84
+ --pruner_reg 10.0 \
85
+ --num_data_models 1 \
86
+ --pretrain_data_models \
87
+ --pretrain_epochs 100
88
+ ```
89
+
90
+ This creates `checkpoints_metapruning/metanetwork.pt`.
91
+
92
+ ### 2. Prune Any Target Model
93
+
94
+ ```bash
95
+ python -m metapruning.inference \
96
+ --metanetwork_path checkpoints_metapruning/metanetwork.pt \
97
+ --target_model resnet56 \
98
+ --finetune_epochs 100 \
99
+ --prune_sparsity 0.5 \
100
+ --lr 0.01
101
+ ```
102
+
103
+ ## Paper Results
104
+
105
+ | Task | Base Acc | Pruned Acc (Ξ”) | Pruned FLOPs |
106
+ |------|----------|----------------|--------------|
107
+ | ResNet56 / CIFAR-10 | 93.51% | **93.64%** (+0.13%) | 65.6% |
108
+ | VGG19 / CIFAR-100 | 73.65% | **69.75%** (βˆ’3.90%) | 88.83% |
109
+ | ResNet50 / ImageNet | 76.14% | **76.13%** (βˆ’0.01%) | 57.2% |
110
+
111
+ ## Key Properties
112
+
113
+ - **Transferable**: Metanetwork trained on ResNet56/CIFAR-10 β†’ prunes ResNet110, VGG, ViT on different datasets
114
+ - **One-shot pruning**: Single metanetwork feedforward + finetuning, no iterative pruning
115
+ - **Universal**: Applies to any CNN or ViT via graph bijection
116
+ - **Low cost (amortized)**: Expensive meta-training once, cheap pruning forever
117
+
118
+ ## Hyperparameters
119
+
120
+ | Param | Default | Description |
121
+ |-------|---------|-------------|
122
+ | `hidden_dim` | 32 | GNN hidden dimension |
123
+ | `num_layers` | 3 | Message passing layers |
124
+ | `alpha` | 0.01 | Node feature residual coefficient |
125
+ | `beta` | 0.01 | Edge feature residual coefficient |
126
+ | `meta_epochs` | 100 | Meta-training epochs |
127
+ | `lr` | 1e-3 | Metanetwork learning rate |
128
+ | `weight_decay` | 5e-4 | AdamW weight decay |
129
+ | `pruner_reg` | 10.0 | Sparsity loss weight |
130
+
131
+ ## Notes
132
+
133
+ - **Graph→Model differentiability**: The current `graph_to_resnet` uses in-place data modification (`module.weight.data += delta`). For fully differentiable meta-training, this should construct new `nn.Parameter` objects from GNN outputs instead. The current implementation demonstrates the architecture; production use should use `torch.autograd.Function` or `torch.nn.Parameter` construction for end-to-end differentiability.
134
+ - **For proper pruning**: Use `torch_pruning` (DepGraph) for structural pruning with dependency groups. The inference script includes a simple magnitude-based channel pruner as a placeholder.
135
+ - **Full paper**: Appendix B.1 contains the complete GNN architecture equations, and Appendix D contains per-task hyperparameters.