|
|
--- |
|
|
license: cc-by-nc-sa-4.0 |
|
|
tags: |
|
|
- flow-matching |
|
|
- generative-model |
|
|
- image-generation |
|
|
- pytorch |
|
|
datasets: |
|
|
- mnist |
|
|
- cifar10 |
|
|
- celeba |
|
|
base_model: keishihara/flow-matching |
|
|
--- |
|
|
|
|
|
# UNet Flow Matching Models |
|
|
|
|
|
Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets. |
|
|
|
|
|
**Training code based on**: [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git) |
|
|
|
|
|
## Models |
|
|
|
|
|
This repository contains three UNet-based velocity field models trained with Flow Matching: |
|
|
|
|
|
### MNIST (28×28 Grayscale) |
|
|
- **Checkpoint**: `mnist/ckpt.pth` (24 MB) |
|
|
- **Parameters**: 6.2M |
|
|
- **Architecture**: UNet with num_channels=64, num_res_blocks=2 |
|
|
- **Conditional**: Yes (10 classes, 0-9 digits) |
|
|
- **Training**: 50 epochs, batch_size=128, lr=1e-3 |
|
|
- **Hardware**: NVIDIA H100 GPU |
|
|
|
|
|
### CIFAR-10 (32×32 RGB) |
|
|
- **Checkpoint**: `cifar10/ckpt.pth` (35 MB) |
|
|
- **Parameters**: 9.0M |
|
|
- **Architecture**: UNet with num_channels=64, num_res_blocks=2 |
|
|
- **Conditional**: Yes (10 classes) |
|
|
- **Training**: 50 epochs, batch_size=128, lr=1e-3 |
|
|
- **Hardware**: NVIDIA H100 GPU |
|
|
|
|
|
### CelebA (64×64 RGB) |
|
|
- **Checkpoint**: `celeba64/ckpt.pth` (332 MB) |
|
|
- **Parameters**: 83.0M |
|
|
- **Architecture**: UNet with num_channels=128, num_res_blocks=2 |
|
|
- **Conditional**: No (unconditional face generation) |
|
|
- **Training**: 50 epochs, batch_size=512, lr=1e-4 |
|
|
- **Dataset**: 202,599 CelebA training images |
|
|
- **Final loss**: 0.114 |
|
|
- **Hardware**: NVIDIA H100 GPU |
|
|
|
|
|
## Sample Results |
|
|
|
|
|
### MNIST |
|
|
 |
|
|
*Generated MNIST digits at different velocity reuse thresholds* |
|
|
|
|
|
### CIFAR-10 |
|
|
 |
|
|
*Generated CIFAR-10 images at different velocity reuse thresholds* |
|
|
|
|
|
### CelebA 64×64 |
|
|
 |
