Bman21 commited on
Commit
6661059
·
verified ·
1 Parent(s): 51d222f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+
4
+ # -------------------
5
+ # CONFIG
6
+ # -------------------
7
+ MODEL_NAME = "Varshitha/flan-t5-small-finetuned-medicine" # Medical Flan-T5
8
+ DEVICE = "cpu"
9
+ MAX_TOKENS = 256
10
+
11
+ SYSTEM_MESSAGE = (
12
+ "You are a helpful and accurate medical tutor. "
13
+ "Explain clearly and safely, using standard medical knowledge. "
14
+ "If unsure, say you don't have enough information."
15
+ )
16
+
17
+ # -------------------
18
+ # LOAD MODEL
19
+ # -------------------
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
22
+ generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=-1)
23
+
24
+ # -------------------
25
+ # CHAT FUNCTION
26
+ # -------------------
27
+ def respond(message, history):
28
+ if not message.strip():
29
+ return "Please enter a medical question."
30
+
31
+ prompt = f"{SYSTEM_MESSAGE}\n\nQuestion: {message}\nAnswer:"
32
+ response = generator(prompt, max_new_tokens=MAX_TOKENS)[0]["generated_text"]
33
+ return response
34
+
35
+ # -------------------
36
+ # GRADIO APP
37
+ # -------------------
38
+ with gr.Blocks(title="Simple Medical Tutor") as demo:
39
+ gr.Markdown("# 🩺 Simple Medical Tutor (CPU-Friendly)")
40
+ chatbot = gr.Chatbot(label="Tutor Chat")
41
+ msg = gr.Textbox(label="Ask a medical question...")
42
+ clear = gr.Button("Clear Chat")
43
+
44
+ def user_input(message, history):
45
+ bot_reply = respond(message, history)
46
+ history.append((message, bot_reply))
47
+ return "", history
48
+
49
+ msg.submit(user_input, [msg, chatbot], [msg, chatbot])
50
+ clear.click(lambda: None, None, chatbot, queue=False)
51
+
52
+ demo.launch()