#!/usr/bin/env python3 """Run inference from a local clone or directly from a Hugging Face repo.""" from __future__ import annotations import argparse from pathlib import Path import torch from huggingface_hub import hf_hub_download from model import TinyGPT, TinyGPTConfig def resolve_ckpt(args) -> Path: if args.ckpt: return Path(args.ckpt) if args.repo_id: return Path( hf_hub_download( repo_id=args.repo_id, filename=args.filename, revision=args.revision, ) ) local = Path(args.filename) if local.exists(): return local raise SystemExit("Provide --ckpt for local checkpoint or --repo-id for Hugging Face download.") def main(): p = argparse.ArgumentParser() p.add_argument("--repo-id", help="Hugging Face repo id, e.g. username/tinyllm-cpu-char") p.add_argument("--revision", default="main") p.add_argument("--filename", default="checkpoints/tinyllm_overfit_3k.pt") p.add_argument("--ckpt", help="Local checkpoint path") p.add_argument("--prompt", default="The little machine") p.add_argument("--tokens", type=int, default=300) p.add_argument("--temperature", type=float, default=0.7) p.add_argument("--top-k", type=int, default=10) args = p.parse_args() ckpt = torch.load(resolve_ckpt(args), map_location="cpu") cfg = TinyGPTConfig(**ckpt["config"]) model = TinyGPT(cfg) model.load_state_dict(ckpt["model_state"]) model.eval() stoi = ckpt["stoi"] itos = {int(k): v for k, v in ckpt["itos"].items()} prompt = "".join(ch for ch in args.prompt if ch in stoi) or "\n" idx = torch.tensor([[stoi[ch] for ch in prompt]], dtype=torch.long) out = model.generate(idx, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k) print("".join(itos[int(i)] for i in out[0])) if __name__ == "__main__": main()