ESPR3SS0 commited on
Commit
e4e6e3c
·
verified ·
1 Parent(s): c2ff581

Add train_pdp.py

Browse files
Files changed (1) hide show
  1. train_pdp.py +297 -0
train_pdp.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PDP Training Script for CIFAR-10 with ResNet18
3
+ Based on: PDP: Parameter-free Differentiable Pruning is All You Need (NeurIPS 2023)
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ from torch.utils.data import DataLoader
11
+ from torchvision import transforms
12
+ from datasets import load_dataset
13
+ import numpy as np
14
+ import argparse
15
+ import json
16
+ import os
17
+ from tqdm import tqdm
18
+
19
+ from pdp import PDPPruner
20
+
21
+
22
+ # ---------------------------------------------------------------------------
23
+ # CIFAR-10 adapted ResNet18
24
+ # ---------------------------------------------------------------------------
25
+
26
+ def conv3x3(in_planes, out_planes, stride=1):
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28
+ padding=1, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+
34
+ def __init__(self, in_planes, planes, stride=1):
35
+ super().__init__()
36
+ self.conv1 = conv3x3(in_planes, planes, stride)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.conv2 = conv3x3(planes, planes)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.shortcut = nn.Sequential()
41
+ if stride != 1 or in_planes != self.expansion * planes:
42
+ self.shortcut = nn.Sequential(
43
+ nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
44
+ stride=stride, bias=False),
45
+ nn.BatchNorm2d(self.expansion * planes)
46
+ )
47
+
48
+ def forward(self, x):
49
+ out = F.relu(self.bn1(self.conv1(x)))
50
+ out = self.bn2(self.conv2(out))
51
+ out += self.shortcut(x)
52
+ out = F.relu(out)
53
+ return out
54
+
55
+
56
+ class ResNet(nn.Module):
57
+ def __init__(self, block, num_blocks, num_classes=10):
58
+ super().__init__()
59
+ self.in_planes = 64
60
+ # First conv adapted for 32x32 CIFAR-10
61
+ self.conv1 = conv3x3(3, 64)
62
+ self.bn1 = nn.BatchNorm2d(64)
63
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
64
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
65
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
66
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
67
+ self.linear = nn.Linear(512 * block.expansion, num_classes)
68
+
69
+ def _make_layer(self, block, planes, num_blocks, stride):
70
+ strides = [stride] + [1] * (num_blocks - 1)
71
+ layers = []
72
+ for s in strides:
73
+ layers.append(block(self.in_planes, planes, s))
74
+ self.in_planes = planes * block.expansion
75
+ return nn.Sequential(*layers)
76
+
77
+ def forward(self, x):
78
+ out = F.relu(self.bn1(self.conv1(x)))
79
+ out = self.layer1(out)
80
+ out = self.layer2(out)
81
+ out = self.layer3(out)
82
+ out = self.layer4(out)
83
+ out = F.avg_pool2d(out, 4)
84
+ out = out.view(out.size(0), -1)
85
+ out = self.linear(out)
86
+ return out
87
+
88
+
89
+ def ResNet18(num_classes=10):
90
+ return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
91
+
92
+
93
+ # ---------------------------------------------------------------------------
94
+ # Data loading
95
+ # ---------------------------------------------------------------------------
96
+
97
+ def get_cifar10_loaders(batch_size=128, num_workers=4):
98
+ transform_train = transforms.Compose([
99
+ transforms.RandomCrop(32, padding=4),
100
+ transforms.RandomHorizontalFlip(),
101
+ transforms.ToTensor(),
102
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
103
+ ])
104
+
105
+ transform_test = transforms.Compose([
106
+ transforms.ToTensor(),
107
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
108
+ ])
109
+
110
+ ds_train = load_dataset("uoft-cs/cifar10", split="train")
111
+ ds_test = load_dataset("uoft-cs/cifar10", split="test")
112
+
113
+ def map_train(examples):
114
+ images = [transform_train(img.convert("RGB")) for img in examples["img"]]
115
+ return {"pixel_values": images, "labels": examples["label"]}
116
+
117
+ def map_test(examples):
118
+ images = [transform_test(img.convert("RGB")) for img in examples["img"]]
119
+ return {"pixel_values": images, "labels": examples["label"]}
120
+
121
+ ds_train = ds_train.map(map_train, batched=True, remove_columns=["img", "label"])
122
+ ds_test = ds_test.map(map_test, batched=True, remove_columns=["img", "label"])
123
+
124
+ ds_train.set_format(type="torch", columns=["pixel_values", "labels"])
125
+ ds_test.set_format(type="torch", columns=["pixel_values", "labels"])
126
+
127
+ train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
128
+ test_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
129
+
130
+ return train_loader, test_loader
131
+
132
+
133
+ # ---------------------------------------------------------------------------
134
+ # Training & evaluation helpers
135
+ # ---------------------------------------------------------------------------
136
+
137
+ def train_epoch(model, loader, optimizer, criterion, device, pruner=None, epoch=None):
138
+ model.train()
139
+ total_loss = 0.0
140
+ correct = 0
141
+ total = 0
142
+ for batch in loader:
143
+ inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device)
144
+ optimizer.zero_grad()
145
+ outputs = model(inputs)
146
+ loss = criterion(outputs, targets)
147
+ loss.backward()
148
+ optimizer.step()
149
+
150
+ if pruner is not None and epoch is not None:
151
+ pruner.step(epoch)
152
+
153
+ total_loss += loss.item() * inputs.size(0)
154
+ _, predicted = outputs.max(1)
155
+ total += targets.size(0)
156
+ correct += predicted.eq(targets).sum().item()
157
+
158
+ return total_loss / total, 100.0 * correct / total
159
+
160
+
161
+ @torch.no_grad()
162
+ def evaluate(model, loader, criterion, device):
163
+ model.eval()
164
+ total_loss = 0.0
165
+ correct = 0
166
+ total = 0
167
+ for batch in loader:
168
+ inputs, targets = batch["pixel_values"].to(device), batch["labels"].to(device)
169
+ outputs = model(inputs)
170
+ loss = criterion(outputs, targets)
171
+ total_loss += loss.item() * inputs.size(0)
172
+ _, predicted = outputs.max(1)
173
+ total += targets.size(0)
174
+ correct += predicted.eq(targets).sum().item()
175
+ return total_loss / total, 100.0 * correct / total
176
+
177
+
178
+ # ---------------------------------------------------------------------------
179
+ # Main
180
+ # ---------------------------------------------------------------------------
181
+
182
+ def main():
183
+ parser = argparse.ArgumentParser(description="PDP Training on CIFAR-10")
184
+ parser.add_argument("--epochs", type=int, default=100)
185
+ parser.add_argument("--batch_size", type=int, default=128)
186
+ parser.add_argument("--lr", type=float, default=0.1)
187
+ parser.add_argument("--momentum", type=float, default=0.9)
188
+ parser.add_argument("--weight_decay", type=float, default=5e-4)
189
+ parser.add_argument("--target_sparsity", type=float, default=0.85)
190
+ parser.add_argument("--s", type=int, default=16, help="Warmup epochs before pruning starts")
191
+ parser.add_argument("--epsilon", type=float, default=0.015, help="Gradual pruning rate per epoch")
192
+ parser.add_argument("--tau", type=float, default=1e-4, help="PDP temperature")
193
+ parser.add_argument("--num_workers", type=int, default=4)
194
+ parser.add_argument("--seed", type=int, default=42)
195
+ parser.add_argument("--save_dir", type=str, default="./checkpoints")
196
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
197
+ args = parser.parse_args()
198
+
199
+ torch.manual_seed(args.seed)
200
+ if args.device == "cuda":
201
+ torch.cuda.manual_seed(args.seed)
202
+
203
+ os.makedirs(args.save_dir, exist_ok=True)
204
+
205
+ device = torch.device(args.device)
206
+ print(f"Using device: {device}")
207
+
208
+ # Data
209
+ train_loader, test_loader = get_cifar10_loaders(args.batch_size, args.num_workers)
210
+ print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
211
+
212
+ # Model
213
+ model = ResNet18(num_classes=10).to(device)
214
+ print(f"Model params: {sum(p.numel() for p in model.parameters()):,}")
215
+
216
+ # Optimizer & scheduler
217
+ criterion = nn.CrossEntropyLoss()
218
+ optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
219
+ weight_decay=args.weight_decay)
220
+ scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 90], gamma=0.1)
221
+
222
+ # PDP Pruner
223
+ pruner = PDPPruner(
224
+ model=model,
225
+ target_sparsity=args.target_sparsity,
226
+ s=args.s,
227
+ epsilon=args.epsilon,
228
+ tau=args.tau,
229
+ )
230
+ pruner.attach()
231
+
232
+ # Training loop
233
+ history = []
234
+ best_acc = 0.0
235
+
236
+ for epoch in range(args.epochs):
237
+ train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device, pruner=pruner, epoch=epoch)
238
+ val_loss, val_acc = evaluate(model, test_loader, criterion, device)
239
+ scheduler.step()
240
+
241
+ current_sparsity = pruner.get_sparsity()
242
+ effective = pruner.current_effective_sparsity
243
+
244
+ print(f"Epoch {epoch+1:3d}/{args.epochs} | "
245
+ f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
246
+ f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}% | "
247
+ f"Sparsity: {current_sparsity:.4f} (eff: {effective:.4f}) | "
248
+ f"LR: {optimizer.param_groups[0]['lr']:.4f}")
249
+
250
+ history.append({
251
+ "epoch": epoch + 1,
252
+ "train_loss": train_loss,
253
+ "train_acc": train_acc,
254
+ "val_loss": val_loss,
255
+ "val_acc": val_acc,
256
+ "sparsity": current_sparsity,
257
+ "effective_sparsity": effective,
258
+ "lr": optimizer.param_groups[0]["lr"],
259
+ })
260
+
261
+ if val_acc > best_acc:
262
+ best_acc = val_acc
263
+ ckpt_path = os.path.join(args.save_dir, "best_model.pt")
264
+ torch.save({
265
+ "epoch": epoch + 1,
266
+ "model_state_dict": model.state_dict(),
267
+ "optimizer_state_dict": optimizer.state_dict(),
268
+ "pruner_state_dict": pruner.state_dict(),
269
+ "val_acc": val_acc,
270
+ }, ckpt_path)
271
+
272
+ # Final hard prune and evaluation
273
+ print("\n--- Final Hard Pruning ---")
274
+ pruner.hard_prune()
275
+ final_sparsity = pruner.get_sparsity()
276
+ final_val_loss, final_val_acc = evaluate(model, test_loader, criterion, device)
277
+ print(f"After hard prune: Sparsity={final_sparsity:.4f}, Val Acc={final_val_acc:.2f}%")
278
+
279
+ # Save final model
280
+ final_path = os.path.join(args.save_dir, "final_model.pt")
281
+ torch.save({
282
+ "model_state_dict": model.state_dict(),
283
+ "pruner_state_dict": pruner.state_dict(),
284
+ "final_sparsity": final_sparsity,
285
+ "final_val_acc": final_val_acc,
286
+ }, final_path)
287
+
288
+ # Save history
289
+ with open(os.path.join(args.save_dir, "history.json"), "w") as f:
290
+ json.dump(history, f, indent=2)
291
+
292
+ print(f"\nBest validation accuracy: {best_acc:.2f}%")
293
+ print(f"Final pruned model saved to {final_path}")
294
+
295
+
296
+ if __name__ == "__main__":
297
+ main()