VAE Model for MNIST
This is a Variational Autoencoder (VAE) model trained on the MNIST dataset.
Model Description
This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 20-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images. The architecture is based on the implementation outlined in Auto-Encoding Variational Bayes by Diederik et al., 2022
Architecture Details
- Model Type: Variational Autoencoder (VAE)
- Framework: PyTorch
- Input: 28×28 grayscale images
- Latent Space: 25 dimensions
- Encoder and Decoder Layers: 2
- Encoder and Decoder Hidden Channels: Conv(8) → Conv(16) (encoder), Conv(16) → Conv(8) (decoder)
- Total Parameters: ~52k
- Data type: Binary/Continous (automatically detected)
- Current Implementation: Continous
Key Components
- Encoder Network: Maps input images to latent distribution parameters (μ, σ²)
- Reparameterization Trick: Enables differentiable sampling from the latent distribution
- Decoder Network: Reconstructs images from latent space samples
- Loss Function: Combines reconstruction loss ELBO (Bernoulli: binary cross-entropy, Gaussian: negative log-likelihood) + KL divergence
Training Details
- Dataset: MNIST (60,000 training images, 10,000 test images) torchvision.datasets.MNIST
- Batch Size: 128
- Epochs: 80
- Optimizer: Adam
- Learning Rate: 1e-3
Model Performance
Metrics
- Final Training Loss: ~32.14
- Final Validation Loss: ~32.42
- Reconstruction Loss: ~19.87
- KL Divergence: ~12.91
Capabilities
- ✅ High-quality digit reconstruction
- ✅ Smooth latent space interpolation
- ✅ Generation of new digit-like samples
- ✅ Well-organized latent space with digit clusters
Usage
Using Transformers
from transformers import AutoModel
import torch
import torchvision.transforms as transforms
# Load model
model = AutoModel.from_pretrained("uday9k/Gaussian_MNIST_VAE")
# Generate samples
with torch.no_grad():
z = torch.randn(1, 20) # Sample from prior
generated = model.generate(z=z)
# Reshape to image
image = generated.view(28, 28).cpu().numpy()
Visualizations Available
- Latent Space Visualization: 2D scatter plot showing digit clusters
- Reconstructions: Original vs. reconstructed digit comparisons
- Generated Samples: New digits sampled from the latent space
- Interpolations: Smooth transitions between different digits
- Training Curves: Loss components over training epochs
Files and Outputs
MNIST_VAE_Train Continous.ipynb: Complete implementation with training and visualizationcustomVAE_model2.pth: Trained model weightsgenerated_samples: Grid of generated digit samples as part of notebooklatent_space_visualization: 2D latent space plot as part of notebookreconstruction_comparison: Original vs reconstructed images as part of notebooklatent_interpolation: Interpolation between digit pairs as part of notebookcomprehensive_training_curves: Training loss curves as part of notebook
Applications
This VAE implementation can be used for:
- Generative Modeling: Create new handwritten digit images
- Dimensionality Reduction: Compress images to 20D representations
- Anomaly Detection: Identify unusual digits using reconstruction error
- Data Augmentation: Generate synthetic training data
- Representation Learning: Learn meaningful features for downstream tasks
- Educational Purposes: Understand VAE concepts and implementation
Research and Educational Value
This implementation serves as an excellent educational resource for:
- Understanding Variational Autoencoders theory and practice
- Learning PyTorch implementation techniques
- Exploring generative modeling concepts
- Analyzing latent space representations
- Studying the balance between reconstruction and regularization
Citation
If you use this implementation in your research or projects, please cite:
@misc{vae_mnist_implementation,
title={Variational Autoencoder Implementation for MNIST},
author={Uday Jain},
year={2026},
url={https://huggingface.co/uday9k/Gaussian_MNIST_VAE}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.
Additional Resources
- GitHub Repository: Profile
Tags: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
Model Card Authors: Uday Jain
- Downloads last month
- 96