import torch import tiktoken def load_model(model_path): """Load the trained model from the specified path.""" from src.inference import GPT from src.utils import GPTConfig config = GPTConfig() model = GPT(config) model.load_state_dict(torch.load(model_path)) model.eval() return model def tokenize_input(text): """Tokenize the input text using the GPT-2 tokenizer.""" enc = tiktoken.get_encoding('gpt2') tokens = enc.encode(text) return torch.tensor(tokens).unsqueeze(0) # Add batch dimension def decode_output(tokens): """Decode the generated tokens back to text.""" enc = tiktoken.get_encoding('gpt2') return enc.decode(tokens.tolist()) def generate_text(model, input_text, max_length=30): """Generate text using the trained model based on the input text.""" input_tokens = tokenize_input(input_text) generated_tokens = input_tokens while generated_tokens.size(1) < max_length: with torch.no_grad(): logits = model(generated_tokens)[0] logits = logits[:, -1, :] probs = torch.softmax(logits, dim=-1) topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) ix = torch.multinomial(topk_probs, 1) xcol = torch.gather(topk_indices, -1, ix) generated_tokens = torch.cat((generated_tokens, xcol), dim=1) return decode_output(generated_tokens[0]) # Return the decoded output for the first sequence