#!/usr/bin/env python3 """CLI: run one sampler on a prompt and print the result.""" from __future__ import annotations import argparse from pathlib import Path import torch from whale4b.core.runner import RunConfig, SamplingRunner from whale4b.samplers import list_samplers def parse_args(): p = argparse.ArgumentParser(description="Whale3B diffusion LM sampler.") p.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pt") p.add_argument("--config", default=str(Path(__file__).parent / "configs" / "whale3b.yaml")) p.add_argument("--tokenizer", default=str(Path(__file__).parent / "whale-tokenizer")) p.add_argument("--prompt", default="") p.add_argument("--sampler", default="standard", choices=list_samplers()) p.add_argument("--steps", type=int, default=64) p.add_argument("--max-new-tokens", type=int, default=256) p.add_argument("--temperature", type=float, default=0.0) p.add_argument("--top-k", type=int, default=0) p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) p.add_argument("--seed", type=int, default=1234) p.add_argument("--no-ema", action="store_true") return p.parse_args() def main(): args = parse_args() cfg = RunConfig( ckpt_path=args.checkpoint, config_path=args.config, tokenizer_path=args.tokenizer, sampler=args.sampler, steps=args.steps, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, device=args.device, dtype=args.dtype, seed=args.seed, use_ema=not args.no_ema, ) runner = SamplingRunner(cfg) result = runner.run(prompt=args.prompt) print(f"\n=== CONTINUATION ===\n{result.new_text}") print(f"\n=== STATS ===") print( f"sampler={result.sampler} | steps={result.steps_run} | " f"tokens={result.generated_tokens} | elapsed={result.elapsed_s:.2f}s" ) if __name__ == "__main__": main()