waifu_diffusion / README.md
ruwwww's picture
Create README.md
42f4abe verified
---
license: mit
pipeline_tag: unconditional-image-generation
tags:
- diffusion
- rectified-flow
- patch-diffusion
- anime
---
# Waifu Diffusion
A 130M-parameter diffusion model trained on 10,000 anime faces (90% monochrome) using **rectified flow**, **patch diffusion**, and **CIELAB color space decoupling**.
## Model Details
- **Architecture**: Diffusion Transformer (DiT-B) with Vision RoPE
- **Parameters**: 130M
- **Training Data**: 10k anime faces (80×80), 90% corrupted to grayscale
- **Training Steps**: 1280 epochs × batch 256
- **Sampling**: 50-step Euler integration
### Versions
| Model | Details |
|-------|---------|
| `waifu_diffusion_1280_bs256.safetensors` | Full training (1280 epochs, bs=256) |
| `waifu_diffusion_128_bs32.safetensors` | Shallow trained version (128 epochs, bs=32) |
## Quick Start
```python
import torch
from safetensors.torch import load_file
from skimage import color
import numpy as np
# Load model
model = JiT(
input_size=80,
patch_size=4,
in_channels=3,
hidden_size=768,
depth=12,
num_heads=12,
num_classes=1
)
state_dict = load_file("waifu_diffusion_1280_bs256.safetensors")
model.load_state_dict(state_dict)
model.eval()
# Generate
device = "cuda"
model.to(device)
with torch.no_grad():
xt = torch.randn((1, 3, 80, 80), device=device)
y = torch.zeros(1, dtype=torch.long, device=device)
for step in range(50):
t = torch.tensor(step / 50, device=device)
pred_x1 = model(xt, t, y, top_idx=0, left_idx=0)
v = (pred_x1 - xt) / max(1.0 - step / 50, 1e-2)
xt = xt + v / 50
# Convert CIELAB → RGB
lab = torch.clamp(pred_x1[0], -1, 1).cpu().numpy()
L = (lab[0] + 1) * 50
a = lab[1] * 128
b = lab[2] * 128
rgb = color.lab2rgb(np.stack([L, a, b], axis=-1))
```
## Key Techniques
- **Rectified Flow**: Straight-line paths from noise to data (50 steps vs. 1000s for DDPM)
- **CIELAB Decoupling**: Separate luminance from color; mask gradients on monochrome → learn structure from all 10k, color from 1k
- **Patch Diffusion**: Random 40×80 px crops act as data augmentation; effectively 10k → ~50k samples
- **Vision RoPE**: 2D rotary embeddings for spatial consistency across patches
## Links
- **GitHub**: https://github.com/ruwwww/waifu_diffusion
- **Blog Post**: [Training a Waifu Diffusion Model](https://ruwwww.github.io/al-folio/blog/2026/waifu-diffusion/)
## Citation
```bibtex
@misc{waifu_diffusion_2026,
author = {Abdurrahman Izzuddin Al Faruq},
title = {Training a Waifu Diffusion Model with Patch Diffusion and Rectified Flow},
year = {2026},
url = {https://github.com/ruwwww/waifu_diffusion}
}
```
## License
MIT