File size: 838 Bytes
861a64c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

class GPT2Inference:
    def __init__(self, model_path):
        self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        self.model = GPT2LMHeadModel.from_pretrained(model_path)
        self.model.eval()

    def generate_text(self, prompt, max_length=30, num_return_sequences=1):
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model.generate(input_ids, max_length=max_length, num_return_sequences=num_return_sequences)
        return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]

def load_model():
    model_path = 'models/best_model.pt'
    inference_model = GPT2Inference(model_path)
    return inference_model

inference_model = load_model()