| """ |
| SamplingRunner: wires model, tokenizer, schedule, and sampler together. |
| """ |
| from __future__ import annotations |
|
|
| import time |
| from dataclasses import dataclass |
| from typing import Callable, Optional |
|
|
| import torch |
| import yaml |
|
|
| from .loader import ModelWrapper, load_checkpoint |
| from .vocab import build_masked_input, decode_output, load_tokenizer |
| from ..samplers import get_sampler |
| from ..samplers.base import StepCallback |
|
|
|
|
| @dataclass |
| class RunConfig: |
| """All settings for one generation run.""" |
| |
| ckpt_path: str |
| config_path: str |
|
|
| |
| tokenizer_path: str = "whale-tokenizer" |
|
|
| |
| sampler: str = "standard" |
| steps: int = 64 |
| max_new_tokens: int = 256 |
| temperature: float = 0.0 |
| top_k: int = 0 |
| p: float = 0.9 |
|
|
| |
| jump_last_steps: int = 10 |
| jump_frac: float = 0.10 |
| jump_min_tokens: int = 1 |
| no_mask_jump: bool = True |
|
|
| |
| gidd_eps: float = 1e-4 |
| gidd_min_p: float = 0.0 |
| posterior_temperature: float = 1.0 |
| suppress_mask_clean: bool = False |
| rho_mode: str = "w1_train_like" |
| gidd_exact_mode: bool = False |
| fail_on_negative_mass: bool = False |
|
|
| |
| device: str = "cuda" |
| dtype: str = "bf16" |
| use_ema: bool = True |
| strict: bool = False |
| seed: Optional[int] = 1234 |
|
|
|
|
| @dataclass |
| class GenerationResult: |
| full_text: str |
| new_text: str |
| prompt_tokens: int |
| generated_tokens: int |
| total_tokens: int |
| elapsed_s: float |
| sampler: str |
| steps_run: int |
|
|
|
|
| class SamplingRunner: |
| """ |
| Load once, call ``run()`` many times with different prompts. |
| |
| Usage:: |
| |
| runner = SamplingRunner(cfg) |
| result = runner.run("The quick brown") |
| """ |
|
|
| def __init__(self, cfg: RunConfig): |
| self.cfg = cfg |
|
|
| config = _load_yaml(cfg.config_path) |
| device = torch.device( |
| cfg.device if cfg.device |
| else ("cuda" if torch.cuda.is_available() else "cpu") |
| ) |
|
|
| self.model_wrapper = load_checkpoint( |
| ckpt_path=cfg.ckpt_path, |
| config=config, |
| device=device, |
| dtype=cfg.dtype, |
| use_ema=cfg.use_ema, |
| strict=cfg.strict, |
| ) |
|
|
| self.tokenizer = load_tokenizer(cfg.tokenizer_path) |
| self.config = config |
| self.device = device |
| self.mask_token_id = self.model_wrapper.mask_token_id |
| self.vocab_size = self.model_wrapper.vocab_size |
| self.max_seq_len = int(config["model"]["max_seq_len"]) |
|
|
| def run( |
| self, |
| prompt: str = "", |
| callback: Optional[StepCallback] = None, |
| ) -> GenerationResult: |
| cfg = self.cfg |
|
|
| if cfg.seed is not None: |
| torch.manual_seed(cfg.seed) |
| if self.device.type == "cuda": |
| torch.cuda.manual_seed_all(cfg.seed) |
|
|
| x_init, prefix_len = build_masked_input( |
| tokenizer=self.tokenizer, |
| prompt=prompt, |
| max_new_tokens=cfg.max_new_tokens, |
| max_seq_len=self.max_seq_len, |
| mask_token_id=self.mask_token_id, |
| device=self.device, |
| ) |
| total_len = x_init.shape[1] |
|
|
| timesteps = torch.linspace(1.0, 1e-4, cfg.steps, device=self.device) |
|
|
| sampler_cfg = { |
| "temperature": cfg.temperature, |
| "top_k": cfg.top_k, |
| "p": cfg.p, |
| "jump_last_steps": cfg.jump_last_steps, |
| "jump_frac": cfg.jump_frac, |
| "jump_min_tokens": cfg.jump_min_tokens, |
| "no_mask_jump": cfg.no_mask_jump, |
| "gidd_eps": cfg.gidd_eps, |
| "gidd_min_p": cfg.gidd_min_p, |
| "posterior_temperature": cfg.posterior_temperature, |
| "suppress_mask_clean": cfg.suppress_mask_clean, |
| "rho_mode": cfg.rho_mode, |
| "gidd_exact_mode": cfg.gidd_exact_mode, |
| "fail_on_negative_mass": cfg.fail_on_negative_mass, |
| } |
|
|
| sampler_fn = get_sampler(cfg.sampler) |
|
|
| t0 = time.perf_counter() |
| if self.device.type == "cuda": |
| torch.cuda.synchronize(self.device) |
|
|
| x_final = sampler_fn( |
| model_fn=self.model_wrapper, |
| x_init=x_init, |
| prefix_len=prefix_len, |
| vocab_size=self.vocab_size, |
| mask_token_id=self.mask_token_id, |
| timesteps=timesteps, |
| cfg=sampler_cfg, |
| callback=callback, |
| ) |
|
|
| if self.device.type == "cuda": |
| torch.cuda.synchronize(self.device) |
| elapsed = time.perf_counter() - t0 |
|
|
| full_text, new_text = decode_output(self.tokenizer, x_final, prefix_len) |
| generated_tokens = total_len - prefix_len |
|
|
| return GenerationResult( |
| full_text=full_text, |
| new_text=new_text, |
| prompt_tokens=prefix_len, |
| generated_tokens=generated_tokens, |
| total_tokens=total_len, |
| elapsed_s=elapsed, |
| sampler=cfg.sampler, |
| steps_run=cfg.steps, |
| ) |
|
|
|
|
| def _load_yaml(path: str) -> dict: |
| with open(path, "r", encoding="utf-8") as f: |
| return yaml.safe_load(f) or {} |
|
|