FlowMatching-Unet-Celeb-64x64 / train_celeba64.py
WayBob's picture
Upload folder using huggingface_hub
4bc5a15 verified
"""
Training script for CelebA 64x64 Flow Matching model.
Usage:
python train_celeba64.py --do_train --n_epochs 50 --batch_size 128
python train_celeba64.py --do_sample
"""
from dataclasses import dataclass
from functools import partial
from pathlib import Path
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.amp import GradScaler
from torch.utils.data import DataLoader
from torchvision.utils import make_grid, save_image
from tqdm import tqdm as std_tqdm
from transformers import HfArgumentParser
from flow_matching.datasets.image_datasets import (
get_image_dataset,
get_test_transform,
get_train_transform,
)
from flow_matching.models import UNetModel
from flow_matching.sampler import PathSampler
from flow_matching.solver import ModelWrapper, ODESolver
from flow_matching.utils import model_size_summary, set_seed
tqdm = partial(std_tqdm, dynamic_ncols=True)
@dataclass
class ScriptArguments:
do_train: bool = False
do_sample: bool = False
dataset: str = "celeba"
image_size: int = 64 # Key parameter for CelebA
batch_size: int = 128
n_epochs: int = 50
learning_rate: float = 1e-4
sigma_min: float = 0.0
seed: int = 42
output_dir: str = "outputs"
horizontal_flip: bool = True # Important for faces
def train(args: ScriptArguments):
"""Train the flow matching model on CelebA 64x64."""
output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
output_dir.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(args.seed)
print(f"Using device: {device}")
print(f"Training CelebA at {args.image_size}x{args.image_size} resolution")
# Load the dataset with resize
dataset = get_image_dataset(
args.dataset,
train=True,
transform=get_train_transform(
horizontal_flip=args.horizontal_flip,
image_size=args.image_size
),
)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4)
print(f"Loaded {args.dataset} dataset with {len(dataset):,} samples")
# CelebA doesn't have classes, so we set num_classes=0 and class_cond=False
input_shape = dataset[0][0].size()
print(f"{input_shape=}")
# Load the UNet model WITHOUT class conditioning for CelebA
flow = UNetModel(
input_shape,
num_channels=128, # Larger model for 64x64
num_res_blocks=2,
num_classes=0, # No class conditioning
class_cond=False,
).to(device)
path_sampler = PathSampler(sigma_min=args.sigma_min)
# Load the optimizer
optimizer = torch.optim.AdamW(flow.parameters(), lr=args.learning_rate)
scaler = GradScaler(enabled=device.type == "cuda")
print("GradScaler enabled:", scaler._enabled)
model_size_summary(flow)
for epoch in range(args.n_epochs):
flow.train()
pbar = tqdm(dataloader, desc=f"Epoch {epoch+1:2d}/{args.n_epochs}")
for x_1, _ in pbar: # CelebA returns (img, label) but we ignore label
x_1 = x_1.to(device)
# Compute the probability path samples
x_0 = torch.randn_like(x_1)
t = torch.rand(x_1.size(0), device=device, dtype=x_1.dtype)
x_t, dx_t = path_sampler.sample(x_0, x_1, t)
flow.zero_grad(set_to_none=True)
# Compute the conditional flow matching loss WITHOUT class conditioning
with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
vf_t = flow(t=t, x=x_t) # No y parameter
loss = F.mse_loss(vf_t, dx_t)
# Gradient scaling and backprop
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
pbar.set_postfix({"loss": loss.item()})
torch.save(flow.state_dict(), output_dir / "ckpt.pth")
print(f"Final checkpoint saved to {output_dir / 'ckpt.pth'}")
def generate_samples_and_save_animation(args: ScriptArguments):
"""Generate samples following the flow and save the animation."""
output_dir = Path(args.output_dir) / "cfm" / f"{args.dataset}{args.image_size}"
assert output_dir.is_dir(), f"Output directory {output_dir} does not exist"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(args.seed)
print(f"Using device: {device}")
# Load the dataset
dataset = get_image_dataset(
args.dataset,
train=False,
transform=get_test_transform(image_size=args.image_size),
)
input_shape = dataset[0][0].size()
# Load the flow model
flow = UNetModel(
input_shape,
num_channels=128,
num_res_blocks=2,
num_classes=0,
class_cond=False,
).to(device)
state_dict = torch.load(output_dir / "ckpt.pth", weights_only=True)
flow.load_state_dict(state_dict)
flow.eval()
# Use ODE solver to sample trajectories
class WrappedModel(ModelWrapper):
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
return self.model(x=x, t=t)
samples_count = 64 # 8x8 grid
sample_steps = 101
time_steps = torch.linspace(0, 1, sample_steps).to(device)
wrapped_model = WrappedModel(flow)
step_size = 0.05
x_init = torch.randn((samples_count, *input_shape), dtype=torch.float32, device=device)
solver = ODESolver(wrapped_model)
sol = solver.sample(
x_init=x_init,
step_size=step_size,
method="midpoint",
time_grid=time_steps,
return_intermediates=True,
)
sol = sol.detach().cpu()
final_samples = sol[-1]
save_image(final_samples, output_dir / "final_samples.png", nrow=8, normalize=True)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
grid = make_grid(final_samples, nrow=8, normalize=True)
ax[0].imshow(grid.permute(1, 2, 0))
ax[0].set_title("Final samples (t = 1.0)", fontsize=16)
ax[0].axis("off")
def update(frame: int):
grid = make_grid(sol[frame], nrow=8, normalize=True)
ax[1].clear()
ax[1].imshow(grid.permute(1, 2, 0))
ax[1].set_title(f"t = {time_steps[frame].item():.2f}", fontsize=16)
ax[1].axis("off")
fig.subplots_adjust(left=0.02, right=0.98, top=0.90, bottom=0.05, wspace=0.1)
ani = animation.FuncAnimation(fig, update, frames=sample_steps)
ani.save(output_dir / "trajectory.gif", writer="pillow", fps=20)
print(f"Generated trajectory saved to {output_dir / 'trajectory.gif'}")
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args, *_ = parser.parse_args_into_dataclasses()
if script_args.do_train:
train(script_args)
if script_args.do_sample:
generate_samples_and_save_animation(script_args)