WayBob commited on
Commit
4bc5a15
·
verified ·
1 Parent(s): 502d9d5

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +177 -30
  2. train_celeba64.py +207 -0
  3. train_flow_matching_on_images.py +198 -0
README.md CHANGED
@@ -9,38 +9,42 @@ datasets:
9
  - mnist
10
  - cifar10
11
  - celeba
 
12
  ---
13
 
14
  # UNet Flow Matching Models
15
 
16
  Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
17
 
 
 
18
  ## Models
19
 
20
  This repository contains three UNet-based velocity field models trained with Flow Matching:
21
 
22
  ### MNIST (28×28 Grayscale)
23
- - **Checkpoint**: `mnist/ckpt.pth`
24
  - **Parameters**: 6.2M
25
  - **Architecture**: UNet with num_channels=64, num_res_blocks=2
26
- - **Conditional**: Yes (10 classes)
27
- - **Training**: 50 epochs on MNIST dataset
28
  - **Hardware**: NVIDIA H100 GPU
29
 
30
- ### CIFAR-10 (32×32 RGB)
31
- - **Checkpoint**: `cifar10/ckpt.pth`
32
  - **Parameters**: 9.0M
33
  - **Architecture**: UNet with num_channels=64, num_res_blocks=2
34
  - **Conditional**: Yes (10 classes)
35
- - **Training**: 50 epochs on CIFAR-10 dataset
36
  - **Hardware**: NVIDIA H100 GPU
37
 
38
  ### CelebA (64×64 RGB)
39
- - **Checkpoint**: `celeba64/ckpt.pth`
40
  - **Parameters**: 83.0M
41
  - **Architecture**: UNet with num_channels=128, num_res_blocks=2
42
  - **Conditional**: No (unconditional face generation)
43
- - **Training**: 50 epochs on CelebA dataset (202,599 images)
 
44
  - **Final loss**: 0.114
45
  - **Hardware**: NVIDIA H100 GPU
46
 
@@ -48,49 +52,192 @@ This repository contains three UNet-based velocity field models trained with Flo
48
 
49
  ### MNIST
50
  ![MNIST Samples](mnist_with_diff.png)
 
51
 
52
  ### CIFAR-10
53
  ![CIFAR-10 Samples](cifar10_with_diff.png)
 
54
 
55
  ### CelebA 64×64
56
  ![CelebA Samples](celeba_with_diff.png)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  ## Usage
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ```python
61
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Load checkpoint (example for CelebA)
64
- checkpoint = torch.load("celeba64/ckpt.pth", map_location="cuda")
65
 
66
- # The checkpoint contains the model state dict
67
- # You need to create the UNet model first, then load the weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ```
69
 
70
- ## Architecture
 
 
 
 
 
 
 
71
 
72
- Based on the UNet architecture from OpenAI Guided Diffusion, adapted for Flow Matching.
73
 
74
- **Key components**:
75
- - Time embedding layers
76
- - ResNet blocks with adaptive normalization
77
- - Self-attention blocks
78
- - U-Net skip connections
79
- - Class conditioning (for MNIST and CIFAR-10)
80
 
81
- ## Training Details
82
 
83
- All models trained with:
84
- - **Optimizer**: AdamW
85
- - **Epochs**: 50
86
- - **GPU**: NVIDIA H100
87
- - **Loss function**: MSE between predicted and target velocity fields
88
 
89
- Dataset-specific:
90
- - **MNIST**: batch_size=128, lr=1e-3
91
- - **CIFAR-10**: batch_size=128, lr=1e-3, horizontal_flip=True
92
- - **CelebA**: batch_size=512, lr=1e-4, horizontal_flip=True
 
 
 
 
93
 
94
  ## License
95
 
96
  CC BY-NC-SA 4.0 - Non-commercial use only.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  - mnist
10
  - cifar10
11
  - celeba
12
+ base_model: keishihara/flow-matching
13
  ---
14
 
15
  # UNet Flow Matching Models
16
 
17
  Pre-trained UNet models for Flow Matching on MNIST, CIFAR-10, and CelebA datasets.
18
 
