| |
| """Run inference from a local clone or directly from a Hugging Face repo.""" |
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from huggingface_hub import hf_hub_download |
|
|
| from model import TinyGPT, TinyGPTConfig |
|
|
|
|
| def resolve_ckpt(args) -> Path: |
| if args.ckpt: |
| return Path(args.ckpt) |
| if args.repo_id: |
| return Path( |
| hf_hub_download( |
| repo_id=args.repo_id, |
| filename=args.filename, |
| revision=args.revision, |
| ) |
| ) |
| local = Path(args.filename) |
| if local.exists(): |
| return local |
| raise SystemExit("Provide --ckpt for local checkpoint or --repo-id for Hugging Face download.") |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--repo-id", help="Hugging Face repo id, e.g. username/tinyllm-cpu-char") |
| p.add_argument("--revision", default="main") |
| p.add_argument("--filename", default="checkpoints/tinyllm_overfit_3k.pt") |
| p.add_argument("--ckpt", help="Local checkpoint path") |
| p.add_argument("--prompt", default="The little machine") |
| p.add_argument("--tokens", type=int, default=300) |
| p.add_argument("--temperature", type=float, default=0.7) |
| p.add_argument("--top-k", type=int, default=10) |
| args = p.parse_args() |
|
|
| ckpt = torch.load(resolve_ckpt(args), map_location="cpu") |
| cfg = TinyGPTConfig(**ckpt["config"]) |
| model = TinyGPT(cfg) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
|
|
| stoi = ckpt["stoi"] |
| itos = {int(k): v for k, v in ckpt["itos"].items()} |
| prompt = "".join(ch for ch in args.prompt if ch in stoi) or "\n" |
| idx = torch.tensor([[stoi[ch] for ch in prompt]], dtype=torch.long) |
| out = model.generate(idx, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k) |
| print("".join(itos[int(i)] for i in out[0])) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|