zizzimars commited on
Commit
57f4322
·
verified ·
1 Parent(s): 5f56892

Update medchat.py

Browse files
Files changed (1) hide show
  1. medchat.py +26 -23
medchat.py CHANGED
@@ -2,27 +2,30 @@ import torch
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
  class MedChat:
5
- def __init__(self):
6
- self.path = "jianghc/medical_chatbot"
7
- def __call__(self):
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
- tokenizer = GPT2Tokenizer.from_pretrained(path)
10
- model = GPT2LMHeadModel.from_pretrained(path).to(device)
11
- prompt_input = (
12
- "The conversation between human and AI assistant.\n"
13
- "[|Human|] {input}\n"
14
- "[|AI|]"
15
- )
16
- sentence = prompt_input.format_map({'input': "what is parkinson's disease?"})
17
- inputs = tokenizer(sentence, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- with torch.no_grad():
20
- beam_output = model.generate(**inputs,
21
- min_new_tokens=1,
22
- max_length=512,
23
- num_beams=3,
24
- repetition_penalty=1.2,
25
- early_stopping=True,
26
- eos_token_id=198
27
- )
28
- print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
 
2
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
 
4
  class MedChat:
5
+ def __init__(self):
6
+ self.path = "jianghc/medical_chatbot"
7
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ self.tokenizer = GPT2Tokenizer.from_pretrained(self.path)
9
+ self.model = GPT2LMHeadModel.from_pretrained(self.path).to(self.device)
10
+ def prompt(self, input):
11
+ prompt_input = (
12
+ "The conversation between human and AI assistant.\n"
13
+ f"[|Human|] {self.input}\n"
14
+ "[|AI|]"
15
+ )
16
+
17
+ def __call__(self, question):
18
+ sentence = prompt_input.format_map({'input': f"{question}"})
19
+ inputs = tokenizer(sentence, return_tensors="pt").to(self.device)
20
+ with torch.no_grad():
21
+ beam_output = self.model.generate(**inputs,
22
+ min_new_tokens=1,
23
+ max_length=512,
24
+ num_beams=3,
25
+ repetition_penalty=1.2,
26
+ early_stopping=True,
27
+ eos_token_id=198
28
+ )
29
+ return tokenizer.decode(beam_output[0], skip_special_tokens=True)
30
+
31