19
+ **Training code based on**: [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
20
+
21
  ## Models
22
 
23
  This repository contains three UNet-based velocity field models trained with Flow Matching:
24
 
25
  ### MNIST (28×28 Grayscale)
26
+ - **Checkpoint**: `mnist/ckpt.pth` (24 MB)
27
  - **Parameters**: 6.2M
28
  - **Architecture**: UNet with num_channels=64, num_res_blocks=2
29
+ - **Conditional**: Yes (10 classes, 0-9 digits)
30
+ - **Training**: 50 epochs, batch_size=128, lr=1e-3
31
  - **Hardware**: NVIDIA H100 GPU
32
 
33
+ ### CIFAR-10 (32×32 RGB)
34
+ - **Checkpoint**: `cifar10/ckpt.pth` (35 MB)
35
  - **Parameters**: 9.0M
36
  - **Architecture**: UNet with num_channels=64, num_res_blocks=2
37
  - **Conditional**: Yes (10 classes)
38
+ - **Training**: 50 epochs, batch_size=128, lr=1e-3
39
  - **Hardware**: NVIDIA H100 GPU
40
 
41
  ### CelebA (64×64 RGB)
42
+ - **Checkpoint**: `celeba64/ckpt.pth` (332 MB)
43
  - **Parameters**: 83.0M
44
  - **Architecture**: UNet with num_channels=128, num_res_blocks=2
45
  - **Conditional**: No (unconditional face generation)
46
+ - **Training**: 50 epochs, batch_size=512, lr=1e-4
47
+ - **Dataset**: 202,599 CelebA training images
48
  - **Final loss**: 0.114
49
  - **Hardware**: NVIDIA H100 GPU
50
 
 
52
 
53
  ### MNIST
54
  ![MNIST Samples](mnist_with_diff.png)
55
+ *Generated MNIST digits at different velocity reuse thresholds*
56
 
57
  ### CIFAR-10
58
  ![CIFAR-10 Samples](cifar10_with_diff.png)
59
+ *Generated CIFAR-10 images at different velocity reuse thresholds*
60
 
61
  ### CelebA 64×64
62
  ![CelebA Samples](celeba_with_diff.png)
63
+ *Generated 64×64 faces at different velocity reuse thresholds*
64
+
65
+ ## Training Code
66
+
67
+ The models were trained using the Flow Matching implementation based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git).
68
+
69
+ ### Training Scripts
70
+
71
+ **MNIST**:
72
+ ```python
73
+ # train_flow_matching_on_images.py
74
+ python train_flow_matching_on_images.py \
75
+ --do_train \
76
+ --dataset mnist \
77
+ --n_epochs 50 \
78
+ --batch_size 128 \
79
+ --learning_rate 1e-3
80
+ ```
81
+
82
+ **CIFAR-10**:
83
+ ```python
84
+ python train_flow_matching_on_images.py \
85
+ --do_train \
86
+ --dataset cifar10 \
87
+ --n_epochs 50 \
88
+ --batch_size 128 \
89
+ --learning_rate 1e-3 \
90
+ --horizontal_flip
91
+ ```
92
+
93
+ **CelebA**:
94
+ ```python
95
+ # train_celeba64.py
96
+ python train_celeba64.py \
97
+ --do_train \
98
+ --n_epochs 50 \
99
+ --batch_size 512 \
100
+ --learning_rate 1e-4 \
101
+ --horizontal_flip
102
+ ```
103
+
104
+ Training code files included:
105
+ - `train_flow_matching_on_images.py` - For MNIST and CIFAR-10
106
+ - `train_celeba64.py` - For CelebA 64×64
107
 
108
  ## Usage
109
 
