W1-4B-dLLM-Base / core /runner.py
Cynthiawhaletech's picture
Initial release: W1-4B dLLM Base
267f903
"""
SamplingRunner: wires model, tokenizer, schedule, and sampler together.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Callable, Optional
import torch
import yaml
from .loader import ModelWrapper, load_checkpoint
from .vocab import build_masked_input, decode_output, load_tokenizer
from ..samplers import get_sampler
from ..samplers.base import StepCallback
@dataclass
class RunConfig:
"""All settings for one generation run."""
# Required
ckpt_path: str
config_path: str
# Tokenizer
tokenizer_path: str = "whale-tokenizer"
# Sampler
sampler: str = "standard"
steps: int = 64
max_new_tokens: int = 256
temperature: float = 0.0
top_k: int = 0
p: float = 0.9
# Jump sampler extras
jump_last_steps: int = 10
jump_frac: float = 0.10
jump_min_tokens: int = 1
no_mask_jump: bool = True
# GIDD sampler extras
gidd_eps: float = 1e-4
gidd_min_p: float = 0.0
posterior_temperature: float = 1.0
suppress_mask_clean: bool = False
rho_mode: str = "w1_train_like"
gidd_exact_mode: bool = False
fail_on_negative_mass: bool = False
# Runtime
device: str = "cuda"
dtype: str = "bf16"
use_ema: bool = True
strict: bool = False
seed: Optional[int] = 1234
@dataclass
class GenerationResult:
full_text: str
new_text: str
prompt_tokens: int
generated_tokens: int
total_tokens: int
elapsed_s: float
sampler: str
steps_run: int
class SamplingRunner:
"""
Load once, call ``run()`` many times with different prompts.
Usage::
runner = SamplingRunner(cfg)
result = runner.run("The quick brown")
"""
def __init__(self, cfg: RunConfig):
self.cfg = cfg
config = _load_yaml(cfg.config_path)
device = torch.device(
cfg.device if cfg.device
else ("cuda" if torch.cuda.is_available() else "cpu")
)
self.model_wrapper = load_checkpoint(
ckpt_path=cfg.ckpt_path,
config=config,
device=device,
dtype=cfg.dtype,
use_ema=cfg.use_ema,
strict=cfg.strict,
)
self.tokenizer = load_tokenizer(cfg.tokenizer_path)
self.config = config
self.device = device
self.mask_token_id = self.model_wrapper.mask_token_id
self.vocab_size = self.model_wrapper.vocab_size
self.max_seq_len = int(config["model"]["max_seq_len"])
def run(
self,
prompt: str = "",
callback: Optional[StepCallback] = None,
) -> GenerationResult:
cfg = self.cfg
if cfg.seed is not None:
torch.manual_seed(cfg.seed)
if self.device.type == "cuda":
torch.cuda.manual_seed_all(cfg.seed)
x_init, prefix_len = build_masked_input(
tokenizer=self.tokenizer,
prompt=prompt,
max_new_tokens=cfg.max_new_tokens,
max_seq_len=self.max_seq_len,
mask_token_id=self.mask_token_id,
device=self.device,
)
total_len = x_init.shape[1]
timesteps = torch.linspace(1.0, 1e-4, cfg.steps, device=self.device)
sampler_cfg = {
"temperature": cfg.temperature,
"top_k": cfg.top_k,
"p": cfg.p,
"jump_last_steps": cfg.jump_last_steps,
"jump_frac": cfg.jump_frac,
"jump_min_tokens": cfg.jump_min_tokens,
"no_mask_jump": cfg.no_mask_jump,
"gidd_eps": cfg.gidd_eps,
"gidd_min_p": cfg.gidd_min_p,
"posterior_temperature": cfg.posterior_temperature,
"suppress_mask_clean": cfg.suppress_mask_clean,
"rho_mode": cfg.rho_mode,
"gidd_exact_mode": cfg.gidd_exact_mode,
"fail_on_negative_mass": cfg.fail_on_negative_mass,
}
sampler_fn = get_sampler(cfg.sampler)
t0 = time.perf_counter()
if self.device.type == "cuda":
torch.cuda.synchronize(self.device)
x_final = sampler_fn(
model_fn=self.model_wrapper,
x_init=x_init,
prefix_len=prefix_len,
vocab_size=self.vocab_size,
mask_token_id=self.mask_token_id,
timesteps=timesteps,
cfg=sampler_cfg,
callback=callback,
)
if self.device.type == "cuda":
torch.cuda.synchronize(self.device)
elapsed = time.perf_counter() - t0
full_text, new_text = decode_output(self.tokenizer, x_final, prefix_len)
generated_tokens = total_len - prefix_len
return GenerationResult(
full_text=full_text,
new_text=new_text,
prompt_tokens=prefix_len,
generated_tokens=generated_tokens,
total_tokens=total_len,
elapsed_s=elapsed,
sampler=cfg.sampler,
steps_run=cfg.steps,
)
def _load_yaml(path: str) -> dict:
with open(path, "r", encoding="utf-8") as f:
return yaml.safe_load(f) or {}