from flask import Flask, render_template_string, jsonify, request import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms import base64 from io import BytesIO import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import threading import time import os app = Flask(__name__) # Global variables for training state training_state = { 'is_training': False, 'progress': 0, 'current_epoch': 0, 'total_epochs': 0, 'losses': [], 'trained': False, 'current_loss': 0 } # VAE Architecture class VAE(nn.Module): def __init__(self, input_dim=784, hidden_dim=400, latent_dim=2): super(VAE, self).__init__() # Encoder self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc_mu = nn.Linear(hidden_dim, latent_dim) self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # Decoder self.fc3 = nn.Linear(latent_dim, hidden_dim) self.fc4 = nn.Linear(hidden_dim, input_dim) self.latent_dim = latent_dim def encode(self, x): h = F.relu(self.fc1(x)) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std return z def decode(self, z): h = F.relu(self.fc3(z)) return torch.sigmoid(self.fc4(h)) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar # Loss function def vae_loss(recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD, BCE, KLD # Load MNIST data def load_mnist_data(): transform = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) # Get subset for faster training and visualization subset_size = 10000 indices = torch.randperm(len(train_dataset))[:subset_size] data = [] labels = [] for idx in indices: img, label = train_dataset[idx] data.append(img.view(-1).numpy()) labels.append(label) return np.array(data), np.array(labels) # Initialize model and data print("Loading MNIST dataset...") vae = None data, labels = load_mnist_data() data_tensor = torch.FloatTensor(data) print(f"Loaded {len(data)} MNIST samples") # Train the VAE in a separate thread def train_vae_thread(epochs, batch_size, learning_rate, hidden_dim, latent_dim): global vae, training_state training_state['is_training'] = True training_state['progress'] = 0 training_state['current_epoch'] = 0 training_state['total_epochs'] = epochs training_state['losses'] = [] # Initialize new model with specified parameters vae = VAE(input_dim=784, hidden_dim=hidden_dim, latent_dim=latent_dim) optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate) dataset = torch.utils.data.TensorDataset(data_tensor) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) for epoch in range(epochs): vae.train() total_loss = 0 batch_count = 0 for batch in dataloader: x = batch[0] optimizer.zero_grad() recon_x, mu, logvar = vae(x) loss, _, _ = vae_loss(recon_x, x, mu, logvar) loss.backward() optimizer.step() total_loss += loss.item() batch_count += 1 avg_loss = total_loss / len(dataloader.dataset) training_state['losses'].append(avg_loss) training_state['current_epoch'] = epoch + 1 training_state['current_loss'] = avg_loss training_state['progress'] = int(((epoch + 1) / epochs) * 100) print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}") training_state['is_training'] = False training_state['trained'] = True print("Training complete!") def fig_to_base64(fig): buf = BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', dpi=100) buf.seek(0) img_str = base64.b64encode(buf.read()).decode() plt.close(fig) return img_str HTML_TEMPLATE = ''' VAE Interactive Playground

🧠 Variational Autoencoder Playground

Interactive visualization for understanding VAE architecture and latent space

⚙️ Training Configuration

Configure your VAE parameters and start training

📊 Training Status

Status: Not Trained

Epoch: 0 / 0

Current Loss: N/A

VAE Architecture: A Variational Autoencoder learns to compress data into a lower-dimensional latent space and reconstruct it. The key innovation is the reparameterization trick, which allows backpropagation through stochastic sampling.

Input (784D)

28×28 image flattened

Encoder: FC Layer (400D)

ReLU activation

Latent Space (2D)

μ (mean) and σ² (variance)
↓ Reparameterization Trick

Sample z ~ N(μ, σ²)

z = μ + σ * ε, where ε ~ N(0,1)

Decoder: FC Layer (400D)

ReLU activation

Output (784D)

Reconstructed image
Loss Function: VAE Loss = Reconstruction Loss (BCE) + KL Divergence
• BCE: Measures how well we reconstruct the input
• KLD: Regularizes latent space to be close to N(0,1)
⚠️ Please train the model first in the Training Dashboard before using this feature.

Latent Space Visualization

Each point represents an MNIST digit encoded in 2D latent space. Colors indicate digit classes (0-9).

Train the model first, then click button to generate...
⚠️ Please train the model first in the Training Dashboard before using this feature.

Input vs Reconstruction

See how well the VAE reconstructs MNIST digits.

Train the model first, then click button to generate...
⚠️ Please train the model first in the Training Dashboard before using this feature. Generation works best with 2D latent space.

Generate from Latent Space

Manipulate latent dimensions to generate new digit-like samples. Explore how different regions of latent space correspond to different digits!

