File size: 1,295 Bytes
c017df9
 
 
 
57f4322
 
 
 
 
0df907f
 
 
57f4322
 
0df907f
57f4322
 
 
0df907f
57f4322
 
0df907f
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

class MedChat:
    def __init__(self):
        self.path = "jianghc/medical_chatbot"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.path)
        self.model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)


    def forward(self, question):
        prompt_input = (
            "The conversation between human and AI assistant.\n"
            "[|Human|]"
            "[|AI|]"
        )
        sentence = prompt_input.format_map({'input': f"{question}"})
        inputs = self.tokenizer(sentence, return_tensors="pt").to(self.device)
        with torch.no_grad():
            beam_output = self.model.generate(**inputs,
                                              min_new_tokens=1,
                                              max_length=512,
                                              num_beams=3,
                                              repetition_penalty=1.2,
                                              early_stopping=True,
                                              eos_token_id=198
                                              )
        return self.tokenizer.decode(beam_output[0], skip_special_tokens=True)