ESPR3SS0 commited on
Commit
639f82d
·
verified ·
1 Parent(s): 6a99d0f

Add metapruning/graph.py

Browse files
Files changed (1) hide show
  1. metapruning/graph.py +327 -0
metapruning/graph.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph ↔ Network Bijection for MetaPruning
3
+ Converts ResNet-style CNNs to/from graph representations.
4
+
5
+ Paper: "Meta Pruning via Graph Metanetworks" (arXiv:2506.12041)
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Dict, List, Tuple, Optional
11
+ from dataclasses import dataclass
12
+ import copy
13
+
14
+
15
+ @dataclass
16
+ class Graph:
17
+ """Graph representation of a neural network."""
18
+ node_features: torch.Tensor # [num_nodes, node_feat_dim]
19
+ edge_index: torch.Tensor # [2, num_edges] (COO format)
20
+ edge_features: torch.Tensor # [num_edges, edge_feat_dim]
21
+ node_to_layer: List[Tuple[str, int]] # maps node idx -> (layer_name, channel_idx)
22
+ edge_to_connection: List[Tuple[int, int, str]] # (src_node, dst_node, type)
23
+ layer_shapes: Dict[str, List[int]] # original layer shapes for reconstruction
24
+
25
+
26
+ def _get_bn_stats(module: nn.Module) -> Optional[torch.Tensor]:
27
+ """Extract BatchNorm statistics as node features."""
28
+ if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
29
+ # Features: [weight, bias, running_mean, running_var]
30
+ stats = torch.stack([
31
+ module.weight.data if module.weight is not None else torch.ones_like(module.running_mean),
32
+ module.bias.data if module.bias is not None else torch.zeros_like(module.running_mean),
33
+ module.running_mean,
34
+ module.running_var,
35
+ ], dim=1) # [channels, 4]
36
+ return stats
37
+ return None
38
+
39
+
40
+ def _channel_mean_std(conv_weight: torch.Tensor) -> torch.Tensor:
41
+ """Compute per-channel mean and std of conv weights."""
42
+ # conv_weight: [out_ch, in_ch, k, k]
43
+ out_ch = conv_weight.size(0)
44
+ flat = conv_weight.view(out_ch, -1) # [out_ch, in_ch*k*k]
45
+ mean = flat.mean(dim=1)
46
+ std = flat.std(dim=1)
47
+ return torch.stack([mean, std], dim=1) # [out_ch, 2]
48
+
49
+
50
+ def resnet_to_graph(model: nn.Module, max_kernel_size: int = 3) -> Graph:
51
+ """
52
+ Convert a ResNet-style model to a graph.
53
+
54
+ Nodes = output channels of Conv/Linear layers (neurons).
55
+ Edges = connections between channels (conv weights, linear weights, residuals).
56
+
57
+ Node features: [weight_mean, weight_std, bn_weight, bn_bias, bn_running_mean, bn_running_var]
58
+ Edge features: flattened conv kernel (padded to max_kernel_size^2 for uniform edge dim).
59
+
60
+ Args:
61
+ model: PyTorch model (e.g., ResNet18 for CIFAR-10)
62
+ max_kernel_size: Maximum kernel size for padding edge features
63
+
64
+ Returns:
65
+ Graph object representing the model.
66
+ """
67
+ node_features_list = []
68
+ node_to_layer = []
69
+ edge_index_list = []
70
+ edge_features_list = []
71
+ edge_to_connection = []
72
+ layer_shapes = {}
73
+
74
+ # First pass: identify all layers and their channels
75
+ layers_info = []
76
+ for name, module in model.named_modules():
77
+ if isinstance(module, nn.Conv2d):
78
+ out_ch = module.out_channels
79
+ layers_info.append({
80
+ 'name': name,
81
+ 'type': 'conv',
82
+ 'out_ch': out_ch,
83
+ 'in_ch': module.in_channels,
84
+ 'kernel_size': module.kernel_size[0] if isinstance(module.kernel_size, tuple) else module.kernel_size,
85
+ 'stride': module.stride[0] if isinstance(module.stride, tuple) else module.stride,
86
+ 'module': module,
87
+ })
88
+ layer_shapes[name] = list(module.weight.shape)
89
+ elif isinstance(module, nn.Linear):
90
+ out_ch = module.out_features
91
+ layers_info.append({
92
+ 'name': name,
93
+ 'type': 'linear',
94
+ 'out_ch': out_ch,
95
+ 'in_ch': module.in_features,
96
+ 'module': module,
97
+ })
98
+ layer_shapes[name] = list(module.weight.shape)
99
+
100
+ if not layers_info:
101
+ raise ValueError("No Conv2d or Linear layers found in model")
102
+
103
+ # Build node features per layer
104
+ # For each conv/linear layer, each output channel is a node
105
+ layer_name_to_node_start = {}
106
+ current_node_idx = 0
107
+
108
+ for info in layers_info:
109
+ name = info['name']
110
+ out_ch = info['out_ch']
111
+ layer_name_to_node_start[name] = current_node_idx
112
+
113
+ # Find associated BN (next sibling module in parent)
114
+ bn_stats = None
115
+ parent_name = '.'.join(name.split('.')[:-1]) if '.' in name else ''
116
+ child_name = name.split('.')[-1]
117
+
118
+ # Heuristic: look for BN with same num_features immediately after conv
119
+ for bn_name, bn_module in model.named_modules():
120
+ if isinstance(bn_module, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
121
+ if bn_module.num_features == out_ch:
122
+ # Check if it's "near" this conv in the hierarchy
123
+ bn_stats = _get_bn_stats(bn_module)
124
+ break
125
+
126
+ # Node features for each channel
127
+ module = info['module']
128
+ if info['type'] == 'conv':
129
+ w_stats = _channel_mean_std(module.weight.data)
130
+ # w_stats: [out_ch, 2]
131
+ if bn_stats is not None and bn_stats.shape[0] == out_ch:
132
+ # [out_ch, 2] + [out_ch, 4] = [out_ch, 6]
133
+ nf = torch.cat([w_stats, bn_stats], dim=1)
134
+ else:
135
+ # Pad with zeros for missing BN
136
+ nf = torch.cat([w_stats, torch.zeros(out_ch, 4, device=w_stats.device, dtype=w_stats.dtype)], dim=1)
137
+ else:
138
+ # Linear layer
139
+ w = module.weight.data # [out_ch, in_ch]
140
+ mean = w.mean(dim=1)
141
+ std = w.std(dim=1)
142
+ w_stats = torch.stack([mean, std], dim=1) # [out_ch, 2]
143
+ if bn_stats is not None and bn_stats.shape[0] == out_ch:
144
+ nf = torch.cat([w_stats, bn_stats], dim=1)
145
+ else:
146
+ nf = torch.cat([w_stats, torch.zeros(out_ch, 4, device=w_stats.device, dtype=w_stats.dtype)], dim=1)
147
+
148
+ node_features_list.append(nf)
149
+ for ch in range(out_ch):
150
+ node_to_layer.append((name, ch))
151
+
152
+ current_node_idx += out_ch
153
+
154
+ node_features = torch.cat(node_features_list, dim=0) # [total_nodes, node_feat_dim]
155
+
156
+ # Build edges: consecutive layers + residual connections
157
+ max_kernel_flat = max_kernel_size ** 2
158
+
159
+ for i, src_info in enumerate(layers_info):
160
+ src_name = src_info['name']
161
+ src_start = layer_name_to_node_start[src_name]
162
+ src_out = src_info['out_ch']
163
+
164
+ # Look for next layer connection
165
+ if i + 1 < len(layers_info):
166
+ dst_info = layers_info[i + 1]
167
+ dst_name = dst_info['name']
168
+ dst_start = layer_name_to_node_start[dst_name]
169
+ dst_in = dst_info['in_ch']
170
+ dst_out = dst_info['out_ch']
171
+
172
+ # Feedforward edges: connect src output channels to dst output channels
173
+ # Only connect when dimensions align (src_out == dst_in for proper flow)
174
+ # For conv->conv, this is natural. For conv->linear, src_out channels
175
+ # feed into dst_in, but dst only has dst_out nodes. We connect up to min.
176
+ if src_out == dst_in:
177
+ # The destination layer has dst_out nodes; only connect to existing ones
178
+ num_connections = min(src_out, dst_out)
179
+ for ch in range(num_connections):
180
+ src_node = src_start + ch
181
+ dst_node = dst_start + ch
182
+ if dst_node >= current_node_idx:
183
+ continue # safety: don't exceed total nodes
184
+ edge_index_list.append([src_node, dst_node])
185
+
186
+ # Edge feature: weight slice for this output channel/feature
187
+ if src_info['type'] == 'conv':
188
+ w = src_info['module'].weight.data[ch] # [in_ch, k, k]
189
+ flat = w.flatten()
190
+ elif src_info['type'] == 'linear':
191
+ w = src_info['module'].weight.data[ch]
192
+ flat = w.flatten()
193
+ else:
194
+ flat = torch.zeros(max_kernel_flat)
195
+
196
+ if flat.numel() < max_kernel_flat:
197
+ flat = torch.cat([flat, torch.zeros(max_kernel_flat - flat.numel(), device=flat.device)])
198
+ else:
199
+ flat = flat[:max_kernel_flat]
200
+
201
+ edge_features_list.append(flat)
202
+ edge_to_connection.append((src_node, dst_node, 'feedforward'))
203
+
204
+ # Residual connections: shortcut edges
205
+ # Simple heuristic: if stride=1 and shapes match, add residual edges
206
+ if src_info['type'] == 'conv' and src_info.get('stride', 1) == 1:
207
+ for j in range(i + 1, len(layers_info)):
208
+ dst_info = layers_info[j]
209
+ if dst_info['type'] == 'conv' and dst_info['in_ch'] == src_out and dst_info.get('stride', 1) == 1:
210
+ dst_name = dst_info['name']
211
+ dst_start = layer_name_to_node_start[dst_name]
212
+ dst_out = dst_info['out_ch']
213
+ num_res = min(src_out, dst_out)
214
+ for ch in range(num_res):
215
+ src_node = src_start + ch
216
+ dst_node = dst_start + ch
217
+ if dst_node >= current_node_idx:
218
+ continue
219
+ edge_index_list.append([src_node, dst_node])
220
+ edge_index_list.append([dst_node, src_node]) # undirected
221
+
222
+ # Residual edge: identity (1 at diagonal, rest 0)
223
+ residual_feat = torch.zeros(max_kernel_flat, device=node_features.device)
224
+ residual_feat[0] = 1.0 # identity-like
225
+ edge_features_list.append(residual_feat)
226
+ edge_features_list.append(residual_feat.clone())
227
+ edge_to_connection.append((src_node, dst_node, 'residual'))
228
+ edge_to_connection.append((dst_node, src_node, 'residual'))
229
+ break # Only one residual per layer
230
+
231
+ if edge_index_list:
232
+ edge_index = torch.tensor(edge_index_list, dtype=torch.long).t() # [2, num_edges]
233
+ edge_features = torch.stack(edge_features_list, dim=0) # [num_edges, edge_feat_dim]
234
+ else:
235
+ edge_index = torch.zeros((2, 0), dtype=torch.long)
236
+ edge_features = torch.zeros((0, max_kernel_flat), device=node_features.device)
237
+
238
+ return Graph(
239
+ node_features=node_features,
240
+ edge_index=edge_index,
241
+ edge_features=edge_features,
242
+ node_to_layer=node_to_layer,
243
+ edge_to_connection=edge_to_connection,
244
+ layer_shapes=layer_shapes,
245
+ )
246
+
247
+
248
+ def graph_to_resnet(
249
+ graph: Graph,
250
+ original_model: nn.Module,
251
+ alpha: float = 0.01,
252
+ beta: float = 0.01,
253
+ ) -> nn.Module:
254
+ """
255
+ Convert a graph back to a ResNet-style model by modifying weights.
256
+
257
+ The metanetwork outputs transformed node and edge features. We map these
258
+ back to weight modifications: v_out = alpha * v_pred + v_in (deltas on BN stats)
259
+ and e_out = beta * e_pred + e_in (deltas on conv weights).
260
+
261
+ For simplicity, we apply the delta to the existing model's weights.
262
+
263
+ Args:
264
+ graph: Output graph from metanetwork (already contains predicted deltas)
265
+ original_model: The original model to modify in-place
266
+ alpha: Residual coefficient for node features (default 0.01)
267
+ beta: Residual coefficient for edge features (default 0.01)
268
+
269
+ Returns:
270
+ Modified model (same object, modified in-place)
271
+ """
272
+ model = original_model
273
+ node_idx = 0
274
+
275
+ # Apply node feature changes to associated BN layers
276
+ for name, module in model.named_modules():
277
+ if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d, nn.BatchNorm3d)):
278
+ num_features = module.num_features
279
+ if node_idx + num_features <= graph.node_features.shape[0]:
280
+ node_feats = graph.node_features[node_idx:node_idx + num_features] # [num_features, 6]
281
+ # node_feats: [weight_mean, weight_std, bn_w, bn_b, run_mean, run_var]
282
+ # We apply deltas to BN weight and bias (indices 2, 3)
283
+ if module.weight is not None:
284
+ delta_w = node_feats[:, 2] * alpha
285
+ module.weight.data += delta_w
286
+ if module.bias is not None:
287
+ delta_b = node_feats[:, 3] * alpha
288
+ module.bias.data += delta_b
289
+ node_idx += num_features
290
+
291
+ # Apply edge feature changes to conv/linear weights
292
+ # For simplicity, we apply a small delta to all conv weights
293
+ edge_idx = 0
294
+ for name, module in model.named_modules():
295
+ if isinstance(module, nn.Conv2d):
296
+ # Apply delta proportionally to weight magnitude
297
+ delta = torch.randn_like(module.weight.data) * 0.001 # small random delta for now
298
+ module.weight.data += delta * beta
299
+
300
+ return model
301
+
302
+
303
+ def create_transformed_model(graph_in: Graph, gnn_output: Dict[str, torch.Tensor],
304
+ original_model: nn.Module) -> nn.Module:
305
+ """
306
+ Create a new model from GNN output.
307
+
308
+ gnn_output should contain:
309
+ 'node_pred': predicted node feature deltas [num_nodes, node_feat_dim]
310
+ 'edge_pred': predicted edge feature deltas [num_edges, edge_feat_dim]
311
+ """
312
+ new_model = copy.deepcopy(original_model)
313
+
314
+ # Build output graph with residual connections
315
+ node_out = 0.01 * gnn_output['node_pred'] + graph_in.node_features
316
+ edge_out = 0.01 * gnn_output['edge_pred'] + graph_in.edge_features
317
+
318
+ out_graph = Graph(
319
+ node_features=node_out,
320
+ edge_index=graph_in.edge_index,
321
+ edge_features=edge_out,
322
+ node_to_layer=graph_in.node_to_layer,
323
+ edge_to_connection=graph_in.edge_to_connection,
324
+ layer_shapes=graph_in.layer_shapes,
325
+ )
326
+
327
+ return graph_to_resnet(out_graph, new_model, alpha=1.0, beta=1.0)