--- license: cc-by-nc-sa-4.0 tags: - flow-matching - generative-model - image-generation - pytorch datasets: - mnist - cifar10 - celeba base_model: keishihara/flow-matching --- # UNet Flow Matching Models Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets. **Training code based on**: [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git) ## Models This repository contains three UNet-based velocity field models trained with Flow Matching: ### MNIST (28×28 Grayscale) - **Checkpoint**: `mnist/ckpt.pth` (24 MB) - **Parameters**: 6.2M - **Architecture**: UNet with num_channels=64, num_res_blocks=2 - **Conditional**: Yes (10 classes, 0-9 digits) - **Training**: 50 epochs, batch_size=128, lr=1e-3 - **Hardware**: NVIDIA H100 GPU ### CIFAR-10 (32×32 RGB) - **Checkpoint**: `cifar10/ckpt.pth` (35 MB) - **Parameters**: 9.0M - **Architecture**: UNet with num_channels=64, num_res_blocks=2 - **Conditional**: Yes (10 classes) - **Training**: 50 epochs, batch_size=128, lr=1e-3 - **Hardware**: NVIDIA H100 GPU ### CelebA (64×64 RGB) - **Checkpoint**: `celeba64/ckpt.pth` (332 MB) - **Parameters**: 83.0M - **Architecture**: UNet with num_channels=128, num_res_blocks=2 - **Conditional**: No (unconditional face generation) - **Training**: 50 epochs, batch_size=512, lr=1e-4 - **Dataset**: 202,599 CelebA training images - **Final loss**: 0.114 - **Hardware**: NVIDIA H100 GPU ## Sample Results ### MNIST ![MNIST Samples](mnist_with_diff.png) *Generated MNIST digits at different velocity reuse thresholds* ### CIFAR-10 ![CIFAR-10 Samples](cifar10_with_diff.png) *Generated CIFAR-10 images at different velocity reuse thresholds* ### CelebA 64×64 ![CelebA Samples](celeba_with_diff.png) *Generated 64×64 faces at different velocity reuse thresholds* ## Training Code The models were trained using the Flow Matching implementation based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git). ### Training Scripts **MNIST**: ```python # train_flow_matching_on_images.py python train_flow_matching_on_images.py \ --do_train \ --dataset mnist \ --n_epochs 50 \ --batch_size 128 \ --learning_rate 1e-3 ``` **CIFAR-10**: ```python python train_flow_matching_on_images.py \ --do_train \ --dataset cifar10 \ --n_epochs 50 \ --batch_size 128 \ --learning_rate 1e-3 \ --horizontal_flip ``` **CelebA**: ```python # train_celeba64.py python train_celeba64.py \ --do_train \ --n_epochs 50 \ --batch_size 512 \ --learning_rate 1e-4 \ --horizontal_flip ``` Training code files included: - `train_flow_matching_on_images.py` - For MNIST and CIFAR-10 - `train_celeba64.py` - For CelebA 64×64 ## Usage ### Load Model ```python import torch from huggingface_hub import hf_hub_download # Download checkpoint ckpt_path = hf_hub_download( repo_id="WayBob/FlowMatching-Unet-Celeb-64x64", filename="celeba64/ckpt.pth" ) # Load checkpoint checkpoint = torch.load(ckpt_path, map_location="cuda") ``` ### Inference (Sampling) ```python import torch from flow_matching.models import UNetModel from flow_matching.solver import ODESolver, ModelWrapper device = "cuda" # Create model (CelebA example) flow = UNetModel( dim=(3, 64, 64), num_channels=128, num_res_blocks=2, num_classes=0, class_cond=False, ).to(device) # Load weights flow.load_state_dict(checkpoint) flow.eval() # Create solver model_wrapper = ModelWrapper(flow) solver = ODESolver(model_wrapper) # Sample from Gaussian noise batch_size = 4 x_init = torch.randn(batch_size, 3, 64, 64).to(device) time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps with torch.no_grad(): samples = solver.sample( x_init=x_init, step_size=0.05, method="euler", time_grid=time_grid ) # Denormalize from [-1, 1] to [0, 1] samples = (samples + 1) / 2 samples = samples.clamp(0, 1) # Save or visualize from torchvision.utils import save_image save_image(samples, "generated_faces.png", nrow=2) ``` ### Conditional Generation (MNIST/CIFAR-10) ```python # For class-conditional models flow = UNetModel( dim=(3, 32, 32), # CIFAR-10 num_channels=64, num_res_blocks=2, num_classes=10, class_cond=True, ).to(device) # Load CIFAR-10 checkpoint ckpt = torch.load("cifar10/ckpt.pth") flow.load_state_dict(ckpt) # Generate specific class (e.g., class 3) y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3 def ode_func(t, x): return flow(x=x, t=t, y=y) # Then use solver as before ``` ## Architecture Details **UNet** based on OpenAI Guided Diffusion: - Encoder-Decoder structure with skip connections - ResNet blocks with GroupNorm - Self-attention at multiple resolutions - Time embedding via sinusoidal position encoding - Optional class embedding for conditional generation ## Flow Matching Flow Matching learns a velocity field that transports samples from source to target: $$\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}$$ Training uses Conditional Flow Matching (CFM) with straight-line paths: $$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]$$ ## Requirements ```bash pip install torch torchvision pip install torchdiffeq einops ``` ## License CC BY-NC-SA 4.0 - Non-commercial use only. ## Acknowledgments - Training code based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git) - UNet architecture from [OpenAI Guided Diffusion](https://github.com/openai/guided-diffusion) ## Citation ```bibtex @misc{flowmatching-unet-2024, title={UNet Flow Matching Models for Image Generation}, author={WayBob}, year={2024}, howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}} } ```