mnoorchenar commited on
Commit
f684a5c
·
verified ·
1 Parent(s): f371dad

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +928 -0
app.py ADDED
@@ -0,0 +1,928 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template_string, jsonify, request
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader
7
+ from torchvision import datasets, transforms
8
+ import base64
9
+ from io import BytesIO
10
+ import matplotlib
11
+ matplotlib.use('Agg')
12
+ import matplotlib.pyplot as plt
13
+ import threading
14
+ import time
15
+
16
+ app = Flask(__name__)
17
+
18
+ # Global variables for training state
19
+ training_state = {
20
+ 'is_training': False,
21
+ 'progress': 0,
22
+ 'current_epoch': 0,
23
+ 'total_epochs': 0,
24
+ 'losses': [],
25
+ 'trained': False,
26
+ 'current_loss': 0
27
+ }
28
+
29
+ # VAE Architecture
30
+ class VAE(nn.Module):
31
+ def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
32
+ super(VAE, self).__init__()
33
+
34
+ # Encoder
35
+ self.fc1 = nn.Linear(input_dim, hidden_dim)
36
+ self.fc_mu = nn.Linear(hidden_dim, latent_dim)
37
+ self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
38
+
39
+ # Decoder
40
+ self.fc3 = nn.Linear(latent_dim, hidden_dim)
41
+ self.fc4 = nn.Linear(hidden_dim, input_dim)
42
+
43
+ self.latent_dim = latent_dim
44
+
45
+ def encode(self, x):
46
+ h = F.relu(self.fc1(x))
47
+ mu = self.fc_mu(h)
48
+ logvar = self.fc_logvar(h)
49
+ return mu, logvar
50
+
51
+ def reparameterize(self, mu, logvar):
52
+ std = torch.exp(0.5 * logvar)
53
+ eps = torch.randn_like(std)
54
+ z = mu + eps * std
55
+ return z
56
+
57
+ def decode(self, z):
58
+ h = F.relu(self.fc3(z))
59
+ return torch.sigmoid(self.fc4(h))
60
+
61
+ def forward(self, x):
62
+ mu, logvar = self.encode(x)
63
+ z = self.reparameterize(mu, logvar)
64
+ return self.decode(z), mu, logvar
65
+
66
+ # Loss function
67
+ def vae_loss(recon_x, x, mu, logvar):
68
+ BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
69
+ KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
70
+ return BCE + KLD, BCE, KLD
71
+
72
+ # Load MNIST data
73
+ def load_mnist_data():
74
+ transform = transforms.Compose([
75
+ transforms.ToTensor(),
76
+ ])
77
+
78
+ train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
79
+
80
+ # Get subset for faster training and visualization
81
+ subset_size = 10000
82
+ indices = torch.randperm(len(train_dataset))[:subset_size]
83
+
84
+ data = []
85
+ labels = []
86
+
87
+ for idx in indices:
88
+ img, label = train_dataset[idx]
89
+ data.append(img.view(-1).numpy())
90
+ labels.append(label)
91
+
92
+ return np.array(data), np.array(labels)
93
+
94
+ # Initialize model and data
95
+ print("Loading MNIST dataset...")
96
+ vae = None
97
+ data, labels = load_mnist_data()
98
+ data_tensor = torch.FloatTensor(data)
99
+ print(f"Loaded {len(data)} MNIST samples")
100
+
101
+ # Train the VAE in a separate thread
102
+ def train_vae_thread(epochs, batch_size, learning_rate, hidden_dim, latent_dim):
103
+ global vae, training_state
104
+
105
+ training_state['is_training'] = True
106
+ training_state['progress'] = 0
107
+ training_state['current_epoch'] = 0
108
+ training_state['total_epochs'] = epochs
109
+ training_state['losses'] = []
110
+
111
+ # Initialize new model with specified parameters
112
+ vae = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim)
113
+ optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
114
+ dataset = torch.utils.data.TensorDataset(data_tensor)
115
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
116
+
117
+ for epoch in range(epochs):
118
+ vae.train()
119
+ total_loss = 0
120
+ batch_count = 0
121
+
122
+ for batch in dataloader:
123
+ x = batch[0]
124
+ optimizer.zero_grad()
125
+ recon_x, mu, logvar = vae(x)
126
+ loss, _, _ = vae_loss(recon_x, x, mu, logvar)
127
+ loss.backward()
128
+ optimizer.step()
129
+ total_loss += loss.item()
130
+ batch_count += 1
131
+
132
+ avg_loss = total_loss / len(dataloader.dataset)
133
+ training_state['losses'].append(avg_loss)
134
+ training_state['current_epoch'] = epoch + 1
135
+ training_state['current_loss'] = avg_loss
136
+ training_state['progress'] = int(((epoch + 1) / epochs) * 100)
137
+
138
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
139
+
140
+ training_state['is_training'] = False
141
+ training_state['trained'] = True
142
+ print("Training complete!")
143
+
144
+ def fig_to_base64(fig):
145
+ buf = BytesIO()
146
+ fig.savefig(buf, format='png', bbox_inches='tight', dpi=100)
147
+ buf.seek(0)
148
+ img_str = base64.b64encode(buf.read()).decode()
149
+ plt.close(fig)
150
+ return img_str
151
+
152
+ HTML_TEMPLATE = '''
153
+ <!DOCTYPE html>
154
+ <html>
155
+ <head>
156
+ <title>VAE Interactive Playground</title>
157
+ <style>
158
+ * { margin: 0; padding: 0; box-sizing: border-box; }
159
+ body {
160
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
161
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
162
+ min-height: 100vh;
163
+ padding: 20px;
164
+ }
165
+ .container {
166
+ max-width: 1400px;
167
+ margin: 0 auto;
168
+ background: white;
169
+ border-radius: 20px;
170
+ padding: 30px;
171
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
172
+ }
173
+ h1 {
174
+ text-align: center;
175
+ color: #667eea;
176
+ margin-bottom: 10px;
177
+ font-size: 2.5em;
178
+ }
179
+ .subtitle {
180
+ text-align: center;
181
+ color: #666;
182
+ margin-bottom: 30px;
183
+ font-size: 1.1em;
184
+ }
185
+ .tab-container {
186
+ display: flex;
187
+ gap: 10px;
188
+ margin-bottom: 20px;
189
+ border-bottom: 2px solid #eee;
190
+ flex-wrap: wrap;
191
+ }
192
+ .tab {
193
+ padding: 12px 24px;
194
+ background: none;
195
+ border: none;
196
+ cursor: pointer;
197
+ font-size: 16px;
198
+ color: #666;
199
+ border-bottom: 3px solid transparent;
200
+ transition: all 0.3s;
201
+ }
202
+ .tab:hover {
203
+ color: #667eea;
204
+ }
205
+ .tab.active {
206
+ color: #667eea;
207
+ border-bottom-color: #667eea;
208
+ font-weight: 600;
209
+ }
210
+ .tab-content {
211
+ display: none;
212
+ }
213
+ .tab-content.active {
214
+ display: block;
215
+ }
216
+ .grid {
217
+ display: grid;
218
+ grid-template-columns: repeat(auto-fit, minmax(400px, 1fr));
219
+ gap: 20px;
220
+ margin-top: 20px;
221
+ }
222
+ .card {
223
+ background: #f8f9fa;
224
+ border-radius: 12px;
225
+ padding: 20px;
226
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
227
+ }
228
+ .card h3 {
229
+ color: #333;
230
+ margin-bottom: 15px;
231
+ font-size: 1.3em;
232
+ }
233
+ .card img {
234
+ width: 100%;
235
+ border-radius: 8px;
236
+ margin-top: 10px;
237
+ }
238
+ .slider-container {
239
+ margin: 15px 0;
240
+ }
241
+ .slider-container label {
242
+ display: block;
243
+ margin-bottom: 8px;
244
+ color: #555;
245
+ font-weight: 500;
246
+ }
247
+ .slider {
248
+ width: 100%;
249
+ height: 8px;
250
+ border-radius: 5px;
251
+ background: #ddd;
252
+ outline: none;
253
+ }
254
+ .slider::-webkit-slider-thumb {
255
+ appearance: none;
256
+ width: 20px;
257
+ height: 20px;
258
+ border-radius: 50%;
259
+ background: #667eea;
260
+ cursor: pointer;
261
+ }
262
+ .value-display {
263
+ display: inline-block;
264
+ background: #667eea;
265
+ color: white;
266
+ padding: 4px 12px;
267
+ border-radius: 12px;
268
+ font-size: 0.9em;
269
+ margin-left: 10px;
270
+ }
271
+ button {
272
+ background: #667eea;
273
+ color: white;
274
+ border: none;
275
+ padding: 12px 24px;
276
+ border-radius: 8px;
277
+ cursor: pointer;
278
+ font-size: 16px;
279
+ transition: all 0.3s;
280
+ margin: 10px 5px;
281
+ }
282
+ button:hover {
283
+ background: #5568d3;
284
+ transform: translateY(-2px);
285
+ box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4);
286
+ }
287
+ button:disabled {
288
+ background: #ccc;
289
+ cursor: not-allowed;
290
+ transform: none;
291
+ }
292
+ .architecture-box {
293
+ background: white;
294
+ border: 2px solid #667eea;
295
+ border-radius: 8px;
296
+ padding: 15px;
297
+ margin: 10px 0;
298
+ text-align: center;
299
+ }
300
+ .arrow {
301
+ text-align: center;
302
+ font-size: 24px;
303
+ color: #667eea;
304
+ margin: 5px 0;
305
+ }
306
+ .info-box {
307
+ background: #e3f2fd;
308
+ border-left: 4px solid #2196F3;
309
+ padding: 15px;
310
+ margin: 15px 0;
311
+ border-radius: 4px;
312
+ }
313
+ .loading {
314
+ text-align: center;
315
+ padding: 20px;
316
+ color: #666;
317
+ }
318
+ .training-controls {
319
+ background: #fff;
320
+ border: 2px solid #667eea;
321
+ border-radius: 12px;
322
+ padding: 25px;
323
+ margin: 20px 0;
324
+ }
325
+ .input-group {
326
+ margin: 15px 0;
327
+ }
328
+ .input-group label {
329
+ display: block;
330
+ margin-bottom: 5px;
331
+ color: #555;
332
+ font-weight: 500;
333
+ }
334
+ .input-group input, .input-group select {
335
+ width: 100%;
336
+ padding: 10px;
337
+ border: 2px solid #ddd;
338
+ border-radius: 6px;
339
+ font-size: 14px;
340
+ }
341
+ .input-group input:focus {
342
+ outline: none;
343
+ border-color: #667eea;
344
+ }
345
+ .progress-container {
346
+ background: #f0f0f0;
347
+ border-radius: 10px;
348
+ height: 30px;
349
+ margin: 20px 0;
350
+ overflow: hidden;
351
+ position: relative;
352
+ }
353
+ .progress-bar {
354
+ background: linear-gradient(90deg, #667eea, #764ba2);
355
+ height: 100%;
356
+ transition: width 0.3s;
357
+ display: flex;
358
+ align-items: center;
359
+ justify-content: center;
360
+ color: white;
361
+ font-weight: bold;
362
+ }
363
+ .status-badge {
364
+ display: inline-block;
365
+ padding: 6px 14px;
366
+ border-radius: 20px;
367
+ font-size: 0.9em;
368
+ font-weight: 600;
369
+ margin: 10px 5px;
370
+ }
371
+ .status-training {
372
+ background: #ffc107;
373
+ color: #000;
374
+ }
375
+ .status-ready {
376
+ background: #4caf50;
377
+ color: white;
378
+ }
379
+ .status-not-trained {
380
+ background: #f44336;
381
+ color: white;
382
+ }
383
+ .training-info {
384
+ background: #f8f9fa;
385
+ padding: 15px;
386
+ border-radius: 8px;
387
+ margin: 15px 0;
388
+ }
389
+ .training-info p {
390
+ margin: 5px 0;
391
+ color: #555;
392
+ }
393
+ .param-grid {
394
+ display: grid;
395
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
396
+ gap: 15px;
397
+ }
398
+ </style>
399
+ </head>
400
+ <body>
401
+ <div class="container">
402
+ <h1>🧠 Variational Autoencoder Playground</h1>
403
+ <p class="subtitle">Interactive visualization for understanding VAE architecture and latent space</p>
404
+
405
+ <div class="tab-container">
406
+ <button class="tab active" onclick="switchTab('training')">Training Dashboard</button>
407
+ <button class="tab" onclick="switchTab('architecture')">Architecture</button>
408
+ <button class="tab" onclick="switchTab('latent')">Latent Space</button>
409
+ <button class="tab" onclick="switchTab('reconstruction')">Reconstruction</button>
410
+ <button class="tab" onclick="switchTab('generation')">Generation</button>
411
+ </div>
412
+
413
+ <div id="training" class="tab-content active">
414
+ <div class="training-controls">
415
+ <h3>⚙️ Training Configuration</h3>
416
+ <p style="color: #666; margin-bottom: 20px;">Configure your VAE parameters and start training</p>
417
+
418
+ <div class="param-grid">
419
+ <div class="input-group">
420
+ <label>Number of Epochs</label>
421
+ <input type="number" id="epochs" value="30" min="1" max="200">
422
+ </div>
423
+
424
+ <div class="input-group">
425
+ <label>Batch Size</label>
426
+ <select id="batch_size">
427
+ <option value="32">32</option>
428
+ <option value="64">64</option>
429
+ <option value="128" selected>128</option>
430
+ <option value="256">256</option>
431
+ </select>
432
+ </div>
433
+
434
+ <div class="input-group">
435
+ <label>Learning Rate</label>
436
+ <select id="learning_rate">
437
+ <option value="0.0001">0.0001</option>
438
+ <option value="0.001" selected>0.001</option>
439
+ <option value="0.01">0.01</option>
440
+ </select>
441
+ </div>
442
+
443
+ <div class="input-group">
444
+ <label>Hidden Dimension</label>
445
+ <select id="hidden_dim">
446
+ <option value="200">200</option>
447
+ <option value="400" selected>400</option>
448
+ <option value="512">512</option>
449
+ </select>
450
+ </div>
451
+
452
+ <div class="input-group">
453
+ <label>Latent Dimension</label>
454
+ <select id="latent_dim">
455
+ <option value="2" selected>2</option>
456
+ <option value="5">5</option>
457
+ <option value="10">10</option>
458
+ <option value="20">20</option>
459
+ </select>
460
+ </div>
461
+ </div>
462
+
463
+ <div style="margin-top: 20px;">
464
+ <button id="train-btn" onclick="startTraining()">🚀 Start Training</button>
465
+ <button onclick="resetModel()">🔄 Reset Model</button>
466
+ </div>
467
+ </div>
468
+
469
+ <div class="training-info">
470
+ <h3>📊 Training Status</h3>
471
+ <p><strong>Status:</strong> <span id="status-badge" class="status-badge status-not-trained">Not Trained</span></p>
472
+ <p id="epoch-info"><strong>Epoch:</strong> 0 / 0</p>
473
+ <p id="loss-info"><strong>Current Loss:</strong> N/A</p>
474
+ </div>
475
+
476
+ <div id="progress-section" style="display: none;">
477
+ <h3>Training Progress</h3>
478
+ <div class="progress-container">
479
+ <div class="progress-bar" id="progress-bar" style="width: 0%">0%</div>
480
+ </div>
481
+ </div>
482
+
483
+ <div class="card" id="loss-curve-card" style="display: none;">
484
+ <h3>Real-time Training Loss</h3>
485
+ <div id="training-plot"></div>
486
+ <button onclick="updateLossCurve()">Refresh Loss Curve</button>
487
+ </div>
488
+ </div>
489
+
490
+ <div id="architecture" class="tab-content">
491
+ <div class="info-box">
492
+ <strong>VAE Architecture:</strong> A Variational Autoencoder learns to compress data into a lower-dimensional latent space and reconstruct it.
493
+ The key innovation is the reparameterization trick, which allows backpropagation through stochastic sampling.
494
+ </div>
495
+
496
+ <div class="architecture-box">
497
+ <h4>Input (784D)</h4>
498
+ <small>28×28 image flattened</small>
499
+ </div>
500
+ <div class="arrow">↓</div>
501
+ <div class="architecture-box" style="background: #fff3e0;">
502
+ <h4>Encoder: FC Layer (<span id="arch-hidden">400</span>D)</h4>
503
+ <small>ReLU activation</small>
504
+ </div>
505
+ <div class="arrow">↓</div>
506
+ <div class="architecture-box" style="background: #e8f5e9;">
507
+ <h4>Latent Space (<span id="arch-latent">2</span>D)</h4>
508
+ <small>μ (mean) and σ² (variance)</small>
509
+ </div>
510
+ <div class="arrow">↓ Reparameterization Trick</div>
511
+ <div class="architecture-box" style="background: #e8f5e9;">
512
+ <h4>Sample z ~ N(μ, σ²)</h4>
513
+ <small>z = μ + σ * ε, where ε ~ N(0,1)</small>
514
+ </div>
515
+ <div class="arrow">↓</div>
516
+ <div class="architecture-box" style="background: #f3e5f5;">
517
+ <h4>Decoder: FC Layer (<span id="arch-hidden2">400</span>D)</h4>
518
+ <small>ReLU activation</small>
519
+ </div>
520
+ <div class="arrow">↓</div>
521
+ <div class="architecture-box">
522
+ <h4>Output (784D)</h4>
523
+ <small>Reconstructed image</small>
524
+ </div>
525
+
526
+ <div class="info-box" style="background: #fff3e0; border-left-color: #ff9800; margin-top: 20px;">
527
+ <strong>Loss Function:</strong> VAE Loss = Reconstruction Loss (BCE) + KL Divergence<br>
528
+ • BCE: Measures how well we reconstruct the input<br>
529
+ • KLD: Regularizes latent space to be close to N(0,1)
530
+ </div>
531
+ </div>
532
+
533
+ <div id="latent" class="tab-content">
534
+ <div class="info-box" style="background: #fff3e0; border-left-color: #ff9800;">
535
+ ⚠️ Please train the model first in the Training Dashboard before using this feature.
536
+ </div>
537
+ <div class="card">
538
+ <h3>Latent Space Visualization</h3>
539
+ <p>Each point represents an MNIST digit encoded in 2D latent space. Colors indicate digit classes (0-9).</p>
540
+ <button onclick="loadLatentSpace()">Refresh Latent Space</button>
541
+ <div id="latent-plot" class="loading">Train the model first, then click button to generate...</div>
542
+ </div>
543
+ </div>
544
+
545
+ <div id="reconstruction" class="tab-content">
546
+ <div class="info-box" style="background: #fff3e0; border-left-color: #ff9800;">
547
+ ⚠️ Please train the model first in the Training Dashboard before using this feature.
548
+ </div>
549
+ <div class="card">
550
+ <h3>Input vs Reconstruction</h3>
551
+ <p>See how well the VAE reconstructs MNIST digits.</p>
552
+ <button onclick="loadReconstruction()">Show Random Reconstruction</button>
553
+ <div id="recon-plot" class="loading">Train the model first, then click button to generate...</div>
554
+ </div>
555
+ </div>
556
+
557
+ <div id="generation" class="tab-content">
558
+ <div class="info-box" style="background: #fff3e0; border-left-color: #ff9800;">
559
+ ⚠️ Please train the model first in the Training Dashboard before using this feature. Generation works best with 2D latent space.
560
+ </div>
561
+ <div class="card">
562
+ <h3>Generate from Latent Space</h3>
563
+ <p>Manipulate latent dimensions to generate new digit-like samples. Explore how different regions of latent space correspond to different digits!</p>
564
+
565
+ <div class="slider-container">
566
+ <label>Z1 (Latent Dimension 1): <span class="value-display" id="z1-val">0.00</span></label>
567
+ <input type="range" class="slider" id="z1" min="-3" max="3" step="0.1" value="0" oninput="updateValue('z1')">
568
+ </div>
569
+
570
+ <div class="slider-container">
571
+ <label>Z2 (Latent Dimension 2): <span class="value-display" id="z2-val">0.00</span></label>
572
+ <input type="range" class="slider" id="z2" min="-3" max="3" step="0.1" value="0" oninput="updateValue('z2')">
573
+ </div>
574
+
575
+ <button onclick="generateSample()">Generate Image</button>
576
+ <button onclick="randomSample()">Random Sample</button>
577
+ <button onclick="generateGrid()">Generate Grid (2D only)</button>
578
+
579
+ <div id="gen-plot" class="loading">Train the model first, then adjust sliders and click Generate...</div>
580
+ </div>
581
+ </div>
582
+ </div>
583
+
584
+ <script>
585
+ let progressInterval = null;
586
+
587
+ function switchTab(tabName) {
588
+ document.querySelectorAll('.tab').forEach(t => t.classList.remove('active'));
589
+ document.querySelectorAll('.tab-content').forEach(c => c.classList.remove('active'));
590
+ event.target.classList.add('active');
591
+ document.getElementById(tabName).classList.add('active');
592
+ }
593
+
594
+ function updateValue(id) {
595
+ const val = document.getElementById(id).value;
596
+ document.getElementById(id + '-val').textContent = parseFloat(val).toFixed(2);
597
+ }
598
+
599
+ async function startTraining() {
600
+ const epochs = parseInt(document.getElementById('epochs').value);
601
+ const batch_size = parseInt(document.getElementById('batch_size').value);
602
+ const learning_rate = parseFloat(document.getElementById('learning_rate').value);
603
+ const hidden_dim = parseInt(document.getElementById('hidden_dim').value);
604
+ const latent_dim = parseInt(document.getElementById('latent_dim').value);
605
+
606
+ // Update architecture display
607
+ document.getElementById('arch-hidden').textContent = hidden_dim;
608
+ document.getElementById('arch-hidden2').textContent = hidden_dim;
609
+ document.getElementById('arch-latent').textContent = latent_dim;
610
+
611
+ document.getElementById('train-btn').disabled = true;
612
+ document.getElementById('progress-section').style.display = 'block';
613
+ document.getElementById('loss-curve-card').style.display = 'block';
614
+
615
+ const response = await fetch('/start_training', {
616
+ method: 'POST',
617
+ headers: {'Content-Type': 'application/json'},
618
+ body: JSON.stringify({epochs, batch_size, learning_rate, hidden_dim, latent_dim})
619
+ });
620
+
621
+ const data = await response.json();
622
+
623
+ if (data.status === 'started') {
624
+ // Start polling for progress
625
+ progressInterval = setInterval(updateProgress, 500);
626
+ }
627
+ }
628
+
629
+ async function updateProgress() {
630
+ const response = await fetch('/training_progress');
631
+ const data = await response.json();
632
+
633
+ const progressBar = document.getElementById('progress-bar');
634
+ progressBar.style.width = data.progress + '%';
635
+ progressBar.textContent = data.progress + '%';
636
+
637
+ document.getElementById('epoch-info').innerHTML = `<strong>Epoch:</strong> ${data.current_epoch} / ${data.total_epochs}`;
638
+ document.getElementById('loss-info').innerHTML = `<strong>Current Loss:</strong> ${data.current_loss.toFixed(4)}`;
639
+
640
+ const statusBadge = document.getElementById('status-badge');
641
+ if (data.is_training) {
642
+ statusBadge.className = 'status-badge status-training';
643
+ statusBadge.textContent = 'Training...';
644
+ } else if (data.trained) {
645
+ statusBadge.className = 'status-badge status-ready';
646
+ statusBadge.textContent = 'Ready';
647
+ document.getElementById('train-btn').disabled = false;
648
+ clearInterval(progressInterval);
649
+ updateLossCurve();
650
+ } else {
651
+ statusBadge.className = 'status-badge status-not-trained';
652
+ statusBadge.textContent = 'Not Trained';
653
+ }
654
+ }
655
+
656
+ async function updateLossCurve() {
657
+ const response = await fetch('/training_curve');
658
+ const data = await response.json();
659
+ if (data.image) {
660
+ document.getElementById('training-plot').innerHTML = `<img src="data:image/png;base64,${data.image}">`;
661
+ }
662
+ }
663
+
664
+ async function resetModel() {
665
+ if (confirm('Are you sure you want to reset the model? All training progress will be lost.')) {
666
+ const response = await fetch('/reset_model', {method: 'POST'});
667
+ const data = await response.json();
668
+ if (data.status === 'reset') {
669
+ location.reload();
670
+ }
671
+ }
672
+ }
673
+
674
+ async function loadLatentSpace() {
675
+ document.getElementById('latent-plot').innerHTML = '<div class="loading">Generating...</div>';
676
+ const response = await fetch('/latent_space');
677
+ const data = await response.json();
678
+ if (data.error) {
679
+ document.getElementById('latent-plot').innerHTML = `<div class="loading" style="color: red;">${data.error}</div>`;
680
+ } else {
681
+ document.getElementById('latent-plot').innerHTML = `<img src="data:image/png;base64,${data.image}">`;
682
+ }
683
+ }
684
+
685
+ async function loadReconstruction() {
686
+ document.getElementById('recon-plot').innerHTML = '<div class="loading">Generating...</div>';
687
+ const response = await fetch('/reconstruction');
688
+ const data = await response.json();
689
+ if (data.error) {
690
+ document.getElementById('recon-plot').innerHTML = `<div class="loading" style="color: red;">${data.error}</div>`;
691
+ } else {
692
+ document.getElementById('recon-plot').innerHTML = `<img src="data:image/png;base64,${data.image}">`;
693
+ }
694
+ }
695
+
696
+ async function generateSample() {
697
+ const z1 = parseFloat(document.getElementById('z1').value);
698
+ const z2 = parseFloat(document.getElementById('z2').value);
699
+ document.getElementById('gen-plot').innerHTML = '<div class="loading">Generating...</div>';
700
+ const response = await fetch('/generate', {
701
+ method: 'POST',
702
+ headers: {'Content-Type': 'application/json'},
703
+ body: JSON.stringify({z1, z2})
704
+ });
705
+ const data = await response.json();
706
+ if (data.error) {
707
+ document.getElementById('gen-plot').innerHTML = `<div class="loading" style="color: red;">${data.error}</div>`;
708
+ } else {
709
+ document.getElementById('gen-plot').innerHTML = `<img src="data:image/png;base64,${data.image}">`;
710
+ }
711
+ }
712
+
713
+ async function randomSample() {
714
+ const z1 = (Math.random() * 6 - 3).toFixed(2);
715
+ const z2 = (Math.random() * 6 - 3).toFixed(2);
716
+ document.getElementById('z1').value = z1;
717
+ document.getElementById('z2').value = z2;
718
+ updateValue('z1');
719
+ updateValue('z2');
720
+ await generateSample();
721
+ }
722
+
723
+ async function generateGrid() {
724
+ document.getElementById('gen-plot').innerHTML = '<div class="loading">Generating grid...</div>';
725
+ const response = await fetch('/generate_grid');
726
+ const data = await response.json();
727
+ if (data.error) {
728
+ document.getElementById('gen-plot').innerHTML = `<div class="loading" style="color: red;">${data.error}</div>`;
729
+ } else {
730
+ document.getElementById('gen-plot').innerHTML = `<img src="data:image/png;base64,${data.image}">`;
731
+ }
732
+ }
733
+
734
+ // Check initial status
735
+ updateProgress();
736
+ </script>
737
+ </body>
738
+ </html>
739
+ '''
740
+
741
+ @app.route('/')
742
+ def index():
743
+ return render_template_string(HTML_TEMPLATE)
744
+
745
+ @app.route('/start_training', methods=['POST'])
746
+ def start_training():
747
+ global training_state
748
+
749
+ if training_state['is_training']:
750
+ return jsonify({'status': 'already_training'})
751
+
752
+ params = request.json
753
+ epochs = params.get('epochs', 30)
754
+ batch_size = params.get('batch_size', 128)
755
+ learning_rate = params.get('learning_rate', 0.001)
756
+ hidden_dim = params.get('hidden_dim', 400)
757
+ latent_dim = params.get('latent_dim', 2)
758
+
759
+ # Start training in a separate thread
760
+ thread = threading.Thread(
761
+ target=train_vae_thread,
762
+ args=(epochs, batch_size, learning_rate, hidden_dim, latent_dim)
763
+ )
764
+ thread.daemon = True
765
+ thread.start()
766
+
767
+ return jsonify({'status': 'started'})
768
+
769
+ @app.route('/training_progress')
770
+ def training_progress():
771
+ return jsonify({
772
+ 'is_training': training_state['is_training'],
773
+ 'progress': training_state['progress'],
774
+ 'current_epoch': training_state['current_epoch'],
775
+ 'total_epochs': training_state['total_epochs'],
776
+ 'current_loss': training_state['current_loss'],
777
+ 'trained': training_state['trained']
778
+ })
779
+
780
+ @app.route('/reset_model', methods=['POST'])
781
+ def reset_model():
782
+ global vae, training_state
783
+ vae = None
784
+ training_state = {
785
+ 'is_training': False,
786
+ 'progress': 0,
787
+ 'current_epoch': 0,
788
+ 'total_epochs': 0,
789
+ 'losses': [],
790
+ 'trained': False,
791
+ 'current_loss': 0
792
+ }
793
+ return jsonify({'status': 'reset'})
794
+
795
+ @app.route('/latent_space')
796
+ def latent_space():
797
+ if vae is None or not training_state['trained']:
798
+ return jsonify({'error': 'Model not trained yet. Please train the model first.'})
799
+
800
+ if vae.latent_dim != 2:
801
+ return jsonify({'error': 'Latent space visualization only works with 2D latent dimension.'})
802
+
803
+ vae.eval()
804
+ with torch.no_grad():
805
+ mu, _ = vae.encode(data_tensor)
806
+ mu_np = mu.numpy()
807
+
808
+ fig, ax = plt.subplots(figsize=(12, 10))
809
+ scatter = ax.scatter(mu_np[:, 0], mu_np[:, 1], c=labels, cmap='tab10',
810
+ alpha=0.6, s=30, edgecolors='black', linewidth=0.5)
811
+ ax.set_xlabel('Latent Dimension 1', fontsize=12, fontweight='bold')
812
+ ax.set_ylabel('Latent Dimension 2', fontsize=12, fontweight='bold')
813
+ ax.set_title('VAE Latent Space - MNIST Digits (2D)', fontsize=14, fontweight='bold')
814
+ ax.grid(True, alpha=0.3)
815
+ cbar = plt.colorbar(scatter, ax=ax, ticks=range(10))
816
+ cbar.set_label('Digit Class', fontsize=11)
817
+ cbar.ax.set_yticklabels(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
818
+
819
+ return jsonify({'image': fig_to_base64(fig)})
820
+
821
+ @app.route('/reconstruction')
822
+ def reconstruction():
823
+ if vae is None or not training_state['trained']:
824
+ return jsonify({'error': 'Model not trained yet. Please train the model first.'})
825
+
826
+ # Show multiple reconstructions
827
+ n_samples = 10
828
+ indices = np.random.choice(len(data), n_samples, replace=False)
829
+
830
+ vae.eval()
831
+ with torch.no_grad():
832
+ originals = data_tensor[indices]
833
+ reconstructions, _, _ = vae(originals)
834
+
835
+ fig, axes = plt.subplots(2, n_samples, figsize=(20, 4))
836
+
837
+ for i in range(n_samples):
838
+ # Original
839
+ axes[0, i].imshow(originals[i].numpy().reshape(28, 28), cmap='gray')
840
+ axes[0, i].set_title(f'Original\n(Digit {labels[indices[i]]})', fontsize=9)
841
+ axes[0, i].axis('off')
842
+
843
+ # Reconstruction
844
+ axes[1, i].imshow(reconstructions[i].numpy().reshape(28, 28), cmap='gray')
845
+ axes[1, i].set_title('Reconstructed', fontsize=9)
846
+ axes[1, i].axis('off')
847
+
848
+ fig.suptitle('MNIST Reconstruction Comparison', fontsize=14, fontweight='bold', y=1.02)
849
+ plt.tight_layout()
850
+
851
+ return jsonify({'image': fig_to_base64(fig)})
852
+
853
+ @app.route('/generate', methods=['POST'])
854
+ def generate():
855
+ if vae is None or not training_state['trained']:
856
+ return jsonify({'error': 'Model not trained yet. Please train the model first.'})
857
+
858
+ data = request.json
859
+ z1 = data['z1']
860
+ z2 = data['z2']
861
+
862
+ # Create latent vector with correct dimensions
863
+ if vae.latent_dim == 2:
864
+ z = torch.FloatTensor([[z1, z2]])
865
+ else:
866
+ # For higher dimensions, use z1 and z2 for first two dims, zeros for rest
867
+ z = torch.zeros(1, vae.latent_dim)
868
+ z[0, 0] = z1
869
+ z[0, 1] = z2
870
+
871
+ vae.eval()
872
+ with torch.no_grad():
873
+ generated = vae.decode(z)
874
+
875
+ fig, ax = plt.subplots(figsize=(6, 6))
876
+ ax.imshow(generated.numpy().reshape(28, 28), cmap='gray')
877
+ ax.set_title(f'Generated Digit\nz1={z1:.2f}, z2={z2:.2f}',
878
+ fontsize=13, fontweight='bold')
879
+ ax.axis('off')
880
+
881
+ return jsonify({'image': fig_to_base64(fig)})
882
+
883
+ @app.route('/generate_grid')
884
+ def generate_grid():
885
+ if vae is None or not training_state['trained']:
886
+ return jsonify({'error': 'Model not trained yet. Please train the model first.'})
887
+
888
+ if vae.latent_dim != 2:
889
+ return jsonify({'error': 'Grid generation only works with 2D latent dimension.'})
890
+
891
+ # Generate a grid of images by sampling latent space
892
+ n = 15
893
+ grid_x = np.linspace(-3, 3, n)
894
+ grid_y = np.linspace(-3, 3, n)
895
+
896
+ fig, axes = plt.subplots(n, n, figsize=(15, 15))
897
+
898
+ vae.eval()
899
+ with torch.no_grad():
900
+ for i, yi in enumerate(grid_y):
901
+ for j, xi in enumerate(grid_x):
902
+ z = torch.FloatTensor([[xi, yi]])
903
+ generated = vae.decode(z)
904
+ axes[i, j].imshow(generated.numpy().reshape(28, 28), cmap='gray')
905
+ axes[i, j].axis('off')
906
+
907
+ fig.suptitle('Latent Space Manifold (15×15 Grid)', fontsize=16, fontweight='bold')
908
+ plt.tight_layout()
909
+
910
+ return jsonify({'image': fig_to_base64(fig)})
911
+
912
+ @app.route('/training_curve')
913
+ def training_curve():
914
+ if not training_state['losses']:
915
+ return jsonify({'error': 'No training data available yet.'})
916
+
917
+ fig, ax = plt.subplots(figsize=(10, 6))
918
+ ax.plot(training_state['losses'], linewidth=2, color='#667eea')
919
+ ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
920
+ ax.set_ylabel('Loss', fontsize=12, fontweight='bold')
921
+ ax.set_title('VAE Training Loss Over Time', fontsize=14, fontweight='bold')
922
+ ax.grid(True, alpha=0.3)
923
+ ax.fill_between(range(len(training_state['losses'])), training_state['losses'], alpha=0.3, color='#667eea')
924
+
925
+ return jsonify({'image': fig_to_base64(fig)})
926
+
927
+ if __name__ == '__main__':
928
+ app.run(debug=True, port=5000, threaded=True)