new-language_model / utils /generation_utils.py
Raja-65's picture
Add ELF demo
6ab8280
Raw
History Blame Contribute Delete
5.85 kB
from typing import Optional
import torch
import torch.nn as nn
from configs.config import Config, SamplingConfig
from utils.sampling_utils import restore_cond, _ode_step, _sde_step
# ============================================
# Generation utilities
# ============================================
def mask_after_eos(predicted_ids: torch.Tensor, eos_token_id: int, pad_token_id: int) -> torch.Tensor:
"""Mask everything at/after first EOS token per sequence."""
eos_mask = (predicted_ids == eos_token_id)
keep_mask = (eos_mask.to(torch.int32).cumsum(dim=1) == 0)
return torch.where(keep_mask, predicted_ids, torch.full_like(predicted_ids, pad_token_id))
def shift_left(x: torch.Tensor, shift_per_sample: torch.Tensor, pad_value=0, axis: int = 1) -> torch.Tensor:
"""Shift each sample left along the sequence axis; pad emptied positions."""
if x.dim() < 2:
raise ValueError("x must have at least batch and sequence dimensions")
if axis < 0:
axis = x.dim() + axis
if axis == 0:
raise ValueError("axis=0 is the batch axis and cannot be shifted")
shift_per_sample = shift_per_sample.to(torch.long)
if axis != 1:
x = x.movedim(axis, 1)
seq_len = x.shape[1]
base_idx = torch.arange(seq_len, device=x.device)[None, :]
gather_idx = shift_per_sample[:, None].to(x.device) + base_idx
valid = gather_idx < seq_len
gather_idx = gather_idx.clamp(0, seq_len - 1)
if x.dim() == 2:
shifted = torch.gather(x, 1, gather_idx)
shifted = torch.where(valid, shifted, torch.full_like(shifted, pad_value))
else:
expand_shape = [-1, -1] + list(x.shape[2:])
idx = gather_idx.view(*gather_idx.shape, *([1] * (x.dim() - 2))).expand(*expand_shape)
valid_b = valid.view(*valid.shape, *([1] * (x.dim() - 2))).expand(*expand_shape)
shifted = torch.gather(x, 1, idx)
shifted = torch.where(valid_b, shifted, torch.full_like(shifted, pad_value))
if axis != 1:
shifted = shifted.movedim(1, axis)
return shifted
# ============================================
# Single-batch sampling (PyTorch)
# ============================================
@torch.no_grad()
def _generate_samples_single_batch(
model: nn.Module,
generator: torch.Generator,
z: torch.Tensor,
t_steps: torch.Tensor,
cond_seq: Optional[torch.Tensor],
cond_seq_mask: Optional[torch.Tensor],
config: Config,
sampling_config: SamplingConfig,
cfg_scale: float,
self_cond_cfg_scale: float,
) -> torch.Tensor:
"""Generate samples for a single batch (PyTorch Euler / SDE rollout)."""
method = sampling_config.sampling_method
batch_size, max_length, d_model = z.shape
if cond_seq is None:
cond_seq = torch.zeros((batch_size, max_length, d_model), dtype=z.dtype, device=z.device)
cond_seq_mask = torch.zeros((batch_size, max_length), dtype=z.dtype, device=z.device)
step_kwargs = dict(
model=model, config=config,
cfg_scale=cfg_scale, self_cond_cfg_scale=self_cond_cfg_scale,
cond_seq=cond_seq, cond_seq_mask=cond_seq_mask,
)
z = restore_cond(z, cond_seq, cond_seq_mask)
x_pred = restore_cond(torch.zeros_like(z), cond_seq, cond_seq_mask)
n = t_steps.shape[0]
sde_gamma = getattr(sampling_config, "sde_gamma", 0.0)
use_bf16 = bool(getattr(config, "use_bf16", True)) and z.is_cuda
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
for i in range(n - 2):
t = t_steps[i].item()
t_next = t_steps[i + 1].item()
if method == "sde":
z, x_pred = _sde_step(
z=z, t=t, t_next=t_next, x_pred_prev=x_pred,
gamma=sde_gamma, generator=generator, **step_kwargs,
)
elif method == "ode":
z, x_pred = _ode_step(z=z, t=t, t_next=t_next, x_pred_prev=x_pred, **step_kwargs)
else:
raise ValueError(f"Invalid sampling method: {method}")
# Last step always with ODE.
t = t_steps[-2].item()
t_next = t_steps[-1].item()
z, x_pred = _ode_step(z=z, t=t, t_next=t_next, x_pred_prev=x_pred, **step_kwargs)
return z
@torch.no_grad()
def _dlm_decode_batch(z: torch.Tensor, model: nn.Module, t_final_val,
config, self_cond_cfg_scale: float) -> torch.Tensor:
"""Decode z -> tokens with the DLM decoder head."""
batch_size = z.shape[0]
if isinstance(t_final_val, torch.Tensor) and t_final_val.dim() == 0:
t_final = torch.full((batch_size,), t_final_val.item(), dtype=z.dtype, device=z.device)
else:
t_final = torch.full((batch_size,), float(t_final_val), dtype=z.dtype, device=z.device)
sc_batch = (
torch.full((batch_size,), float(self_cond_cfg_scale), dtype=z.dtype, device=z.device)
if config.num_self_cond_cfg_tokens > 0 else None
)
z_input = torch.cat([z, torch.zeros_like(z)], dim=-1) if config.self_cond_prob > 0 else z
use_bf16 = bool(getattr(config, "use_bf16", True)) and z.is_cuda
with torch.amp.autocast('cuda', dtype=torch.bfloat16, enabled=use_bf16):
_, decoder_logits = model(
z_input, t_final, deterministic=True,
self_cond_cfg_scale=sc_batch,
decoder_step_active=True,
)
return decoder_logits.argmax(dim=-1)
def _build_run_name(sampling_method, num_sampling_steps, cfg_scale, self_cond_cfg_scale,
time_schedule, sde_gamma, suffix):
ts_str = f"-ts_{time_schedule}"
sccfg_str = f"-sccfg{self_cond_cfg_scale}" if self_cond_cfg_scale != 1.0 else ""
sde_str = f"-gamma{sde_gamma}" if sampling_method == "sde" else ""
return f"{sampling_method}-steps{num_sampling_steps}-cfg{cfg_scale}{sccfg_str}{ts_str}{sde_str}-{suffix}"