110
+ ### Load Model
111
+
112
+ ```python
113
+ import torch
114
+ from huggingface_hub import hf_hub_download
115
+
116
+ # Download checkpoint
117
+ ckpt_path = hf_hub_download(
118
+ repo_id="WayBob/FlowMatching-Unet-Celeb-64x64",
119
+ filename="celeba64/ckpt.pth"
120
+ )
121
+
122
+ # Load checkpoint
123
+ checkpoint = torch.load(ckpt_path, map_location="cuda")
124
+ ```
125
+
126
+ ### Inference (Sampling)
127
+
128
  ```python
129
  import torch
130
+ from flow_matching.models import UNetModel
131
+ from flow_matching.solver import ODESolver, ModelWrapper
132
+
133
+ device = "cuda"
134
+
135
+ # Create model (CelebA example)
136
+ flow = UNetModel(
137
+ dim=(3, 64, 64),
138
+ num_channels=128,
139
+ num_res_blocks=2,
140
+ num_classes=0,
141
+ class_cond=False,
142
+ ).to(device)
143
+
144
+ # Load weights
145
+ flow.load_state_dict(checkpoint)
146
+ flow.eval()
147
+
148
+ # Create solver
149
+ model_wrapper = ModelWrapper(flow)
150
+ solver = ODESolver(model_wrapper)
151
+
152
+ # Sample from Gaussian noise
153
+ batch_size = 4
154
+ x_init = torch.randn(batch_size, 3, 64, 64).to(device)
155
+ time_grid = torch.linspace(0, 1, 21).to(device) # 20 steps
156
+
157
+ with torch.no_grad():
158
+ samples = solver.sample(
159
+ x_init=x_init,
160
+ step_size=0.05,
161
+ method="euler",
162
+ time_grid=time_grid
163
+ )
164
+
165
+ # Denormalize from [-1, 1] to [0, 1]
166
+ samples = (samples + 1) / 2
167
+ samples = samples.clamp(0, 1)
168
+
169
+ # Save or visualize
170
+ from torchvision.utils import save_image
171
+ save_image(samples, "generated_faces.png", nrow=2)
172
+ ```
173
 
174
+ ### Conditional Generation (MNIST/CIFAR-10)
 
175
 
176
+ ```python
177
+ # For class-conditional models
178
+ flow = UNetModel(
179
+ dim=(3, 32, 32), # CIFAR-10
180
+ num_channels=64,
181
+ num_res_blocks=2,
182
+ num_classes=10,
183
+ class_cond=True,
184
+ ).to(device)
185
+
186
+ # Load CIFAR-10 checkpoint
187
+ ckpt = torch.load("cifar10/ckpt.pth")
188
+ flow.load_state_dict(ckpt)
189
+
190
+ # Generate specific class (e.g., class 3)
191
+ y = torch.tensor([3, 3, 3, 3]).to(device) # Batch of 4, all class 3
192
+
193
+ def ode_func(t, x):
194
+ return flow(x=x, t=t, y=y)
195
+
196
+ # Then use solver as before
197
  ```
198
 
199
+ ## Architecture Details
200
+
201
+ **UNet** based on OpenAI Guided Diffusion:
202
+ - Encoder-Decoder structure with skip connections
203
+ - ResNet blocks with GroupNorm
204
+ - Self-attention at multiple resolutions
205
+ - Time embedding via sinusoidal position encoding
206
+ - Optional class embedding for conditional generation
207
 
208
+ ## Flow Matching
209
 
210
+ Flow Matching learns a velocity field that transports samples from source to target:
 
 
 
 
 
211
 
212
+ $$\frac{dx}{dt} = v_\theta(x_t, t), \quad x_0 \sim \mathcal{N}(0, I), \quad x_1 \sim p_{data}$$
213
 
214
+ Training uses Conditional Flow Matching (CFM) with straight-line paths:
 
 
 
 
215
 
216
+ $$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[ \| v_\theta(x_t, t) - (x_1 - (1-\sigma)x_0) \|^2 \right]$$
217
+
218
+ ## Requirements
219
+
220
+ ```bash
221
+ pip install torch torchvision
222
+ pip install torchdiffeq einops
223
+ ```
224
 
225
  ## License
226
 
227
  CC BY-NC-SA 4.0 - Non-commercial use only.
