W1-4B-dLLM-Base / sample.py
Cynthiawhaletech's picture
Initial release: W1-4B dLLM Base
267f903
#!/usr/bin/env python3
"""CLI: run one sampler on a prompt and print the result."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from whale4b.core.runner import RunConfig, SamplingRunner
from whale4b.samplers import list_samplers
def parse_args():
p = argparse.ArgumentParser(description="Whale3B diffusion LM sampler.")
p.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pt")
p.add_argument("--config", default=str(Path(__file__).parent / "configs" / "whale3b.yaml"))
p.add_argument("--tokenizer", default=str(Path(__file__).parent / "whale-tokenizer"))
p.add_argument("--prompt", default="")
p.add_argument("--sampler", default="standard", choices=list_samplers())
p.add_argument("--steps", type=int, default=64)
p.add_argument("--max-new-tokens", type=int, default=256)
p.add_argument("--temperature", type=float, default=0.0)
p.add_argument("--top-k", type=int, default=0)
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
p.add_argument("--seed", type=int, default=1234)
p.add_argument("--no-ema", action="store_true")
return p.parse_args()
def main():
args = parse_args()
cfg = RunConfig(
ckpt_path=args.checkpoint,
config_path=args.config,
tokenizer_path=args.tokenizer,
sampler=args.sampler,
steps=args.steps,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
device=args.device,
dtype=args.dtype,
seed=args.seed,
use_ema=not args.no_ema,
)
runner = SamplingRunner(cfg)
result = runner.run(prompt=args.prompt)
print(f"\n=== CONTINUATION ===\n{result.new_text}")
print(f"\n=== STATS ===")
print(
f"sampler={result.sampler} | steps={result.steps_run} | "
f"tokens={result.generated_tokens} | elapsed={result.elapsed_s:.2f}s"
)
if __name__ == "__main__":
main()