File size: 3,715 Bytes
8385328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
"""Generate completions from a trained JSCoder checkpoint.

Two modes:

- plain prompt completion::

    python3 sample.py --ckpt checkpoints/ckpt.pt --prompt "function add(a, b) {"

- fill-in-the-middle (autocomplete at the cursor), giving prefix + suffix::

    python3 sample.py --ckpt checkpoints/fim/ckpt.pt \
        --prefix "function sum(arr) {\n  let total = 0;\n  " \
        --suffix "\n  return total;\n}"

Use ``--fim`` (with --prefix/--suffix) to build a StarCoder-style FIM prompt;
generation stops at ``<|endoftext|>``.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import torch

from model.gpt import GPT, GPTConfig
from tokenizer.tokenizer import JSCoderTokenizer


def pick_device(requested: str) -> str:
    if requested != "auto":
        return requested
    if torch.cuda.is_available():
        return "cuda"
    if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--ckpt", default="checkpoints/ckpt.pt")
    p.add_argument("--tokenizer", default=None, help="path to js_bpe.json")
    p.add_argument("--prompt", default="// ", help="prompt for plain completion")
    p.add_argument("--fim", action="store_true", help="build a FIM prompt from --prefix/--suffix")
    p.add_argument("--prefix", default="", help="FIM: code before the cursor")
    p.add_argument("--suffix", default="", help="FIM: code after the cursor")
    p.add_argument("--mode", choices=["psm", "spm"], default="psm")
    p.add_argument("--max-new-tokens", type=int, default=200)
    p.add_argument("--temperature", type=float, default=0.8)
    p.add_argument("--top-k", type=int, default=50)
    p.add_argument("--num-samples", type=int, default=1)
    p.add_argument("--device", default="auto")
    p.add_argument("--seed", type=int, default=1337)
    return p.parse_args()


def main() -> None:
    args = parse_args()
    torch.manual_seed(args.seed)
    device = pick_device(args.device)

    ckpt_path = Path(args.ckpt)
    if not ckpt_path.exists():
        raise SystemExit(f"No checkpoint at {ckpt_path}. Train one with train.py first.")
    ckpt = torch.load(ckpt_path, map_location="cpu")

    model = GPT(GPTConfig(**ckpt["config"]))
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.eval()

    tok = JSCoderTokenizer.load(args.tokenizer) if args.tokenizer else JSCoderTokenizer.load()
    eot_id = ckpt.get("eot_id", tok.eot_id)

    if args.fim:
        ids = tok.build_fim_prompt(args.prefix, args.suffix, mode=args.mode)
        seed_text = f"[FIM {args.mode}] prefix={args.prefix!r} suffix={args.suffix!r}"
    else:
        ids = tok.encode(args.prompt)
        seed_text = args.prompt

    if not ids:
        ids = [eot_id]

    idx = torch.tensor([ids], dtype=torch.long, device=device)
    print(f"[sample] device={device} prompt: {seed_text}\n")

    for s in range(args.num_samples):
        out = model.generate(
            idx,
            max_new_tokens=args.max_new_tokens,
            temperature=args.temperature,
            top_k=args.top_k,
            eot_id=eot_id,
        )
        generated_ids = out[0, len(ids):].tolist()
        completion = tok.decode(generated_ids, skip_special_tokens=True)
        print(f"===== sample {s + 1}/{args.num_samples} =====")
        if args.fim:
            print(args.prefix + completion + args.suffix)
        else:
            print(args.prompt + completion)
        print()


if __name__ == "__main__":
    main()