tiny-38m / sample.py
darthcrawl's picture
Add files using upload-large-folder tool
6e14144 verified
"""Generate from a trained checkpoint."""
import argparse
from pathlib import Path
import torch
from tokenizers import Tokenizer
from config import ModelConfig
from model import GPT
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", type=str, default="checkpoints/best.pt")
p.add_argument("--tokenizer", type=str, default="data/tokenizer.json")
p.add_argument("--prompt", type=str, default="Once upon a time")
p.add_argument("--max-new-tokens", type=int, default=256)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top-k", type=int, default=200)
p.add_argument("--num-samples", type=int, default=1)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--device", type=str, default=None)
args = p.parse_args()
device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(args.seed)
ckpt = torch.load(args.ckpt, map_location=device, weights_only=False)
cfg_dict = ckpt["model_cfg"]
valid = {f for f in ModelConfig.__dataclass_fields__}
cfg = ModelConfig(**{k: v for k, v in cfg_dict.items() if k in valid})
model = GPT(cfg).to(device).eval()
model.load_state_dict(ckpt["model"])
tok = Tokenizer.from_file(args.tokenizer)
eot = tok.token_to_id("<|endoftext|>")
ids = tok.encode(args.prompt).ids
if not ids:
ids = [eot]
x = torch.tensor([ids], dtype=torch.long, device=device)
for s in range(args.num_samples):
out = model.generate(
x, max_new_tokens=args.max_new_tokens,
temperature=args.temperature, top_k=args.top_k, eos_id=eot,
)[0].tolist()
text = tok.decode(out)
print(f"\n--- sample {s + 1} ---")
print(text)
if __name__ == "__main__":
main()