UNet Flow Matching Models
Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
Training code based on: keishihara/flow-matching
Models
This repository contains three UNet-based velocity field models trained with Flow Matching:
MNIST (28Γ28 Grayscale)
- Checkpoint:
mnist/ckpt.pth(24 MB) - Parameters: 6.2M
- Architecture: UNet with num_channels=64, num_res_blocks=2
- Conditional: Yes (10 classes, 0-9 digits)
- Training: 50 epochs, batch_size=128, lr=1e-3
- Hardware: NVIDIA H100 GPU
CIFAR-10 (32Γ32 RGB)
- Checkpoint:
cifar10/ckpt.pth(35 MB) - Parameters: 9.0M
- Architecture: UNet with num_channels=64, num_res_blocks=2
- Conditional: Yes (10 classes)
- Training: 50 epochs, batch_size=128, lr=1e-3
- Hardware: NVIDIA H100 GPU
CelebA (64Γ64 RGB)
- Checkpoint:
celeba64/ckpt.pth(332 MB) - Parameters: 83.0M
- Architecture: UNet with num_channels=128, num_res_blocks=2
- Conditional: No (unconditional face generation)
- Training: 50 epochs, batch_size=512, lr=1e-4
- Dataset: 202,599 CelebA training images
- Final loss: 0.114
- Hardware: NVIDIA H100 GPU
Sample Results
MNIST
Generated MNIST digits at different velocity reuse thresholds
CIFAR-10
Generated CIFAR-10 images at different velocity reuse thresholds
CelebA 64Γ64
Generated 64Γ64 faces at different velocity reuse thresholds
Training Code
The models were trained using the Flow Matching implementation based on keishihara/flow-matching.
Training Scripts
MNIST:
# train_flow_matching_on_images.py
python train_flow_matching_on_images.py \
--do_train \
--dataset mnist \
--n_epochs 50 \
--batch_size 128 \
--learning_rate 1e-3
CIFAR-10:
python train_flow_matching_on_images.py \
--do_train \
--dataset cifar10 \
--n_epochs 50 \
--batch_size 128 \
--learning_rate 1e-3 \
--horizontal_flip
CelebA:
# train_celeba64.py
python train_celeba64.py \
--do_train \
--n_epochs 50 \
--batch_size 512 \
--learning_rate 1e-4 \
--horizontal_flip
Training code files included:
train_flow_matching_on_images.py- For MNIST and CIFAR-10train_celeba64.py- For CelebA 64Γ64
Usage
Load Model
import torch
from huggingface_hub import hf_hub_download
# Download checkpoint
ckpt_path = hf_hub_download(
repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
filename="celeba64/ckpt.pth"
)
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cuda")
Inference (Sampling)
import torch
from flow_matching.models import UNetModel
from flow_matching.solver import ODESolver, ModelWrapper
device = "cuda"
# Create model (CelebA example)
flow = UNetModel(
dim=(3, 64, 64),
num_channels=128,
num_res_blocks=2,
num_classes=0,
class_cond=False,
).to(device)
# Load weights
flow.load_state_dict(checkpoint)
flow.eval()
# Create solver
model_wrapper = ModelWrapper(flow)
solver = ODESolver(model_wrapper)
# Sample from Gaussian noise
batch_size = 4
x_init = torch.randn(batch_size, 3, 64, 64).to(device)
time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps
with torch.no_grad():
samples = solver.sample(
x_init=x_init,
step_size=0.05,
method="euler",
time_grid=time_grid
)
# Denormalize from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)
# Save or visualize
from torchvision.utils import save_image
save_image(samples, "generated_faces.png", nrow=2)
Conditional Generation (MNIST/CIFAR-10)
# For class-conditional models
flow = UNetModel(
dim=(3, 32, 32), # CIFAR-10
num_channels=64,
num_res_blocks=2,
num_classes=10,
class_cond=True,
).to(device)
# Load CIFAR-10 checkpoint
ckpt = torch.load("cifar10/ckpt.pth")
flow.load_state_dict(ckpt)
# Generate specific class (e.g., class 3)
y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3
def ode_func(t, x):
return flow(x=x, t=t, y=y)
# Then use solver as before
Architecture Details
UNet based on OpenAI Guided Diffusion:
- Encoder-Decoder structure with skip connections
- ResNet blocks with GroupNorm
- Self-attention at multiple resolutions
- Time embedding via sinusoidal position encoding
- Optional class embedding for conditional generation
Flow Matching
Flow Matching learns a velocity field that transports samples from source to target:
Training uses Conditional Flow Matching (CFM) with straight-line paths:
Requirements
pip install torch torchvision
pip install torchdiffeq einops
License
CC BY-NC-SA 4.0 - Non-commercial use only.
Acknowledgments
- Training code based on keishihara/flow-matching
- UNet architecture from OpenAI Guided Diffusion
Citation
@misc{flowmatching-unet-2024,
title={UNet Flow Matching Models for Image Generation},
author={WayBob},
year={2024},
howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
}