RFTSystems commited on
Commit
8492c41
·
verified ·
1 Parent(s): fe05156

Update train_dclr_model.py

Browse files
Files changed (1) hide show
  1. train_dclr_model.py +202 -87
train_dclr_model.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
@@ -5,6 +6,7 @@ import torchvision
5
  import torchvision.transforms as transforms
6
  from torch.utils.data import DataLoader
7
  import matplotlib.pyplot as plt
 
8
 
9
  # Import the DCLR optimizer from the local file
10
  from dclr_optimizer import DCLR
@@ -26,98 +28,211 @@ class SimpleCNN(nn.Module):
26
  x = F.relu(self.fc1(x))
27
  return self.fc2(x)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # === CIFAR-10 Data Loading ===
30
- transform = transforms.Compose([
 
 
31
  transforms.ToTensor(),
32
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
33
  ])
34
 
35
- train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
36
- train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
37
-
38
- test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
39
- test_loader = DataLoader(test_set, batch_size=128, shuffle=False)
40
-
41
- # === Training Configuration ===
42
- model = SimpleCNN()
43
-
44
- best_lr = 0.1
45
- best_lambda = 0.1
46
- optimizer = DCLR(model.parameters(), lr=best_lr, lambda_=best_lambda, verbose=False)
47
-
48
- criterion = nn.CrossEntropyLoss()
49
- extended_epochs = 20
50
-
51
- print(f"Starting training for SimpleCNN with DCLR (lr={best_lr}, lambda_={best_lambda}) for {extended_epochs} epochs...")
52
-
53
- losses, accs = [], []
54
 
55
- # === Training Loop ===
56
- for epoch in range(extended_epochs):
57
- model.train()
58
- running_loss = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  correct = 0
60
  total = 0
61
- for batch_idx, (inputs, labels) in enumerate(train_loader):
62
- optimizer.zero_grad()
63
- outputs = model(inputs)
64
- loss = criterion(outputs, labels)
65
- loss.backward()
66
-
67
- # DCLR requires output_activations for its step method
68
- optimizer.step(output_activations=outputs)
69
-
70
- running_loss += loss.item()
71
- _, predicted = outputs.max(1)
72
- total += labels.size(0)
73
- correct += predicted.eq(labels).sum().item()
74
-
75
- epoch_loss = running_loss / len(train_loader)
76
- epoch_acc = 100.0 * correct / total
77
- losses.append(epoch_loss)
78
- accs.append(epoch_acc)
79
- print(f"Epoch {epoch+1}/{extended_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
80
-
81
- print("Training complete.")
82
-
83
- # === Evaluate on Test Set ===
84
- model.eval()
85
- correct = 0
86
- total = 0
87
- with torch.no_grad():
88
- for inputs, labels in test_loader:
89
- outputs = model(inputs)
90
- _, predicted = outputs.max(1)
91
- total += labels.size(0)
92
- correct += predicted.eq(labels).sum().item()
93
-
94
- test_acc = 100.0 * correct / total
95
- print(f"Final Test Accuracy: {test_acc:.2f}%")
96
-
97
- # === Save the Trained Model ===
98
- torch.save(model.state_dict(), 'simple_cnn_dclr_tuned.pth')
99
- print("Model saved to simple_cnn_dclr_tuned.pth")
100
-
101
- # === Save Training Performance Plot ===
102
- plt.figure()
103
- plt.plot(range(1, extended_epochs+1), losses, label='Loss')
104
- plt.plot(range(1, extended_epochs+1), accs, label='Accuracy')
105
- plt.xlabel('Epoch')
106
- plt.ylabel('Value')
107
- plt.legend()
108
- plt.title('Training Performance on CIFAR-10')
109
- plt.savefig('training_performance.png')
110
- print("Training performance plot saved to training_performance.png")
111
-
112
- # === Save Final Test Accuracy Plot ===
113
- plt.figure()
114
- plt.bar(['CIFAR-10'], [test_acc])
115
- plt.ylabel('Accuracy (%)')
116
- plt.title('Final Test Accuracy')
117
- plt.savefig('final_test_accuracy.png')
118
- print("Final test accuracy plot saved to final_test_accuracy.png")
119
-
120
- # === Save Final Test Accuracy Number ===
121
- with open("final_test_accuracy.txt", "w") as f:
122
- f.write(f"{test_acc:.2f}")
123
- print("Final test accuracy saved to final_test_accuracy.txt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
6
  import torchvision.transforms as transforms
7
  from torch.utils.data import DataLoader
8
  import matplotlib.pyplot as plt
9
+ from datetime import datetime
10
 
11
  # Import the DCLR optimizer from the local file
12
  from dclr_optimizer import DCLR
 
28
  x = F.relu(self.fc1(x))
29
  return self.fc2(x)
30
 
31
+ # === Self-contained Lion optimizer (no external dependency) ===
32
+ class Lion(torch.optim.Optimizer):
33
+ """
34
+ Minimal Lion optimizer implementation (Chen et al., 2023).
35
+ Uses sign of momentum with weight decay. Works for standard use-cases.
36
+ """
37
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), weight_decay=0.0):
38
+ defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
39
+ super().__init__(params, defaults)
40
+
41
+ @torch.no_grad()
42
+ def step(self):
43
+ for group in self.param_groups:
44
+ lr = group['lr']
45
+ beta1, beta2 = group['betas']
46
+ wd = group['weight_decay']
47
+
48
+ for p in group['params']:
49
+ if p.grad is None:
50
+ continue
51
+ grad = p.grad
52
+
53
+ # weight decay
54
+ if wd != 0:
55
+ grad = grad.add(p, alpha=wd)
56
+
57
+ state = self.state[p]
58
+ if len(state) == 0:
59
+ state['exp_avg'] = torch.zeros_like(p)
60
+
61
+ exp_avg = state['exp_avg']
62
+ # Update momentum
63
+ exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)
64
+
65
+ # Parameter update: sign of momentum + sign of gradient blend
66
+ update = exp_avg.mul(beta1).add(grad, alpha=1 - beta1)
67
+ p.add_(torch.sign(update), alpha=-lr)
68
+
69
  # === CIFAR-10 Data Loading ===
