UNet Flow Matching Models

Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.

Training code based on: keishihara/flow-matching

Models

This repository contains three UNet-based velocity field models trained with Flow Matching:

MNIST (28Γ—28 Grayscale)

  • Checkpoint: mnist/ckpt.pth (24 MB)
  • Parameters: 6.2M
  • Architecture: UNet with num_channels=64, num_res_blocks=2
  • Conditional: Yes (10 classes, 0-9 digits)
  • Training: 50 epochs, batch_size=128, lr=1e-3
  • Hardware: NVIDIA H100 GPU

CIFAR-10 (32Γ—32 RGB)

  • Checkpoint: cifar10/ckpt.pth (35 MB)
  • Parameters: 9.0M
  • Architecture: UNet with num_channels=64, num_res_blocks=2
  • Conditional: Yes (10 classes)
  • Training: 50 epochs, batch_size=128, lr=1e-3
  • Hardware: NVIDIA H100 GPU

CelebA (64Γ—64 RGB)

  • Checkpoint: celeba64/ckpt.pth (332 MB)
  • Parameters: 83.0M
  • Architecture: UNet with num_channels=128, num_res_blocks=2
  • Conditional: No (unconditional face generation)
  • Training: 50 epochs, batch_size=512, lr=1e-4
  • Dataset: 202,599 CelebA training images
  • Final loss: 0.114
  • Hardware: NVIDIA H100 GPU

Sample Results

MNIST

MNIST Samples Generated MNIST digits at different velocity reuse thresholds

CIFAR-10

CIFAR-10 Samples Generated CIFAR-10 images at different velocity reuse thresholds

CelebA 64Γ—64

CelebA Samples Generated 64Γ—64 faces at different velocity reuse thresholds

Training Code

The models were trained using the Flow Matching implementation based on keishihara/flow-matching.

Training Scripts

MNIST:

# train_flow_matching_on_images.py
python train_flow_matching_on_images.py \
    --do_train \
    --dataset mnist \
    --n_epochs 50 \
    --batch_size 128 \
    --learning_rate 1e-3

CIFAR-10:

python train_flow_matching_on_images.py \
    --do_train \
    --dataset cifar10 \
    --n_epochs 50 \
    --batch_size 128 \
    --learning_rate 1e-3 \
    --horizontal_flip

CelebA:

# train_celeba64.py
python train_celeba64.py \
    --do_train \
    --n_epochs 50 \
    --batch_size 512 \
    --learning_rate 1e-4 \
    --horizontal_flip

Training code files included:

  • train_flow_matching_on_images.py - For MNIST and CIFAR-10
  • train_celeba64.py - For CelebA 64Γ—64

Usage

Load Model

import torch
from huggingface_hub import hf_hub_download

# Download checkpoint
ckpt_path = hf_hub_download(
    repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
    filename="celeba64/ckpt.pth"
)

# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cuda")

Inference (Sampling)

import torch
from flow_matching.models import UNetModel
from flow_matching.solver import ODESolver, ModelWrapper

device = "cuda"

# Create model (CelebA example)
flow = UNetModel(
    dim=(3, 64, 64),
    num_channels=128,
    num_res_blocks=2,
    num_classes=0,
    class_cond=False,
).to(device)

# Load weights
flow.load_state_dict(checkpoint)
flow.eval()

# Create solver
model_wrapper = ModelWrapper(flow)
solver = ODESolver(model_wrapper)

# Sample from Gaussian noise
batch_size = 4
x_init = torch.randn(batch_size, 3, 64, 64).to(device)
time_grid = torch.linspace(0, 1, 21).to(device)  # 20 steps

with torch.no_grad():
    samples = solver.sample(
        x_init=x_init,
        step_size=0.05,
        method="euler",
        time_grid=time_grid
    )

# Denormalize from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)

# Save or visualize
from torchvision.utils import save_image
save_image(samples, "generated_faces.png", nrow=2)

Conditional Generation (MNIST/CIFAR-10)

# For class-conditional models
flow = UNetModel(
    dim=(3, 32, 32),  # CIFAR-10
    num_channels=64,
    num_res_blocks=2,
    num_classes=10,
    class_cond=True,
).to(device)

# Load CIFAR-10 checkpoint
ckpt = torch.load("cifar10/ckpt.pth")
flow.load_state_dict(ckpt)

# Generate specific class (e.g., class 3)
y = torch.tensor([3, 3, 3, 3]).to(device)  # Batch of 4, all class 3

def ode_func(t, x):
    return flow(x=x, t=t, y=y)

# Then use solver as before

Architecture Details

UNet based on OpenAI Guided Diffusion:

  • Encoder-Decoder structure with skip connections
  • ResNet blocks with GroupNorm
  • Self-attention at multiple resolutions
  • Time embedding via sinusoidal position encoding
  • Optional class embedding for conditional generation

Flow Matching

Flow Matching learns a velocity field that transports samples from source to target:

dxdt=vθ(xt,t),x0∼N(0,I),x1∼pdata\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}

Training uses Conditional Flow Matching (CFM) with straight-line paths:

L=Et,x0,x1[βˆ₯vΞΈ(xt,t)βˆ’(x1βˆ’(1βˆ’Οƒ)x0)βˆ₯2]\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]

Requirements

pip install torch torchvision
pip install torchdiffeq einops

License

CC BY-NC-SA 4.0 - Non-commercial use only.

Acknowledgments

Citation

@misc{flowmatching-unet-2024,
  title={UNet Flow Matching Models for Image Generation},
  author={WayBob},
  year={2024},
  howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Datasets used to train WayBob/FlowMatching-Unet-Celeb-64x64