|
|
*Generated 64×64 faces at different velocity reuse thresholds* |
|
|
|
|
|
## Training Code |
|
|
|
|
|
The models were trained using the Flow Matching implementation based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git). |
|
|
|
|
|
### Training Scripts |
|
|
|
|
|
**MNIST**: |
|
|
```python |
|
|
# train_flow_matching_on_images.py |
|
|
python train_flow_matching_on_images.py \ |
|
|
--do_train \ |
|
|
--dataset mnist \ |
|
|
--n_epochs 50 \ |
|
|
--batch_size 128 \ |
|
|
--learning_rate 1e-3 |
|
|
``` |
|
|
|
|
|
**CIFAR-10**: |
|
|
```python |
|
|
python train_flow_matching_on_images.py \ |
|
|
--do_train \ |
|
|
--dataset cifar10 \ |
|
|
--n_epochs 50 \ |
|
|
--batch_size 128 \ |
|
|
--learning_rate 1e-3 \ |
|
|
--horizontal_flip |
|
|
``` |
|
|
|
|
|
**CelebA**: |
|
|
```python |
|
|
# train_celeba64.py |
|
|
python train_celeba64.py \ |
|
|
--do_train \ |
|
|
--n_epochs 50 \ |
|
|
--batch_size 512 \ |
|
|
--learning_rate 1e-4 \ |
|
|
--horizontal_flip |
|
|
``` |
|
|
|
|
|
Training code files included: |
|
|
- `train_flow_matching_on_images.py` - For MNIST and CIFAR-10 |
|
|
- `train_celeba64.py` - For CelebA 64×64 |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Load Model |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
# Download checkpoint |
|
|
ckpt_path = hf_hub_download( |
|
|
repo_id="WayBob/FlowMatching-Unet-Celeb-64x64", |
|
|
filename="celeba64/ckpt.pth" |
|
|
) |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint = torch.load(ckpt_path, map_location="cuda") |
|
|
``` |
|
|
|
|
|
### Inference (Sampling) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from flow_matching.models import UNetModel |
|
|
from flow_matching.solver import ODESolver, ModelWrapper |
|
|
|
|
|
device = "cuda" |
|
|
|
|
|
# Create model (CelebA example) |
|
|
flow = UNetModel( |
|
|
dim=(3, 64, 64), |
|
|
num_channels=128, |
|
|
num_res_blocks=2, |
|
|
num_classes=0, |
|
|
class_cond=False, |
|
|
).to(device) |
|
|
|
|
|
# Load weights |
|
|
flow.load_state_dict(checkpoint) |
|
|
flow.eval() |
|
|
|
|
|
# Create solver |
|
|
model_wrapper = ModelWrapper(flow) |
|
|
solver = ODESolver(model_wrapper) |
|
|
|
|
|
# Sample from Gaussian noise |
|
|
batch_size = 4 |
|
|
x_init = torch.randn(batch_size, 3, 64, 64).to(device) |
|
|
time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps |
|
|
|
|
|
with torch.no_grad(): |
|
|
samples = solver.sample( |
|
|
x_init=x_init, |
|
|
step_size=0.05, |
|
|
method="euler", |
|
|
time_grid=time_grid |
|
|
) |
|
|
|
|
|
# Denormalize from [-1, 1] to [0, 1] |
|
|
samples = (samples + 1) / 2 |
|
|
samples = samples.clamp(0, 1) |
|
|
|
|
|
# Save or visualize |
|
|
from torchvision.utils import save_image |
|
|
save_image(samples, "generated_faces.png", nrow=2) |
|
|
``` |
|
|
|
|
|
### Conditional Generation (MNIST/CIFAR-10) |
|
|
|
|
|
```python |
|
|
# For class-conditional models |
|
|
flow = UNetModel( |
|
|
dim=(3, 32, 32), # CIFAR-10 |
|
|
num_channels=64, |
|
|
num_res_blocks=2, |
|
|
num_classes=10, |
|
|
class_cond=True, |
|
|
).to(device) |
|
|
|
|
|
# Load CIFAR-10 checkpoint |
|
|
ckpt = torch.load("cifar10/ckpt.pth") |
|
|
flow.load_state_dict(ckpt) |
|
|
|
|
|
# Generate specific class (e.g., class 3) |
|
|
y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3 |
|
|
|
|
|
def ode_func(t, x): |
|
|
return flow(x=x, t=t, y=y) |
|
|
|
|
|
# Then use solver as before |
|
|
``` |
|
|
|
|
|
## Architecture Details |
|
|
|
|
|
**UNet** based on OpenAI Guided Diffusion: |
|
|
- Encoder-Decoder structure with skip connections |
|
|
- ResNet blocks with GroupNorm |
|
|
- Self-attention at multiple resolutions |
|
|
- Time embedding via sinusoidal position encoding |
|
|
- Optional class embedding for conditional generation |
|
|
|
|
|
## Flow Matching |
|
|
|
|
|
Flow Matching learns a velocity field that transports samples from source to target: |
|
|
|
|
|
$$\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}$$ |
|
|
|
|
|
Training uses Conditional Flow Matching (CFM) with straight-line paths: |
|
|
|
|
|
$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]$$ |
|
|
|
|
|
## Requirements |
|
|
|
|
|
```bash |
|
|
pip install torch torchvision |
|
|
pip install torchdiffeq einops |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
CC BY-NC-SA 4.0 - Non-commercial use only. |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
- Training code based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git) |
|
|
- UNet architecture from [OpenAI Guided Diffusion](https://github.com/openai/guided-diffusion) |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@misc{flowmatching-unet-2024, |
|
|
title={UNet Flow Matching Models for Image Generation}, |
|
|
author={WayBob}, |
|
|
year={2024}, |
|
|
howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}} |
|
|
} |
|
|
``` |
|
|
|