228
+
229
+ ## Acknowledgments
230
+
231
+ - Training code based on [keishihara/flow-matching](https://github.com/keishihara/flow-matching.git)
232
+ - UNet architecture from [OpenAI Guided Diffusion](https://github.com/openai/guided-diffusion)
233
+
234
+ ## Citation
235
+
236
+ ```bibtex
237
+ @misc{flowmatching-unet-2024,
238
+ title={UNet Flow Matching Models for Image Generation},
239
+ author={WayBob},
240
+ year={2024},
241
+ howpublished={\url{https://huggingface.co/WayBob/FlowMatching-Unet-Celeb-64x64}}
242
+ }
243
+ ```
train_celeba64.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training script for CelebA 64x64 Flow Matching model.
3
+ Usage:
4
+ python train_celeba64.py --do_train --n_epochs 50 --batch_size 128
5
+ python train_celeba64.py --do_sample
6
+ """
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import matplotlib.animation as animation
12
+ import matplotlib.pyplot as plt
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor
16
+ from torch.amp import GradScaler
17
+ from torch.utils.data import DataLoader
18
+ from torchvision.utils import make_grid, save_image
19
+ from tqdm import tqdm as std_tqdm
20
+ from transformers import HfArgumentParser
21
+
22
+ from flow_matching.datasets.image_datasets import (
23
+ get_image_dataset,
24
+ get_test_transform,
25
+ get_train_transform,
26
+ )
27
+ from flow_matching.models import UNetModel
28
+ from flow_matching.sampler import PathSampler
29
+ from flow_matching.solver import ModelWrapper, ODESolver
30
+ from flow_matching.utils import model_size_summary, set_seed
31
+
32
+ tqdm = partial(std_tqdm, dynamic_ncols=True)
33
+
34
+
35
+ @dataclass
36
+ class ScriptArguments:
37
+ do_train: bool = False
38
+ do_sample: bool = False
39
+ dataset: str = "celeba"
40
+ image_size: int = 64 # Key parameter for CelebA
41
+ batch_size: int = 128
42
+ n_epochs: int = 50
43
+ learning_rate: float = 1e-4
44
+ sigma_min: float = 0.0
45
+ seed: int = 42
46
+ output_dir: str = "outputs"
47
+ horizontal_flip: bool = True # Important for faces
48
+
49
+
50
+ def train(args: ScriptArguments):
51
+ """Train the flow matching model on CelebA 64x64."""
52
+
53
+ output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
54
+ output_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ set_seed(args.seed)
58
+ print(f"Using device: {device}")
59
+ print(f"Training CelebA at {args.image_size}x{args.image_size} resolution")
60
+
61
+ # Load the dataset with resize
62
+ dataset = get_image_dataset(
63
+ args.dataset,
64
+ train=True,
65
+ transform=get_train_transform(
66
+ horizontal_flip=args.horizontal_flip,
67
+ image_size=args.image_size
68
+ ),
69
+ )
70
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4)
71
+ print(f"Loaded {args.dataset} dataset with {len(dataset):,} samples")
72
+
73
+ # CelebA doesn't have classes, so we set num_classes=0 and class_cond=False
74
+ input_shape = dataset[0][0].size()
75
+ print(f"{input_shape=}")
76
+
77
+ # Load the UNet model WITHOUT class conditioning for CelebA
78
+ flow = UNetModel(
79
+ input_shape,
80
+ num_channels=128, # Larger model for 64x64
81
+ num_res_blocks=2,
82
+ num_classes=0, # No class conditioning
83
+ class_cond=False,
84
+ ).to(device)
85
+ path_sampler = PathSampler(sigma_min=args.sigma_min)
86
+
87
+ # Load the optimizer
88
+ optimizer = torch.optim.AdamW(flow.parameters(), lr=args.learning_rate)
89
+ scaler = GradScaler(enabled=device.type == "cuda")
90
+ print("GradScaler enabled:", scaler._enabled)
91
+ model_size_summary(flow)
92
+
93
+ for epoch in range(args.n_epochs):
94
+ flow.train()
95
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1:2d}/{args.n_epochs}")
96
+
97
+ for x_1, _ in pbar: # CelebA returns (img, label) but we ignore label
98
+ x_1 = x_1.to(device)
99
+
100
+ # Compute the probability path samples
101
+ x_0 = torch.randn_like(x_1)
102
+ t = torch.rand(x_1.size(0), device=device, dtype=x_1.dtype)
103
+ x_t, dx_t = path_sampler.sample(x_0, x_1, t)
104
+
105
+ flow.zero_grad(set_to_none=True)
106
+
107
+ # Compute the conditional flow matching loss WITHOUT class conditioning
108
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
109
+ vf_t = flow(t=t, x=x_t) # No y parameter
110
+ loss = F.mse_loss(vf_t, dx_t)
111
+
112
+ # Gradient scaling and backprop
113
+ scaler.scale(loss).backward()
114
+ torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
115
+ scaler.step(optimizer)
116
+ scaler.update()
117
+
118
+ pbar.set_postfix({"loss": loss.item()})
119
+
120
+ torch.save(flow.state_dict(), output_dir / "ckpt.pth")
121
+ print(f"Final checkpoint saved to {output_dir / 'ckpt.pth'}")
122
+
123
+
124
+ def generate_samples_and_save_animation(args: ScriptArguments):
125
+ """Generate samples following the flow and save the animation."""
126
+
127
+ output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
128
+ assert output_dir.is_dir(), f"Output directory {output_dir} does not exist"
129
+
130
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
131
+ set_seed(args.seed)
132
+ print(f"Using device: {device}")
133
+
134
+ # Load the dataset
135
+ dataset = get_image_dataset(
136
+ args.dataset,
137
+ train=False,
138
+ transform=get_test_transform(image_size=args.image_size),
139
+ )
140
+ input_shape = dataset[0][0].size()
141
+
142
+ # Load the flow model
143
+ flow = UNetModel(
144
+ input_shape,
145
+ num_channels=128,
146
+ num_res_blocks=2,
147
+ num_classes=0,
148
+ class_cond=False,
149
+ ).to(device)
150
+ state_dict = torch.load(output_dir / "ckpt.pth", weights_only=True)
151
+ flow.load_state_dict(state_dict)
152
+ flow.eval()
153
+
154
+ # Use ODE solver to sample trajectories
155
+ class WrappedModel(ModelWrapper):
156
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
157
+ return self.model(x=x, t=t)
158
+
159
+ samples_count = 64 # 8x8 grid
160
+ sample_steps = 101
161
+ time_steps = torch.linspace(0, 1, sample_steps).to(device)
162
+
163
+ wrapped_model = WrappedModel(flow)
164
+ step_size = 0.05
165
+ x_init = torch.randn((samples_count, *input_shape), dtype=torch.float32, device=device)
166
+ solver = ODESolver(wrapped_model)
167
+ sol = solver.sample(
168
+ x_init=x_init,
169
+ step_size=step_size,
170
+ method="midpoint",
171
+ time_grid=time_steps,
172
+ return_intermediates=True,
173
+ )
174
+ sol = sol.detach().cpu()
175
+ final_samples = sol[-1]
176
+
177
+ save_image(final_samples, output_dir / "final_samples.png", nrow=8, normalize=True)
178
+
179
+ fig, ax = plt.subplots(1, 2, figsize=(10, 5))
180
+ grid = make_grid(final_samples, nrow=8, normalize=True)
181
+ ax[0].imshow(grid.permute(1, 2, 0))
182
+ ax[0].set_title("Final samples (t = 1.0)", fontsize=16)
183
+ ax[0].axis("off")
184
+
185
+ def update(frame: int):
186
+ grid = make_grid(sol[frame], nrow=8, normalize=True)
187
+ ax[1].clear()
188
+ ax[1].imshow(grid.permute(1, 2, 0))
189
+ ax[1].set_title(f"t = {time_steps[frame].item():.2f}", fontsize=16)
190
+ ax[1].axis("off")
191
+
192
+ fig.subplots_adjust(left=0.02, right=0.98, top=0.90, bottom=0.05, wspace=0.1)
193
+ ani = animation.FuncAnimation(fig, update, frames=sample_steps)
194
+ ani.save(output_dir / "trajectory.gif", writer="pillow", fps=20)
195
+ print(f"Generated trajectory saved to {output_dir / 'trajectory.gif'}")
196
+
197
+
198
+ if __name__ == "__main__":
199
+ parser = HfArgumentParser(ScriptArguments)
200
+ script_args, *_ = parser.parse_args_into_dataclasses()
201
+
202
+ if script_args.do_train:
203
+ train(script_args)
204
+
205
+ if script_args.do_sample:
206
+ generate_samples_and_save_animation(script_args)
207
+
train_flow_matching_on_images.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import matplotlib.animation as animation
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor
10
+ from torch.amp import GradScaler
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.utils import make_grid, save_image
13
+ from tqdm import tqdm as std_tqdm
14
+ from transformers import HfArgumentParser
15
+
16
+ from flow_matching.datasets.image_datasets import (
17
+ get_image_dataset,
18
+ get_test_transform,
19
+ get_train_transform,
20
+ )
21
+ from flow_matching.models import UNetModel
22
+ from flow_matching.sampler import PathSampler
23
+ from flow_matching.solver import ModelWrapper, ODESolver
24
+ from flow_matching.utils import model_size_summary, set_seed
25
+
26
+ tqdm = partial(std_tqdm, dynamic_ncols=True)
27
+
28
+
29
+ @dataclass
30
+ class ScriptArguments:
31
+ do_train: bool = False
32
+ do_sample: bool = False
33
+ dataset: str = "mnist"
34
+ batch_size: int = 128
35
+ n_epochs: int = 10
36
+ learning_rate: float = 1e-3
37
+ sigma_min: float = 0.0
38
+ seed: int = 42
39
+ output_dir: str = "outputs"
40
+ horizontal_flip: bool = False
41
+
42
+
43
+ def train(args: ScriptArguments):
44
+ """Train the flow matching model on the given dataset."""
45
+
46
+ output_dir = Path(args.output_dir) / "cfm" / args.dataset
47
+ output_dir.mkdir(parents=True, exist_ok=True)
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ set_seed(args.seed)
51
+ print(f"Using device: {device}")
52
+
53
+ # Load the dataset
54
+ dataset = get_image_dataset(
55
+ args.dataset,
56
+ train=True,
57
+ transform=get_train_transform(horizontal_flip=args.horizontal_flip),
58
+ )
59
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
60
+ print(f"Loaded {args.dataset} dataset with {len(dataset):,} samples")
61
+
62
+ num_classes = len(dataset.classes)
63
+ input_shape = dataset[0][0].size()
64
+ print(f"{input_shape=}, {num_classes=}")
65
+
66
+ # Load the UNet model with class conditioning for flow matching
67
+ flow = UNetModel(
68
+ input_shape,
69
+ num_channels=64,
70
+ num_res_blocks=2,
71
+ num_classes=num_classes,
72
+ class_cond=True,
73
+ ).to(device)
74
+ path_sampler = PathSampler(sigma_min=args.sigma_min)
75
+
76
+ # Load the optimizer
77
+ optimizer = torch.optim.AdamW(flow.parameters(), lr=args.learning_rate)
78
+ scaler = GradScaler(enabled=device.type == "cuda")
79
+ print("GradScaler enabled:", scaler._enabled)
80
+ model_size_summary(flow)
81
+
82
+ for epoch in range(args.n_epochs):
83
+ flow.train()
84
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1:2d}/{args.n_epochs}")
85
+
86
+ for x_1, y in pbar:
87
+ x_1, y = x_1.to(device), y.to(device)
88
+
89
+ # Compute the probability path samples
90
+ x_0 = torch.randn_like(x_1)
91
+ t = torch.rand(x_1.size(0), device=device, dtype=x_1.dtype)
92
+ x_t, dx_t = path_sampler.sample(x_0, x_1, t)
93
+
94
+ flow.zero_grad(set_to_none=True)
95
+
96
+ # Compute the conditional flow matching loss with class conditioning
97
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
98
+ vf_t = flow(t=t, x=x_t, y=y)
99
+ loss = F.mse_loss(vf_t, dx_t)
100
+
101
+ # Gradient scaling and backprop
102
+ scaler.scale(loss).backward()
103
+ torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0) # clip gradients
104
+ scaler.step(optimizer)
105
+ scaler.update()
106
+
107
+ pbar.set_postfix({"loss": loss.item()})
108
+
109
+ torch.save(flow.state_dict(), output_dir / "ckpt.pth")
110
+ print(f"Final checkpoint saved to {output_dir / 'ckpt.pth'}")
111
+
112
+
113
+ def generate_samples_and_save_animation(args: ScriptArguments):
114
+ """Generate samples following the flow and save the animation."""
115
+
116
+ output_dir = Path(args.output_dir) / "cfm" / args.dataset
117
+ assert output_dir.is_dir(), f"Output directory {output_dir} does not exist"
118
+
119
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
+ set_seed(args.seed)
121
+ print(f"Using device: {device}")
122
+
123
+ # Load the dataset
124
+ dataset = get_image_dataset(
125
+ args.dataset,
126
+ train=False,
127
+ transform=get_test_transform(),
128
+ )
129
+ input_shape = dataset[0][0].size()
130
+ num_classes = len(dataset.classes)
131
+
132
+ # Load the flow model
133
+ flow = UNetModel(
134
+ input_shape,
135
+ num_channels=64,
136
+ num_res_blocks=2,
137
+ num_classes=num_classes,
138
+ class_cond=True,
139
+ ).to(device)
140
+ state_dict = torch.load(output_dir / "ckpt.pth", weights_only=True)
141
+ flow.load_state_dict(state_dict)
142
+ flow.eval()
143
+
144
+ # Use ODE solver to sample trajectories
145
+ class WrappedModel(ModelWrapper):
146
+ def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
147
+ return self.model(x=x, t=t, **extras)
148
+
149
+ samples_per_class = 10
150
+ sample_steps = 101
151
+ time_steps = torch.linspace(0, 1, sample_steps).to(device)
152
+ class_list = torch.arange(num_classes, device=device).repeat(samples_per_class)
153
+
154
+ wrapped_model = WrappedModel(flow)
155
+ step_size = 0.05
156
+ x_init = torch.randn((class_list.size(0), *input_shape), dtype=torch.float32, device=device)
157
+ solver = ODESolver(wrapped_model)
158
+ sol = solver.sample(
159
+ x_init=x_init,
160
+ step_size=step_size,
161
+ method="midpoint",
162
+ time_grid=time_steps,
163
+ return_intermediates=True,
164
+ y=class_list,
165
+ )
166
+ sol = sol.detach().cpu()
167
+ final_samples = sol[-1]
168
+
169
+ save_image(final_samples, output_dir / "final_samples.png", nrow=num_classes, normalize=True)
170
+
171
+ fig, ax = plt.subplots(1, 2, figsize=(8, 4))
172
+ grid = make_grid(final_samples, nrow=num_classes, normalize=True)
173
+ ax[0].imshow(grid.permute(1, 2, 0))
174
+ ax[0].set_title("Final samples (t = 1.0)", fontsize=16)
175
+ ax[0].axis("off")
176
+
177
+ def update(frame: int):
178
+ grid = make_grid(sol[frame], nrow=num_classes, normalize=True)
179
+ ax[1].clear()
180
+ ax[1].imshow(grid.permute(1, 2, 0))
181
+ ax[1].set_title(f"t = {time_steps[frame].item():.2f}", fontsize=16)
182
+ ax[1].axis("off")
183
+
184
+ fig.subplots_adjust(left=0.02, right=0.98, top=0.90, bottom=0.05, wspace=0.1)
185
+ ani = animation.FuncAnimation(fig, update, frames=sample_steps)
186
+ ani.save(output_dir / "trajectory.gif", writer="pillow", fps=20)
187
+ print(f"Generated trajectory saved to {output_dir / 'trajectory.gif'}")
188
+
189
+
190
+ if __name__ == "__main__":
191
+ parser = HfArgumentParser(ScriptArguments)
192
+ script_args, *_ = parser.parse_args_into_dataclasses()
193
+
194
+ if script_args.do_train:
195
+ train(script_args)
196
+
197
+ if script_args.do_sample:
198
+ generate_samples_and_save_animation(script_args)