new-language_model / utils /sampling_utils.py
Raja-65's picture
Add ELF demo
6ab8280
Raw
History Blame Contribute Delete
10 kB
from typing import Optional
import torch
import torch.nn.functional as F
# ============================================
# Noise Schedulers (how to compute z from x0 and noise)
# ============================================
def add_noise(x0, noise, t, config, cond_seq_mask=None):
"""Flow-matching interpolation z = t*x0 + (1-t)*noise*scale, preserving cond tokens."""
t_expanded = t.reshape(-1, 1, 1)
z = t_expanded * x0 + (1 - t_expanded) * noise * config.denoiser_noise_scale
if cond_seq_mask is not None:
z = cond_seq_mask * x0 + (1 - cond_seq_mask) * z
return z
# ============================================
# Time Schedulers (how to sample t)
# ============================================
def sample_timesteps(
batch_size: int,
P_mean: float = -0.8,
P_std: float = 0.8,
time_schedule: str = 'logit_normal',
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
):
"""Sample timesteps using various time schedules.
Args:
batch_size: Number of samples
P_mean: Mean for logit-normal distribution
P_std: Std for logit-normal distribution
time_schedule: 'logit_normal' or 'uniform'
Returns:
Sampled timesteps in [0, 1]
"""
if time_schedule == 'logit_normal':
z = torch.randn((batch_size,), dtype=dtype, device=device) * P_std + P_mean
return torch.sigmoid(z)
if time_schedule == 'uniform':
return torch.rand((batch_size,), dtype=dtype, device=device)
raise ValueError(f"Unknown time_schedule: {time_schedule}")
def get_sampling_steps(
n_steps: int, time_schedule: str = "logit_normal",
P_mean: float = -0.8, P_std: float = 0.8,
device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Return a length-(n_steps+1) tensor of t values in [0, 1] for a sampling run.
- "uniform": evenly-spaced linspace from 0 to 1 (deterministic).
- "logit_normal": sorted logit-normal samples with 0 / 1 endpoints (random).
"""
if time_schedule == "uniform":
return torch.linspace(0.0, 1.0, n_steps + 1, dtype=dtype, device=device)
if time_schedule == "logit_normal":
steps = sample_timesteps(
batch_size=n_steps - 1,
P_mean=P_mean, P_std=P_std, time_schedule=time_schedule,
device=device, dtype=dtype,
)
steps = torch.sort(steps).values
endpoints_lo = torch.zeros((1,), dtype=dtype, device=steps.device)
endpoints_hi = torch.ones((1,), dtype=dtype, device=steps.device)
return torch.cat([endpoints_lo, steps, endpoints_hi], dim=0)
raise ValueError(f"Unknown time_schedule: {time_schedule}")
# ============================================
# CFG Scale Sampling (how to sample cfg scale)
# ============================================
def sample_cfg_scale(batch_size, cfg_min=0.0, cfg_max=3.0,
dtype=torch.float32, device=None):
"""Sample CFG scale from log-uniform distribution in [cfg_min, cfg_max]."""
u = torch.rand((batch_size,), dtype=dtype, device=device)
a = float(1.0 + cfg_min)
b = float(1.0 + cfg_max)
log_ratio = torch.tensor(b / a, dtype=dtype, device=u.device).log()
return a * torch.exp(u * log_ratio) - 1.0
# ============================================
# Conditioning helpers (preserve clean tokens during sampling)
# ============================================
def restore_cond(z_updated, cond_seq, cond_seq_mask):
"""Restore clean conditioning tokens in z after a denoising step."""
mask = cond_seq_mask
target_ndim = max(z_updated.dim(), cond_seq.dim())
while mask.dim() < target_ndim:
mask = mask.unsqueeze(-1)
return torch.where(mask > 0, cond_seq, z_updated)
def restore_vx(v, x, cond_seq, cond_seq_mask):
"""Restore cond positions: x -> clean cond_seq, v -> 0 (cond tokens don't move)."""
if cond_seq is not None:
x = restore_cond(x, cond_seq, cond_seq_mask)
v = restore_cond(v, torch.zeros_like(cond_seq), cond_seq_mask)
return v, x
# ============================================
# Flow-matching forward passes (with optional self-cond / CFG)
# ============================================
def net_out_to_v_x(net_out, z, t, t_eps=5e-2):
"""Convert x_pred network output to v and x.
When the model returns a tuple (denoised_output, decoder_logits),
decoder logits are discarded here (used separately in training).
"""
if isinstance(net_out, tuple):
net_out = net_out[0]
t_reshaped = t.reshape(-1, 1, 1)
x = net_out
denom = torch.clamp(1.0 - t_reshaped, min=t_eps)
v = (x - z) / denom
return v, x
def _forward_sample_self_cond(
model, z, t_batch, x_pred_prev, config,
self_cond_cfg_scale, cond_seq, cond_seq_mask,
):
"""Forward pass with self-conditioning."""
t_eps = config.t_eps
self_cond_prob = config.self_cond_prob
def _restore(v, x):
return restore_vx(v, x, cond_seq=cond_seq, cond_seq_mask=cond_seq_mask)
if config.num_self_cond_cfg_tokens > 0:
if x_pred_prev is None:
x_pred_prev = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)
z_input_cond = torch.cat([z, x_pred_prev], dim=-1)
self_cond_scale_batch = torch.full((z.shape[0],), float(self_cond_cfg_scale),
dtype=z.dtype, device=z.device)
net_out_cond = model(z_input_cond, t_batch, deterministic=True,
self_cond_cfg_scale=self_cond_scale_batch)
v_cond, x_cond = net_out_to_v_x(net_out_cond, z, t_batch, t_eps)
return _restore(v_cond, x_cond)
# No self-conditioning
if self_cond_prob == 0:
net_out = model(z, t_batch, deterministic=True)
v, x = net_out_to_v_x(net_out, z, t_batch, t_eps)
return _restore(v, x)
# Combined unconditional and conditional forward pass
v_uncond = x_uncond = None
if self_cond_cfg_scale != 1 or x_pred_prev is None:
z_uncond = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)
z_input_uncond = torch.cat([z, z_uncond], dim=-1)
net_out_uncond = model(z_input_uncond, t_batch, deterministic=True)
v_uncond, x_uncond = net_out_to_v_x(net_out_uncond, z, t_batch, t_eps)
v_uncond, x_uncond = _restore(v_uncond, x_uncond)
if self_cond_cfg_scale == 0.0 or x_pred_prev is None:
return v_uncond, x_uncond
z_input_cond = torch.cat([z, x_pred_prev], dim=-1)
net_out_cond = model(z_input_cond, t_batch, deterministic=True)
v_cond, x_cond = net_out_to_v_x(net_out_cond, z, t_batch, t_eps)
v_cond, x_cond = _restore(v_cond, x_cond)
if self_cond_cfg_scale == 1:
return v_cond, x_cond
v_out = v_uncond + self_cond_cfg_scale * (v_cond - v_uncond)
x_out = x_uncond + self_cond_cfg_scale * (x_cond - x_uncond)
return _restore(v_out, x_out)
def _forward_sample(
model, z, t_batch, x_pred_prev, config,
cfg_scale, self_cond_cfg_scale, cond_seq, cond_seq_mask,
):
"""Forward pass with optional self-conditioning and CFG."""
v_cond, x_cond = _forward_sample_self_cond(
model, z, t_batch, x_pred_prev, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
if cfg_scale == 1.0:
return v_cond, x_cond
# Unconditional forward: zero out cond prefix, no self-cond state, no restore
z_uncond = restore_cond(z, torch.zeros_like(z), cond_seq_mask)
x_pred_prev_uncond = (
None if x_pred_prev is None
else restore_cond(x_pred_prev, torch.zeros_like(x_pred_prev), cond_seq_mask)
)
v_uncond, x_uncond = _forward_sample_self_cond(
model, z_uncond, t_batch, x_pred_prev_uncond, config,
self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=torch.zeros_like(cond_seq), cond_seq_mask=cond_seq_mask,
)
v_out = v_uncond + cfg_scale * (v_cond - v_uncond)
x_out = x_uncond + cfg_scale * (x_cond - x_uncond)
return restore_vx(v_out, x_out, cond_seq, cond_seq_mask)
def _ode_step(
model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask,
):
"""Single ODE (Euler) step for sampling."""
t_batch = torch.full((z.shape[0],), float(t), dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
return z + (t_next - t) * v_pred, x_pred
def _sde_step(
model, z, t, t_next, x_pred_prev,
config, cfg_scale, self_cond_cfg_scale,
cond_seq, cond_seq_mask, gamma, generator,
):
"""Per-step SDE-style sampler with hybrid (t-and-step) noise scaling.
t_back = t * (1 - gamma * h), where h = t_next - t. alpha = 1 - gamma*h is the
signal-preservation fraction, constant in t. gamma=0 degenerates to a plain ODE step.
Uniform-N-step equivalence with old multiplicative gamma_old: gamma_hybrid = gamma_old * N.
"""
h = float(t_next - t)
alpha = max(0.0, min(1.0, 1.0 - gamma * h))
t_back = alpha * float(t)
if z.is_cuda:
eps = torch.randn(z.shape, dtype=z.dtype, device=z.device) * config.denoiser_noise_scale
else:
eps = torch.randn(z.shape, generator=generator, dtype=z.dtype) * config.denoiser_noise_scale
z_back = restore_cond(alpha * z + (1.0 - alpha) * eps, cond_seq, cond_seq_mask)
t_batch = torch.full((z.shape[0],), t_back, dtype=z.dtype, device=z.device)
v_pred, x_pred = _forward_sample(
model=model, z=z_back, t_batch=t_batch, x_pred_prev=x_pred_prev,
config=config, cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
return z_back + (t_next - t_back) * v_pred, x_pred