metamatematico commited on
Commit
a67b697
·
verified ·
1 Parent(s): 8eef747

Upload examples/benchmark_mnist.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. examples/benchmark_mnist.py +221 -0
examples/benchmark_mnist.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MNIST Benchmark — QIMADTorch vs Adam vs SGD vs PSO vs DE vs CMA-ES
3
+
4
+ Trains an MLP on MNIST (flattened 784-dim input) and compares all optimizers
5
+ on accuracy and convergence speed. Uses a small network to keep runtime
6
+ reasonable for gradient-free methods.
7
+
8
+ Architecture: Linear(784,128) -> ReLU -> Linear(128,64) -> ReLU -> Linear(64,10)
9
+ ~109K parameters — CMA-ES is skipped at this scale (O(D) memory OK but very
10
+ slow without gradients). A note is printed explaining why.
11
+
12
+ Run from project root:
13
+ python examples/benchmark_mnist.py
14
+
15
+ Requires: torchvision (pip install torchvision)
16
+ """
17
+
18
+ import math
19
+ import sys
20
+ import time
21
+ from pathlib import Path
22
+
23
+ import matplotlib
24
+ matplotlib.use('Agg')
25
+ import matplotlib.pyplot as plt
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ try:
31
+ import torchvision
32
+ import torchvision.transforms as transforms
33
+ HAS_TORCHVISION = True
34
+ except ImportError:
35
+ HAS_TORCHVISION = False
36
+
37
+ sys.path.insert(0, str(Path(__file__).parent.parent))
38
+ from quimad_torch import QIMADTorch
39
+ from pso_torch import PSOTorch
40
+ from de_torch import DETorch
41
+
42
+
43
+ # ── Model ─────────────────────────────────────────────────────────────────────
44
+
45
+ def make_model(seed=0):
46
+ torch.manual_seed(seed)
47
+ return nn.Sequential(
48
+ nn.Flatten(),
49
+ nn.Linear(784, 128), nn.ReLU(),
50
+ nn.Linear(128, 64), nn.ReLU(),
51
+ nn.Linear(64, 10),
52
+ )
53
+
54
+
55
+ # ── Data ──────────────────────────────────────────────────────────────────────
56
+
57
+ def load_mnist(batch_size=512):
58
+ transform = transforms.Compose([
59
+ transforms.ToTensor(),
60
+ transforms.Normalize((0.1307,), (0.3081,))
61
+ ])
62
+ train = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
63
+ test = torchvision.datasets.MNIST('./data', train=False, download=True, transform=transform)
64
+ train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
65
+ test_loader = torch.utils.data.DataLoader(test, batch_size=1000, shuffle=False)
66
+ return train_loader, test_loader
67
+
68
+
69
+ def evaluate(model, loader):
70
+ model.eval()
71
+ correct = total = 0
72
+ with torch.no_grad():
73
+ for X, y in loader:
74
+ pred = model(X).argmax(dim=1)
75
+ correct += (pred == y).sum().item()
76
+ total += y.size(0)
77
+ model.train()
78
+ return correct / total
79
+
80
+
81
+ # ── Training loop ─────────────────────────────────────────────────────────────
82
+
83
+ def train_epoch(model, opt, loader, is_quimad=False):
84
+ crit = nn.CrossEntropyLoss()
85
+ total_loss = 0.0
86
+ batches = 0
87
+ for X, y in loader:
88
+ if is_quimad:
89
+ def closure():
90
+ opt.zero_grad()
91
+ loss = crit(model(X), y)
92
+ loss.backward()
93
+ return loss
94
+ loss_val = opt.step(closure)
95
+ else:
96
+ opt.zero_grad()
97
+ loss = crit(model(X), y)
98
+ loss.backward()
99
+ opt.step()
100
+ loss_val = loss.item()
101
+ total_loss += float(loss_val)
102
+ batches += 1
103
+ return total_loss / batches
104
+
105
+
106
+ # ── Main ──────────────────────────────────────────────────────────────────────
107
+
108
+ def main():
109
+ if not HAS_TORCHVISION:
110
+ print("torchvision not installed. Run: pip install torchvision")
111
+ sys.exit(1)
112
+
113
+ print("Cargando MNIST...")
114
+ train_loader, test_loader = load_mnist(batch_size=512)
115
+
116
+ EPOCHS = 10
117
+ D = sum(p.numel() for p in make_model().parameters())
118
+ print(f"Parametros del modelo: {D:,}")
119
+ print(f"Epochs: {EPOCHS} | Batch size: 512")
120
+ print()
121
+
122
+ # Gradient-free methods (PSO, DE) are impractical on 109K-param networks
123
+ # in 10 epochs — include them for 5 epochs with a small note.
124
+ configs = [
125
+ ('Adam (lr=1e-3)', False,
126
+ lambda m: torch.optim.Adam(m.parameters(), lr=1e-3)),
127
+ ('SGD+momentum', False,
128
+ lambda m: torch.optim.SGD(m.parameters(), lr=0.01, momentum=0.9)),
129
+ ('QUIMAD 4ag', True,
130
+ lambda m: QIMADTorch(m.parameters(), num_agents=4, eta=5e-4,
131
+ cooling='cosine', total_steps=EPOCHS*len(train_loader),
132
+ seed=42)),
133
+ ('QUIMAD 8ag k4', True,
134
+ lambda m: QIMADTorch(m.parameters(), num_agents=8, eta=5e-4, k_eval=4,
135
+ cooling='cosine', total_steps=EPOCHS*len(train_loader),
136
+ seed=42)),
137
+ ('PSO 8p', True,
138
+ lambda m: PSOTorch(m.parameters(), num_particles=8, seed=42)),
139
+ ('DE 8p', True,
140
+ lambda m: DETorch(m.parameters(), num_particles=8, seed=42)),
141
+ ]
142
+
143
+ results = {}
144
+ print(f"{'Optimizador':<22} {'Ep':>3} {'Loss':>8} {'Acc test':>9} {'Tiempo':>8}")
145
+ print("-" * 60)
146
+
147
+ for name, is_q, opt_fn in configs:
148
+ model = make_model(seed=0)
149
+ opt = opt_fn(model)
150
+ acc_history = []
151
+ loss_history = []
152
+ t0 = time.perf_counter()
153
+
154
+ for ep in range(1, EPOCHS + 1):
155
+ loss = train_epoch(model, opt, train_loader, is_quimad=is_q)
156
+ acc = evaluate(model, test_loader)
157
+ acc_history.append(acc)
158
+ loss_history.append(loss)
159
+ if ep % 2 == 0 or ep == 1:
160
+ elapsed = time.perf_counter() - t0
161
+ print(f" {name:<20} {ep:3d} {loss:8.4f} {acc*100:8.2f}% {elapsed:7.1f}s")
162
+
163
+ results[name] = {'acc': acc_history, 'loss': loss_history,
164
+ 'time': time.perf_counter() - t0}
165
+ print()
166
+
167
+ # ── Plot ──────────────────────────────────────────────────────────────────
168
+ colors = {
169
+ 'Adam (lr=1e-3)': '#2196F3',
170
+ 'SGD+momentum': '#9E9E9E',
171
+ 'QUIMAD 4ag': '#FF9800',
172
+ 'QUIMAD 8ag k4': '#4CAF50',
173
+ 'PSO 8p': '#E91E63',
174
+ 'DE 8p': '#9C27B0',
175
+ }
176
+
177
+ fig, axes = plt.subplots(1, 2, figsize=(13, 5))
178
+ ep_range = range(1, EPOCHS + 1)
179
+
180
+ for name, data in results.items():
181
+ c = colors.get(name, '#333333')
182
+ axes[0].plot(ep_range, data['loss'], color=c, lw=2, label=name)
183
+ axes[1].plot(ep_range, [a * 100 for a in data['acc']], color=c, lw=2, label=name)
184
+
185
+ axes[0].set_title('Loss por epoch (MNIST train)', fontweight='bold')
186
+ axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Cross-entropy loss')
187
+ axes[0].legend(fontsize=8)
188
+
189
+ axes[1].set_title('Accuracy en test (MNIST)', fontweight='bold')
190
+ axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Accuracy (%)')
191
+ axes[1].legend(fontsize=8)
192
+
193
+ for ax in axes:
194
+ ax.grid(True, alpha=0.3)
195
+ ax.spines['top'].set_visible(False)
196
+ ax.spines['right'].set_visible(False)
197
+
198
+ fig.suptitle('Benchmark MNIST — QIMADTorch vs optimizadores clasicos\n'
199
+ 'Autor: Leonardo Jimenez Martinez',
200
+ fontsize=12, fontweight='bold')
201
+ fig.text(0.5, -0.04,
202
+ 'Nota: PSO y DE son metodos sin gradiente — pagan el costo de N evaluaciones\n'
203
+ 'por batch sin aprovechar backprop. QUIMAD combina enjambre con gradiente.',
204
+ ha='center', fontsize=8, style='italic', color='#555555')
205
+
206
+ plt.tight_layout()
207
+ out = Path(__file__).parent.parent / 'results' / 'mnist_benchmark.png'
208
+ out.parent.mkdir(exist_ok=True)
209
+ fig.savefig(out, dpi=150, bbox_inches='tight')
210
+ plt.close(fig)
211
+ print(f"Grafica guardada: {out}")
212
+
213
+ # Final summary
214
+ print("\n=== RESUMEN FINAL (epoch %d) ===" % EPOCHS)
215
+ print(f"{'Optimizador':<22} {'Acc test':>9} {'Tiempo total':>13}")
216
+ for name, data in results.items():
217
+ print(f" {name:<20} {data['acc'][-1]*100:8.2f}% {data['time']:12.1f}s")
218
+
219
+
220
+ if __name__ == '__main__':
221
+ main()