Train the model first, then adjust sliders and click Generate...
''' @app.route('/') def index(): return render_template_string(HTML_TEMPLATE) @app.route('/start_training', methods=['POST']) def start_training(): global training_state if training_state['is_training']: return jsonify({'status': 'already_training'}) params = request.json epochs = params.get('epochs', 30) batch_size = params.get('batch_size', 128) learning_rate = params.get('learning_rate', 0.001) hidden_dim = params.get('hidden_dim', 400) latent_dim = params.get('latent_dim', 2) # Start training in a separate thread thread = threading.Thread( target=train_vae_thread, args=(epochs, batch_size, learning_rate, hidden_dim, latent_dim) ) thread.daemon = True thread.start() return jsonify({'status': 'started'}) @app.route('/training_progress') def training_progress(): return jsonify({ 'is_training': training_state['is_training'], 'progress': training_state['progress'], 'current_epoch': training_state['current_epoch'], 'total_epochs': training_state['total_epochs'], 'current_loss': training_state['current_loss'], 'trained': training_state['trained'] }) @app.route('/reset_model', methods=['POST']) def reset_model(): global vae, training_state vae = None training_state = { 'is_training': False, 'progress': 0, 'current_epoch': 0, 'total_epochs': 0, 'losses': [], 'trained': False, 'current_loss': 0 } return jsonify({'status': 'reset'}) @app.route('/latent_space') def latent_space(): if vae is None or not training_state['trained']: return jsonify({'error': 'Model not trained yet. Please train the model first.'}) if vae.latent_dim != 2: return jsonify({'error': 'Latent space visualization only works with 2D latent dimension.'}) vae.eval() with torch.no_grad(): mu, _ = vae.encode(data_tensor) mu_np = mu.numpy() fig, ax = plt.subplots(figsize=(12, 10)) scatter = ax.scatter(mu_np[:, 0], mu_np[:, 1], c=labels, cmap='tab10', alpha=0.6, s=30, edgecolors='black', linewidth=0.5) ax.set_xlabel('Latent Dimension 1', fontsize=12, fontweight='bold') ax.set_ylabel('Latent Dimension 2', fontsize=12, fontweight='bold') ax.set_title('VAE Latent Space - MNIST Digits (2D)', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) cbar = plt.colorbar(scatter, ax=ax, ticks=range(10)) cbar.set_label('Digit Class', fontsize=11) cbar.ax.set_yticklabels(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']) return jsonify({'image': fig_to_base64(fig)}) @app.route('/reconstruction') def reconstruction(): if vae is None or not training_state['trained']: return jsonify({'error': 'Model not trained yet. Please train the model first.'}) # Show multiple reconstructions n_samples = 10 indices = np.random.choice(len(data), n_samples, replace=False) vae.eval() with torch.no_grad(): originals = data_tensor[indices] reconstructions, _, _ = vae(originals) fig, axes = plt.subplots(2, n_samples, figsize=(20, 4)) for i in range(n_samples): # Original axes[0, i].imshow(originals[i].numpy().reshape(28, 28), cmap='gray') axes[0, i].set_title(f'Original\n(Digit {labels[indices[i]]})', fontsize=9) axes[0, i].axis('off') # Reconstruction axes[1, i].imshow(reconstructions[i].numpy().reshape(28, 28), cmap='gray') axes[1, i].set_title('Reconstructed', fontsize=9) axes[1, i].axis('off') fig.suptitle('MNIST Reconstruction Comparison', fontsize=14, fontweight='bold', y=1.02) plt.tight_layout() return jsonify({'image': fig_to_base64(fig)}) @app.route('/generate', methods=['POST']) def generate(): if vae is None or not training_state['trained']: return jsonify({'error': 'Model not trained yet. Please train the model first.'}) data = request.json z1 = data['z1'] z2 = data['z2'] # Create latent vector with correct dimensions if vae.latent_dim == 2: z = torch.FloatTensor([[z1, z2]]) else: # For higher dimensions, use z1 and z2 for first two dims, zeros for rest z = torch.zeros(1, vae.latent_dim) z[0, 0] = z1 z[0, 1] = z2 vae.eval() with torch.no_grad(): generated = vae.decode(z) fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(generated.numpy().reshape(28, 28), cmap='gray') ax.set_title(f'Generated Digit\nz1={z1:.2f}, z2={z2:.2f}', fontsize=13, fontweight='bold') ax.axis('off') return jsonify({'image': fig_to_base64(fig)}) @app.route('/generate_grid') def generate_grid(): if vae is None or not training_state['trained']: return jsonify({'error': 'Model not trained yet. Please train the model first.'}) if vae.latent_dim != 2: return jsonify({'error': 'Grid generation only works with 2D latent dimension.'}) # Generate a grid of images by sampling latent space n = 15 grid_x = np.linspace(-3, 3, n) grid_y = np.linspace(-3, 3, n) fig, axes = plt.subplots(n, n, figsize=(15, 15)) vae.eval() with torch.no_grad(): for i, yi in enumerate(grid_y): for j, xi in enumerate(grid_x): z = torch.FloatTensor([[xi, yi]]) generated = vae.decode(z) axes[i, j].imshow(generated.numpy().reshape(28, 28), cmap='gray') axes[i, j].axis('off') fig.suptitle('Latent Space Manifold (15×15 Grid)', fontsize=16, fontweight='bold') plt.tight_layout() return jsonify({'image': fig_to_base64(fig)}) @app.route('/training_curve') def training_curve(): if not training_state['losses']: return jsonify({'error': 'No training data available yet.'}) fig, ax = plt.subplots(figsize=(10, 6)) ax.plot(training_state['losses'], linewidth=2, color='#667eea') ax.set_xlabel('Epoch', fontsize=12, fontweight='bold') ax.set_ylabel('Loss', fontsize=12, fontweight='bold') ax.set_title('VAE Training Loss Over Time', fontsize=14, fontweight='bold') ax.grid(True, alpha=0.3) ax.fill_between(range(len(training_state['losses'])), training_state['losses'], alpha=0.3, color='#667eea') return jsonify({'image': fig_to_base64(fig)}) if __name__ == '__main__': port = int(os.environ.get('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=False, threaded=True)