| # CIFAR-10 Diffusion Model | |
| A lightweight diffusion model trained from scratch on the CIFAR-10 dataset in just 14.5 minutes using PyTorch. | |
| ## Model Description | |
| This is a **SimpleUNet-based diffusion model** trained to generate 32x32 RGB images similar to the CIFAR-10 dataset. The model demonstrates the fundamentals of diffusion-based image generation with a compact architecture suitable for educational purposes and quick experimentation. | |
| ### Key Features | |
| - π **Fast Training**: Complete training in under 15 minutes on RTX 3060 | |
| - πΎ **Lightweight**: Only 16.8M parameters (~64MB model size) | |
| - π― **Educational**: Clean, well-documented code for learning diffusion models | |
| - β‘ **Efficient Inference**: Generate images in seconds on consumer GPUs | |
| ## Model Details | |
| | Attribute | Value | | |
| |-----------|-------| | |
| | **Architecture** | SimpleUNet with ResNet blocks + Attention | | |
| | **Parameters** | 16,808,835 | | |
| | **Dataset** | CIFAR-10 (50,000 training images) | | |
| | **Image Size** | 32Γ32 RGB | | |
| | **Training Steps** | 7,820 (20 epochs Γ 391 batches) | | |
| | **Training Time** | 14.54 minutes | | |
| | **Hardware** | NVIDIA RTX 3060 (0.43GB VRAM used) | | |
| | **Framework** | PyTorch 2.0+ | | |
| ## Quick Start | |
| ### Installation | |
| ```bash | |
| pip install torch torchvision matplotlib tqdm pillow numpy | |
| ``` | |
| ### Basic Usage | |
| ```python | |
| import torch | |
| import matplotlib.pyplot as plt | |
| # Load model | |
| checkpoint = torch.load('complete_diffusion_model.pth') | |
| model = SimpleUNet(**checkpoint['model_config']) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| # Initialize scheduler | |
| scheduler = DDPMScheduler(**checkpoint['diffusion_config']) | |
| # Generate images | |
| @torch.no_grad() | |
| def generate_images(model, scheduler, num_images=4): | |
| device = next(model.parameters()).device | |
| images = torch.randn(num_images, 3, 32, 32).to(device) | |
| for t in range(999, -1, -20): # 50 denoising steps | |
| timestep = torch.full((num_images,), t, device=device) | |
| noise_pred = model(images, timestep) | |
| # Simplified DDPM step | |
| alpha_t = scheduler.alpha_cumprod[t] | |
| alpha_prev = scheduler.alpha_cumprod[t-20] if t >= 20 else 1.0 | |
| pred_x0 = (images - torch.sqrt(1-alpha_t) * noise_pred) / torch.sqrt(alpha_t) | |
| images = torch.sqrt(alpha_prev) * pred_x0 + torch.sqrt(1-alpha_prev) * noise_pred | |
| return images | |
| # Generate and display | |
| generated = generate_images(model, scheduler) | |
| ``` | |
| ## Training Details | |
| - **Loss Function**: MSE between predicted and actual noise | |
| - **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-6) | |
| - **Scheduler**: CosineAnnealingLR | |
| - **Batch Size**: 128 | |
| - **Final Loss**: 0.0363 (73% reduction from initial) | |
| - **Diffusion Steps**: 1000 (linear beta schedule) | |
| ## Performance | |
| ### Training Loss Curve | |
| The model shows excellent convergence: | |
| - **Epoch 1**: 0.1349 β **Epoch 20**: 0.0363 | |
| - **Best Loss**: 0.0358 (Epoch 19) | |
| - **Stable convergence** without overfitting | |
| ### Generation Quality | |
| - β Captures CIFAR-10 color distributions | |
| - β Generates diverse, non-repetitive outputs | |
| - β οΈ Abstract patterns (needs longer training for object recognition) | |
| - π― Suitable for color/texture generation tasks | |
| ## Files in this Repository | |
| | File | Description | Size | | |
| |------|-------------|------| | |
| | `complete_diffusion_model.pth` | Full model with config and weights | ~64MB | | |
| | `diffusion_model_final.pth` | Training checkpoint (epoch 20) | ~64MB | | |
| | `model_info.json` | Training metadata and hyperparameters | <1KB | | |
| | `inference_example.py` | Complete inference script with model classes | ~5KB | | |
| ## Model Architecture | |
| ``` | |
| SimpleUNet( | |
| time_embedding: TimeEmbedding(128) | |
| encoder: 3 ResNet blocks with downsampling | |
| middle: ResNet + Self-Attention + ResNet | |
| decoder: 3 ResNet blocks with upsampling | |
| output: GroupNorm β SiLU β Conv2d | |
| ) | |
| ``` | |
| ## Use Cases | |
| - π **Educational**: Learn diffusion model fundamentals | |
| - π¬ **Research**: Baseline for diffusion experiments | |
| - π¨ **Art**: Generate abstract textures and patterns | |
| - β‘ **Prototyping**: Quick diffusion model testing | |
| ## Limitations & Improvements | |
| ### Current Limitations | |
| - Generates abstract patterns rather than recognizable objects | |
| - Trained on small 32Γ32 resolution | |
| - Limited to 20 training epochs | |
| ### Suggested Improvements | |
| 1. **Extended Training**: 50-100 epochs for better object generation | |
| 2. **Larger Architecture**: Increase model capacity | |
| 3. **Advanced Sampling**: Implement DDIM or DPM-Solver++ | |
| 4. **Higher Resolution**: Train on 64Γ64 or 128Γ128 images | |
| 5. **Better Datasets**: Use CelebA-HQ or custom datasets | |
| ## Citation | |
| ```bibtex | |
| @misc{cifar10-diffusion-2025, | |
| title={CIFAR-10 Diffusion Model: Fast Training Implementation}, | |
| author={Karthik}, | |
| year={2025}, | |
| publisher={Hugging Face}, | |
| howpublished={\url{https://huggingface.co/karthik-2905/DiffusionPretrained}} | |
| } | |
| ``` | |
| ## License | |
| MIT License - Free for research and commercial use. | |
| --- | |
| **π Want to train your own?** Check out the [full implementation](https://github.com/GruheshKurra/DiffusionModelPretrained) with Jupyter notebooks and step-by-step training code! | |
| **π Training Stats**: 16.8M params β’ 14.5min training β’ RTX 3060 β’ PyTorch 2.0 |