Project 11 β Variational Autoencoder (VAE)
Level: Advanced | Dataset: MNIST (torchvision) | Framework: PyTorch
Objective
Build a VAE to learn a continuous, structured latent space and generate new digits. Cover: reparameterization trick, ELBO loss (reconstruction + KL divergence), latent space interpolation, conditional generation.
Project Structure
11_vae_mnist/
βββ notebooks/
β βββ 01_vae_theory.ipynb
β βββ 02_train.ipynb
β βββ 03_latent_explore.ipynb
βββ data/
βββ models/model.pkl
βββ charts/
βββ path_utils.py
βββ dashboard_core.py
βββ app.py
Notebook 01 β Theory (01_vae_theory.ipynb)
STOP 1 β AE vs VAE Core Difference
Write theory cells explaining:
- AE: x β z (point) β xΜ β deterministic latent space
- VAE: x β (ΞΌ, Ο) β z ~ N(ΞΌ, ΟΒ²) β xΜ β stochastic latent space
- Run a simple AE on 2D toy data, show the disconnected latent space
- Agent stops here. Explain:
- Why AE's latent space has "holes": decoder was never trained on points between training samples
- What "holes" cause: generating from the middle of latent space gives nonsense
- How VAE fixes this: forces latent space to be a continuous smooth Gaussian
- The key intuition: VAE learns WHERE to put things + HOW WIDE to make the region around them
- Wait for user confirmation before continuing
STOP 2 β Probabilistic Encoder
- In a VAE, the encoder outputs TWO vectors:
muandlog_var(each shape [B, latent_dim]) - From these, we sample:
z = mu + epsilon * stdwhereepsilon ~ N(0,1)andstd = exp(0.5 * log_var) - Agent stops here. Explain:
- Why we output log_var not var: log_var can be any real number, var must be positive
- What the encoder is learning: a distribution over z, not a single point
- The sampling process: every forward pass samples a different z (stochastic)
- Why this stochasticity enables generation: we can sample from N(0,1) without needing an input
- Wait for confirmation
STOP 3 β Reparameterization Trick
Write math cells:
- Naive: z ~ N(ΞΌ, ΟΒ²) β cannot backpropagate through sampling (stochastic node)
- Reparameterized: z = ΞΌ + Ο * Ξ΅, where Ξ΅ ~ N(0,1) β gradients flow through ΞΌ and Ο
- Implement both and show that naive breaks
.backward() - Agent stops here. Explain:
- Why we can't backpropagate through a sampling operation (not a deterministic function)
- The trick: move the randomness to Ξ΅ (a separate input), make z a DETERMINISTIC function of (ΞΌ, Ο, Ξ΅)
- Why gradients now flow through ΞΌ and Ο: they're just parameters in z = ΞΌ + Ο * Ξ΅
- This is one of the most important tricks in modern deep learning
- Wait for confirmation
Notebook 02 β Training (02_train.ipynb)
STOP 4 β VAE Architecture
Encoder:
Flatten (28*28=784) β Linear(784, 400) β ReLU
β Linear(400, latent_dim*2) split into β mu [B, latent_dim], log_var [B, latent_dim]
Reparameterize: z = mu + exp(0.5*log_var) * epsilon
Decoder:
Linear(latent_dim, 400) β ReLU
Linear(400, 784) β Sigmoid [output in (0,1) β pixel values]
- Use
latent_dim=20 - Agent stops here. Explain:
- Why Sigmoid at decoder output: MNIST pixels in [0,1]
- Why latent_dim=20: enough to encode digit identity + style
- How to split encoder output into mu and log_var:
mu, log_var = out.chunk(2, dim=1)orout[:, :ld]andout[:, ld:] - What the 20-dimensional z represents: each dimension captures some aspect of digit variation
- Wait for confirmation
STOP 5 β ELBO Loss Function
def elbo_loss(x, x_hat, mu, log_var, beta=1.0):
# Reconstruction: binary cross entropy (pixels are in [0,1])
recon = F.binary_cross_entropy(x_hat, x, reduction='sum')
# KL divergence: push posterior N(mu, sigma) toward prior N(0,1)
kl = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return (recon + beta * kl) / x.size(0) # normalize by batch size
- Agent stops here. Explain:
- What ELBO means: Evidence Lower BOund β maximizing ELBO β maximizing likelihood
- The two terms:
- Reconstruction loss: AE objective β how well we reconstruct input
- KL divergence: regularizer β pushes z distribution toward standard Gaussian
- Why KL toward N(0,1): so we can sample from N(0,1) at generation time
- The KL formula: -0.5 * sum(1 + log_var - muΒ² - exp(log_var)) β derive this
- What Ξ² does (Ξ²-VAE): Ξ²>1 encourages more disentangled latent space
- Wait for confirmation
STOP 6 β KL Annealing
- Start with Ξ²=0, linearly increase to Ξ²=1 over first 20 epochs
- Plot reconstruction loss and KL loss separately per epoch
- Agent stops here. Explain:
- What KL annealing solves: "posterior collapse" β without annealing, KL often collapses to 0 early
- Posterior collapse: model ignores z and decoder learns to reconstruct without latent code
- Why starting with Ξ²=0 allows the model to learn reconstruction first
- How to detect posterior collapse: KL term goes to 0, mu and log_var stop changing
- Wait for confirmation
STOP 7 β Training Loop
- 50 epochs, Adam lr=1e-3, batch_size=128
- Track total ELBO, reconstruction term, KL term separately
- Plot all three curves
- Agent stops here. Explain:
- Why tracking reconstruction and KL separately is important (not just total loss)
- What healthy training looks like: KL increases gradually, reconstruction decreases
- What unhealthy training looks like: KL β 0 (collapse) or reconstruction doesn't decrease
- Wait for confirmation
Notebook 03 β Latent Space Exploration (03_latent_explore.ipynb)
STOP 8 β 2D Latent Space Visualization
- If latent_dim=2 (train a separate 2D version): plot all test digits in 2D z-space colored by digit label
- If latent_dim=20: use t-SNE to project to 2D
- Agent stops here. Explain:
- What we expect: each digit forms a cluster, similar digits (4 vs 9, 3 vs 8) cluster closer
- What "disentangled" means: different dimensions control different factors (style, rotation, thickness)
- How the VAE latent space is structured compared to regular AE (smooth, no holes)
- Wait for confirmation
STOP 9 β Latent Space Interpolation
- Encode two different digit images to get z1, z2
- Generate 10 intermediate z values:
z = (1-t)*z1 + t*z2for t in [0,1] - Decode each z and display as a row of images
- Agent stops here. Explain:
- Why interpolation works in VAE but not in AE: VAE latent space is continuous (no holes)
- What a smooth interpolation shows: gradual morphing from digit A to digit B
- What a broken AE interpolation shows: sudden jumps and nonsense images in the middle
- This is the key visual proof that VAE learns a better structured latent space
- Wait for confirmation
STOP 10 β Generation from Prior
- Sample 64 z vectors from N(0,1):
z = torch.randn(64, latent_dim) - Decode all 64 samples
- Display as 8Γ8 grid of generated digits
- Agent stops here. Explain:
- Why sampling from N(0,1) works: KL loss forced the posterior to be close to N(0,1)
- What good generation looks like: recognizable digits, diverse styles
- What bad generation looks like (poor training or posterior collapse): blurry identical images
- The fundamental difference from AE: AE cannot generate because we don't know which z values are valid
- Wait for confirmation
STOP 11 β Digit-Conditioned Generation (Simple)
- For each of 10 digit classes: find all test samples, average their z vectors β class prototype z
- Generate 10 images from the 10 prototype z vectors
- Agent stops here. Explain:
- What "class prototype in latent space" means: center of the class cluster
- How to do true conditional generation (Conditional VAE β CVAE): feed label as input to encoder and decoder
- Why this simple approach works at all: VAE clusters same-class digits together in z-space
- Wait for confirmation
STOP 12 β Reconstruction Quality
- Pick 20 test images
- Show side by side: original | reconstruction
- Compute SSIM (Structural Similarity) between originals and reconstructions
- Agent stops here. Explain:
- Why reconstructions look slightly blurry: VAE averages over the distribution β blurriness
- The VAE-GAN tradeoff: VAE β blurry but stable, GAN β sharp but training unstable
- What SSIM measures vs MSE: SSIM accounts for structure, not just pixel values
- Wait for confirmation
STOP 13 β Save & Generation Function
- Save model.state_dict()
- Write
generate(n=16)β n images sampled from prior - Write
encode_and_reconstruct(pil_image)β z_vector, reconstructed_image - Write
interpolate(img1, img2, steps=10)β list of 10 images - Agent stops here. Explain:
- Why generation requires no input (unlike all previous projects)
- The three use modes of a trained VAE: reconstruct, encode, generate
- Which operations require
torch.no_grad()and which don't (generation always needs it)
- Wait for confirmation
dashboard_core.py
Functions:
load_model()β modelgenerate_digits(n=16)β grid imageinterpolate(z1, z2, steps=10)β list of PIL imagesget_latent_viz()β 2D coords + labels for all test digitsget_training_curves()β recon_loss, kl_loss, total_loss arrays
app.py β Streamlit (~80 lines)
Sections:
- "Generate" button β display 8Γ8 grid of generated digits
- Upload digit image β show reconstruction + z vector values
- Tab 1: ELBO curve (split into recon + KL)
- Tab 2: t-SNE latent space scatter plot colored by digit
- Tab 3: Interpolation visualization between two selected digits
Key Concepts Covered
- AE vs VAE: deterministic vs stochastic latent space
- Reparameterization trick (the core DL trick)
- ELBO loss = reconstruction + Ξ² * KL divergence
- KL divergence math: -0.5 * sum(1 + log_var - muΒ² - exp(log_var))
- Posterior collapse and KL annealing
- Latent space interpolation (proof of continuity)
- Generation from N(0,1) prior
- Ξ²-VAE for disentanglement