ShakespeareGPT / src /utils.py
sidharthg's picture
Upload 7 files
861a64c verified
import torch
import tiktoken
def load_model(model_path):
"""Load the trained model from the specified path."""
from src.inference import GPT
from src.utils import GPTConfig
config = GPTConfig()
model = GPT(config)
model.load_state_dict(torch.load(model_path))
model.eval()
return model
def tokenize_input(text):
"""Tokenize the input text using the GPT-2 tokenizer."""
enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
return torch.tensor(tokens).unsqueeze(0) # Add batch dimension
def decode_output(tokens):
"""Decode the generated tokens back to text."""
enc = tiktoken.get_encoding('gpt2')
return enc.decode(tokens.tolist())
def generate_text(model, input_text, max_length=30):
"""Generate text using the trained model based on the input text."""
input_tokens = tokenize_input(input_text)
generated_tokens = input_tokens
while generated_tokens.size(1) < max_length:
with torch.no_grad():
logits = model(generated_tokens)[0]
logits = logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
ix = torch.multinomial(topk_probs, 1)
xcol = torch.gather(topk_indices, -1, ix)
generated_tokens = torch.cat((generated_tokens, xcol), dim=1)
return decode_output(generated_tokens[0]) # Return the decoded output for the first sequence