Spaces:
Sleeping
Sleeping
File size: 838 Bytes
861a64c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
class GPT2Inference:
def __init__(self, model_path):
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
self.model = GPT2LMHeadModel.from_pretrained(model_path)
self.model.eval()
def generate_text(self, prompt, max_length=30, num_return_sequences=1):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
with torch.no_grad():
outputs = self.model.generate(input_ids, max_length=max_length, num_return_sequences=num_return_sequences)
return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
def load_model():
model_path = 'models/best_model.pt'
inference_model = GPT2Inference(model_path)
return inference_model
inference_model = load_model() |