File size: 5,147 Bytes
c403b10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# GAN Implementation - MNIST Digit Generation
A comprehensive implementation of Generative Adversarial Networks (GANs) for generating MNIST handwritten digits using PyTorch.
## π₯ Features
- **Complete GAN Implementation**: Both standard and optimized versions
- **MNIST Digit Generation**: Generate realistic handwritten digits
- **Multiple Training Modes**: Standard and lite modes for different performance needs
- **Comprehensive Logging**: Detailed training logs and progress tracking
- **GPU Support**: MPS (Apple Silicon), CUDA, and CPU support
- **Visualization**: Real-time training progress and generated samples
## π Results
The implementation successfully generates realistic MNIST digits with:
- **Generator Parameters**: 576K (lite) / 3.5M (standard)
- **Discriminator Parameters**: 533K (lite) / 2.7M (standard)
- **Training Time**: ~5 minutes (lite mode) / ~30 minutes (standard)
## π Quick Start
### Installation
```bash
# Clone the repository
git clone https://github.com/GruheshKurra/GAN_Implementation.git
cd GAN_Implementation
# Install dependencies
pip install -r requirements.txt
```
### Usage
1. **Open the Jupyter Notebook**:
```bash
jupyter notebook Gan.ipynb
```
2. **Run the cells** to train the GAN and generate digits
3. **Choose your mode**:
- **Standard Mode**: Full implementation with detailed logging
- **Lite Mode**: Optimized for faster training and lower resource usage
## π Project Structure
```
GAN_Implementation/
βββ Gan.ipynb # Main implementation notebook
βββ requirements.txt # Python dependencies
βββ README.md # This file
βββ Generative Adversarial Networks (GANs).md # Theory and documentation
βββ gan_training.log # Training logs (standard mode)
βββ gan_training_lite.log # Training logs (lite mode)
βββ generator_lite.pth # Saved model weights
βββ data/ # MNIST dataset
βββ MNIST/
βββ raw/ # Raw MNIST data files
```
## π§ Implementation Details
### Architecture
**Generator Network**:
- Input: Random noise vector (100D standard / 64D lite)
- Hidden layers with ReLU/BatchNorm activation
- Output: 784D vector (28x28 MNIST image)
- Activation: Tanh (output range [-1, 1])
**Discriminator Network**:
- Input: 784D image vector
- Hidden layers with LeakyReLU/Dropout
- Output: Single probability (real vs fake)
- Activation: Sigmoid
### Training Process
1. **Data Preparation**: MNIST dataset normalized to [-1, 1]
2. **Adversarial Training**:
- Discriminator learns to distinguish real vs fake images
- Generator learns to fool the discriminator
3. **Loss Function**: Binary Cross-Entropy Loss
4. **Optimization**: Adam optimizer with Ξ²β=0.5, Ξ²β=0.999
## π Training Modes
### Standard Mode
- **Latent Dimension**: 100
- **Epochs**: 50-100
- **Batch Size**: 64-128
- **Dataset**: Full MNIST (60K samples)
- **Best for**: High-quality results
### Lite Mode
- **Latent Dimension**: 64
- **Epochs**: 50
- **Batch Size**: 64
- **Dataset**: Subset (10K samples)
- **Best for**: Quick experimentation and testing
## π§ Technical Features
- **Device Auto-Detection**: Automatically uses MPS, CUDA, or CPU
- **Memory Optimization**: Efficient memory usage with cache clearing
- **Progress Tracking**: Real-time loss monitoring and sample generation
- **Model Persistence**: Save/load trained models
- **Comprehensive Logging**: Detailed training metrics and timing
## π Performance Metrics
| Mode | Training Time | Generator Loss | Discriminator Loss | Quality |
|------|---------------|----------------|-------------------|---------|
| Standard | ~30 min | ~1.5 | ~0.7 | High |
| Lite | ~5 min | ~2.0 | ~0.6 | Good |
## π― Use Cases
- **Educational**: Learn GAN fundamentals with working code
- **Research**: Baseline for GAN experiments
- **Prototyping**: Quick testing of GAN modifications
- **Production**: Scalable digit generation system
## π Links & Resources
- **GitHub Repository**: [https://github.com/GruheshKurra/GAN_Implementation](https://github.com/GruheshKurra/GAN_Implementation)
- **Hugging Face**: [https://huggingface.co/karthik-2905/GAN_Implementation](https://huggingface.co/karthik-2905/GAN_Implementation)
- **Blog Post**: [Coming Soon on daily.dev]
- **Theory Documentation**: See `Generative Adversarial Networks (GANs).md`
## π οΈ Requirements
- Python 3.7+
- PyTorch 2.0+
- torchvision 0.15+
- matplotlib 3.5+
- numpy 1.21+
- jupyter 1.0+
## π License
This project is open source and available under the MIT License.
## π€ Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
## π Contact
- **Author**: Karthik
- **GitHub**: [@GruheshKurra](https://github.com/GruheshKurra)
## π Acknowledgments
- Original GAN paper by Ian Goodfellow et al.
- PyTorch team for the excellent deep learning framework
- MNIST dataset creators
---
**β If you find this implementation helpful, please give it a star!** |