""" PDP Training Script for CIFAR-10 with ResNet18 Based on: PDP: Parameter-free Differentiable Pruning is All You Need (NeurIPS 2023) """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets import load_dataset import numpy as np import argparse import json import os from tqdm import tqdm from pdp import PDPPruner # --------------------------------------------------------------------------- # CIFAR-10 adapted ResNet18 # --------------------------------------------------------------------------- def conv3x3(in_planes, out_planes, stride=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super().__init__() self.conv1 = conv3x3(in_planes, planes, stride) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = conv3x3(planes, planes) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super().__init__() self.in_planes = 64 # First conv adapted for 32x32 CIFAR-10 self.conv1 = conv3x3(3, 64) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for s in strides: layers.append(block(self.in_planes, planes, s)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out def ResNet18(num_classes=10): return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) # --------------------------------------------------------------------------- # Data loading # --------------------------------------------------------------------------- def get_cifar10_loaders(batch_size=128, num_workers=4): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) ds_train = load_dataset("uoft-cs/cifar10", split="train") ds_test = load_dataset("uoft-cs/cifar10", split="test") def map_train(examples): images = [transform_train(img.convert("RGB")) for img in examples["img"]] return {"pixel_values": images, "labels": examples["label"]} def map_test(examples): images = [transform_test(img.convert("RGB")) for img in examples["img"]] return {"pixel_values": images, "labels": examples["label"]} ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"]) ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"]) ds_train.set_format(type="torch", columns=["pixel_values", "labels"]) ds_test.set_format(type="torch", columns=["pixel_values", "labels"]) train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) return train_loader, test_loader # --------------------------------------------------------------------------- # Training & evaluation helpers # --------------------------------------------------------------------------- def train_epoch(model, loader, optimizer, criterion, device, pruner=None, epoch=None): model.train() total_loss = 0.0 correct = 0 total = 0 for batch in loader: inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() if pruner is not None and epoch is not None: pruner.step(epoch) total_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return total_loss / total, 100.0 * correct / total @torch.no_grad() def evaluate(model, loader, criterion, device): model.eval() total_loss = 0.0 correct = 0 total = 0 for batch in loader: inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) outputs = model(inputs) loss = criterion(outputs, targets) total_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return total_loss / total, 100.0 * correct / total # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser(description="PDP Training on CIFAR-10") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--lr", type=float, default=0.1) parser.add_argument("--momentum", type=float, default=0.9) parser.add_argument("--weight_decay", type=float, default=5e-4) parser.add_argument("--target_sparsity", type=float, default=0.85) parser.add_argument("--s", type=int, default=16, help="Warmup epochs before pruning starts") parser.add_argument("--epsilon", type=float, default=0.015, help="Gradual pruning rate per epoch") parser.add_argument("--tau", type=float, default=1e-4, help="PDP temperature") parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--save_dir", type=str, default="./checkpoints") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() torch.manual_seed(args.seed) if args.device == "cuda": torch.cuda.manual_seed(args.seed) os.makedirs(args.save_dir, exist_ok=True) device = torch.device(args.device) print(f"Using device: {device}") # Data train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers) print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}") # Model model = ResNet18(num_classes=10).to(device) print(f"Model params: {sum(p.numel() for p in model.parameters()):,}") # Optimizer & scheduler criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) # PDP Pruner pruner = PDPPruner( model=model, target_sparsity=args.target_sparsity, s=args.s, epsilon=args.epsilon, tau=args.tau, ) pruner.attach() # Training loop history = [] best_acc = 0.0 for epoch in range(args.epochs): train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, pruner=pruner, epoch=epoch) val_loss, val_acc = evaluate(model, test_loader, criterion, device) scheduler.step() current_sparsity = pruner.get_sparsity() effective = pruner.current_effective_sparsity print(f"Epoch {epoch+1:3d}/{args.epochs} | " f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | " f"Sparsity: {current_sparsity:.4f} (eff: {effective:.4f}) | " f"LR: {optimizer.param_groups[0]['lr']:.4f}") history.append({ "epoch": epoch + 1, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc, "sparsity": current_sparsity, "effective_sparsity": effective, "lr": optimizer.param_groups[0]["lr"], }) if val_acc > best_acc: best_acc = val_acc ckpt_path = os.path.join(args.save_dir, "best_model.pt") torch.save({ "epoch": epoch + 1, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "pruner_state_dict": pruner.state_dict(), "val_acc": val_acc, }, ckpt_path) # Final hard prune and evaluation print("\n--- Final Hard Pruning ---") pruner.hard_prune() final_sparsity = pruner.get_sparsity() final_val_loss, final_val_acc = evaluate(model, test_loader, criterion, device) print(f"After hard prune: Sparsity={final_sparsity:.4f}, Val Acc={final_val_acc:.2f}%") # Save final model final_path = os.path.join(args.save_dir, "final_model.pt") torch.save({ "model_state_dict": model.state_dict(), "pruner_state_dict": pruner.state_dict(), "final_sparsity": final_sparsity, "final_val_acc": final_val_acc, }, final_path) # Save history with open(os.path.join(args.save_dir, "history.json"), "w") as f: json.dump(history, f, indent=2) print(f"\nBest validation accuracy: {best_acc:.2f}%") print(f"Final pruned model saved to {final_path}") if __name__ == "__main__": main()