Spaces:
Sleeping
Sleeping
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import torch | |
| class GPT2Inference: | |
| def __init__(self, model_path): | |
| self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| self.model = GPT2LMHeadModel.from_pretrained(model_path) | |
| self.model.eval() | |
| def generate_text(self, prompt, max_length=30, num_return_sequences=1): | |
| input_ids = self.tokenizer.encode(prompt, return_tensors='pt') | |
| with torch.no_grad(): | |
| outputs = self.model.generate(input_ids, max_length=max_length, num_return_sequences=num_return_sequences) | |
| return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] | |
| def load_model(): | |
| model_path = 'models/best_model.pt' | |
| inference_model = GPT2Inference(model_path) | |
| return inference_model | |
| inference_model = load_model() |