my-gpt-from-scratch / generate.py
edgemindroboticslabs's picture
Upload generate.py with huggingface_hub
9e87788 verified
"""Generate text from a trained checkpoint."""
import argparse
import torch
from model import GPT, GPTConfig
from tokenizer import load_tokenizer
def get_device():
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def load_model(checkpoint_path, device):
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
config = GPTConfig(**ckpt["config"])
model = GPT(config).to(device)
model.load_state_dict(ckpt["model_state"])
model.eval()
return model
def alpaca_prompt(instruction, input_text=""):
"""Format a prompt in Alpaca instruction style (for models trained on Alpaca)."""
if input_text.strip():
return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
return f"### Instruction:\n{instruction}\n\n### Response:\n"
def generate_text(model, tokenizer, prompt, max_new_tokens=200, temperature=1.0, top_k=40, device="cpu"):
encoded = tokenizer.encode(prompt)
if not encoded:
encoded = [0]
idx = torch.tensor([encoded], dtype=torch.long, device=device)
with torch.no_grad():
out = model.generate(idx, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
return tokenizer.decode(out[0].tolist())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/best_model.pt")
parser.add_argument("--tokenizer", default="tokenizer.json")
parser.add_argument("--prompt", default="To be or not to be")
parser.add_argument("--instruction", default=None,
help="Use Alpaca-style prompt. Overrides --prompt.")
parser.add_argument("--input", default="", help="Optional input for Alpaca prompt")
parser.add_argument("--max_new_tokens", type=int, default=300)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_k", type=int, default=40)
args = parser.parse_args()
device = get_device()
print(f"Device: {device}")
tokenizer = load_tokenizer(args.tokenizer)
model = load_model(args.checkpoint, device)
print(f"Model loaded ({model.num_params():,} params)\n")
if args.instruction:
prompt = alpaca_prompt(args.instruction, args.input)
print(f"Prompt:\n{prompt}")
else:
prompt = args.prompt
result = generate_text(
model, tokenizer, prompt,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
device=device,
)
print(result)