zizzimars commited on
Commit
c017df9
·
verified ·
1 Parent(s): ae7a494

Create medchat

Browse files
Files changed (1) hide show
  1. medchat +28 -0
medchat ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+
4
+ class MedChat:
5
+ def __init__(self):
6
+ self.path = "jianghc/medical_chatbot"
7
+ def __call__(self):
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ tokenizer = GPT2Tokenizer.from_pretrained(path)
10
+ model = GPT2LMHeadModel.from_pretrained(path).to(device)
11
+ prompt_input = (
12
+ "The conversation between human and AI assistant.\n"
13
+ "[|Human|] {input}\n"
14
+ "[|AI|]"
15
+ )
16
+ sentence = prompt_input.format_map({'input': "what is parkinson's disease?"})
17
+ inputs = tokenizer(sentence, return_tensors="pt").to(device)
18
+
19
+ with torch.no_grad():
20
+ beam_output = model.generate(**inputs,
21
+ min_new_tokens=1,
22
+ max_length=512,
23
+ num_beams=3,
24
+ repetition_penalty=1.2,
25
+ early_stopping=True,
26
+ eos_token_id=198
27
+ )
28
+ print(tokenizer.decode(beam_output[0], skip_special_tokens=True))