| """ |
| PDP Training Script for CIFAR-10 with ResNet18 |
| Based on: PDP: Parameter-free Differentiable Pruning is All You Need (NeurIPS 2023) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from datasets import load_dataset |
| import numpy as np |
| import argparse |
| import json |
| import os |
| from tqdm import tqdm |
|
|
| from pdp import PDPPruner |
|
|
|
|
| |
| |
| |
|
|
| def conv3x3(in_planes, out_planes, stride=1): |
| return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, |
| padding=1, bias=False) |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, in_planes, planes, stride=1): |
| super().__init__() |
| self.conv1 = conv3x3(in_planes, planes, stride) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = conv3x3(planes, planes) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, |
| stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes) |
| ) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class ResNet(nn.Module): |
| def __init__(self, block, num_blocks, num_classes=10): |
| super().__init__() |
| self.in_planes = 64 |
| |
| self.conv1 = conv3x3(3, 64) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) |
| self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) |
| self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) |
| self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) |
| self.linear = nn.Linear(512 * block.expansion, num_classes) |
|
|
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for s in strides: |
| layers.append(block(self.in_planes, planes, s)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
| out = F.avg_pool2d(out, 4) |
| out = out.view(out.size(0), -1) |
| out = self.linear(out) |
| return out |
|
|
|
|
| def ResNet18(num_classes=10): |
| return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes) |
|
|
|
|
| |
| |
| |
|
|
| def get_cifar10_loaders(batch_size=128, num_workers=4): |
| transform_train = transforms.Compose([ |
| transforms.RandomCrop(32, padding=4), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
|
|
| transform_test = transforms.Compose([ |
| transforms.ToTensor(), |
| transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), |
| ]) |
|
|
| ds_train = load_dataset("uoft-cs/cifar10", split="train") |
| ds_test = load_dataset("uoft-cs/cifar10", split="test") |
|
|
| def map_train(examples): |
| images = [transform_train(img.convert("RGB")) for img in examples["img"]] |
| return {"pixel_values": images, "labels": examples["label"]} |
|
|
| def map_test(examples): |
| images = [transform_test(img.convert("RGB")) for img in examples["img"]] |
| return {"pixel_values": images, "labels": examples["label"]} |
|
|
| ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"]) |
| ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"]) |
|
|
| ds_train.set_format(type="torch", columns=["pixel_values", "labels"]) |
| ds_test.set_format(type="torch", columns=["pixel_values", "labels"]) |
|
|
| train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) |
| test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) |
|
|
| return train_loader, test_loader |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, loader, optimizer, criterion, device, pruner=None, epoch=None): |
| model.train() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| for batch in loader: |
| inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) |
| optimizer.zero_grad() |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
| loss.backward() |
| optimizer.step() |
|
|
| if pruner is not None and epoch is not None: |
| pruner.step(epoch) |
|
|
| total_loss += loss.item() * inputs.size(0) |
| _, predicted = outputs.max(1) |
| total += targets.size(0) |
| correct += predicted.eq(targets).sum().item() |
|
|
| return total_loss / total, 100.0 * correct / total |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, loader, criterion, device): |
| model.eval() |
| total_loss = 0.0 |
| correct = 0 |
| total = 0 |
| for batch in loader: |
| inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device) |
| outputs = model(inputs) |
| loss = criterion(outputs, targets) |
| total_loss += loss.item() * inputs.size(0) |
| _, predicted = outputs.max(1) |
| total += targets.size(0) |
| correct += predicted.eq(targets).sum().item() |
| return total_loss / total, 100.0 * correct / total |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="PDP Training on CIFAR-10") |
| parser.add_argument("--epochs", type=int, default=100) |
| parser.add_argument("--batch_size", type=int, default=128) |
| parser.add_argument("--lr", type=float, default=0.1) |
| parser.add_argument("--momentum", type=float, default=0.9) |
| parser.add_argument("--weight_decay", type=float, default=5e-4) |
| parser.add_argument("--target_sparsity", type=float, default=0.85) |
| parser.add_argument("--s", type=int, default=16, help="Warmup epochs before pruning starts") |
| parser.add_argument("--epsilon", type=float, default=0.015, help="Gradual pruning rate per epoch") |
| parser.add_argument("--tau", type=float, default=1e-4, help="PDP temperature") |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--save_dir", type=str, default="./checkpoints") |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| if args.device == "cuda": |
| torch.cuda.manual_seed(args.seed) |
|
|
| os.makedirs(args.save_dir, exist_ok=True) |
|
|
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
|
|
| |
| train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers) |
| print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}") |
|
|
| |
| model = ResNet18(num_classes=10).to(device) |
| print(f"Model params: {sum(p.numel() for p in model.parameters()):,}") |
|
|
| |
| criterion = nn.CrossEntropyLoss() |
| optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, |
| weight_decay=args.weight_decay) |
| scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1) |
|
|
| |
| pruner = PDPPruner( |
| model=model, |
| target_sparsity=args.target_sparsity, |
| s=args.s, |
| epsilon=args.epsilon, |
| tau=args.tau, |
| ) |
| pruner.attach() |
|
|
| |
| history = [] |
| best_acc = 0.0 |
|
|
| for epoch in range(args.epochs): |
| train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, pruner=pruner, epoch=epoch) |
| val_loss, val_acc = evaluate(model, test_loader, criterion, device) |
| scheduler.step() |
|
|
| current_sparsity = pruner.get_sparsity() |
| effective = pruner.current_effective_sparsity |
|
|
| print(f"Epoch {epoch+1:3d}/{args.epochs} | " |
| f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | " |
| f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | " |
| f"Sparsity: {current_sparsity:.4f} (eff: {effective:.4f}) | " |
| f"LR: {optimizer.param_groups[0]['lr']:.4f}") |
|
|
| history.append({ |
| "epoch": epoch + 1, |
| "train_loss": train_loss, |
| "train_acc": train_acc, |
| "val_loss": val_loss, |
| "val_acc": val_acc, |
| "sparsity": current_sparsity, |
| "effective_sparsity": effective, |
| "lr": optimizer.param_groups[0]["lr"], |
| }) |
|
|
| if val_acc > best_acc: |
| best_acc = val_acc |
| ckpt_path = os.path.join(args.save_dir, "best_model.pt") |
| torch.save({ |
| "epoch": epoch + 1, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "pruner_state_dict": pruner.state_dict(), |
| "val_acc": val_acc, |
| }, ckpt_path) |
|
|
| |
| print("\n--- Final Hard Pruning ---") |
| pruner.hard_prune() |
| final_sparsity = pruner.get_sparsity() |
| final_val_loss, final_val_acc = evaluate(model, test_loader, criterion, device) |
| print(f"After hard prune: Sparsity={final_sparsity:.4f}, Val Acc={final_val_acc:.2f}%") |
|
|
| |
| final_path = os.path.join(args.save_dir, "final_model.pt") |
| torch.save({ |
| "model_state_dict": model.state_dict(), |
| "pruner_state_dict": pruner.state_dict(), |
| "final_sparsity": final_sparsity, |
| "final_val_acc": final_val_acc, |
| }, final_path) |
|
|
| |
| with open(os.path.join(args.save_dir, "history.json"), "w") as f: |
| json.dump(history, f, indent=2) |
|
|
| print(f"\nBest validation accuracy: {best_acc:.2f}%") |
| print(f"Final pruned model saved to {final_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|