"""Generate text from a trained checkpoint.""" import argparse import torch from model import GPT, GPTConfig from tokenizer import load_tokenizer def get_device(): if torch.backends.mps.is_available(): return torch.device("mps") if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") def load_model(checkpoint_path, device): ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) config = GPTConfig(**ckpt["config"]) model = GPT(config).to(device) model.load_state_dict(ckpt["model_state"]) model.eval() return model def alpaca_prompt(instruction, input_text=""): """Format a prompt in Alpaca instruction style (for models trained on Alpaca).""" if input_text.strip(): return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" return f"### Instruction:\n{instruction}\n\n### Response:\n" def generate_text(model, tokenizer, prompt, max_new_tokens=200, temperature=1.0, top_k=40, device="cpu"): encoded = tokenizer.encode(prompt) if not encoded: encoded = [0] idx = torch.tensor([encoded], dtype=torch.long, device=device) with torch.no_grad(): out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k) return tokenizer.decode(out[0].tolist()) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="checkpoints/best_model.pt") parser.add_argument("--tokenizer", default="tokenizer.json") parser.add_argument("--prompt", default="To be or not to be") parser.add_argument("--instruction", default=None, help="Use Alpaca-style prompt. Overrides --prompt.") parser.add_argument("--input", default="", help="Optional input for Alpaca prompt") parser.add_argument("--max_new_tokens", type=int, default=300) parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top_k", type=int, default=40) args = parser.parse_args() device = get_device() print(f"Device: {device}") tokenizer = load_tokenizer(args.tokenizer) model = load_model(args.checkpoint, device) print(f"Model loaded ({model.num_params():,} params)\n") if args.instruction: prompt = alpaca_prompt(args.instruction, args.input) print(f"Prompt:\n{prompt}") else: prompt = args.prompt result = generate_text( model, tokenizer, prompt, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, device=device, ) print(result)