""" whale4b -- open-source inference for the W1-4B diffusion language model. """ from __future__ import annotations from pathlib import Path from typing import Optional _PKG_DIR = Path(__file__).resolve().parent # Re-export key classes for advanced users from .core.model import LangDiT, create_model from .core.loader import load_checkpoint, ModelWrapper from .core.runner import SamplingRunner, RunConfig, GenerationResult from .samplers import get_sampler, list_samplers def generate( checkpoint: str, prompt: str = "", *, sampler: str = "standard", steps: int = 64, max_new_tokens: int = 256, temperature: float = 0.0, top_k: int = 0, device: str = "", dtype: str = "bf16", seed: Optional[int] = 1234, config: Optional[str] = None, tokenizer_path: Optional[str] = None, ) -> str: """ One-call convenience API. Returns the generated continuation text. For repeated calls with the same model, instantiate :class:`SamplingRunner` directly to avoid reloading weights each time. """ cfg = RunConfig( ckpt_path=checkpoint, config_path=config or str(_PKG_DIR / "configs" / "whale3b.yaml"), tokenizer_path=tokenizer_path or str(_PKG_DIR / "whale-tokenizer"), sampler=sampler, steps=steps, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, device=device, dtype=dtype, seed=seed, ) runner = SamplingRunner(cfg) result = runner.run(prompt) return result.new_text