Upload folder using huggingface_hub
Browse files- README.md +177 -30
- train_celeba64.py +207 -0
- train_flow_matching_on_images.py +198 -0
README.md
CHANGED
|
@@ -9,38 +9,42 @@ datasets:
|
|
| 9 |
- mnist
|
| 10 |
- cifar10
|
| 11 |
- celeba
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
# UNet Flow Matching Models
|
| 15 |
|
| 16 |
Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
|
| 17 |
|
|
|
|
|
|
|
| 18 |
## Models
|
| 19 |
|
| 20 |
This repository contains three UNet-based velocity field models trained with Flow Matching:
|
| 21 |
|
| 22 |
### MNIST (28×28 Grayscale)
|
| 23 |
-
- **Checkpoint**: `mnist/ckpt.pth`
|
| 24 |
- **Parameters**: 6.2M
|
| 25 |
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
|
| 26 |
-
- **Conditional**: Yes (10 classes)
|
| 27 |
-
- **Training**: 50 epochs
|
| 28 |
- **Hardware**: NVIDIA H100 GPU
|
| 29 |
|
| 30 |
-
### CIFAR-10 (32×32 RGB)
|
| 31 |
-
- **Checkpoint**: `cifar10/ckpt.pth`
|
| 32 |
- **Parameters**: 9.0M
|
| 33 |
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
|
| 34 |
- **Conditional**: Yes (10 classes)
|
| 35 |
-
- **Training**: 50 epochs
|
| 36 |
- **Hardware**: NVIDIA H100 GPU
|
| 37 |
|
| 38 |
### CelebA (64×64 RGB)
|
| 39 |
-
- **Checkpoint**: `celeba64/ckpt.pth`
|
| 40 |
- **Parameters**: 83.0M
|
| 41 |
- **Architecture**: UNet with num_channels=128, num_res_blocks=2
|
| 42 |
- **Conditional**: No (unconditional face generation)
|
| 43 |
-
- **Training**: 50 epochs
|
|
|
|
| 44 |
- **Final loss**: 0.114
|
| 45 |
- **Hardware**: NVIDIA H100 GPU
|
| 46 |
|
|
@@ -48,49 +52,192 @@ This repository contains three UNet-based velocity field models trained with Flo
|
|
| 48 |
|
| 49 |
### MNIST
|
| 50 |

|
|
|
|
| 51 |
|
| 52 |
### CIFAR-10
|
| 53 |

|
|
|
|
| 54 |
|
| 55 |
### CelebA 64×64
|
| 56 |

