File size: 2,102 Bytes
267f903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/usr/bin/env python3
"""CLI: run one sampler on a prompt and print the result."""
from __future__ import annotations

import argparse
from pathlib import Path

import torch

from whale4b.core.runner import RunConfig, SamplingRunner
from whale4b.samplers import list_samplers


def parse_args():
    p = argparse.ArgumentParser(description="Whale3B diffusion LM sampler.")
    p.add_argument("--checkpoint", required=True, help="Path to .safetensors or .pt")
    p.add_argument("--config", default=str(Path(__file__).parent / "configs" / "whale3b.yaml"))
    p.add_argument("--tokenizer", default=str(Path(__file__).parent / "whale-tokenizer"))
    p.add_argument("--prompt", default="")
    p.add_argument("--sampler", default="standard", choices=list_samplers())
    p.add_argument("--steps", type=int, default=64)
    p.add_argument("--max-new-tokens", type=int, default=256)
    p.add_argument("--temperature", type=float, default=0.0)
    p.add_argument("--top-k", type=int, default=0)
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"])
    p.add_argument("--seed", type=int, default=1234)
    p.add_argument("--no-ema", action="store_true")
    return p.parse_args()


def main():
    args = parse_args()
    cfg = RunConfig(
        ckpt_path=args.checkpoint,
        config_path=args.config,
        tokenizer_path=args.tokenizer,
        sampler=args.sampler,
        steps=args.steps,
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
        device=args.device,
        dtype=args.dtype,
        seed=args.seed,
        use_ema=not args.no_ema,
    )
    runner = SamplingRunner(cfg)
    result = runner.run(prompt=args.prompt)

    print(f"\n=== CONTINUATION ===\n{result.new_text}")
    print(f"\n=== STATS ===")
    print(
        f"sampler={result.sampler} | steps={result.steps_run} | "
        f"tokens={result.generated_tokens} | elapsed={result.elapsed_s:.2f}s"
    )


if __name__ == "__main__":
    main()