| |
| """CLI: run one sampler on a prompt and print the result.""" |
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
|
|
| from whale4b.core.runner import RunConfig, SamplingRunner |
| from whale4b.samplers import list_samplers |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Whale3B diffusion LM sampler.") |
| p.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pt") |
| p.add_argument("--config", default=str(Path(__file__).parent / "configs" / "whale3b.yaml")) |
| p.add_argument("--tokenizer", default=str(Path(__file__).parent / "whale-tokenizer")) |
| p.add_argument("--prompt", default="") |
| p.add_argument("--sampler", default="standard", choices=list_samplers()) |
| p.add_argument("--steps", type=int, default=64) |
| p.add_argument("--max-new-tokens", type=int, default=256) |
| p.add_argument("--temperature", type=float, default=0.0) |
| p.add_argument("--top-k", type=int, default=0) |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"]) |
| p.add_argument("--seed", type=int, default=1234) |
| p.add_argument("--no-ema", action="store_true") |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| cfg = RunConfig( |
| ckpt_path=args.checkpoint, |
| config_path=args.config, |
| tokenizer_path=args.tokenizer, |
| sampler=args.sampler, |
| steps=args.steps, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| device=args.device, |
| dtype=args.dtype, |
| seed=args.seed, |
| use_ema=not args.no_ema, |
| ) |
| runner = SamplingRunner(cfg) |
| result = runner.run(prompt=args.prompt) |
|
|
| print(f"\n=== CONTINUATION ===\n{result.new_text}") |
| print(f"\n=== STATS ===") |
| print( |
| f"sampler={result.sampler} | steps={result.steps_run} | " |
| f"tokens={result.generated_tokens} | elapsed={result.elapsed_s:.2f}s" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|