WayBob's picture
Upload folder using huggingface_hub
4bc5a15 verified
metadata
license: cc-by-nc-sa-4.0
tags:
  - flow-matching
  - generative-model
  - image-generation
  - pytorch
datasets:
  - mnist
  - cifar10
  - celeba
base_model: keishihara/flow-matching

UNet Flow Matching Models

Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.

Training code based on: keishihara/flow-matching

Models

This repository contains three UNet-based velocity field models trained with Flow Matching:

MNIST (28×28 Grayscale)

  • Checkpoint: mnist/ckpt.pth (24 MB)
  • Parameters: 6.2M
  • Architecture: UNet with num_channels=64, num_res_blocks=2
  • Conditional: Yes (10 classes, 0-9 digits)
  • Training: 50 epochs, batch_size=128, lr=1e-3
  • Hardware: NVIDIA H100 GPU

CIFAR-10 (32×32 RGB)

  • Checkpoint: cifar10/ckpt.pth (35 MB)
  • Parameters: 9.0M
  • Architecture: UNet with num_channels=64, num_res_blocks=2
  • Conditional: Yes (10 classes)
  • Training: 50 epochs, batch_size=128, lr=1e-3
  • Hardware: NVIDIA H100 GPU

CelebA (64×64 RGB)

  • Checkpoint: celeba64/ckpt.pth (332 MB)
  • Parameters: 83.0M
  • Architecture: UNet with num_channels=128, num_res_blocks=2
  • Conditional: No (unconditional face generation)
  • Training: 50 epochs, batch_size=512, lr=1e-4
  • Dataset: 202,599 CelebA training images
  • Final loss: 0.114
  • Hardware: NVIDIA H100 GPU

Sample Results

MNIST

MNIST Samples Generated MNIST digits at different velocity reuse thresholds

CIFAR-10

CIFAR-10 Samples Generated CIFAR-10 images at different velocity reuse thresholds

CelebA 64×64

CelebA Samples Generated 64×64 faces at different velocity reuse thresholds

Training Code

The models were trained using the Flow Matching implementation based on keishihara/flow-matching.

Training Scripts

MNIST:

# train_flow_matching_on_images.py
python train_flow_matching_on_images.py \
    --do_train \
    --dataset mnist \
    --n_epochs 50 \
    --batch_size 128 \
    --learning_rate 1e-3

CIFAR-10:

python train_flow_matching_on_images.py \
    --do_train \
    --dataset cifar10 \
    --n_epochs 50 \
    --batch_size 128 \
    --learning_rate 1e-3 \
    --horizontal_flip

CelebA:

# train_celeba64.py
python train_celeba64.py \
    --do_train \
    --n_epochs 50 \
    --batch_size 512 \
    --learning_rate 1e-4 \
    --horizontal_flip

Training code files included:

  • train_flow_matching_on_images.py - For MNIST and CIFAR-10
  • train_celeba64.py - For CelebA 64×64

Usage

Load Model

import torch
from huggingface_hub import hf_hub_download

# Download checkpoint
ckpt_path = hf_hub_download(
    repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
    filename="celeba64/ckpt.pth"
)

# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location="cuda")

Inference (Sampling)

import torch
from flow_matching.models import UNetModel
from flow_matching.solver import ODESolver, ModelWrapper

device = "cuda"

# Create model (CelebA example)
flow = UNetModel(
    dim=(3, 64, 64),
    num_channels=128,
    num_res_blocks=2,
    num_classes=0,
    class_cond=False,
).to(device)

# Load weights
flow.load_state_dict(checkpoint)
flow.eval()

# Create solver
model_wrapper = ModelWrapper(flow)
solver = ODESolver(model_wrapper)

# Sample from Gaussian noise
batch_size = 4
x_init = torch.randn(batch_size, 3, 64, 64).to(device)
time_grid = torch.linspace(0, 1, 21).to(device)  # 20 steps

with torch.no_grad():
    samples = solver.sample(
        x_init=x_init,
        step_size=0.05,
        method="euler",
        time_grid=time_grid
    )

# Denormalize from [-1, 1] to [0, 1]
samples = (samples + 1) / 2
samples = samples.clamp(0, 1)

# Save or visualize
from torchvision.utils import save_image
save_image(samples, "generated_faces.png", nrow=2)

Conditional Generation (MNIST/CIFAR-10)

# For class-conditional models
flow = UNetModel(
    dim=(3, 32, 32),  # CIFAR-10
    num_channels=64,
    num_res_blocks=2,
    num_classes=10,
    class_cond=True,
).to(device)

# Load CIFAR-10 checkpoint
ckpt = torch.load("cifar10/ckpt.pth")
flow.load_state_dict(ckpt)

# Generate specific class (e.g., class 3)
y = torch.tensor([3, 3, 3, 3]).to(device)  # Batch of 4, all class 3

def ode_func(t, x):
    return flow(x=x, t=t, y=y)

# Then use solver as before

Architecture Details

UNet based on OpenAI Guided Diffusion:

  • Encoder-Decoder structure with skip connections
  • ResNet blocks with GroupNorm
  • Self-attention at multiple resolutions
  • Time embedding via sinusoidal position encoding
  • Optional class embedding for conditional generation

Flow Matching

Flow Matching learns a velocity field that transports samples from source to target:

dxdt=vθ(xt,t),x0∼N(0,I),x1∼pdata\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}

Training uses Conditional Flow Matching (CFM) with straight-line paths:

L=Et,x0,x1[∥vθ(xt,t)−(x1−(1−σ)x0)∥2]\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]

Requirements

pip install torch torchvision
pip install torchdiffeq einops

License

CC BY-NC-SA 4.0 - Non-commercial use only.

Acknowledgments

Citation

@misc{flowmatching-unet-2024,
  title={UNet Flow Matching Models for Image Generation},
  author={WayBob},
  year={2024},
  howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
}