Add sampling.py — Euler and Heun ODE samplers
Browse files- liquidflow/sampling.py +91 -0
liquidflow/sampling.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sampling / inference for LiquidFlow.
|
| 3 |
+
|
| 4 |
+
Uses ODE integration (Euler or Heun's method) to solve:
|
| 5 |
+
x_{t+dt} = x_t + v_θ(x_t, t) * dt
|
| 6 |
+
|
| 7 |
+
Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.no_grad()
|
| 16 |
+
def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0):
|
| 17 |
+
"""
|
| 18 |
+
Generate images using Euler method ODE integration.
|
| 19 |
+
|
| 20 |
+
Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1
|
| 21 |
+
x_0 ~ N(0, I) → x_1 = generated image
|
| 22 |
+
"""
|
| 23 |
+
B = shape[0]
|
| 24 |
+
x = torch.randn(shape, device=device)
|
| 25 |
+
dt = 1.0 / num_steps
|
| 26 |
+
|
| 27 |
+
for i in range(num_steps):
|
| 28 |
+
t = torch.full((B,), i * dt, device=device)
|
| 29 |
+
v = model(x, t, class_label)
|
| 30 |
+
if cfg_scale > 0 and class_label is not None:
|
| 31 |
+
v_uncond = model(x, t, None)
|
| 32 |
+
v = v_uncond + cfg_scale * (v - v_uncond)
|
| 33 |
+
x = x + v * dt
|
| 34 |
+
|
| 35 |
+
return x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0):
|
| 40 |
+
"""
|
| 41 |
+
Generate images using Heun's method (2nd order) ODE integration.
|
| 42 |
+
More accurate than Euler. Each step costs 2 model evaluations.
|
| 43 |
+
"""
|
| 44 |
+
B = shape[0]
|
| 45 |
+
x = torch.randn(shape, device=device)
|
| 46 |
+
dt = 1.0 / num_steps
|
| 47 |
+
|
| 48 |
+
def get_v(x_in, t_in):
|
| 49 |
+
v = model(x_in, t_in, class_label)
|
| 50 |
+
if cfg_scale > 0 and class_label is not None:
|
| 51 |
+
v_uncond = model(x_in, t_in, None)
|
| 52 |
+
v = v_uncond + cfg_scale * (v - v_uncond)
|
| 53 |
+
return v
|
| 54 |
+
|
| 55 |
+
for i in range(num_steps):
|
| 56 |
+
t = torch.full((B,), i * dt, device=device)
|
| 57 |
+
t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device)
|
| 58 |
+
k1 = get_v(x, t)
|
| 59 |
+
x_hat = x + dt * k1
|
| 60 |
+
if i < num_steps - 1:
|
| 61 |
+
k2 = get_v(x_hat, t_next)
|
| 62 |
+
x = x + dt * 0.5 * (k1 + k2)
|
| 63 |
+
else:
|
| 64 |
+
x = x + dt * k1
|
| 65 |
+
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def generate_grid(model, num_images=16, num_steps=50, img_size=128,
|
| 71 |
+
device='cpu', class_label=None, cfg_scale=0.0, method='euler'):
|
| 72 |
+
"""Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1]."""
|
| 73 |
+
shape = (num_images, 3, img_size, img_size)
|
| 74 |
+
if method == 'euler':
|
| 75 |
+
images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale)
|
| 76 |
+
elif method == 'heun':
|
| 77 |
+
images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale)
|
| 78 |
+
else:
|
| 79 |
+
raise ValueError(f"Unknown method: {method}")
|
| 80 |
+
return images.clamp(-1, 1) * 0.5 + 0.5
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def make_grid_image(images, nrow=4, padding=2):
|
| 84 |
+
"""Arrange images into a grid. Returns a PIL Image."""
|
| 85 |
+
from torchvision.utils import make_grid
|
| 86 |
+
from PIL import Image
|
| 87 |
+
import numpy as np
|
| 88 |
+
grid = make_grid(images, nrow=nrow, padding=padding, normalize=False)
|
| 89 |
+
grid = grid.permute(1, 2, 0).cpu().numpy()
|
| 90 |
+
grid = (grid * 255).clip(0, 255).astype(np.uint8)
|
| 91 |
+
return Image.fromarray(grid)
|