vae-mnist-generative / project_problem.md
Ashutosh
Upload folder using huggingface_hub
3dbfabb verified

Project 11 β€” Variational Autoencoder (VAE)

Level: Advanced | Dataset: MNIST (torchvision) | Framework: PyTorch


Objective

Build a VAE to learn a continuous, structured latent space and generate new digits. Cover: reparameterization trick, ELBO loss (reconstruction + KL divergence), latent space interpolation, conditional generation.


Project Structure

11_vae_mnist/
β”œβ”€β”€ notebooks/
β”‚   β”œβ”€β”€ 01_vae_theory.ipynb
β”‚   β”œβ”€β”€ 02_train.ipynb
β”‚   └── 03_latent_explore.ipynb
β”œβ”€β”€ data/
β”œβ”€β”€ models/model.pkl
β”œβ”€β”€ charts/
β”œβ”€β”€ path_utils.py
β”œβ”€β”€ dashboard_core.py
└── app.py

Notebook 01 β€” Theory (01_vae_theory.ipynb)

STOP 1 β€” AE vs VAE Core Difference

Write theory cells explaining:

  • AE: x β†’ z (point) β†’ xΜ‚ β€” deterministic latent space
  • VAE: x β†’ (ΞΌ, Οƒ) β†’ z ~ N(ΞΌ, σ²) β†’ xΜ‚ β€” stochastic latent space
  • Run a simple AE on 2D toy data, show the disconnected latent space
  • Agent stops here. Explain:
    • Why AE's latent space has "holes": decoder was never trained on points between training samples
    • What "holes" cause: generating from the middle of latent space gives nonsense
    • How VAE fixes this: forces latent space to be a continuous smooth Gaussian
    • The key intuition: VAE learns WHERE to put things + HOW WIDE to make the region around them
  • Wait for user confirmation before continuing

STOP 2 β€” Probabilistic Encoder

  • In a VAE, the encoder outputs TWO vectors: mu and log_var (each shape [B, latent_dim])
  • From these, we sample: z = mu + epsilon * std where epsilon ~ N(0,1) and std = exp(0.5 * log_var)
  • Agent stops here. Explain:
    • Why we output log_var not var: log_var can be any real number, var must be positive
    • What the encoder is learning: a distribution over z, not a single point
    • The sampling process: every forward pass samples a different z (stochastic)
    • Why this stochasticity enables generation: we can sample from N(0,1) without needing an input
  • Wait for confirmation

STOP 3 β€” Reparameterization Trick

Write math cells:

  • Naive: z ~ N(ΞΌ, σ²) β€” cannot backpropagate through sampling (stochastic node)
  • Reparameterized: z = ΞΌ + Οƒ * Ξ΅, where Ξ΅ ~ N(0,1) β€” gradients flow through ΞΌ and Οƒ
  • Implement both and show that naive breaks .backward()
  • Agent stops here. Explain:
    • Why we can't backpropagate through a sampling operation (not a deterministic function)
    • The trick: move the randomness to Ξ΅ (a separate input), make z a DETERMINISTIC function of (ΞΌ, Οƒ, Ξ΅)
    • Why gradients now flow through ΞΌ and Οƒ: they're just parameters in z = ΞΌ + Οƒ * Ξ΅
    • This is one of the most important tricks in modern deep learning
  • Wait for confirmation

Notebook 02 β€” Training (02_train.ipynb)

STOP 4 β€” VAE Architecture

Encoder:
  Flatten (28*28=784) β†’ Linear(784, 400) β†’ ReLU
  β†’ Linear(400, latent_dim*2) split into β†’ mu [B, latent_dim], log_var [B, latent_dim]

Reparameterize: z = mu + exp(0.5*log_var) * epsilon

Decoder:
  Linear(latent_dim, 400) β†’ ReLU
  Linear(400, 784) β†’ Sigmoid [output in (0,1) β€” pixel values]
  • Use latent_dim=20
  • Agent stops here. Explain:
    • Why Sigmoid at decoder output: MNIST pixels in [0,1]
    • Why latent_dim=20: enough to encode digit identity + style
    • How to split encoder output into mu and log_var: mu, log_var = out.chunk(2, dim=1) or out[:, :ld] and out[:, ld:]
    • What the 20-dimensional z represents: each dimension captures some aspect of digit variation
  • Wait for confirmation

STOP 5 β€” ELBO Loss Function

def elbo_loss(x, x_hat, mu, log_var, beta=1.0):
    # Reconstruction: binary cross entropy (pixels are in [0,1])
    recon = F.binary_cross_entropy(x_hat, x, reduction='sum')
    # KL divergence: push posterior N(mu, sigma) toward prior N(0,1)
    kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return (recon + beta * kl) / x.size(0)  # normalize by batch size
  • Agent stops here. Explain:
    • What ELBO means: Evidence Lower BOund β€” maximizing ELBO β‰ˆ maximizing likelihood
    • The two terms:
      1. Reconstruction loss: AE objective β€” how well we reconstruct input
      2. KL divergence: regularizer β€” pushes z distribution toward standard Gaussian
    • Why KL toward N(0,1): so we can sample from N(0,1) at generation time
    • The KL formula: -0.5 * sum(1 + log_var - muΒ² - exp(log_var)) β€” derive this
    • What Ξ² does (Ξ²-VAE): Ξ²>1 encourages more disentangled latent space
  • Wait for confirmation