|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
## Usage
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
```python
|
| 61 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
checkpoint = torch.load("celeba64/ckpt.pth", map_location="cuda")
|
| 65 |
|
| 66 |
-
|
| 67 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
```
|
| 69 |
|
| 70 |
-
## Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
-
|
| 73 |
|
| 74 |
-
|
| 75 |
-
- Time embedding layers
|
| 76 |
-
- ResNet blocks with adaptive normalization
|
| 77 |
-
- Self-attention blocks
|
| 78 |
-
- U-Net skip connections
|
| 79 |
-
- Class conditioning (for MNIST and CIFAR-10)
|
| 80 |
|
| 81 |
-
|
| 82 |
|
| 83 |
-
|
| 84 |
-
- **Optimizer**: AdamW
|
| 85 |
-
- **Epochs**: 50
|
| 86 |
-
- **GPU**: NVIDIA H100
|
| 87 |
-
- **Loss function**: MSE between predicted and target velocity fields
|
| 88 |
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
## License
|
| 95 |
|
| 96 |
CC BY-NC-SA 4.0 - Non-commercial use only.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
- mnist
|
| 10 |
- cifar10
|
| 11 |
- celeba
|
| 12 |
+
base_model: keishihara/flow-matching
|
| 13 |
---
|
| 14 |
|
| 15 |
# UNet Flow Matching Models
|
| 16 |
|
| 17 |
Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
|
| 18 |
|
| 19 |
+
**Training code based on**: [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
|
| 20 |
+
|
| 21 |
## Models
|
| 22 |
|
| 23 |
This repository contains three UNet-based velocity field models trained with Flow Matching:
|
| 24 |
|
| 25 |
### MNIST (28×28 Grayscale)
|
| 26 |
+
- **Checkpoint**: `mnist/ckpt.pth` (24 MB)
|
| 27 |
- **Parameters**: 6.2M
|
| 28 |
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
|
| 29 |
+
- **Conditional**: Yes (10 classes, 0-9 digits)
|
| 30 |
+
- **Training**: 50 epochs, batch_size=128, lr=1e-3
|
| 31 |
- **Hardware**: NVIDIA H100 GPU
|
| 32 |
|
| 33 |
+
### CIFAR-10 (32×32 RGB)
|
| 34 |
+
- **Checkpoint**: `cifar10/ckpt.pth` (35 MB)
|
| 35 |
- **Parameters**: 9.0M
|
| 36 |
- **Architecture**: UNet with num_channels=64, num_res_blocks=2
|
| 37 |
- **Conditional**: Yes (10 classes)
|
| 38 |
+
- **Training**: 50 epochs, batch_size=128, lr=1e-3
|
| 39 |
- **Hardware**: NVIDIA H100 GPU
|
| 40 |
|
| 41 |
### CelebA (64×64 RGB)
|
| 42 |
+
- **Checkpoint**: `celeba64/ckpt.pth` (332 MB)
|
| 43 |
- **Parameters**: 83.0M
|
| 44 |
- **Architecture**: UNet with num_channels=128, num_res_blocks=2
|
| 45 |
- **Conditional**: No (unconditional face generation)
|
| 46 |
+
- **Training**: 50 epochs, batch_size=512, lr=1e-4
|
| 47 |
+
- **Dataset**: 202,599 CelebA training images
|
| 48 |
- **Final loss**: 0.114
|
| 49 |
- **Hardware**: NVIDIA H100 GPU
|
| 50 |
|
|
|
|
| 52 |
|
| 53 |
### MNIST
|
| 54 |

|
| 55 |
+
*Generated MNIST digits at different velocity reuse thresholds*
|
| 56 |
|
| 57 |
### CIFAR-10
|
| 58 |

|
| 59 |
+
*Generated CIFAR-10 images at different velocity reuse thresholds*
|
| 60 |
|
| 61 |
### CelebA 64×64
|
| 62 |

|
| 63 |
+
*Generated 64×64 faces at different velocity reuse thresholds*
|
| 64 |
+
|
| 65 |
+
## Training Code
|
| 66 |
+
|
| 67 |
+
The models were trained using the Flow Matching implementation based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git).
|
| 68 |
+
|
| 69 |
+
### Training Scripts
|
| 70 |
+
|
| 71 |
+
**MNIST**:
|
| 72 |
+
```python
|
| 73 |
+
# train_flow_matching_on_images.py
|
| 74 |
+
python train_flow_matching_on_images.py \
|
| 75 |
+
--do_train \
|
| 76 |
+
--dataset mnist \
|
| 77 |
+
--n_epochs 50 \
|
| 78 |
+
--batch_size 128 \
|
| 79 |
+
--learning_rate 1e-3
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
**CIFAR-10**:
|
| 83 |
+
```python
|
| 84 |
+
python train_flow_matching_on_images.py \
|
| 85 |
+
--do_train \
|
| 86 |
+
--dataset cifar10 \
|
| 87 |
+
--n_epochs 50 \
|
| 88 |
+
--batch_size 128 \
|
| 89 |
+
--learning_rate 1e-3 \
|
| 90 |
+
--horizontal_flip
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
**CelebA**:
|
| 94 |
+
```python
|
| 95 |
+
# train_celeba64.py
|
| 96 |
+
python train_celeba64.py \
|
| 97 |
+
--do_train \
|
| 98 |
+
--n_epochs 50 \
|
| 99 |
+
--batch_size 512 \
|
| 100 |
+
--learning_rate 1e-4 \
|
| 101 |
+
--horizontal_flip
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Training code files included:
|
| 105 |
+
- `train_flow_matching_on_images.py` - For MNIST and CIFAR-10
|
| 106 |
+
- `train_celeba64.py` - For CelebA 64×64
|
| 107 |
|
| 108 |
## Usage
|
| 109 |
|
| 110 |
+
### Load Model
|
| 111 |
+
|
| 112 |
+
```python
|
| 113 |
+
import torch
|
| 114 |
+
from huggingface_hub import hf_hub_download
|
| 115 |
+
|
| 116 |
+
# Download checkpoint
|
| 117 |
+
ckpt_path = hf_hub_download(
|
| 118 |
+
repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
|
| 119 |
+
filename="celeba64/ckpt.pth"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Load checkpoint
|
| 123 |
+
checkpoint = torch.load(ckpt_path, map_location="cuda")
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
### Inference (Sampling)
|
| 127 |
+
|
| 128 |
```python
|
| 129 |
import torch
|
| 130 |
+
from flow_matching.models import UNetModel
|
| 131 |
+
from flow_matching.solver import ODESolver, ModelWrapper
|
| 132 |
+
|
| 133 |
+
device = "cuda"
|
| 134 |
+
|
| 135 |
+
# Create model (CelebA example)
|
| 136 |
+
flow = UNetModel(
|
| 137 |
+
dim=(3, 64, 64),
|
| 138 |
+
num_channels=128,
|
| 139 |
+
num_res_blocks=2,
|
| 140 |
+
num_classes=0,
|
| 141 |
+
class_cond=False,
|
| 142 |
+
).to(device)
|
| 143 |
+
|
| 144 |
+
# Load weights
|
| 145 |
+
flow.load_state_dict(checkpoint)
|
| 146 |
+
flow.eval()
|
| 147 |
+
|
| 148 |
+
# Create solver
|
| 149 |
+
model_wrapper = ModelWrapper(flow)
|
| 150 |
+
solver = ODESolver(model_wrapper)
|
| 151 |
+
|
| 152 |
+
# Sample from Gaussian noise
|
| 153 |
+
batch_size = 4
|
| 154 |
+
x_init = torch.randn(batch_size, 3, 64, 64).to(device)
|
| 155 |
+
time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps
|
| 156 |
+
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
samples = solver.sample(
|
| 159 |
+
x_init=x_init,
|
| 160 |
+
step_size=0.05,
|
| 161 |
+
method="euler",
|
| 162 |
+
time_grid=time_grid
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Denormalize from [-1, 1] to [0, 1]
|
| 166 |
+
samples = (samples + 1) / 2
|
| 167 |
+
samples = samples.clamp(0, 1)
|
| 168 |
+
|
| 169 |
+
# Save or visualize
|
| 170 |
+
from torchvision.utils import save_image
|
| 171 |
+
save_image(samples, "generated_faces.png", nrow=2)
|
| 172 |
+
```
|
| 173 |
|
| 174 |
+
### Conditional Generation (MNIST/CIFAR-10)
|
|
|
|
| 175 |
|
| 176 |
+
```python
|
| 177 |
+
# For class-conditional models
|
| 178 |
+
flow = UNetModel(
|
| 179 |
+
dim=(3, 32, 32), # CIFAR-10
|
| 180 |
+
num_channels=64,
|
| 181 |
+
num_res_blocks=2,
|
| 182 |
+
num_classes=10,
|
| 183 |
+
class_cond=True,
|
| 184 |
+
).to(device)
|
| 185 |
+
|
| 186 |
+
# Load CIFAR-10 checkpoint
|
| 187 |
+
ckpt = torch.load("cifar10/ckpt.pth")
|
| 188 |
+
flow.load_state_dict(ckpt)
|
| 189 |
+
|
| 190 |
+
# Generate specific class (e.g., class 3)
|
| 191 |
+
y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3
|
| 192 |
+
|
| 193 |
+
def ode_func(t, x):
|
| 194 |
+
return flow(x=x, t=t, y=y)
|
| 195 |
+
|
| 196 |
+
# Then use solver as before
|
| 197 |
```
|
| 198 |
|
| 199 |
+
## Architecture Details
|
| 200 |
+
|
| 201 |
+
**UNet** based on OpenAI Guided Diffusion:
|
| 202 |
+
- Encoder-Decoder structure with skip connections
|
| 203 |
+
- ResNet blocks with GroupNorm
|
| 204 |
+
- Self-attention at multiple resolutions
|
| 205 |
+
- Time embedding via sinusoidal position encoding
|
| 206 |
+
- Optional class embedding for conditional generation
|
| 207 |
|
| 208 |
+
## Flow Matching
|
| 209 |
|
| 210 |
+
Flow Matching learns a velocity field that transports samples from source to target:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
+
$$\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}$$
|
| 213 |
|
| 214 |
+
Training uses Conditional Flow Matching (CFM) with straight-line paths:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
$$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]$$
|
| 217 |
+
|
| 218 |
+
## Requirements
|
| 219 |
+
|
| 220 |
+
```bash
|
| 221 |
+
pip install torch torchvision
|
| 222 |
+
pip install torchdiffeq einops
|
| 223 |
+
```
|
| 224 |
|
| 225 |
## License
|
| 226 |
|
| 227 |
CC BY-NC-SA 4.0 - Non-commercial use only.
|
| 228 |
+
|
| 229 |
+
## Acknowledgments
|
| 230 |
+
|
| 231 |
+
- Training code based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
|
| 232 |
+
- UNet architecture from [OpenAI Guided Diffusion](https://github.com/openai/guided-diffusion)
|
| 233 |
+
|
| 234 |
+
## Citation
|
| 235 |
+
|
| 236 |
+
```bibtex
|
| 237 |
+
@misc{flowmatching-unet-2024,
|
| 238 |
+
title={UNet Flow Matching Models for Image Generation},
|
| 239 |
+
author={WayBob},
|
| 240 |
+
year={2024},
|
| 241 |
+
howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
|
| 242 |
+
}
|
| 243 |
+
```
|
train_celeba64.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training script for CelebA 64x64 Flow Matching model.
|
| 3 |
+
Usage:
|
| 4 |
+
python train_celeba64.py --do_train --n_epochs 50 --batch_size 128
|
| 5 |
+
python train_celeba64.py --do_sample
|
| 6 |
+
"""
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from functools import partial
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
|
| 11 |
+
import matplotlib.animation as animation
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from torch import Tensor
|
| 16 |
+
from torch.amp import GradScaler
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
from torchvision.utils import make_grid, save_image
|
| 19 |
+
from tqdm import tqdm as std_tqdm
|
| 20 |
+
from transformers import HfArgumentParser
|
| 21 |
+
|
| 22 |
+
from flow_matching.datasets.image_datasets import (
|
| 23 |
+
get_image_dataset,
|
| 24 |
+
get_test_transform,
|
| 25 |
+
get_train_transform,
|
| 26 |
+
)
|
| 27 |
+
from flow_matching.models import UNetModel
|
| 28 |
+
from flow_matching.sampler import PathSampler
|
| 29 |
+
from flow_matching.solver import ModelWrapper, ODESolver
|
| 30 |
+
from flow_matching.utils import model_size_summary, set_seed
|
| 31 |
+
|
| 32 |
+
tqdm = partial(std_tqdm, dynamic_ncols=True)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ScriptArguments:
|
| 37 |
+
do_train: bool = False
|
| 38 |
+
do_sample: bool = False
|
| 39 |
+
dataset: str = "celeba"
|
| 40 |
+
image_size: int = 64 # Key parameter for CelebA
|
| 41 |
+
batch_size: int = 128
|
| 42 |
+
n_epochs: int = 50
|
| 43 |
+
learning_rate: float = 1e-4
|
| 44 |
+
sigma_min: float = 0.0
|
| 45 |
+
seed: int = 42
|
| 46 |
+
output_dir: str = "outputs"
|
| 47 |
+
horizontal_flip: bool = True # Important for faces
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def train(args: ScriptArguments):
|
| 51 |
+
"""Train the flow matching model on CelebA 64x64."""
|
| 52 |
+
|
| 53 |
+
output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
|
| 54 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
set_seed(args.seed)
|
| 58 |
+
print(f"Using device: {device}")
|
| 59 |
+
print(f"Training CelebA at {args.image_size}x{args.image_size} resolution")
|
| 60 |
+
|
| 61 |
+
# Load the dataset with resize
|
| 62 |
+
dataset = get_image_dataset(
|
| 63 |
+
args.dataset,
|
| 64 |
+
train=True,
|
| 65 |
+
transform=get_train_transform(
|
| 66 |
+
horizontal_flip=args.horizontal_flip,
|
| 67 |
+
image_size=args.image_size
|
| 68 |
+
),
|
| 69 |
+
)
|
| 70 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4)
|
| 71 |
+
print(f"Loaded {args.dataset} dataset with {len(dataset):,} samples")
|
| 72 |
+
|
| 73 |
+
# CelebA doesn't have classes, so we set num_classes=0 and class_cond=False
|
| 74 |
+
input_shape = dataset[0][0].size()
|
| 75 |
+
print(f"{input_shape=}")
|
| 76 |
+
|
| 77 |
+
# Load the UNet model WITHOUT class conditioning for CelebA
|
| 78 |
+
flow = UNetModel(
|
| 79 |
+
input_shape,
|
| 80 |
+
num_channels=128, # Larger model for 64x64
|
| 81 |
+
num_res_blocks=2,
|
| 82 |
+
num_classes=0, # No class conditioning
|
| 83 |
+
class_cond=False,
|
| 84 |
+
).to(device)
|
| 85 |
+
path_sampler = PathSampler(sigma_min=args.sigma_min)
|
| 86 |
+
|
| 87 |
+
# Load the optimizer
|
| 88 |
+
optimizer = torch.optim.AdamW(flow.parameters(), lr=args.learning_rate)
|
| 89 |
+
scaler = GradScaler(enabled=device.type == "cuda")
|
| 90 |
+
print("GradScaler enabled:", scaler._enabled)
|
| 91 |
+
model_size_summary(flow)
|
| 92 |
+
|
| 93 |
+
for epoch in range(args.n_epochs):
|
| 94 |
+
flow.train()
|
| 95 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1:2d}/{args.n_epochs}")
|
| 96 |
+
|
| 97 |
+
for x_1, _ in pbar: # CelebA returns (img, label) but we ignore label
|
| 98 |
+
x_1 = x_1.to(device)
|
| 99 |
+
|
| 100 |
+
# Compute the probability path samples
|
| 101 |
+
x_0 = torch.randn_like(x_1)
|
| 102 |
+
t = torch.rand(x_1.size(0), device=device, dtype=x_1.dtype)
|
| 103 |
+
x_t, dx_t = path_sampler.sample(x_0, x_1, t)
|
| 104 |
+
|
| 105 |
+
flow.zero_grad(set_to_none=True)
|
| 106 |
+
|
| 107 |
+
# Compute the conditional flow matching loss WITHOUT class conditioning
|
| 108 |
+
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
| 109 |
+
vf_t = flow(t=t, x=x_t) # No y parameter
|
| 110 |
+
loss = F.mse_loss(vf_t, dx_t)
|
| 111 |
+
|
| 112 |
+
# Gradient scaling and backprop
|
| 113 |
+
scaler.scale(loss).backward()
|
| 114 |
+
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
|
| 115 |
+
scaler.step(optimizer)
|
| 116 |
+
scaler.update()
|
| 117 |
+
|
| 118 |
+
pbar.set_postfix({"loss": loss.item()})
|
| 119 |
+
|
| 120 |
+
torch.save(flow.state_dict(), output_dir / "ckpt.pth")
|
| 121 |
+
print(f"Final checkpoint saved to {output_dir / 'ckpt.pth'}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_samples_and_save_animation(args: ScriptArguments):
|
| 125 |
+
"""Generate samples following the flow and save the animation."""
|
| 126 |
+
|
| 127 |
+
output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
|
| 128 |
+
assert output_dir.is_dir(), f"Output directory {output_dir} does not exist"
|
| 129 |
+
|
| 130 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 131 |
+
set_seed(args.seed)
|
| 132 |
+
print(f"Using device: {device}")
|
| 133 |
+
|
| 134 |
+
# Load the dataset
|
| 135 |
+
dataset = get_image_dataset(
|
| 136 |
+
args.dataset,
|
| 137 |
+
train=False,
|
| 138 |
+
transform=get_test_transform(image_size=args.image_size),
|
| 139 |
+
)
|
| 140 |
+
input_shape = dataset[0][0].size()
|
| 141 |
+
|
| 142 |
+
# Load the flow model
|
| 143 |
+
flow = UNetModel(
|
| 144 |
+
input_shape,
|
| 145 |
+
num_channels=128,
|
| 146 |
+
num_res_blocks=2,
|
| 147 |
+
num_classes=0,
|
| 148 |
+
class_cond=False,
|
| 149 |
+
).to(device)
|
| 150 |
+
state_dict = torch.load(output_dir / "ckpt.pth", weights_only=True)
|
| 151 |
+
flow.load_state_dict(state_dict)
|
| 152 |
+
flow.eval()
|
| 153 |
+
|
| 154 |
+
# Use ODE solver to sample trajectories
|
| 155 |
+
class WrappedModel(ModelWrapper):
|
| 156 |
+
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
|
| 157 |
+
return self.model(x=x, t=t)
|
| 158 |
+
|
| 159 |
+
samples_count = 64 # 8x8 grid
|
| 160 |
+
sample_steps = 101
|
| 161 |
+
time_steps = torch.linspace(0, 1, sample_steps).to(device)
|
| 162 |
+
|
| 163 |
+
wrapped_model = WrappedModel(flow)
|
| 164 |
+
step_size = 0.05
|
| 165 |
+
x_init = torch.randn((samples_count, *input_shape), dtype=torch.float32, device=device)
|
| 166 |
+
solver = ODESolver(wrapped_model)
|
| 167 |
+
sol = solver.sample(
|
| 168 |
+
x_init=x_init,
|
| 169 |
+
step_size=step_size,
|
| 170 |
+
method="midpoint",
|
| 171 |
+
time_grid=time_steps,
|
| 172 |
+
return_intermediates=True,
|
| 173 |
+
)
|
| 174 |
+
sol = sol.detach().cpu()
|
| 175 |
+
final_samples = sol[-1]
|
| 176 |
+
|
| 177 |
+
save_image(final_samples, output_dir / "final_samples.png", nrow=8, normalize=True)
|
| 178 |
+
|
| 179 |
+
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
|
| 180 |
+
grid = make_grid(final_samples, nrow=8, normalize=True)
|
| 181 |
+
ax[0].imshow(grid.permute(1, 2, 0))
|
| 182 |
+
ax[0].set_title("Final samples (t = 1.0)", fontsize=16)
|
| 183 |
+
ax[0].axis("off")
|
| 184 |
+
|
| 185 |
+
def update(frame: int):
|
| 186 |
+
grid = make_grid(sol[frame], nrow=8, normalize=True)
|
| 187 |
+
ax[1].clear()
|
| 188 |
+
ax[1].imshow(grid.permute(1, 2, 0))
|
| 189 |
+
ax[1].set_title(f"t = {time_steps[frame].item():.2f}", fontsize=16)
|
| 190 |
+
ax[1].axis("off")
|
| 191 |
+
|
| 192 |
+
fig.subplots_adjust(left=0.02, right=0.98, top=0.90, bottom=0.05, wspace=0.1)
|
| 193 |
+
ani = animation.FuncAnimation(fig, update, frames=sample_steps)
|
| 194 |
+
ani.save(output_dir / "trajectory.gif", writer="pillow", fps=20)
|
| 195 |
+
print(f"Generated trajectory saved to {output_dir / 'trajectory.gif'}")
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if __name__ == "__main__":
|
| 199 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 200 |
+
script_args, *_ = parser.parse_args_into_dataclasses()
|
| 201 |
+
|
| 202 |
+
if script_args.do_train:
|
| 203 |
+
train(script_args)
|
| 204 |
+
|
| 205 |
+
if script_args.do_sample:
|
| 206 |
+
generate_samples_and_save_animation(script_args)
|
| 207 |
+
|
train_flow_matching_on_images.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from functools import partial
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import matplotlib.animation as animation
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.amp import GradScaler
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
from torchvision.utils import make_grid, save_image
|
| 13 |
+
from tqdm import tqdm as std_tqdm
|
| 14 |
+
from transformers import HfArgumentParser
|
| 15 |
+
|
| 16 |
+
from flow_matching.datasets.image_datasets import (
|
| 17 |
+
get_image_dataset,
|
| 18 |
+
get_test_transform,
|
| 19 |
+
get_train_transform,
|
| 20 |
+
)
|
| 21 |
+
from flow_matching.models import UNetModel
|
| 22 |
+
from flow_matching.sampler import PathSampler
|
| 23 |
+
from flow_matching.solver import ModelWrapper, ODESolver
|
| 24 |
+
from flow_matching.utils import model_size_summary, set_seed
|
| 25 |
+
|
| 26 |
+
tqdm = partial(std_tqdm, dynamic_ncols=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ScriptArguments:
|
| 31 |
+
do_train: bool = False
|
| 32 |
+
do_sample: bool = False
|
| 33 |
+
dataset: str = "mnist"
|
| 34 |
+
batch_size: int = 128
|
| 35 |
+
n_epochs: int = 10
|
| 36 |
+
learning_rate: float = 1e-3
|
| 37 |
+
sigma_min: float = 0.0
|
| 38 |
+
seed: int = 42
|
| 39 |
+
output_dir: str = "outputs"
|
| 40 |
+
horizontal_flip: bool = False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def train(args: ScriptArguments):
|
| 44 |
+
"""Train the flow matching model on the given dataset."""
|
| 45 |
+
|
| 46 |
+
output_dir = Path(args.output_dir) / "cfm" / args.dataset
|
| 47 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 50 |
+
set_seed(args.seed)
|
| 51 |
+
print(f"Using device: {device}")
|
| 52 |
+
|
| 53 |
+
# Load the dataset
|
| 54 |
+
dataset = get_image_dataset(
|
| 55 |
+
args.dataset,
|
| 56 |
+
train=True,
|
| 57 |
+
transform=get_train_transform(horizontal_flip=args.horizontal_flip),
|
| 58 |
+
)
|
| 59 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
|
| 60 |
+
print(f"Loaded {args.dataset} dataset with {len(dataset):,} samples")
|
| 61 |
+
|
| 62 |
+
num_classes = len(dataset.classes)
|
| 63 |
+
input_shape = dataset[0][0].size()
|
| 64 |
+
print(f"{input_shape=}, {num_classes=}")
|
| 65 |
+
|
| 66 |
+
# Load the UNet model with class conditioning for flow matching
|
| 67 |
+
flow = UNetModel(
|
| 68 |
+
input_shape,
|
| 69 |
+
num_channels=64,
|
| 70 |
+
num_res_blocks=2,
|
| 71 |
+
num_classes=num_classes,
|
| 72 |
+
class_cond=True,
|
| 73 |
+
).to(device)
|
| 74 |
+
path_sampler = PathSampler(sigma_min=args.sigma_min)
|
| 75 |
+
|
| 76 |
+
# Load the optimizer
|
| 77 |
+
optimizer = torch.optim.AdamW(flow.parameters(), lr=args.learning_rate)
|
| 78 |
+
scaler = GradScaler(enabled=device.type == "cuda")
|
| 79 |
+
print("GradScaler enabled:", scaler._enabled)
|
| 80 |
+
model_size_summary(flow)
|
| 81 |
+
|
| 82 |
+
for epoch in range(args.n_epochs):
|
| 83 |
+
flow.train()
|
| 84 |
+
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1:2d}/{args.n_epochs}")
|
| 85 |
+
|
| 86 |
+
for x_1, y in pbar:
|
| 87 |
+
x_1, y = x_1.to(device), y.to(device)
|
| 88 |
+
|
| 89 |
+
# Compute the probability path samples
|
| 90 |
+
x_0 = torch.randn_like(x_1)
|
| 91 |
+
t = torch.rand(x_1.size(0), device=device, dtype=x_1.dtype)
|
| 92 |
+
x_t, dx_t = path_sampler.sample(x_0, x_1, t)
|
| 93 |
+
|
| 94 |
+
flow.zero_grad(set_to_none=True)
|
| 95 |
+
|
| 96 |
+
# Compute the conditional flow matching loss with class conditioning
|
| 97 |
+
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
|
| 98 |
+
vf_t = flow(t=t, x=x_t, y=y)
|
| 99 |
+
loss = F.mse_loss(vf_t, dx_t)
|
| 100 |
+
|
| 101 |
+
# Gradient scaling and backprop
|
| 102 |
+
scaler.scale(loss).backward()
|
| 103 |
+
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0) # clip gradients
|
| 104 |
+
scaler.step(optimizer)
|
| 105 |
+
scaler.update()
|
| 106 |
+
|
| 107 |
+
pbar.set_postfix({"loss": loss.item()})
|
| 108 |
+
|
| 109 |
+
torch.save(flow.state_dict(), output_dir / "ckpt.pth")
|
| 110 |
+
print(f"Final checkpoint saved to {output_dir / 'ckpt.pth'}")
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def generate_samples_and_save_animation(args: ScriptArguments):
|
| 114 |
+
"""Generate samples following the flow and save the animation."""
|
| 115 |
+
|
| 116 |
+
output_dir = Path(args.output_dir) / "cfm" / args.dataset
|
| 117 |
+
assert output_dir.is_dir(), f"Output directory {output_dir} does not exist"
|
| 118 |
+
|
| 119 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 120 |
+
set_seed(args.seed)
|
| 121 |
+
print(f"Using device: {device}")
|
| 122 |
+
|
| 123 |
+
# Load the dataset
|
| 124 |
+
dataset = get_image_dataset(
|
| 125 |
+
args.dataset,
|
| 126 |
+
train=False,
|
| 127 |
+
transform=get_test_transform(),
|
| 128 |
+
)
|
| 129 |
+
input_shape = dataset[0][0].size()
|
| 130 |
+
num_classes = len(dataset.classes)
|
| 131 |
+
|
| 132 |
+
# Load the flow model
|
| 133 |
+
flow = UNetModel(
|
| 134 |
+
input_shape,
|
| 135 |
+
num_channels=64,
|
| 136 |
+
num_res_blocks=2,
|
| 137 |
+
num_classes=num_classes,
|
| 138 |
+
class_cond=True,
|
| 139 |
+
).to(device)
|
| 140 |
+
state_dict = torch.load(output_dir / "ckpt.pth", weights_only=True)
|
| 141 |
+
flow.load_state_dict(state_dict)
|
| 142 |
+
flow.eval()
|
| 143 |
+
|
| 144 |
+
# Use ODE solver to sample trajectories
|
| 145 |
+
class WrappedModel(ModelWrapper):
|
| 146 |
+
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
|
| 147 |
+
return self.model(x=x, t=t, **extras)
|
| 148 |
+
|
| 149 |
+
samples_per_class = 10
|
| 150 |
+
sample_steps = 101
|
| 151 |
+
time_steps = torch.linspace(0, 1, sample_steps).to(device)
|
| 152 |
+
class_list = torch.arange(num_classes, device=device).repeat(samples_per_class)
|
| 153 |
+
|
| 154 |
+
wrapped_model = WrappedModel(flow)
|
| 155 |
+
step_size = 0.05
|
| 156 |
+
x_init = torch.randn((class_list.size(0), *input_shape), dtype=torch.float32, device=device)
|
| 157 |
+
solver = ODESolver(wrapped_model)
|
| 158 |
+
sol = solver.sample(
|
| 159 |
+
x_init=x_init,
|
| 160 |
+
step_size=step_size,
|
| 161 |
+
method="midpoint",
|
| 162 |
+
time_grid=time_steps,
|
| 163 |
+
return_intermediates=True,
|
| 164 |
+
y=class_list,
|
| 165 |
+
)
|
| 166 |
+
sol = sol.detach().cpu()
|
| 167 |
+
final_samples = sol[-1]
|
| 168 |
+
|
| 169 |
+
save_image(final_samples, output_dir / "final_samples.png", nrow=num_classes, normalize=True)
|
| 170 |
+
|
| 171 |
+
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
|
| 172 |
+
grid = make_grid(final_samples, nrow=num_classes, normalize=True)
|
| 173 |
+
ax[0].imshow(grid.permute(1, 2, 0))
|
| 174 |
+
ax[0].set_title("Final samples (t = 1.0)", fontsize=16)
|
| 175 |
+
ax[0].axis("off")
|
| 176 |
+
|
| 177 |
+
def update(frame: int):
|
| 178 |
+
grid = make_grid(sol[frame], nrow=num_classes, normalize=True)
|
| 179 |
+
ax[1].clear()
|
| 180 |
+
ax[1].imshow(grid.permute(1, 2, 0))
|
| 181 |
+
ax[1].set_title(f"t = {time_steps[frame].item():.2f}", fontsize=16)
|
| 182 |
+
ax[1].axis("off")
|
| 183 |
+
|
| 184 |
+
fig.subplots_adjust(left=0.02, right=0.98, top=0.90, bottom=0.05, wspace=0.1)
|
| 185 |
+
ani = animation.FuncAnimation(fig, update, frames=sample_steps)
|
| 186 |
+
ani.save(output_dir / "trajectory.gif", writer="pillow", fps=20)
|
| 187 |
+
print(f"Generated trajectory saved to {output_dir / 'trajectory.gif'}")
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__ == "__main__":
|
| 191 |
+
parser = HfArgumentParser(ScriptArguments)
|
| 192 |
+
script_args, *_ = parser.parse_args_into_dataclasses()
|
| 193 |
+
|
| 194 |
+
if script_args.do_train:
|
| 195 |
+
train(script_args)
|
| 196 |
+
|
| 197 |
+
if script_args.do_sample:
|
| 198 |
+
generate_samples_and_save_animation(script_args)
|