| """Generate text from a trained checkpoint.""" |
| import argparse |
| import torch |
|
|
| from model import GPT, GPTConfig |
| from tokenizer import load_tokenizer |
|
|
|
|
| def get_device(): |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| return torch.device("cpu") |
|
|
|
|
| def load_model(checkpoint_path, device): |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| config = GPTConfig(**ckpt["config"]) |
| model = GPT(config).to(device) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| return model |
|
|
|
|
| def alpaca_prompt(instruction, input_text=""): |
| """Format a prompt in Alpaca instruction style (for models trained on Alpaca).""" |
| if input_text.strip(): |
| return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" |
| return f"### Instruction:\n{instruction}\n\n### Response:\n" |
|
|
|
|
| def generate_text(model, tokenizer, prompt, max_new_tokens=200, temperature=1.0, top_k=40, device="cpu"): |
| encoded = tokenizer.encode(prompt) |
| if not encoded: |
| encoded = [0] |
| idx = torch.tensor([encoded], dtype=torch.long, device=device) |
| with torch.no_grad(): |
| out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k) |
| return tokenizer.decode(out[0].tolist()) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default="checkpoints/best_model.pt") |
| parser.add_argument("--tokenizer", default="tokenizer.json") |
| parser.add_argument("--prompt", default="To be or not to be") |
| parser.add_argument("--instruction", default=None, |
| help="Use Alpaca-style prompt. Overrides --prompt.") |
| parser.add_argument("--input", default="", help="Optional input for Alpaca prompt") |
| parser.add_argument("--max_new_tokens", type=int, default=300) |
| parser.add_argument("--temperature", type=float, default=0.8) |
| parser.add_argument("--top_k", type=int, default=40) |
| args = parser.parse_args() |
|
|
| device = get_device() |
| print(f"Device: {device}") |
|
|
| tokenizer = load_tokenizer(args.tokenizer) |
| model = load_model(args.checkpoint, device) |
| print(f"Model loaded ({model.num_params():,} params)\n") |
|
|
| if args.instruction: |
| prompt = alpaca_prompt(args.instruction, args.input) |
| print(f"Prompt:\n{prompt}") |
| else: |
| prompt = args.prompt |
|
|
| result = generate_text( |
| model, tokenizer, prompt, |
| max_new_tokens=args.max_new_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k, |
| device=device, |
| ) |
| print(result) |
|
|