ESPR3SS0's picture
Add metapruning/inference.py
4d76378 verified
"""
MetaPruning Inference Pipeline.
1. Load trained metanetwork
2. Take any target model
3. Convert -> metanetwork feedforward -> transform back
4. Finetune the transformed model
5. Prune using magnitude-based criterion
6. Evaluate
Paper: "Meta Pruning via Graph Metanetworks" (arXiv:2506.12041)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision import transforms
import argparse
import os
from graph import resnet_to_graph, create_transformed_model, Graph
from gnn import Metanetwork
# ---------------------------------------------------------------------------
# Data loading (same as training script)
# ---------------------------------------------------------------------------
def get_cifar10_loaders(batch_size=128, num_workers=4):
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
ds_train = load_dataset("uoft-cs/cifar10", split="train")
ds_test = load_dataset("uoft-cs/cifar10", split="test")
def map_train(examples):
images = [transform_train(img.convert("RGB")) for img in examples["img"]]
return {"pixel_values": images, "labels": examples["label"]}
def map_test(examples):
images = [transform_test(img.convert("RGB")) for img in examples["img"]]
return {"pixel_values": images, "labels": examples["label"]}
ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"])
ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"])
ds_train.set_format(type="torch", columns=["pixel_values", "labels"])
ds_test.set_format(type="torch", columns=["pixel_values", "labels"])
train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
return train_loader, test_loader
# ---------------------------------------------------------------------------
# Model definitions (from training script)
# ---------------------------------------------------------------------------
def conv3x3(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.in_planes = 16
self.conv1 = conv3x3(3, 16)
self.bn1 = nn.BatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = nn.Linear(64 * block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for s in strides:
layers.append(block(self.in_planes, planes, s))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def ResNet56(num_classes=10):
return ResNet(BasicBlock, [9, 9, 9], num_classes=num_classes)
def ResNet110(num_classes=10):
return ResNet(BasicBlock, [18, 18, 18], num_classes=num_classes)
# ---------------------------------------------------------------------------
# Training helpers
# ---------------------------------------------------------------------------
def train_epoch(model, loader, optimizer, criterion, device):
model.train()
total_loss = 0.0
correct = 0
total = 0
for batch in loader:
inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return total_loss / total, 100.0 * correct / total
@torch.no_grad()
def evaluate(model, loader, criterion, device):
model.eval()
total_loss = 0.0
correct = 0
total = 0
for batch in loader:
inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return total_loss / total, 100.0 * correct / total
# ---------------------------------------------------------------------------
# Simple magnitude-based structured pruning
# ---------------------------------------------------------------------------
def prune_model(model, target_sparsity=0.5):
"""
Simple channel pruning based on L2 norm of conv filter weights.
For a proper implementation, use torch-pruning with DepGraph.
"""
# Compute L2 norm per output channel for each conv layer
channel_norms = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
norms = module.weight.data.view(module.out_channels, -1).norm(dim=1)
channel_norms[name] = norms
# For simplicity, just compute a global threshold across all channels
all_norms = torch.cat([n for n in channel_norms.values()])
threshold_idx = int(target_sparsity * all_norms.numel())
threshold = torch.sort(all_norms)[0][threshold_idx].item()
# Prune channels below threshold
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
norms = channel_norms[name]
keep_mask = norms > threshold
num_keep = keep_mask.sum().item()
if num_keep < module.out_channels:
# Simple: just zero out pruned channels
for ch in range(module.out_channels):
if not keep_mask[ch]:
module.weight.data[ch] = 0
if module.bias is not None:
module.bias.data[ch] = 0
print(f"[Prune] Applied simple magnitude pruning (target={target_sparsity:.2f})")
def compute_model_sparsity(model):
total = 0
zeros = 0
for p in model.parameters():
total += p.numel()
zeros += (p.data == 0).sum().item()
return zeros / total
# ---------------------------------------------------------------------------
# Main inference pipeline
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="MetaPruning Inference Pipeline")
parser.add_argument("--metanetwork_path", type=str, required=True,
help="Path to trained metanetwork checkpoint")
parser.add_argument("--target_model", type=str, default="resnet56",
choices=["resnet56", "resnet110"])
parser.add_argument("--finetune_epochs", type=int, default=100,
help="Finetune epochs after metanetwork (paper uses 100-200)")
parser.add_argument("--prune_sparsity", type=float, default=0.5,
help="Target pruning sparsity")
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=0.01)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight_decay", type=float, default=5e-4)
parser.add_argument("--milestones", type=int, nargs="+", default=[60, 90])
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device(args.device)
print(f"Using device: {device}")
# Load data
train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers)
# Load target model
if args.target_model == "resnet56":
model = ResNet56(num_classes=10).to(device)
else:
model = ResNet110(num_classes=10).to(device)
print(f"Loaded target model: {args.target_model}")
# Load metanetwork
ckpt = torch.load(args.metanetwork_path, map_location=device)
config = ckpt["config"]
metanetwork = Metanetwork(
node_in_dim=config["node_in_dim"],
edge_in_dim=config["edge_in_dim"],
node_out_dim=config["node_out_dim"],
edge_out_dim=config["edge_out_dim"],
hidden_dim=config["hidden_dim"],
num_layers=config["num_layers"],
alpha=config["alpha"],
beta=config["beta"],
).to(device)
metanetwork.load_state_dict(ckpt["metanetwork_state_dict"])
metanetwork.eval()
print(f"Loaded metanetwork (hidden_dim={config['hidden_dim']}, layers={config['num_layers']})")
# Baseline: evaluate untransformed model
criterion = nn.CrossEntropyLoss()
_, base_acc = evaluate(model, test_loader, criterion, device)
print(f"\nBaseline model accuracy (before metanetwork): {base_acc:.2f}%")
# Step 1: Convert to graph
print("\n[Step 1] Converting model to graph...")
graph = resnet_to_graph(model, max_kernel_size=3)
print(f" Nodes: {graph.node_features.size(0)}, Edges: {graph.edge_features.size(0)}")
# Step 2: Feed through metanetwork
print("[Step 2] Metanetwork feedforward...")
with torch.no_grad():
graph.node_features = graph.node_features.to(device)
graph.edge_features = graph.edge_features.to(device)
graph.edge_index = graph.edge_index.to(device)
gnn_output = metanetwork(
graph.node_features,
graph.edge_index,
graph.edge_features,
)
# Step 3: Convert back to transformed model
print("[Step 3] Converting transformed graph back to model...")
transformed_model = create_transformed_model(graph, gnn_output, model).to(device)
# Evaluate after metanetwork (before finetuning)
_, meta_acc = evaluate(transformed_model, test_loader, criterion, device)
print(f" Accuracy after metanetwork (before finetune): {meta_acc:.2f}%")
# Step 4: Finetune transformed model
print(f"\n[Step 4] Finetuning for {args.finetune_epochs} epochs...")
optimizer = optim.SGD(transformed_model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1)
best_acc = 0.0
for epoch in range(args.finetune_epochs):
train_loss, train_acc = train_epoch(transformed_model, train_loader, optimizer, criterion, device)
val_loss, val_acc = evaluate(transformed_model, test_loader, criterion, device)
scheduler.step()
if val_acc > best_acc:
best_acc = val_acc
if (epoch + 1) % 20 == 0:
print(f" Epoch {epoch+1:3d}: train_acc={train_acc:.2f}%, val_acc={val_acc:.2f}%")
print(f" Best finetuned accuracy: {best_acc:.2f}%")
# Step 5: Prune
print(f"\n[Step 5] Pruning (target sparsity={args.prune_sparsity:.2f})...")
prune_model(transformed_model, target_sparsity=args.prune_sparsity)
sparsity = compute_model_sparsity(transformed_model)
print(f" Actual model sparsity: {sparsity:.4f}")
# Evaluate pruned model
_, pruned_acc = evaluate(transformed_model, test_loader, criterion, device)
print(f" Accuracy after pruning: {pruned_acc:.2f}%")
# Summary
print("\n" + "=" * 50)
print("SUMMARY")
print("=" * 50)
print(f"Baseline accuracy: {base_acc:.2f}%")
print(f"After metanetwork: {meta_acc:.2f}%")
print(f"After finetuning: {best_acc:.2f}%")
print(f"After pruning: {pruned_acc:.2f}%")
print(f"Sparsity: {sparsity:.4f}")
print(f"Accuracy drop: {base_acc - pruned_acc:.2f}%")
print("=" * 50)
if __name__ == "__main__":
main()