File size: 5,191 Bytes
42c78f7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # CIFAR-10 Diffusion Model
A lightweight diffusion model trained from scratch on the CIFAR-10 dataset in just 14.5 minutes using PyTorch.
## Model Description
This is a **SimpleUNet-based diffusion model** trained to generate 32x32 RGB images similar to the CIFAR-10 dataset. The model demonstrates the fundamentals of diffusion-based image generation with a compact architecture suitable for educational purposes and quick experimentation.
### Key Features
- π **Fast Training**: Complete training in under 15 minutes on RTX 3060
- πΎ **Lightweight**: Only 16.8M parameters (~64MB model size)
- π― **Educational**: Clean, well-documented code for learning diffusion models
- β‘ **Efficient Inference**: Generate images in seconds on consumer GPUs
## Model Details
| Attribute | Value |
|-----------|-------|
| **Architecture** | SimpleUNet with ResNet blocks + Attention |
| **Parameters** | 16,808,835 |
| **Dataset** | CIFAR-10 (50,000 training images) |
| **Image Size** | 32Γ32 RGB |
| **Training Steps** | 7,820 (20 epochs Γ 391 batches) |
| **Training Time** | 14.54 minutes |
| **Hardware** | NVIDIA RTX 3060 (0.43GB VRAM used) |
| **Framework** | PyTorch 2.0+ |
## Quick Start
### Installation
```bash
pip install torch torchvision matplotlib tqdm pillow numpy
```
### Basic Usage
```python
import torch
import matplotlib.pyplot as plt
# Load model
checkpoint = torch.load('complete_diffusion_model.pth')
model = SimpleUNet(**checkpoint['model_config'])
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# Initialize scheduler
scheduler = DDPMScheduler(**checkpoint['diffusion_config'])
# Generate images
@torch.no_grad()
def generate_images(model, scheduler, num_images=4):
device = next(model.parameters()).device
images = torch.randn(num_images, 3, 32, 32).to(device)
for t in range(999, -1, -20): # 50 denoising steps
timestep = torch.full((num_images,), t, device=device)
noise_pred = model(images, timestep)
# Simplified DDPM step
alpha_t = scheduler.alpha_cumprod[t]
alpha_prev = scheduler.alpha_cumprod[t-20] if t >= 20 else 1.0
pred_x0 = (images - torch.sqrt(1-alpha_t) * noise_pred) / torch.sqrt(alpha_t)
images = torch.sqrt(alpha_prev) * pred_x0 + torch.sqrt(1-alpha_prev) * noise_pred
return images
# Generate and display
generated = generate_images(model, scheduler)
```
## Training Details
- **Loss Function**: MSE between predicted and actual noise
- **Optimizer**: AdamW (lr=1e-4, weight_decay=1e-6)
- **Scheduler**: CosineAnnealingLR
- **Batch Size**: 128
- **Final Loss**: 0.0363 (73% reduction from initial)
- **Diffusion Steps**: 1000 (linear beta schedule)
## Performance
### Training Loss Curve
The model shows excellent convergence:
- **Epoch 1**: 0.1349 β **Epoch 20**: 0.0363
- **Best Loss**: 0.0358 (Epoch 19)
- **Stable convergence** without overfitting
### Generation Quality
- β
Captures CIFAR-10 color distributions
- β
Generates diverse, non-repetitive outputs
- β οΈ Abstract patterns (needs longer training for object recognition)
- π― Suitable for color/texture generation tasks
## Files in this Repository
| File | Description | Size |
|------|-------------|------|
| `complete_diffusion_model.pth` | Full model with config and weights | ~64MB |
| `diffusion_model_final.pth` | Training checkpoint (epoch 20) | ~64MB |
| `model_info.json` | Training metadata and hyperparameters | <1KB |
| `inference_example.py` | Complete inference script with model classes | ~5KB |
## Model Architecture
```
SimpleUNet(
time_embedding: TimeEmbedding(128)
encoder: 3 ResNet blocks with downsampling
middle: ResNet + Self-Attention + ResNet
decoder: 3 ResNet blocks with upsampling
output: GroupNorm β SiLU β Conv2d
)
```
## Use Cases
- π **Educational**: Learn diffusion model fundamentals
- π¬ **Research**: Baseline for diffusion experiments
- π¨ **Art**: Generate abstract textures and patterns
- β‘ **Prototyping**: Quick diffusion model testing
## Limitations & Improvements
### Current Limitations
- Generates abstract patterns rather than recognizable objects
- Trained on small 32Γ32 resolution
- Limited to 20 training epochs
### Suggested Improvements
1. **Extended Training**: 50-100 epochs for better object generation
2. **Larger Architecture**: Increase model capacity
3. **Advanced Sampling**: Implement DDIM or DPM-Solver++
4. **Higher Resolution**: Train on 64Γ64 or 128Γ128 images
5. **Better Datasets**: Use CelebA-HQ or custom datasets
## Citation
```bibtex
@misc{cifar10-diffusion-2025,
title={CIFAR-10 Diffusion Model: Fast Training Implementation},
author={Karthik},
year={2025},
publisher={Hugging Face},
howpublished={\url{https://huggingface.co/karthik-2905/DiffusionPretrained}}
}
```
## License
MIT License - Free for research and commercial use.
---
**π Want to train your own?** Check out the [full implementation](https://github.com/GruheshKurra/DiffusionModelPretrained) with Jupyter notebooks and step-by-step training code!
**π Training Stats**: 16.8M params β’ 14.5min training β’ RTX 3060 β’ PyTorch 2.0 |