from flask import Flask, render_template_string, jsonify, request
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import base64
from io import BytesIO
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import threading
import time
import os
app = Flask(__name__)
# Global variables for training state
training_state = {
'is_training': False,
'progress': 0,
'current_epoch': 0,
'total_epochs': 0,
'losses': [],
'trained': False,
'current_loss': 0
}
# VAE Architecture
class VAE(nn.Module):
def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2):
super(VAE, self).__init__()
# Encoder
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder
self.fc3 = nn.Linear(latent_dim, hidden_dim)
self.fc4 = nn.Linear(hidden_dim, input_dim)
self.latent_dim = latent_dim
def encode(self, x):
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
return z
def decode(self, z):
h = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Loss function
def vae_loss(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD, BCE, KLD
# Load MNIST data
def load_mnist_data():
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
# Get subset for faster training and visualization
subset_size = 10000
indices = torch.randperm(len(train_dataset))[:subset_size]
data = []
labels = []
for idx in indices:
img, label = train_dataset[idx]
data.append(img.view(-1).numpy())
labels.append(label)
return np.array(data), np.array(labels)
# Initialize model and data
print("Loading MNIST dataset...")
vae = None
data, labels = load_mnist_data()
data_tensor = torch.FloatTensor(data)
print(f"Loaded {len(data)} MNIST samples")
# Train the VAE in a separate thread
def train_vae_thread(epochs, batch_size, learning_rate, hidden_dim, latent_dim):
global vae, training_state
training_state['is_training'] = True
training_state['progress'] = 0
training_state['current_epoch'] = 0
training_state['total_epochs'] = epochs
training_state['losses'] = []
# Initialize new model with specified parameters
vae = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
dataset = torch.utils.data.TensorDataset(data_tensor)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
vae.train()
total_loss = 0
batch_count = 0
for batch in dataloader:
x = batch[0]
optimizer.zero_grad()
recon_x, mu, logvar = vae(x)
loss, _, _ = vae_loss(recon_x, x, mu, logvar)
loss.backward()
optimizer.step()
total_loss += loss.item()
batch_count += 1
avg_loss = total_loss / len(dataloader.dataset)
training_state['losses'].append(avg_loss)
training_state['current_epoch'] = epoch + 1
training_state['current_loss'] = avg_loss
training_state['progress'] = int(((epoch + 1) / epochs) * 100)
print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
training_state['is_training'] = False
training_state['trained'] = True
print("Training complete!")
def fig_to_base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', dpi=100)
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode()
plt.close(fig)
return img_str
HTML_TEMPLATE = '''
VAE Interactive Playground
🧠 Variational Autoencoder Playground
Interactive visualization for understanding VAE architecture and latent space
⚙️ Training Configuration
Configure your VAE parameters and start training
📊 Training Status
Status:Not Trained
Epoch: 0 / 0
Current Loss: N/A
Training Progress
0%
Real-time Training Loss
VAE Architecture: A Variational Autoencoder learns to compress data into a lower-dimensional latent space and reconstruct it.
The key innovation is the reparameterization trick, which allows backpropagation through stochastic sampling.
Input (784D)
28×28 image flattened
↓
Encoder: FC Layer (400D)
ReLU activation
↓
Latent Space (2D)
μ (mean) and σ² (variance)
↓ Reparameterization Trick
Sample z ~ N(μ, σ²)
z = μ + σ * ε, where ε ~ N(0,1)
↓
Decoder: FC Layer (400D)
ReLU activation
↓
Output (784D)
Reconstructed image
Loss Function: VAE Loss = Reconstruction Loss (BCE) + KL Divergence
• BCE: Measures how well we reconstruct the input
• KLD: Regularizes latent space to be close to N(0,1)
⚠️ Please train the model first in the Training Dashboard before using this feature.
Latent Space Visualization
Each point represents an MNIST digit encoded in 2D latent space. Colors indicate digit classes (0-9).
Train the model first, then click button to generate...
⚠️ Please train the model first in the Training Dashboard before using this feature.
Input vs Reconstruction
See how well the VAE reconstructs MNIST digits.
Train the model first, then click button to generate...
⚠️ Please train the model first in the Training Dashboard before using this feature. Generation works best with 2D latent space.
Generate from Latent Space
Manipulate latent dimensions to generate new digit-like samples. Explore how different regions of latent space correspond to different digits!
Train the model first, then adjust sliders and click Generate...
'''
@app.route('/')
def index():
return render_template_string(HTML_TEMPLATE)
@app.route('/start_training', methods=['POST'])
def start_training():
global training_state
if training_state['is_training']:
return jsonify({'status': 'already_training'})
params = request.json
epochs = params.get('epochs', 30)
batch_size = params.get('batch_size', 128)
learning_rate = params.get('learning_rate', 0.001)
hidden_dim = params.get('hidden_dim', 400)
latent_dim = params.get('latent_dim', 2)
# Start training in a separate thread
thread = threading.Thread(
target=train_vae_thread,
args=(epochs, batch_size, learning_rate, hidden_dim, latent_dim)
)
thread.daemon = True
thread.start()
return jsonify({'status': 'started'})
@app.route('/training_progress')
def training_progress():
return jsonify({
'is_training': training_state['is_training'],
'progress': training_state['progress'],
'current_epoch': training_state['current_epoch'],
'total_epochs': training_state['total_epochs'],
'current_loss': training_state['current_loss'],
'trained': training_state['trained']
})
@app.route('/reset_model', methods=['POST'])
def reset_model():
global vae, training_state
vae = None
training_state = {
'is_training': False,
'progress': 0,
'current_epoch': 0,
'total_epochs': 0,
'losses': [],
'trained': False,
'current_loss': 0
}
return jsonify({'status': 'reset'})
@app.route('/latent_space')
def latent_space():
if vae is None or not training_state['trained']:
return jsonify({'error': 'Model not trained yet. Please train the model first.'})
if vae.latent_dim != 2:
return jsonify({'error': 'Latent space visualization only works with 2D latent dimension.'})
vae.eval()
with torch.no_grad():
mu, _ = vae.encode(data_tensor)
mu_np = mu.numpy()
fig, ax = plt.subplots(figsize=(12, 10))
scatter = ax.scatter(mu_np[:, 0], mu_np[:, 1], c=labels, cmap='tab10',
alpha=0.6, s=30, edgecolors='black', linewidth=0.5)
ax.set_xlabel('Latent Dimension 1', fontsize=12, fontweight='bold')
ax.set_ylabel('Latent Dimension 2', fontsize=12, fontweight='bold')
ax.set_title('VAE Latent Space - MNIST Digits (2D)', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
cbar = plt.colorbar(scatter, ax=ax, ticks=range(10))
cbar.set_label('Digit Class', fontsize=11)
cbar.ax.set_yticklabels(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'])
return jsonify({'image': fig_to_base64(fig)})
@app.route('/reconstruction')
def reconstruction():
if vae is None or not training_state['trained']:
return jsonify({'error': 'Model not trained yet. Please train the model first.'})
# Show multiple reconstructions
n_samples = 10
indices = np.random.choice(len(data), n_samples, replace=False)
vae.eval()
with torch.no_grad():
originals = data_tensor[indices]
reconstructions, _, _ = vae(originals)
fig, axes = plt.subplots(2, n_samples, figsize=(20, 4))
for i in range(n_samples):
# Original
axes[0, i].imshow(originals[i].numpy().reshape(28, 28), cmap='gray')
axes[0, i].set_title(f'Original\n(Digit {labels[indices[i]]})', fontsize=9)
axes[0, i].axis('off')
# Reconstruction
axes[1, i].imshow(reconstructions[i].numpy().reshape(28, 28), cmap='gray')
axes[1, i].set_title('Reconstructed', fontsize=9)
axes[1, i].axis('off')
fig.suptitle('MNIST Reconstruction Comparison', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
return jsonify({'image': fig_to_base64(fig)})
@app.route('/generate', methods=['POST'])
def generate():
if vae is None or not training_state['trained']:
return jsonify({'error': 'Model not trained yet. Please train the model first.'})
data = request.json
z1 = data['z1']
z2 = data['z2']
# Create latent vector with correct dimensions
if vae.latent_dim == 2:
z = torch.FloatTensor([[z1, z2]])
else:
# For higher dimensions, use z1 and z2 for first two dims, zeros for rest
z = torch.zeros(1, vae.latent_dim)
z[0, 0] = z1
z[0, 1] = z2
vae.eval()
with torch.no_grad():
generated = vae.decode(z)
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(generated.numpy().reshape(28, 28), cmap='gray')
ax.set_title(f'Generated Digit\nz1={z1:.2f}, z2={z2:.2f}',
fontsize=13, fontweight='bold')
ax.axis('off')
return jsonify({'image': fig_to_base64(fig)})
@app.route('/generate_grid')
def generate_grid():
if vae is None or not training_state['trained']:
return jsonify({'error': 'Model not trained yet. Please train the model first.'})
if vae.latent_dim != 2:
return jsonify({'error': 'Grid generation only works with 2D latent dimension.'})
# Generate a grid of images by sampling latent space
n = 15
grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)
fig, axes = plt.subplots(n, n, figsize=(15, 15))
vae.eval()
with torch.no_grad():
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z = torch.FloatTensor([[xi, yi]])
generated = vae.decode(z)
axes[i, j].imshow(generated.numpy().reshape(28, 28), cmap='gray')
axes[i, j].axis('off')
fig.suptitle('Latent Space Manifold (15×15 Grid)', fontsize=16, fontweight='bold')
plt.tight_layout()
return jsonify({'image': fig_to_base64(fig)})
@app.route('/training_curve')
def training_curve():
if not training_state['losses']:
return jsonify({'error': 'No training data available yet.'})
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(training_state['losses'], linewidth=2, color='#667eea')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Loss', fontsize=12, fontweight='bold')
ax.set_title('VAE Training Loss Over Time', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.fill_between(range(len(training_state['losses'])), training_state['losses'], alpha=0.3, color='#667eea')
return jsonify({'image': fig_to_base64(fig)})
if __name__ == '__main__':
port = int(os.environ.get('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=False, threaded=True)