70
+ transform_train = transforms.Compose([
71
+ transforms.RandomCrop(32, padding=4),
72
+ transforms.RandomHorizontalFlip(),
73
  transforms.ToTensor(),
74
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
75
  ])
76
 
77
+ transform_test = transforms.Compose([
78
+ transforms.ToTensor(),
79
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
80
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
83
+ train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)
84
+
85
+ test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
86
+ test_loader = DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)
87
+
88
+ # === Utility: Train and evaluate with a given optimizer ===
89
+ def train_and_evaluate(optimizer_name, optimizer_ctor, optimizer_kwargs, epochs=20, save_prefix=""):
90
+ model = SimpleCNN()
91
+ criterion = nn.CrossEntropyLoss()
92
+ optimizer = optimizer_ctor(model.parameters(), **optimizer_kwargs)
93
+
94
+ losses = []
95
+ accs = []
96
+
97
+ print(f"Starting training [{optimizer_name}] for {epochs} epochs...")
98
+ for epoch in range(epochs):
99
+ model.train()
100
+ running_loss = 0.0
101
+ correct = 0
102
+ total = 0
103
+ for inputs, labels in train_loader:
104
+ optimizer.zero_grad()
105
+ outputs = model(inputs)
106
+ loss = criterion(outputs, labels)
107
+ loss.backward()
108
+
109
+ # DCLR requires output_activations argument
110
+ if optimizer_name.lower() == "dclr":
111
+ if hasattr(optimizer, "step"):
112
+ optimizer.step(output_activations=outputs)
113
+ else:
114
+ raise RuntimeError("DCLR optimizer missing step(output_activations=...)")
115
+ else:
116
+ optimizer.step()
117
+
118
+ running_loss += loss.item()
119
+ _, predicted = outputs.max(1)
120
+ total += labels.size(0)
121
+ correct += predicted.eq(labels).sum().item()
122
+
123
+ epoch_loss = running_loss / len(train_loader)
124
+ epoch_acc = 100.0 * correct / total
125
+ losses.append(epoch_loss)
126
+ accs.append(epoch_acc)
127
+ print(f"[{optimizer_name}] Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.2f}%")
128
+
129
+ print(f"Training complete for [{optimizer_name}]. Evaluating on test set...")
130
+ model.eval()
131
  correct = 0
132
  total = 0
