| |
| """Generate completions from a trained JSCoder checkpoint. |
| |
| Two modes: |
| |
| - plain prompt completion:: |
| |
| python3 sample.py --ckpt checkpoints/ckpt.pt --prompt "function add(a, b) {" |
| |
| - fill-in-the-middle (autocomplete at the cursor), giving prefix + suffix:: |
| |
| python3 sample.py --ckpt checkpoints/fim/ckpt.pt \ |
| --prefix "function sum(arr) {\n let total = 0;\n " \ |
| --suffix "\n return total;\n}" |
| |
| Use ``--fim`` (with --prefix/--suffix) to build a StarCoder-style FIM prompt; |
| generation stops at ``<|endoftext|>``. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
|
|
| from model.gpt import GPT, GPTConfig |
| from tokenizer.tokenizer import JSCoderTokenizer |
|
|
|
|
| def pick_device(requested: str) -> str: |
| if requested != "auto": |
| return requested |
| if torch.cuda.is_available(): |
| return "cuda" |
| if getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) |
| p.add_argument("--ckpt", default="checkpoints/ckpt.pt") |
| p.add_argument("--tokenizer", default=None, help="path to js_bpe.json") |
| p.add_argument("--prompt", default="// ", help="prompt for plain completion") |
| p.add_argument("--fim", action="store_true", help="build a FIM prompt from --prefix/--suffix") |
| p.add_argument("--prefix", default="", help="FIM: code before the cursor") |
| p.add_argument("--suffix", default="", help="FIM: code after the cursor") |
| p.add_argument("--mode", choices=["psm", "spm"], default="psm") |
| p.add_argument("--max-new-tokens", type=int, default=200) |
| p.add_argument("--temperature", type=float, default=0.8) |
| p.add_argument("--top-k", type=int, default=50) |
| p.add_argument("--num-samples", type=int, default=1) |
| p.add_argument("--device", default="auto") |
| p.add_argument("--seed", type=int, default=1337) |
| return p.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| torch.manual_seed(args.seed) |
| device = pick_device(args.device) |
|
|
| ckpt_path = Path(args.ckpt) |
| if not ckpt_path.exists(): |
| raise SystemExit(f"No checkpoint at {ckpt_path}. Train one with train.py first.") |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
|
|
| model = GPT(GPTConfig(**ckpt["config"])) |
| model.load_state_dict(ckpt["model"]) |
| model.to(device) |
| model.eval() |
|
|
| tok = JSCoderTokenizer.load(args.tokenizer) if args.tokenizer else JSCoderTokenizer.load() |
| eot_id = ckpt.get("eot_id", tok.eot_id) |
|
|
| if args.fim: |
| ids = tok.build_fim_prompt(args.prefix, args.suffix, mode=args.mode) |
| seed_text = f"[FIM {args.mode}] prefix={args.prefix!r} suffix={args.suffix!r}" |
| else: |
| ids = tok.encode(args.prompt) |
| seed_text = args.prompt |
|
|
| if not ids: |
| ids = [eot_id] |
|
|
| idx = torch.tensor([ids], dtype=torch.long, device=device) |
| print(f"[sample] device={device} prompt: {seed_text}\n") |
|
|
| for s in range(args.num_samples): |
| out = model.generate( |
| idx, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| eot_id=eot_id, |
| ) |
| generated_ids = out[0, len(ids):].tolist() |
| completion = tok.decode(generated_ids, skip_special_tokens=True) |
| print(f"===== sample {s + 1}/{args.num_samples} =====") |
| if args.fim: |
| print(args.prefix + completion + args.suffix) |
| else: |
| print(args.prompt + completion) |
| print() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|