tinyllm-cpu-char / hf_infer.py
Vishwas1's picture
Add CPU-trained tiny character LLM
170658b verified
#!/usr/bin/env python3
"""Run inference from a local clone or directly from a Hugging Face repo."""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
from model import TinyGPT, TinyGPTConfig
def resolve_ckpt(args) -> Path:
if args.ckpt:
return Path(args.ckpt)
if args.repo_id:
return Path(
hf_hub_download(
repo_id=args.repo_id,
filename=args.filename,
revision=args.revision,
)
)
local = Path(args.filename)
if local.exists():
return local
raise SystemExit("Provide --ckpt for local checkpoint or --repo-id for Hugging Face download.")
def main():
p = argparse.ArgumentParser()
p.add_argument("--repo-id", help="Hugging Face repo id, e.g. username/tinyllm-cpu-char")
p.add_argument("--revision", default="main")
p.add_argument("--filename", default="checkpoints/tinyllm_overfit_3k.pt")
p.add_argument("--ckpt", help="Local checkpoint path")
p.add_argument("--prompt", default="The little machine")
p.add_argument("--tokens", type=int, default=300)
p.add_argument("--temperature", type=float, default=0.7)
p.add_argument("--top-k", type=int, default=10)
args = p.parse_args()
ckpt = torch.load(resolve_ckpt(args), map_location="cpu")
cfg = TinyGPTConfig(**ckpt["config"])
model = TinyGPT(cfg)
model.load_state_dict(ckpt["model_state"])
model.eval()
stoi = ckpt["stoi"]
itos = {int(k): v for k, v in ckpt["itos"].items()}
prompt = "".join(ch for ch in args.prompt if ch in stoi) or "\n"
idx = torch.tensor([[stoi[ch] for ch in prompt]], dtype=torch.long)
out = model.generate(idx, max_new_tokens=args.tokens, temperature=args.temperature, top_k=args.top_k)
print("".join(itos[int(i)] for i in out[0]))
if __name__ == "__main__":
main()