133
+ with torch.no_grad():
134
+ for inputs, labels in test_loader:
135
+ outputs = model(inputs)
136
+ _, predicted = outputs.max(1)
137
+ total += labels.size(0)
138
+ correct += predicted.eq(labels).sum().item()
139
+
140
+ test_acc = 100.0 * correct / total
141
+ print(f"[{optimizer_name}] Final Test Accuracy: {test_acc:.2f}%")
142
+
143
+ # Save artifacts with optimizer-specific names
144
+ if save_prefix == "":
145
+ save_prefix = optimizer_name.lower()
146
+
147
+ # Model weights
148
+ weights_path = f"{save_prefix}_simple_cnn.pth"
149
+ torch.save(model.state_dict(), weights_path)
150
+ print(f"[{optimizer_name}] Model saved to {weights_path}")
151
+
152
+ # Training performance plot
153
+ plt.figure()
154
+ plt.plot(range(1, epochs+1), losses, label='Loss')
155
+ plt.plot(range(1, epochs+1), accs, label='Accuracy')
156
+ plt.xlabel('Epoch')
157
+ plt.ylabel('Value')
158
+ plt.legend()
159
+ plt.title(f'Training Performance on CIFAR-10 ({optimizer_name})')
160
+ perf_path = f"{save_prefix}_training_performance.png"
161
+ plt.savefig(perf_path)
162
+ print(f"[{optimizer_name}] Training performance plot saved to {perf_path}")
163
+
164
+ # Final test accuracy plot
165
+ plt.figure()
166
+ plt.bar([optimizer_name], [test_acc])
167
+ plt.ylabel('Accuracy (%)')
168
+ plt.title(f'Final Test Accuracy ({optimizer_name})')
169
+ acc_plot_path = f"{save_prefix}_final_test_accuracy.png"
170
+ plt.savefig(acc_plot_path)
171
+ print(f"[{optimizer_name}] Final test accuracy plot saved to {acc_plot_path}")
172
+
173
+ # Final test accuracy number
174
+ acc_txt_path = f"{save_prefix}_final_test_accuracy.txt"
175
+ with open(acc_txt_path, "w") as f:
176
+ f.write(f"{test_acc:.2f}")
177
+ print(f"[{optimizer_name}] Final test accuracy saved to {acc_txt_path}")
178
+
179
+ return {
180
+ "optimizer": optimizer_name,
181
+ "test_acc": test_acc,
182
+ "weights_path": weights_path,
183
+ "perf_plot_path": perf_path,
184
+ "acc_plot_path": acc_plot_path,
185
+ "acc_txt_path": acc_txt_path,
186
+ "losses": losses,
187
+ "accs": accs,
188
+ }
189
+
190
+ # === Run benchmarks for DCLR vs Adam vs Lion ===
191
+ def main():
192
+ os.makedirs("artifacts", exist_ok=True)
193
+ os.chdir("artifacts") # keep outputs organized
194
+
195
+ epochs = 20
196
+
197
+ # DCLR (using your tuned hyperparams)
198
+ dclr_results = train_and_evaluate(
199
+ optimizer_name="DCLR",
200
+ optimizer_ctor=lambda params, lr, lambda_, verbose=False: DCLR(params, lr=lr, lambda_=lambda_, verbose=verbose),
201
+ optimizer_kwargs={"lr": 0.1, "lambda_": 0.1, "verbose": False},
202
+ epochs=epochs,
203
+ save_prefix="dclr"
204
+ )
205
+
206
+ # Adam
207
+ adam_results = train_and_evaluate(
208
+ optimizer_name="Adam",
209
+ optimizer_ctor=lambda params, lr: torch.optim.Adam(params, lr=lr),
210
+ optimizer_kwargs={"lr": 0.001},
211
+ epochs=epochs,
212
+ save_prefix="adam"
213
+ )
214
+
215
+ # Lion
216
+ lion_results = train_and_evaluate(
217
+ optimizer_name="Lion",
218
+ optimizer_ctor=lambda params, lr, betas, weight_decay: Lion(params, lr=lr, betas=betas, weight_decay=weight_decay),
219
+ optimizer_kwargs={"lr": 0.001, "betas": (0.9, 0.99), "weight_decay": 0.0},
220
+ epochs=epochs,
221
+ save_prefix="lion"
222
+ )
223
+
224
+ # Combined benchmark ledger
225
+ ledger_path = "benchmark_results.txt"
226
+ with open(ledger_path, "w") as f:
227
+ f.write(f"Run timestamp: {datetime.utcnow().isoformat()}Z\n")
228
+ f.write(f"DCLR: {dclr_results['test_acc']:.2f}%\n")
229
+ f.write(f"Adam: {adam_results['test_acc']:.2f}%\n")
230
+ f.write(f"Lion: {lion_results['test_acc']:.2f}%\n")
231
+ print(f"Benchmark results saved to {ledger_path}")
232
+
233
+ # Symlink or copy DCLR artifacts to legacy names for existing app (optional)
234
+ # If your current app expects specific filenames at repo root, you can create copies:
235
+ # For a clean setup, prefer reading from artifacts/ in app.py.
236
+
237
+ if __name__ == "__main__":
238
+ main()