| """ |
| MetaPruning Inference Pipeline. |
| |
| 1. Load trained metanetwork |
| 2. Take any target model |
| 3. Convert -> metanetwork feedforward -> transform back |
| 4. Finetune the transformed model |
| 5. Prune using magnitude-based criterion |
| 6. Evaluate |
| |
| Paper: "Meta Pruning via Graph Metanetworks" (arXiv:2506.12041) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from datasets import load_dataset |
| from torchvision import transforms |
| import argparse |
| import os |
|
|
| from graph import resnet_to_graph, create_transformed_model, Graph |
| from gnn import Metanetwork |
|
|
|
|
| |
| |
| |
|
|
| def get_cifar10_loaders(batch_size=128, num_workers=4): |
| transform_train = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
|
|
| transform_test = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
|
|
| ds_train = load_dataset("uoft-cs/cifar10", split="train") |
| ds_test = load_dataset("uoft-cs/cifar10", split="test") |
|
|
| def map_train(examples): |
| images = [transform_train(img.convert("RGB")) for img in examples["img"]] |
| return {"pixel_values": images, "labels": examples["label"]} |
|
|
| def map_test(examples): |
| images = [transform_test(img.convert("RGB")) for img in examples["img"]] |
| return {"pixel_values": images, "labels": examples["label"]} |
|
|
| ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"]) |
| ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"]) |
|
|
| ds_train.set_format(type="torch", columns=["pixel_values", "labels"]) |
| ds_test.set_format(type="torch", columns=["pixel_values", "labels"]) |
|
|
| train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) |
| test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) |
|
|
| return train_loader, test_loader |
|
|
|
|
| |
| |
| |
|
|
| def conv3x3(in_planes, out_planes, stride=1): |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| padding=1, bias=False) |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
| def __init__(self, in_planes, planes, stride=1): |
| super().__init__() |
| self.conv1 = conv3x3(in_planes, planes, stride) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, |
| stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes) |
| ) |
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class ResNet(nn.Module): |
| def __init__(self, block, num_blocks, num_classes=10): |
| super().__init__() |
| self.in_planes = 16 |
| self.conv1 = conv3x3(3, 16) |
| self.bn1 = nn.BatchNorm2d(16) |
| self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) |
| self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) |
| self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) |
| self.linear = nn.Linear(64 * block.expansion, num_classes) |
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for s in strides: |
| layers.append(block(self.in_planes, planes, s)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = F.avg_pool2d(out, out.size()[3]) |
| out = out.view(out.size(0), -1) |
| out = self.linear(out) |
| return out |
|
|
|
|
| def ResNet56(num_classes=10): |
| return ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes) |
|
|
|
|
| def ResNet110(num_classes=10): |
| return ResNet(BasicBlock, [18, 18, 18], num_classes=num_classes) |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, loader, optimizer, criterion, device): |
| model.train() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| for batch in loader: |
| inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) |
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() * inputs.size(0) |
| _, predicted = outputs.max(1) |
| total += targets.size(0) |
| correct += predicted.eq(targets).sum().item() |
| return total_loss / total, 100.0 * correct / total |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device): |
| model.eval() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| for batch in loader: |
| inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
| total_loss += loss.item() * inputs.size(0) |
| _, predicted = outputs.max(1) |
| total += targets.size(0) |
| correct += predicted.eq(targets).sum().item() |
| return total_loss / total, 100.0 * correct / total |
|
|
|
|
| |
| |
| |
|
|
| def prune_model(model, target_sparsity=0.5): |
| """ |
| Simple channel pruning based on L2 norm of conv filter weights. |
| For a proper implementation, use torch-pruning with DepGraph. |
| """ |
| |
| channel_norms = {} |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Conv2d): |
| norms = module.weight.data.view(module.out_channels, -1).norm(dim=1) |
| channel_norms[name] = norms |
| |
| |
| all_norms = torch.cat([n for n in channel_norms.values()]) |
| threshold_idx = int(target_sparsity * all_norms.numel()) |
| threshold = torch.sort(all_norms)[0][threshold_idx].item() |
| |
| |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Conv2d): |
| norms = channel_norms[name] |
| keep_mask = norms > threshold |
| num_keep = keep_mask.sum().item() |
| if num_keep < module.out_channels: |
| |
| for ch in range(module.out_channels): |
| if not keep_mask[ch]: |
| module.weight.data[ch] = 0 |
| if module.bias is not None: |
| module.bias.data[ch] = 0 |
| |
| print(f"[Prune] Applied simple magnitude pruning (target={target_sparsity:.2f})") |
|
|
|
|
| def compute_model_sparsity(model): |
| total = 0 |
| zeros = 0 |
| for p in model.parameters(): |
| total += p.numel() |
| zeros += (p.data == 0).sum().item() |
| return zeros / total |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MetaPruning Inference Pipeline") |
| parser.add_argument("--metanetwork_path", type=str, required=True, |
| help="Path to trained metanetwork checkpoint") |
| parser.add_argument("--target_model", type=str, default="resnet56", |
| choices=["resnet56", "resnet110"]) |
| parser.add_argument("--finetune_epochs", type=int, default=100, |
| help="Finetune epochs after metanetwork (paper uses 100-200)") |
| parser.add_argument("--prune_sparsity", type=float, default=0.5, |
| help="Target pruning sparsity") |
| parser.add_argument("--batch_size", type=int, default=128) |
| parser.add_argument("--lr", type=float, default=0.01) |
| parser.add_argument("--momentum", type=float, default=0.9) |
| parser.add_argument("--weight_decay", type=float, default=5e-4) |
| parser.add_argument("--milestones", type=int, nargs="+", default=[60, 90]) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--seed", type=int, default=42) |
| args = parser.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
|
|
| |
| train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers) |
|
|
| |
| if args.target_model == "resnet56": |
| model = ResNet56(num_classes=10).to(device) |
| else: |
| model = ResNet110(num_classes=10).to(device) |
| print(f"Loaded target model: {args.target_model}") |
|
|
| |
| ckpt = torch.load(args.metanetwork_path, map_location=device) |
| config = ckpt["config"] |
|
|
| metanetwork = Metanetwork( |
| node_in_dim=config["node_in_dim"], |
| edge_in_dim=config["edge_in_dim"], |
| node_out_dim=config["node_out_dim"], |
| edge_out_dim=config["edge_out_dim"], |
| hidden_dim=config["hidden_dim"], |
| num_layers=config["num_layers"], |
| alpha=config["alpha"], |
| beta=config["beta"], |
| ).to(device) |
| metanetwork.load_state_dict(ckpt["metanetwork_state_dict"]) |
| metanetwork.eval() |
| print(f"Loaded metanetwork (hidden_dim={config['hidden_dim']}, layers={config['num_layers']})") |
|
|
| |
| criterion = nn.CrossEntropyLoss() |
| _, base_acc = evaluate(model, test_loader, criterion, device) |
| print(f"\nBaseline model accuracy (before metanetwork): {base_acc:.2f}%") |
|
|
| |
| print("\n[Step 1] Converting model to graph...") |
| graph = resnet_to_graph(model, max_kernel_size=3) |
| print(f" Nodes: {graph.node_features.size(0)}, Edges: {graph.edge_features.size(0)}") |
|
|
| |
| print("[Step 2] Metanetwork feedforward...") |
| with torch.no_grad(): |
| graph.node_features = graph.node_features.to(device) |
| graph.edge_features = graph.edge_features.to(device) |
| graph.edge_index = graph.edge_index.to(device) |
|
|
| gnn_output = metanetwork( |
| graph.node_features, |
| graph.edge_index, |
| graph.edge_features, |
| ) |
|
|
| |
| print("[Step 3] Converting transformed graph back to model...") |
| transformed_model = create_transformed_model(graph, gnn_output, model).to(device) |
|
|
| |
| _, meta_acc = evaluate(transformed_model, test_loader, criterion, device) |
| print(f" Accuracy after metanetwork (before finetune): {meta_acc:.2f}%") |
|
|
| |
| print(f"\n[Step 4] Finetuning for {args.finetune_epochs} epochs...") |
| optimizer = optim.SGD(transformed_model.parameters(), lr=args.lr, |
| momentum=args.momentum, weight_decay=args.weight_decay) |
| scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1) |
|
|
| best_acc = 0.0 |
| for epoch in range(args.finetune_epochs): |
| train_loss, train_acc = train_epoch(transformed_model, train_loader, optimizer, criterion, device) |
| val_loss, val_acc = evaluate(transformed_model, test_loader, criterion, device) |
| scheduler.step() |
|
|
| if val_acc > best_acc: |
| best_acc = val_acc |
|
|
| if (epoch + 1) % 20 == 0: |
| print(f" Epoch {epoch+1:3d}: train_acc={train_acc:.2f}%, val_acc={val_acc:.2f}%") |
|
|
| print(f" Best finetuned accuracy: {best_acc:.2f}%") |
|
|
| |
| print(f"\n[Step 5] Pruning (target sparsity={args.prune_sparsity:.2f})...") |
| prune_model(transformed_model, target_sparsity=args.prune_sparsity) |
| sparsity = compute_model_sparsity(transformed_model) |
| print(f" Actual model sparsity: {sparsity:.4f}") |
|
|
| |
| _, pruned_acc = evaluate(transformed_model, test_loader, criterion, device) |
| print(f" Accuracy after pruning: {pruned_acc:.2f}%") |
|
|
| |
| print("\n" + "=" * 50) |
| print("SUMMARY") |
| print("=" * 50) |
| print(f"Baseline accuracy: {base_acc:.2f}%") |
| print(f"After metanetwork: {meta_acc:.2f}%") |
| print(f"After finetuning: {best_acc:.2f}%") |
| print(f"After pruning: {pruned_acc:.2f}%") |
| print(f"Sparsity: {sparsity:.4f}") |
| print(f"Accuracy drop: {base_acc - pruned_acc:.2f}%") |
| print("=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|