#!/usr/bin/env python3 """Generate completions from a trained JSCoder checkpoint. Two modes: - plain prompt completion:: python3 sample.py --ckpt checkpoints/ckpt.pt --prompt "function add(a, b) {" - fill-in-the-middle (autocomplete at the cursor), giving prefix + suffix:: python3 sample.py --ckpt checkpoints/fim/ckpt.pt \ --prefix "function sum(arr) {\n let total = 0;\n " \ --suffix "\n return total;\n}" Use ``--fim`` (with --prefix/--suffix) to build a StarCoder-style FIM prompt; generation stops at ``<|endoftext|>``. """ from __future__ import annotations import argparse from pathlib import Path import torch from model.gpt import GPT, GPTConfig from tokenizer.tokenizer import JSCoderTokenizer def pick_device(requested: str) -> str: if requested != "auto": return requested if torch.cuda.is_available(): return "cuda" if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): return "mps" return "cpu" def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--ckpt", default="checkpoints/ckpt.pt") p.add_argument("--tokenizer", default=None, help="path to js_bpe.json") p.add_argument("--prompt", default="// ", help="prompt for plain completion") p.add_argument("--fim", action="store_true", help="build a FIM prompt from --prefix/--suffix") p.add_argument("--prefix", default="", help="FIM: code before the cursor") p.add_argument("--suffix", default="", help="FIM: code after the cursor") p.add_argument("--mode", choices=["psm", "spm"], default="psm") p.add_argument("--max-new-tokens", type=int, default=200) p.add_argument("--temperature", type=float, default=0.8) p.add_argument("--top-k", type=int, default=50) p.add_argument("--num-samples", type=int, default=1) p.add_argument("--device", default="auto") p.add_argument("--seed", type=int, default=1337) return p.parse_args() def main() -> None: args = parse_args() torch.manual_seed(args.seed) device = pick_device(args.device) ckpt_path = Path(args.ckpt) if not ckpt_path.exists(): raise SystemExit(f"No checkpoint at {ckpt_path}. Train one with train.py first.") ckpt = torch.load(ckpt_path, map_location="cpu") model = GPT(GPTConfig(**ckpt["config"])) model.load_state_dict(ckpt["model"]) model.to(device) model.eval() tok = JSCoderTokenizer.load(args.tokenizer) if args.tokenizer else JSCoderTokenizer.load() eot_id = ckpt.get("eot_id", tok.eot_id) if args.fim: ids = tok.build_fim_prompt(args.prefix, args.suffix, mode=args.mode) seed_text = f"[FIM {args.mode}] prefix={args.prefix!r} suffix={args.suffix!r}" else: ids = tok.encode(args.prompt) seed_text = args.prompt if not ids: ids = [eot_id] idx = torch.tensor([ids], dtype=torch.long, device=device) print(f"[sample] device={device} prompt: {seed_text}\n") for s in range(args.num_samples): out = model.generate( idx, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, eot_id=eot_id, ) generated_ids = out[0, len(ids):].tolist() completion = tok.decode(generated_ids, skip_special_tokens=True) print(f"===== sample {s + 1}/{args.num_samples} =====") if args.fim: print(args.prefix + completion + args.suffix) else: print(args.prompt + completion) print() if __name__ == "__main__": main()