jscoder-300m / sample.py
Shadid's picture
Upload sample.py with huggingface_hub
8385328 verified
Raw
History Blame Contribute Delete
3.72 kB
#!/usr/bin/env python3
"""Generate completions from a trained JSCoder checkpoint.
Two modes:
- plain prompt completion::
python3 sample.py --ckpt checkpoints/ckpt.pt --prompt "function add(a, b) {"
- fill-in-the-middle (autocomplete at the cursor), giving prefix + suffix::
python3 sample.py --ckpt checkpoints/fim/ckpt.pt \
--prefix "function sum(arr) {\n let total = 0;\n " \
--suffix "\n return total;\n}"
Use ``--fim`` (with --prefix/--suffix) to build a StarCoder-style FIM prompt;
generation stops at ``<|endoftext|>``.
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from model.gpt import GPT, GPTConfig
from tokenizer.tokenizer import JSCoderTokenizer
def pick_device(requested: str) -> str:
if requested != "auto":
return requested
if torch.cuda.is_available():
return "cuda"
if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return "mps"
return "cpu"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--ckpt", default="checkpoints/ckpt.pt")
p.add_argument("--tokenizer", default=None, help="path to js_bpe.json")
p.add_argument("--prompt", default="// ", help="prompt for plain completion")
p.add_argument("--fim", action="store_true", help="build a FIM prompt from --prefix/--suffix")
p.add_argument("--prefix", default="", help="FIM: code before the cursor")
p.add_argument("--suffix", default="", help="FIM: code after the cursor")
p.add_argument("--mode", choices=["psm", "spm"], default="psm")
p.add_argument("--max-new-tokens", type=int, default=200)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top-k", type=int, default=50)
p.add_argument("--num-samples", type=int, default=1)
p.add_argument("--device", default="auto")
p.add_argument("--seed", type=int, default=1337)
return p.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
device = pick_device(args.device)
ckpt_path = Path(args.ckpt)
if not ckpt_path.exists():
raise SystemExit(f"No checkpoint at {ckpt_path}. Train one with train.py first.")
ckpt = torch.load(ckpt_path, map_location="cpu")
model = GPT(GPTConfig(**ckpt["config"]))
model.load_state_dict(ckpt["model"])
model.to(device)
model.eval()
tok = JSCoderTokenizer.load(args.tokenizer) if args.tokenizer else JSCoderTokenizer.load()
eot_id = ckpt.get("eot_id", tok.eot_id)
if args.fim:
ids = tok.build_fim_prompt(args.prefix, args.suffix, mode=args.mode)
seed_text = f"[FIM {args.mode}] prefix={args.prefix!r} suffix={args.suffix!r}"
else:
ids = tok.encode(args.prompt)
seed_text = args.prompt
if not ids:
ids = [eot_id]
idx = torch.tensor([ids], dtype=torch.long, device=device)
print(f"[sample] device={device} prompt: {seed_text}\n")
for s in range(args.num_samples):
out = model.generate(
idx,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
eot_id=eot_id,
)
generated_ids = out[0, len(ids):].tolist()
completion = tok.decode(generated_ids, skip_special_tokens=True)
print(f"===== sample {s + 1}/{args.num_samples} =====")
if args.fim:
print(args.prefix + completion + args.suffix)
else:
print(args.prompt + completion)
print()
if __name__ == "__main__":
main()