""" MetaPruning Inference Pipeline. 1. Load trained metanetwork 2. Take any target model 3. Convert -> metanetwork feedforward -> transform back 4. Finetune the transformed model 5. Prune using magnitude-based criterion 6. Evaluate Paper: "Meta Pruning via Graph Metanetworks" (arXiv:2506.12041) """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from datasets import load_dataset from torchvision import transforms import argparse import os from graph import resnet_to_graph, create_transformed_model, Graph from gnn import Metanetwork # --------------------------------------------------------------------------- # Data loading (same as training script) # --------------------------------------------------------------------------- def get_cifar10_loaders(batch_size=128, num_workers=4): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) ds_train = load_dataset("uoft-cs/cifar10", split="train") ds_test = load_dataset("uoft-cs/cifar10", split="test") def map_train(examples): images = [transform_train(img.convert("RGB")) for img in examples["img"]] return {"pixel_values": images, "labels": examples["label"]} def map_test(examples): images = [transform_test(img.convert("RGB")) for img in examples["img"]] return {"pixel_values": images, "labels": examples["label"]} ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"]) ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"]) ds_train.set_format(type="torch", columns=["pixel_values", "labels"]) ds_test.set_format(type="torch", columns=["pixel_values", "labels"]) train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, test_loader # --------------------------------------------------------------------------- # Model definitions (from training script) # --------------------------------------------------------------------------- def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = conv3x3(in_planes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.in_planes = 16 self.conv1 = conv3x3(3, 16) self.bn1 = nn.BatchNorm2d(16) self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) self.linear = nn.Linear(64 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for s in strides: layers.append(block(self.in_planes, planes, s)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = F.avg_pool2d(out, out.size()[3]) out = out.view(out.size(0), -1) out = self.linear(out) return out def ResNet56(num_classes=10): return ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes) def ResNet110(num_classes=10): return ResNet(BasicBlock, [18, 18, 18], num_classes=num_classes) # --------------------------------------------------------------------------- # Training helpers # --------------------------------------------------------------------------- def train_epoch(model, loader, optimizer, criterion, device): model.train() total_loss = 0.0 correct = 0 total = 0 for batch in loader: inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() total_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return total_loss / total, 100.0 * correct / total @torch.no_grad() def evaluate(model, loader, criterion, device): model.eval() total_loss = 0.0 correct = 0 total = 0 for batch in loader: inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) outputs = model(inputs) loss = criterion(outputs, targets) total_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return total_loss / total, 100.0 * correct / total # --------------------------------------------------------------------------- # Simple magnitude-based structured pruning # --------------------------------------------------------------------------- def prune_model(model, target_sparsity=0.5): """ Simple channel pruning based on L2 norm of conv filter weights. For a proper implementation, use torch-pruning with DepGraph. """ # Compute L2 norm per output channel for each conv layer channel_norms = {} for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): norms = module.weight.data.view(module.out_channels, -1).norm(dim=1) channel_norms[name] = norms # For simplicity, just compute a global threshold across all channels all_norms = torch.cat([n for n in channel_norms.values()]) threshold_idx = int(target_sparsity * all_norms.numel()) threshold = torch.sort(all_norms)[0][threshold_idx].item() # Prune channels below threshold for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): norms = channel_norms[name] keep_mask = norms > threshold num_keep = keep_mask.sum().item() if num_keep < module.out_channels: # Simple: just zero out pruned channels for ch in range(module.out_channels): if not keep_mask[ch]: module.weight.data[ch] = 0 if module.bias is not None: module.bias.data[ch] = 0 print(f"[Prune] Applied simple magnitude pruning (target={target_sparsity:.2f})") def compute_model_sparsity(model): total = 0 zeros = 0 for p in model.parameters(): total += p.numel() zeros += (p.data == 0).sum().item() return zeros / total # --------------------------------------------------------------------------- # Main inference pipeline # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="MetaPruning Inference Pipeline") parser.add_argument("--metanetwork_path", type=str, required=True, help="Path to trained metanetwork checkpoint") parser.add_argument("--target_model", type=str, default="resnet56", choices=["resnet56", "resnet110"]) parser.add_argument("--finetune_epochs", type=int, default=100, help="Finetune epochs after metanetwork (paper uses 100-200)") parser.add_argument("--prune_sparsity", type=float, default=0.5, help="Target pruning sparsity") parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=0.01) parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--weight_decay", type=float, default=5e-4) parser.add_argument("--milestones", type=int, nargs="+", default=[60, 90]) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--seed", type=int, default=42) args = parser.parse_args() torch.manual_seed(args.seed) device = torch.device(args.device) print(f"Using device: {device}") # Load data train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers) # Load target model if args.target_model == "resnet56": model = ResNet56(num_classes=10).to(device) else: model = ResNet110(num_classes=10).to(device) print(f"Loaded target model: {args.target_model}") # Load metanetwork ckpt = torch.load(args.metanetwork_path, map_location=device) config = ckpt["config"] metanetwork = Metanetwork( node_in_dim=config["node_in_dim"], edge_in_dim=config["edge_in_dim"], node_out_dim=config["node_out_dim"], edge_out_dim=config["edge_out_dim"], hidden_dim=config["hidden_dim"], num_layers=config["num_layers"], alpha=config["alpha"], beta=config["beta"], ).to(device) metanetwork.load_state_dict(ckpt["metanetwork_state_dict"]) metanetwork.eval() print(f"Loaded metanetwork (hidden_dim={config['hidden_dim']}, layers={config['num_layers']})") # Baseline: evaluate untransformed model criterion = nn.CrossEntropyLoss() _, base_acc = evaluate(model, test_loader, criterion, device) print(f"\nBaseline model accuracy (before metanetwork): {base_acc:.2f}%") # Step 1: Convert to graph print("\n[Step 1] Converting model to graph...") graph = resnet_to_graph(model, max_kernel_size=3) print(f" Nodes: {graph.node_features.size(0)}, Edges: {graph.edge_features.size(0)}") # Step 2: Feed through metanetwork print("[Step 2] Metanetwork feedforward...") with torch.no_grad(): graph.node_features = graph.node_features.to(device) graph.edge_features = graph.edge_features.to(device) graph.edge_index = graph.edge_index.to(device) gnn_output = metanetwork( graph.node_features, graph.edge_index, graph.edge_features, ) # Step 3: Convert back to transformed model print("[Step 3] Converting transformed graph back to model...") transformed_model = create_transformed_model(graph, gnn_output, model).to(device) # Evaluate after metanetwork (before finetuning) _, meta_acc = evaluate(transformed_model, test_loader, criterion, device) print(f" Accuracy after metanetwork (before finetune): {meta_acc:.2f}%") # Step 4: Finetune transformed model print(f"\n[Step 4] Finetuning for {args.finetune_epochs} epochs...") optimizer = optim.SGD(transformed_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1) best_acc = 0.0 for epoch in range(args.finetune_epochs): train_loss, train_acc = train_epoch(transformed_model, train_loader, optimizer, criterion, device) val_loss, val_acc = evaluate(transformed_model, test_loader, criterion, device) scheduler.step() if val_acc > best_acc: best_acc = val_acc if (epoch + 1) % 20 == 0: print(f" Epoch {epoch+1:3d}: train_acc={train_acc:.2f}%, val_acc={val_acc:.2f}%") print(f" Best finetuned accuracy: {best_acc:.2f}%") # Step 5: Prune print(f"\n[Step 5] Pruning (target sparsity={args.prune_sparsity:.2f})...") prune_model(transformed_model, target_sparsity=args.prune_sparsity) sparsity = compute_model_sparsity(transformed_model) print(f" Actual model sparsity: {sparsity:.4f}") # Evaluate pruned model _, pruned_acc = evaluate(transformed_model, test_loader, criterion, device) print(f" Accuracy after pruning: {pruned_acc:.2f}%") # Summary print("\n" + "=" * 50) print("SUMMARY") print("=" * 50) print(f"Baseline accuracy: {base_acc:.2f}%") print(f"After metanetwork: {meta_acc:.2f}%") print(f"After finetuning: {best_acc:.2f}%") print(f"After pruning: {pruned_acc:.2f}%") print(f"Sparsity: {sparsity:.4f}") print(f"Accuracy drop: {base_acc - pruned_acc:.2f}%") print("=" * 50) if __name__ == "__main__": main()