| |
| """Generate text from a TinyLM checkpoint. |
| |
| Usage: |
| # Download checkpoint from HF automatically: |
| python scripts/generate.py --prompt "The theory of relativity states that" |
| |
| # Interactive mode: |
| python scripts/generate.py |
| |
| # Local checkpoint: |
| python scripts/generate.py --checkpoint checkpoints/step_19999.pt |
| |
| # Greedy decoding: |
| python scripts/generate.py --prompt "Once upon a time" --temperature 0 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
|
|
| try: |
| from tinylm.model import ModelConfig, TinyLM |
| except ImportError: |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) |
| from tinylm.model import ModelConfig, TinyLM |
|
|
| HF_CHECKPOINT_REPO = "Shiv-22/tinylm-checkpoints" |
| HF_CHECKPOINT_FILE = "step_19999.pt" |
| TOKENIZER = "meta-llama/Llama-2-7b-hf" |
|
|
|
|
| def load_model(checkpoint: str | None, device: str) -> TinyLM: |
| if checkpoint is None: |
| from huggingface_hub import hf_hub_download |
| print(f"Downloading checkpoint from {HF_CHECKPOINT_REPO}...") |
| checkpoint = hf_hub_download( |
| repo_id=HF_CHECKPOINT_REPO, filename=HF_CHECKPOINT_FILE |
| ) |
| print(f"Loading {checkpoint} ...") |
| ckpt = torch.load(checkpoint, map_location="cpu", weights_only=True) |
| c = ckpt["config"] |
| model = TinyLM(ModelConfig( |
| n_layers=c["n_layers"], d_model=c["d_model"], n_heads=c["n_heads"], |
| d_latent=c["d_latent"], d_rope=c["d_rope"], ffn_hidden=c["ffn_hidden"], |
| ctx=c["ctx"], vocab_size=c["vocab_size"], tie_weights=c["tie_weights"], |
| attention=c["attention"], |
| )) |
| state = ckpt["model"] |
| if any(k.startswith("_orig_mod.") for k in state): |
| state = {k.removeprefix("_orig_mod."): v for k, v in state.items()} |
| model.load_state_dict(state) |
| return model.to(device).eval() |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| model: TinyLM, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int = 200, |
| temperature: float = 0.8, |
| top_p: float = 0.9, |
| device: str = "cpu", |
| ) -> str: |
| bos = [tokenizer.bos_token_id] if tokenizer.bos_token_id is not None else [] |
| ids = bos + tokenizer.encode(prompt, add_special_tokens=False) |
| tokens = torch.tensor([ids], dtype=torch.long, device=device) |
|
|
| for _ in range(max_new_tokens): |
| inp = tokens[:, -model.cfg.ctx:] |
| logits = model(inp)[:, -1, :].float() |
|
|
| if temperature == 0.0: |
| next_id = logits.argmax(dim=-1, keepdim=True) |
| else: |
| logits /= temperature |
| probs = F.softmax(logits, dim=-1) |
| sorted_probs, sorted_ids = torch.sort(probs, descending=True, dim=-1) |
| cumsum = sorted_probs.cumsum(dim=-1) |
| sorted_probs[cumsum - sorted_probs > top_p] = 0.0 |
| sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True) |
| sample_idx = torch.multinomial(sorted_probs, num_samples=1) |
| next_id = sorted_ids.gather(1, sample_idx) |
|
|
| tokens = torch.cat([tokens, next_id], dim=1) |
| if next_id.item() == tokenizer.eos_token_id: |
| break |
|
|
| generated = tokens[0, len(ids):].tolist() |
| return tokenizer.decode(generated, skip_special_tokens=True) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter) |
| parser.add_argument("--checkpoint", default=None, |
| help="Path to local .pt checkpoint (default: download from HF)") |
| parser.add_argument("--prompt", default=None, |
| help="Prompt text (omit for interactive mode)") |
| parser.add_argument("--max-new-tokens", type=int, default=200) |
| parser.add_argument("--temperature", type=float, default=0.8, |
| help="Sampling temperature (0 = greedy)") |
| parser.add_argument("--top-p", type=float, default=0.9, |
| help="Nucleus sampling probability threshold") |
| parser.add_argument("--device", |
| default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
|
|
| print(f"Loading tokenizer ({TOKENIZER}) ...") |
| tokenizer = AutoTokenizer.from_pretrained(TOKENIZER) |
| model = load_model(args.checkpoint, args.device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"Ready — {n_params / 1e6:.0f}M params on {args.device}\n") |
|
|
| def run(prompt: str) -> None: |
| out = generate( |
| model, tokenizer, prompt, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| device=args.device, |
| ) |
| print(f"[prompt] {prompt}") |
| print(f"[output] {out}\n") |
|
|
| if args.prompt: |
| run(args.prompt) |
| else: |
| print("Interactive mode — enter a prompt and press Enter. Ctrl+C to quit.\n") |
| while True: |
| try: |
| prompt = input(">>> ").strip() |
| if prompt: |
| run(prompt) |
| except (KeyboardInterrupt, EOFError): |
| print("\nBye.") |
| break |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|