tinylm / generate.py
Shiv-22's picture
add generate.py
f23be9c verified
#!/usr/bin/env python3
"""Generate text from a TinyLM checkpoint.
Usage:
# Download checkpoint from HF automatically:
python scripts/generate.py --prompt "The theory of relativity states that"
# Interactive mode:
python scripts/generate.py
# Local checkpoint:
python scripts/generate.py --checkpoint checkpoints/step_19999.pt
# Greedy decoding:
python scripts/generate.py --prompt "Once upon a time" --temperature 0
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
try:
from tinylm.model import ModelConfig, TinyLM
except ImportError:
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from tinylm.model import ModelConfig, TinyLM
HF_CHECKPOINT_REPO = "Shiv-22/tinylm-checkpoints"
HF_CHECKPOINT_FILE = "step_19999.pt"
TOKENIZER = "meta-llama/Llama-2-7b-hf"
def load_model(checkpoint: str | None, device: str) -> TinyLM:
if checkpoint is None:
from huggingface_hub import hf_hub_download
print(f"Downloading checkpoint from {HF_CHECKPOINT_REPO}...")
checkpoint = hf_hub_download(
repo_id=HF_CHECKPOINT_REPO, filename=HF_CHECKPOINT_FILE
)
print(f"Loading {checkpoint} ...")
ckpt = torch.load(checkpoint, map_location="cpu", weights_only=True)
c = ckpt["config"]
model = TinyLM(ModelConfig(
n_layers=c["n_layers"], d_model=c["d_model"], n_heads=c["n_heads"],
d_latent=c["d_latent"], d_rope=c["d_rope"], ffn_hidden=c["ffn_hidden"],
ctx=c["ctx"], vocab_size=c["vocab_size"], tie_weights=c["tie_weights"],
attention=c["attention"],
))
state = ckpt["model"]
if any(k.startswith("_orig_mod.") for k in state):
state = {k.removeprefix("_orig_mod."): v for k, v in state.items()}
model.load_state_dict(state)
return model.to(device).eval()
@torch.no_grad()
def generate(
model: TinyLM,
tokenizer,
prompt: str,
max_new_tokens: int = 200,
temperature: float = 0.8,
top_p: float = 0.9,
device: str = "cpu",
) -> str:
bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else []
ids = bos + tokenizer.encode(prompt, add_special_tokens=False)
tokens = torch.tensor([ids], dtype=torch.long, device=device)
for _ in range(max_new_tokens):
inp = tokens[:, -model.cfg.ctx:]
logits = model(inp)[:, -1, :].float() # (1, vocab)
if temperature == 0.0:
next_id = logits.argmax(dim=-1, keepdim=True)
else:
logits /= temperature
probs = F.softmax(logits, dim=-1)
sorted_probs, sorted_ids = torch.sort(probs, descending=True, dim=-1)
cumsum = sorted_probs.cumsum(dim=-1)
sorted_probs[cumsum - sorted_probs > top_p] = 0.0
sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
sample_idx = torch.multinomial(sorted_probs, num_samples=1)
next_id = sorted_ids.gather(1, sample_idx)
tokens = torch.cat([tokens, next_id], dim=1)
if next_id.item() == tokenizer.eos_token_id:
break
generated = tokens[0, len(ids):].tolist()
return tokenizer.decode(generated, skip_special_tokens=True)
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument("--checkpoint", default=None,
help="Path to local .pt checkpoint (default: download from HF)")
parser.add_argument("--prompt", default=None,
help="Prompt text (omit for interactive mode)")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.8,
help="Sampling temperature (0 = greedy)")
parser.add_argument("--top-p", type=float, default=0.9,
help="Nucleus sampling probability threshold")
parser.add_argument("--device",
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
print(f"Loading tokenizer ({TOKENIZER}) ...")
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
model = load_model(args.checkpoint, args.device)
n_params = sum(p.numel() for p in model.parameters())
print(f"Ready — {n_params / 1e6:.0f}M params on {args.device}\n")
def run(prompt: str) -> None:
out = generate(
model, tokenizer, prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
device=args.device,
)
print(f"[prompt] {prompt}")
print(f"[output] {out}\n")
if args.prompt:
run(args.prompt)
else:
print("Interactive mode — enter a prompt and press Enter. Ctrl+C to quit.\n")
while True:
try:
prompt = input(">>> ").strip()
if prompt:
run(prompt)
except (KeyboardInterrupt, EOFError):
print("\nBye.")
break
if __name__ == "__main__":
main()