krystv commited on
Commit
5614582
·
verified ·
1 Parent(s): 2b4ad8c

Add sampling.py — Euler and Heun ODE samplers

Browse files
Files changed (1) hide show
  1. liquidflow/sampling.py +91 -0
liquidflow/sampling.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sampling / inference for LiquidFlow.
3
+
4
+ Uses ODE integration (Euler or Heun's method) to solve:
5
+ x_{t+dt} = x_t + v_θ(x_t, t) * dt
6
+
7
+ Starting from x_0 ~ N(0, I) and integrating to x_1 (clean image).
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+
14
+
15
+ @torch.no_grad()
16
+ def euler_sample(model, shape, num_steps=50, device='cpu', class_label=None, cfg_scale=0.0):
17
+ """
18
+ Generate images using Euler method ODE integration.
19
+
20
+ Flow matching: integrate dx/dt = v_θ(x_t, t) from t=0 to t=1
21
+ x_0 ~ N(0, I) → x_1 = generated image
22
+ """
23
+ B = shape[0]
24
+ x = torch.randn(shape, device=device)
25
+ dt = 1.0 / num_steps
26
+
27
+ for i in range(num_steps):
28
+ t = torch.full((B,), i * dt, device=device)
29
+ v = model(x, t, class_label)
30
+ if cfg_scale > 0 and class_label is not None:
31
+ v_uncond = model(x, t, None)
32
+ v = v_uncond + cfg_scale * (v - v_uncond)
33
+ x = x + v * dt
34
+
35
+ return x
36
+
37
+
38
+ @torch.no_grad()
39
+ def heun_sample(model, shape, num_steps=25, device='cpu', class_label=None, cfg_scale=0.0):
40
+ """
41
+ Generate images using Heun's method (2nd order) ODE integration.
42
+ More accurate than Euler. Each step costs 2 model evaluations.
43
+ """
44
+ B = shape[0]
45
+ x = torch.randn(shape, device=device)
46
+ dt = 1.0 / num_steps
47
+
48
+ def get_v(x_in, t_in):
49
+ v = model(x_in, t_in, class_label)
50
+ if cfg_scale > 0 and class_label is not None:
51
+ v_uncond = model(x_in, t_in, None)
52
+ v = v_uncond + cfg_scale * (v - v_uncond)
53
+ return v
54
+
55
+ for i in range(num_steps):
56
+ t = torch.full((B,), i * dt, device=device)
57
+ t_next = torch.full((B,), min((i + 1) * dt, 1.0), device=device)
58
+ k1 = get_v(x, t)
59
+ x_hat = x + dt * k1
60
+ if i < num_steps - 1:
61
+ k2 = get_v(x_hat, t_next)
62
+ x = x + dt * 0.5 * (k1 + k2)
63
+ else:
64
+ x = x + dt * k1
65
+
66
+ return x
67
+
68
+
69
+ @torch.no_grad()
70
+ def generate_grid(model, num_images=16, num_steps=50, img_size=128,
71
+ device='cpu', class_label=None, cfg_scale=0.0, method='euler'):
72
+ """Generate a grid of images. Returns (B, C, H, W) tensor in [0, 1]."""
73
+ shape = (num_images, 3, img_size, img_size)
74
+ if method == 'euler':
75
+ images = euler_sample(model, shape, num_steps, device, class_label, cfg_scale)
76
+ elif method == 'heun':
77
+ images = heun_sample(model, shape, num_steps, device, class_label, cfg_scale)
78
+ else:
79
+ raise ValueError(f"Unknown method: {method}")
80
+ return images.clamp(-1, 1) * 0.5 + 0.5
81
+
82
+
83
+ def make_grid_image(images, nrow=4, padding=2):
84
+ """Arrange images into a grid. Returns a PIL Image."""
85
+ from torchvision.utils import make_grid
86
+ from PIL import Image
87
+ import numpy as np
88
+ grid = make_grid(images, nrow=nrow, padding=padding, normalize=False)
89
+ grid = grid.permute(1, 2, 0).cpu().numpy()
90
+ grid = (grid * 255).clip(0, 255).astype(np.uint8)
91
+ return Image.fromarray(grid)