GAN_Implementation / README.md
karthik-2905's picture
Upload folder using huggingface_hub
c403b10 verified
# GAN Implementation - MNIST Digit Generation
A comprehensive implementation of Generative Adversarial Networks (GANs) for generating MNIST handwritten digits using PyTorch.
## πŸ”₯ Features
- **Complete GAN Implementation**: Both standard and optimized versions
- **MNIST Digit Generation**: Generate realistic handwritten digits
- **Multiple Training Modes**: Standard and lite modes for different performance needs
- **Comprehensive Logging**: Detailed training logs and progress tracking
- **GPU Support**: MPS (Apple Silicon), CUDA, and CPU support
- **Visualization**: Real-time training progress and generated samples
## πŸ“Š Results
The implementation successfully generates realistic MNIST digits with:
- **Generator Parameters**: 576K (lite) / 3.5M (standard)
- **Discriminator Parameters**: 533K (lite) / 2.7M (standard)
- **Training Time**: ~5 minutes (lite mode) / ~30 minutes (standard)
## πŸš€ Quick Start
### Installation
```bash
# Clone the repository
git clone https://github.com/GruheshKurra/GAN_Implementation.git
cd GAN_Implementation
# Install dependencies
pip install -r requirements.txt
```
### Usage
1. **Open the Jupyter Notebook**:
```bash
jupyter notebook Gan.ipynb
```
2. **Run the cells** to train the GAN and generate digits
3. **Choose your mode**:
- **Standard Mode**: Full implementation with detailed logging
- **Lite Mode**: Optimized for faster training and lower resource usage
## πŸ“ Project Structure
```
GAN_Implementation/
β”œβ”€β”€ Gan.ipynb # Main implementation notebook
β”œβ”€β”€ requirements.txt # Python dependencies
β”œβ”€β”€ README.md # This file
β”œβ”€β”€ Generative Adversarial Networks (GANs).md # Theory and documentation
β”œβ”€β”€ gan_training.log # Training logs (standard mode)
β”œβ”€β”€ gan_training_lite.log # Training logs (lite mode)
β”œβ”€β”€ generator_lite.pth # Saved model weights
└── data/ # MNIST dataset
└── MNIST/
└── raw/ # Raw MNIST data files
```
## 🧠 Implementation Details
### Architecture
**Generator Network**:
- Input: Random noise vector (100D standard / 64D lite)
- Hidden layers with ReLU/BatchNorm activation
- Output: 784D vector (28x28 MNIST image)
- Activation: Tanh (output range [-1, 1])
**Discriminator Network**:
- Input: 784D image vector
- Hidden layers with LeakyReLU/Dropout
- Output: Single probability (real vs fake)
- Activation: Sigmoid
### Training Process
1. **Data Preparation**: MNIST dataset normalized to [-1, 1]
2. **Adversarial Training**:
- Discriminator learns to distinguish real vs fake images
- Generator learns to fool the discriminator
3. **Loss Function**: Binary Cross-Entropy Loss
4. **Optimization**: Adam optimizer with β₁=0.5, Ξ²β‚‚=0.999
## πŸ“ˆ Training Modes
### Standard Mode
- **Latent Dimension**: 100
- **Epochs**: 50-100
- **Batch Size**: 64-128
- **Dataset**: Full MNIST (60K samples)
- **Best for**: High-quality results
### Lite Mode
- **Latent Dimension**: 64
- **Epochs**: 50
- **Batch Size**: 64
- **Dataset**: Subset (10K samples)
- **Best for**: Quick experimentation and testing
## πŸ”§ Technical Features
- **Device Auto-Detection**: Automatically uses MPS, CUDA, or CPU
- **Memory Optimization**: Efficient memory usage with cache clearing
- **Progress Tracking**: Real-time loss monitoring and sample generation
- **Model Persistence**: Save/load trained models
- **Comprehensive Logging**: Detailed training metrics and timing
## πŸ“Š Performance Metrics
| Mode | Training Time | Generator Loss | Discriminator Loss | Quality |
|------|---------------|----------------|-------------------|---------|
| Standard | ~30 min | ~1.5 | ~0.7 | High |
| Lite | ~5 min | ~2.0 | ~0.6 | Good |
## 🎯 Use Cases
- **Educational**: Learn GAN fundamentals with working code
- **Research**: Baseline for GAN experiments
- **Prototyping**: Quick testing of GAN modifications
- **Production**: Scalable digit generation system
## πŸ”— Links & Resources
- **GitHub Repository**: [https://github.com/GruheshKurra/GAN_Implementation](https://github.com/GruheshKurra/GAN_Implementation)
- **Hugging Face**: [https://huggingface.co/karthik-2905/GAN_Implementation](https://huggingface.co/karthik-2905/GAN_Implementation)
- **Blog Post**: [Coming Soon on daily.dev]
- **Theory Documentation**: See `Generative Adversarial Networks (GANs).md`
## πŸ› οΈ Requirements
- Python 3.7+
- PyTorch 2.0+
- torchvision 0.15+
- matplotlib 3.5+
- numpy 1.21+
- jupyter 1.0+
## πŸ“ License
This project is open source and available under the MIT License.
## 🀝 Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## πŸ“ž Contact
- **Author**: Karthik
- **GitHub**: [@GruheshKurra](https://github.com/GruheshKurra)
## πŸ™ Acknowledgments
- Original GAN paper by Ian Goodfellow et al.
- PyTorch team for the excellent deep learning framework
- MNIST dataset creators
---
**⭐ If you find this implementation helpful, please give it a star!**