WayBob's picture
Upload folder using huggingface_hub
4bc5a15 verified
---
license: cc-by-nc-sa-4.0
tags:
- flow-matching
- generative-model
- image-generation
- pytorch
datasets:
- mnist
- cifar10
- celeba
base_model: keishihara/flow-matching
---
# UNet Flow Matching Models
Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
**Training code based on**: [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
## Models
This repository contains three UNet-based velocity field models trained with Flow Matching:
### MNIST (28×28 Grayscale)
- **Checkpoint**: `mnist/ckpt.pth` (24 MB)
- **Parameters**: 6.2M
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
- **Conditional**: Yes (10 classes, 0-9 digits)
- **Training**: 50 epochs, batch_size=128, lr=1e-3
- **Hardware**: NVIDIA H100 GPU
### CIFAR-10 (32×32 RGB)
- **Checkpoint**: `cifar10/ckpt.pth` (35 MB)
- **Parameters**: 9.0M
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
- **Conditional**: Yes (10 classes)
- **Training**: 50 epochs, batch_size=128, lr=1e-3
- **Hardware**: NVIDIA H100 GPU
### CelebA (64×64 RGB)
- **Checkpoint**: `celeba64/ckpt.pth` (332 MB)
- **Parameters**: 83.0M
- **Architecture**: UNet with num_channels=128, num_res_blocks=2
- **Conditional**: No (unconditional face generation)
- **Training**: 50 epochs, batch_size=512, lr=1e-4
- **Dataset**: 202,599 CelebA training images
- **Final loss**: 0.114
- **Hardware**: NVIDIA H100 GPU
## Sample Results
### MNIST
![MNIST Samples](mnist_with_diff.png)
*Generated MNIST digits at different velocity reuse thresholds*
### CIFAR-10
![CIFAR-10 Samples](cifar10_with_diff.png)
*Generated CIFAR-10 images at different velocity reuse thresholds*
### CelebA 64×64
![CelebA Samples](celeba_with_diff.png)
*Generated 64×64 faces at different velocity reuse thresholds*
## Training Code
The models were trained using the Flow Matching implementation based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git).
### Training Scripts
**MNIST**:
```python
# train_flow_matching_on_images.py
python train_flow_matching_on_images.py \
--do_train \
--dataset mnist \
--n_epochs 50 \
--batch_size 128 \
--learning_rate 1e-3
```
**CIFAR-10**:
```python
python train_flow_matching_on_images.py \
--do_train \
--dataset cifar10 \
--n_epochs 50 \
--batch_size 128 \
--learning_rate 1e-3 \
--horizontal_flip
```
**CelebA**:
```python
# train_celeba64.py
python train_celeba64.py \
--do_train \
--n_epochs 50 \
--batch_size 512 \
--learning_rate 1e-4 \
--horizontal_flip
```
Training code files included:
- `train_flow_matching_on_images.py` - For MNIST and CIFAR-10
- `train_celeba64.py` - For CelebA 64×64
## Usage
### Load Model
```python
import torch
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download(
repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
filename="celeba64/ckpt.pth"
)
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cuda")
```
### Inference (Sampling)
```python
import torch
from flow_matching.models import UNetModel
from flow_matching.solver import ODESolver, ModelWrapper
device = "cuda"
# Create model (CelebA example)
flow = UNetModel(
dim=(3, 64, 64),
num_channels=128,
num_res_blocks=2,
num_classes=0,
class_cond=False,
).to(device)
# Load weights
flow.load_state_dict(checkpoint)
flow.eval()
# Create solver
model_wrapper = ModelWrapper(flow)
solver = ODESolver(model_wrapper)
# Sample from Gaussian noise
batch_size = 4
x_init = torch.randn(batch_size, 3, 64, 64).to(device)
time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps
with torch.no_grad():
samples = solver.sample(
x_init=x_init,
step_size=0.05,
method="euler",
time_grid=time_grid
)
# Denormalize from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)
# Save or visualize
from torchvision.utils import save_image
save_image(samples, "generated_faces.png", nrow=2)
```
### Conditional Generation (MNIST/CIFAR-10)
```python
# For class-conditional models
flow = UNetModel(
dim=(3, 32, 32), # CIFAR-10
num_channels=64,
num_res_blocks=2,
num_classes=10,
class_cond=True,
).to(device)
# Load CIFAR-10 checkpoint
ckpt = torch.load("cifar10/ckpt.pth")
flow.load_state_dict(ckpt)
# Generate specific class (e.g., class 3)
y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3
def ode_func(t, x):
return flow(x=x, t=t, y=y)
# Then use solver as before
```
## Architecture Details
**UNet** based on OpenAI Guided Diffusion:
- Encoder-Decoder structure with skip connections
- ResNet blocks with GroupNorm
- Self-attention at multiple resolutions
- Time embedding via sinusoidal position encoding
- Optional class embedding for conditional generation
## Flow Matching
Flow Matching learns a velocity field that transports samples from source to target:
$$\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}$$
Training uses Conditional Flow Matching (CFM) with straight-line paths:
$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]$$
## Requirements
```bash
pip install torch torchvision
pip install torchdiffeq einops
```
## License
CC BY-NC-SA 4.0 - Non-commercial use only.
## Acknowledgments
- Training code based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
- UNet architecture from [OpenAI Guided Diffusion](https://github.com/openai/guided-diffusion)
## Citation
```bibtex
@misc{flowmatching-unet-2024,
title={UNet Flow Matching Models for Image Generation},
author={WayBob},
year={2024},
howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
}
```