Spaces:
Sleeping
Sleeping
| import torch | |
| import tiktoken | |
| def load_model(model_path): | |
| """Load the trained model from the specified path.""" | |
| from src.inference import GPT | |
| from src.utils import GPTConfig | |
| config = GPTConfig() | |
| model = GPT(config) | |
| model.load_state_dict(torch.load(model_path)) | |
| model.eval() | |
| return model | |
| def tokenize_input(text): | |
| """Tokenize the input text using the GPT-2 tokenizer.""" | |
| enc = tiktoken.get_encoding('gpt2') | |
| tokens = enc.encode(text) | |
| return torch.tensor(tokens).unsqueeze(0) # Add batch dimension | |
| def decode_output(tokens): | |
| """Decode the generated tokens back to text.""" | |
| enc = tiktoken.get_encoding('gpt2') | |
| return enc.decode(tokens.tolist()) | |
| def generate_text(model, input_text, max_length=30): | |
| """Generate text using the trained model based on the input text.""" | |
| input_tokens = tokenize_input(input_text) | |
| generated_tokens = input_tokens | |
| while generated_tokens.size(1) < max_length: | |
| with torch.no_grad(): | |
| logits = model(generated_tokens)[0] | |
| logits = logits[:, -1, :] | |
| probs = torch.softmax(logits, dim=-1) | |
| topk_probs, topk_indices = torch.topk(probs, 50, dim=-1) | |
| ix = torch.multinomial(topk_probs, 1) | |
| xcol = torch.gather(topk_indices, -1, ix) | |
| generated_tokens = torch.cat((generated_tokens, xcol), dim=1) | |
| return decode_output(generated_tokens[0]) # Return the decoded output for the first sequence |