STOP 6 β€” KL Annealing

  • Start with Ξ²=0, linearly increase to Ξ²=1 over first 20 epochs
  • Plot reconstruction loss and KL loss separately per epoch
  • Agent stops here. Explain:
    • What KL annealing solves: "posterior collapse" β€” without annealing, KL often collapses to 0 early
    • Posterior collapse: model ignores z and decoder learns to reconstruct without latent code
    • Why starting with Ξ²=0 allows the model to learn reconstruction first
    • How to detect posterior collapse: KL term goes to 0, mu and log_var stop changing
  • Wait for confirmation

STOP 7 β€” Training Loop

  • 50 epochs, Adam lr=1e-3, batch_size=128
  • Track total ELBO, reconstruction term, KL term separately
  • Plot all three curves
  • Agent stops here. Explain:
    • Why tracking reconstruction and KL separately is important (not just total loss)
    • What healthy training looks like: KL increases gradually, reconstruction decreases
    • What unhealthy training looks like: KL β†’ 0 (collapse) or reconstruction doesn't decrease
  • Wait for confirmation

Notebook 03 β€” Latent Space Exploration (03_latent_explore.ipynb)

STOP 8 β€” 2D Latent Space Visualization

  • If latent_dim=2 (train a separate 2D version): plot all test digits in 2D z-space colored by digit label
  • If latent_dim=20: use t-SNE to project to 2D
  • Agent stops here. Explain:
    • What we expect: each digit forms a cluster, similar digits (4 vs 9, 3 vs 8) cluster closer
    • What "disentangled" means: different dimensions control different factors (style, rotation, thickness)
    • How the VAE latent space is structured compared to regular AE (smooth, no holes)
  • Wait for confirmation

STOP 9 β€” Latent Space Interpolation

  • Encode two different digit images to get z1, z2
  • Generate 10 intermediate z values: z = (1-t)*z1 + t*z2 for t in [0,1]
  • Decode each z and display as a row of images
  • Agent stops here. Explain:
    • Why interpolation works in VAE but not in AE: VAE latent space is continuous (no holes)
    • What a smooth interpolation shows: gradual morphing from digit A to digit B
    • What a broken AE interpolation shows: sudden jumps and nonsense images in the middle
    • This is the key visual proof that VAE learns a better structured latent space
  • Wait for confirmation

STOP 10 β€” Generation from Prior

  • Sample 64 z vectors from N(0,1): z = torch.randn(64, latent_dim)
  • Decode all 64 samples
  • Display as 8Γ—8 grid of generated digits
  • Agent stops here. Explain:
    • Why sampling from N(0,1) works: KL loss forced the posterior to be close to N(0,1)
    • What good generation looks like: recognizable digits, diverse styles
    • What bad generation looks like (poor training or posterior collapse): blurry identical images
    • The fundamental difference from AE: AE cannot generate because we don't know which z values are valid
  • Wait for confirmation

STOP 11 β€” Digit-Conditioned Generation (Simple)

  • For each of 10 digit classes: find all test samples, average their z vectors β†’ class prototype z
  • Generate 10 images from the 10 prototype z vectors
  • Agent stops here. Explain:
    • What "class prototype in latent space" means: center of the class cluster
    • How to do true conditional generation (Conditional VAE β€” CVAE): feed label as input to encoder and decoder
    • Why this simple approach works at all: VAE clusters same-class digits together in z-space
  • Wait for confirmation

STOP 12 β€” Reconstruction Quality

  • Pick 20 test images
  • Show side by side: original | reconstruction
  • Compute SSIM (Structural Similarity) between originals and reconstructions
  • Agent stops here. Explain:
    • Why reconstructions look slightly blurry: VAE averages over the distribution β†’ blurriness
    • The VAE-GAN tradeoff: VAE β†’ blurry but stable, GAN β†’ sharp but training unstable
    • What SSIM measures vs MSE: SSIM accounts for structure, not just pixel values
  • Wait for confirmation

STOP 13 β€” Save & Generation Function

  • Save model.state_dict()
  • Write generate(n=16) β†’ n images sampled from prior
  • Write encode_and_reconstruct(pil_image) β†’ z_vector, reconstructed_image
  • Write interpolate(img1, img2, steps=10) β†’ list of 10 images
  • Agent stops here. Explain:
    • Why generation requires no input (unlike all previous projects)
    • The three use modes of a trained VAE: reconstruct, encode, generate
    • Which operations require torch.no_grad() and which don't (generation always needs it)
  • Wait for confirmation

dashboard_core.py

Functions:

  • load_model() β†’ model
  • generate_digits(n=16) β†’ grid image
  • interpolate(z1, z2, steps=10) β†’ list of PIL images
  • get_latent_viz() β†’ 2D coords + labels for all test digits
  • get_training_curves() β†’ recon_loss, kl_loss, total_loss arrays

app.py β€” Streamlit (~80 lines)

Sections:

  1. "Generate" button β†’ display 8Γ—8 grid of generated digits
  2. Upload digit image β†’ show reconstruction + z vector values
  3. Tab 1: ELBO curve (split into recon + KL)
  4. Tab 2: t-SNE latent space scatter plot colored by digit
  5. Tab 3: Interpolation visualization between two selected digits

Key Concepts Covered

  • AE vs VAE: deterministic vs stochastic latent space
  • Reparameterization trick (the core DL trick)
  • ELBO loss = reconstruction + Ξ² * KL divergence
  • KL divergence math: -0.5 * sum(1 + log_var - muΒ² - exp(log_var))
  • Posterior collapse and KL annealing
  • Latent space interpolation (proof of continuity)
  • Generation from N(0,1) prior
  • Ξ²-VAE for disentanglement