uday9k commited on
Commit
4d191db
·
verified ·
1 Parent(s): e7bf2eb

Upload 3 files

Browse files

uploading model and related files

Files changed (3) hide show
  1. MNIST_VAE_Train.ipynb +0 -0
  2. README.md +145 -0
  3. customVAE_model2.pth +3 -0
MNIST_VAE_Train.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,3 +1,148 @@
1
  ---
 
 
 
 
 
 
 
2
  license: mit
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ tags:
4
+ - vae
5
+ - generative-model
6
+ - pytorch
7
+ - mnist
8
+ - unsupervised-learning
9
  license: mit
10
+ datasets:
11
+ - mnist
12
  ---
13
+
14
+ # VAE Model for MNIST
15
+
16
+ This is a Variational Autoencoder (VAE) model trained on the MNIST dataset.
17
+
18
+ ## Model Description
19
+
20
+ This repository contains a complete implementation of a Variational Autoencoder (VAE) trained on the MNIST handwritten digits dataset. The model learns to encode images into a 2-dimensional latent space and decode them back to reconstructed images, enabling both data compression and generation of new digit-like images.
21
+ The architecture is based on the implementation outlined in **Auto-Encoding Variational Bayes by Diederik et al., 2022**
22
+
23
+ ### Architecture Details
24
+
25
+ - **Model Type**: Variational Autoencoder (VAE)
26
+ - **Framework**: PyTorch
27
+ - **Input**: 28×28 grayscale images (784 dimensions)
28
+ - **Latent Space**: 20 dimensions
29
+ - **Encoder and Decoder Layers**: 2
30
+ - **Encoder and Decoder Hidden Units**: 1024 → 512 (encoder), 1024 → 512 (decoder)
31
+ - **Total Parameters**: ~4.8M
32
+ - **Data type:** Binary/Continous (automatically detected)
33
+ - **Current Implementation:** Binary (pixel>0.5)
34
+
35
+ ### Key Components
36
+
37
+ 1. **Encoder Network**: Maps input images to latent distribution parameters (μ, σ²)
38
+ 2. **Reparameterization Trick**: Enables differentiable sampling from the latent distribution
39
+ 3. **Decoder Network**: Reconstructs images from latent space samples
40
+ 4. **Loss Function**: Combines reconstruction loss ELBO (Bernoulli: binary cross-entropy, Gaussian: negative log-likelihood) + KL divergence
41
+
42
+ ## Training Details
43
+
44
+ - **Dataset**: MNIST (60,000 training images, 10,000 test images) torchvision.datasets.MNIST
45
+ - **Batch Size**: 128
46
+ - **Epochs**: 44
47
+ - **Optimizer**: Adam
48
+ - **Learning Rate**: 1e-3
49
+
50
+ ## Model Performance
51
+
52
+ ### Metrics
53
+ - **Final Training Loss**: ~79.6
54
+ - **Final Validation Loss**: ~84.3
55
+ - **Reconstruction Loss**: ~48.0
56
+ - **KL Divergence**: ~31.5
57
+
58
+ ### Capabilities
59
+ - ✅ High-quality digit reconstruction
60
+ - ✅ Smooth latent space interpolation
61
+ - ✅ Generation of new digit-like samples
62
+ - ✅ Well-organized latent space with digit clusters
63
+
64
+
65
+ ## Usage
66
+
67
+ ### Using Transformers
68
+
69
+ ```python
70
+ from transformers import AutoModel
71
+ import torch
72
+ import torchvision.transforms as transforms
73
+
74
+ # Load model
75
+ model = AutoModel.from_pretrained("uday9k/Binarized_MNIST_VAE")
76
+
77
+ # Generate samples
78
+ with torch.no_grad():
79
+ z = torch.randn(1, 20) # Sample from prior
80
+ generated = model.generate(z=z)
81
+ # Reshape to image
82
+ image = generated.view(28, 28).cpu().numpy()
83
+
84
+ ### Visualizations Available
85
+
86
+ 1. **Latent Space Visualization**: 2D scatter plot showing digit clusters
87
+ 2. **Reconstructions**: Original vs. reconstructed digit comparisons
88
+ 3. **Generated Samples**: New digits sampled from the latent space
89
+ 4. **Interpolations**: Smooth transitions between different digits
90
+ 5. **Training Curves**: Loss components over training epochs
91
+
92
+ ## Files and Outputs
93
+
94
+ - `MNIST_VAE_Train.ipynb`: Complete implementation with training and visualization
95
+ - `best_vae_model.pth`: Trained model weights
96
+ - `generated_samples`: Grid of generated digit samples as part of notebook
97
+ - `latent_space_visualization.png`: 2D latent space plot as part of notebook
98
+ - `reconstruction_comparison.png`: Original vs reconstructed images as part of notebook
99
+ - `latent_interpolation.png`: Interpolation between digit pairs as part of notebook
100
+ - `comprehensive_training_curves.png`: Training loss curves as part of notebook
101
+
102
+ ## Applications
103
+
104
+ This VAE implementation can be used for:
105
+
106
+ - **Generative Modeling**: Create new handwritten digit images
107
+ - **Dimensionality Reduction**: Compress images to 2D representations
108
+ - **Anomaly Detection**: Identify unusual digits using reconstruction error
109
+ - **Data Augmentation**: Generate synthetic training data
110
+ - **Representation Learning**: Learn meaningful features for downstream tasks
111
+ - **Educational Purposes**: Understand VAE concepts and implementation
112
+
113
+ ## Research and Educational Value
114
+
115
+ This implementation serves as an excellent educational resource for:
116
+
117
+ - Understanding Variational Autoencoders theory and practice
118
+ - Learning PyTorch implementation techniques
119
+ - Exploring generative modeling concepts
120
+ - Analyzing latent space representations
121
+ - Studying the balance between reconstruction and regularization
122
+
123
+ ## Citation
124
+
125
+ If you use this implementation in your research or projects, please cite:
126
+
127
+ ```bibtex
128
+ @misc{vae_mnist_implementation,
129
+ title={Variational Autoencoder Implementation for MNIST},
130
+ author={Uday Jain},
131
+ year={2026},
132
+ url={https://huggingface.co/uday9k/Binarized_MNIST_VAE}
133
+ }
134
+ ```
135
+
136
+ ## License
137
+
138
+ This project is licensed under the MIT License - see the LICENSE file for details.
139
+
140
+ ## Additional Resources
141
+
142
+ - **GitHub Repository**: [Profile](https://github.com/SpikeStriker/)
143
+
144
+ ---
145
+
146
+ **Tags**: deep-learning, generative-ai, pytorch, vae, mnist, computer-vision, unsupervised-learning
147
+
148
+ **Model Card Authors**: Uday Jain
customVAE_model2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c6fac73e7c30ffb37e71029f6f8d319507bb055b22d4e2f536227ead417d806
3
+ size 36823498