""" SamplingRunner: wires model, tokenizer, schedule, and sampler together. """ from __future__ import annotations import time from dataclasses import dataclass from typing import Callable, Optional import torch import yaml from .loader import ModelWrapper, load_checkpoint from .vocab import build_masked_input, decode_output, load_tokenizer from ..samplers import get_sampler from ..samplers.base import StepCallback @dataclass class RunConfig: """All settings for one generation run.""" # Required ckpt_path: str config_path: str # Tokenizer tokenizer_path: str = "whale-tokenizer" # Sampler sampler: str = "standard" steps: int = 64 max_new_tokens: int = 256 temperature: float = 0.0 top_k: int = 0 p: float = 0.9 # Jump sampler extras jump_last_steps: int = 10 jump_frac: float = 0.10 jump_min_tokens: int = 1 no_mask_jump: bool = True # GIDD sampler extras gidd_eps: float = 1e-4 gidd_min_p: float = 0.0 posterior_temperature: float = 1.0 suppress_mask_clean: bool = False rho_mode: str = "w1_train_like" gidd_exact_mode: bool = False fail_on_negative_mass: bool = False # Runtime device: str = "cuda" dtype: str = "bf16" use_ema: bool = True strict: bool = False seed: Optional[int] = 1234 @dataclass class GenerationResult: full_text: str new_text: str prompt_tokens: int generated_tokens: int total_tokens: int elapsed_s: float sampler: str steps_run: int class SamplingRunner: """ Load once, call ``run()`` many times with different prompts. Usage:: runner = SamplingRunner(cfg) result = runner.run("The quick brown") """ def __init__(self, cfg: RunConfig): self.cfg = cfg config = _load_yaml(cfg.config_path) device = torch.device( cfg.device if cfg.device else ("cuda" if torch.cuda.is_available() else "cpu") ) self.model_wrapper = load_checkpoint( ckpt_path=cfg.ckpt_path, config=config, device=device, dtype=cfg.dtype, use_ema=cfg.use_ema, strict=cfg.strict, ) self.tokenizer = load_tokenizer(cfg.tokenizer_path) self.config = config self.device = device self.mask_token_id = self.model_wrapper.mask_token_id self.vocab_size = self.model_wrapper.vocab_size self.max_seq_len = int(config["model"]["max_seq_len"]) def run( self, prompt: str = "", callback: Optional[StepCallback] = None, ) -> GenerationResult: cfg = self.cfg if cfg.seed is not None: torch.manual_seed(cfg.seed) if self.device.type == "cuda": torch.cuda.manual_seed_all(cfg.seed) x_init, prefix_len = build_masked_input( tokenizer=self.tokenizer, prompt=prompt, max_new_tokens=cfg.max_new_tokens, max_seq_len=self.max_seq_len, mask_token_id=self.mask_token_id, device=self.device, ) total_len = x_init.shape[1] timesteps = torch.linspace(1.0, 1e-4, cfg.steps, device=self.device) sampler_cfg = { "temperature": cfg.temperature, "top_k": cfg.top_k, "p": cfg.p, "jump_last_steps": cfg.jump_last_steps, "jump_frac": cfg.jump_frac, "jump_min_tokens": cfg.jump_min_tokens, "no_mask_jump": cfg.no_mask_jump, "gidd_eps": cfg.gidd_eps, "gidd_min_p": cfg.gidd_min_p, "posterior_temperature": cfg.posterior_temperature, "suppress_mask_clean": cfg.suppress_mask_clean, "rho_mode": cfg.rho_mode, "gidd_exact_mode": cfg.gidd_exact_mode, "fail_on_negative_mass": cfg.fail_on_negative_mass, } sampler_fn = get_sampler(cfg.sampler) t0 = time.perf_counter() if self.device.type == "cuda": torch.cuda.synchronize(self.device) x_final = sampler_fn( model_fn=self.model_wrapper, x_init=x_init, prefix_len=prefix_len, vocab_size=self.vocab_size, mask_token_id=self.mask_token_id, timesteps=timesteps, cfg=sampler_cfg, callback=callback, ) if self.device.type == "cuda": torch.cuda.synchronize(self.device) elapsed = time.perf_counter() - t0 full_text, new_text = decode_output(self.tokenizer, x_final, prefix_len) generated_tokens = total_len - prefix_len return GenerationResult( full_text=full_text, new_text=new_text, prompt_tokens=prefix_len, generated_tokens=generated_tokens, total_tokens=total_len, elapsed_s=elapsed, sampler=cfg.sampler, steps_run=cfg.steps, ) def _load_yaml(path: str) -> dict: with open(path, "r", encoding="utf-8") as f: return yaml.safe_load(f) or {}