GAN for MNIST Digit Generation
This repository contains a Generative Adversarial Network (GAN) trained on the MNIST dataset to generate realistic handwritten digits. The model was trained as part of the Generative AI course.
Model Details
- Model Type: GAN
- Dataset: MNIST (handwritten digits)
- Generator Input: Latent vector of size 100
- Output: 28x28 grayscale images
- Framework: PyTorch
Training Details
- Optimizer: Adam
- Learning Rate: 0.0002
- Beta1: 0.5
- Epochs: 50
- Batch Size: 64
- Weight Decay: 0.0001
- Logging: Weights & Biases
Usage
Loading the Model
To load the trained model, use the following code snippet:
from gan import Generator
import torch
latent_dim = 100
generator = Generator(latent_dim)
generator.load_state_dict(torch.load("./gan_mnist.pth"))
generator.eval()
z = torch.randn(16, latent_dim)
samples = generator(z)
Example Results

References
License
This project is licensed under the MIT License.