Update app.py
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import os
|
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
-
from transformers import
|
| 6 |
|
| 7 |
# =============================
|
| 8 |
# Configuration
|
|
@@ -13,26 +13,24 @@ TEMPERATURE = 0.5
|
|
| 13 |
TOP_K = 50
|
| 14 |
REPETITION_PENALTY = 1.1
|
| 15 |
|
| 16 |
-
# Detect device
|
| 17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
-
print(f"Loading model from {MODEL_PATH} on {device}...")
|
| 19 |
|
| 20 |
# =============================
|
| 21 |
-
# Load
|
| 22 |
# =============================
|
| 23 |
-
tokenizer =
|
| 24 |
-
model =
|
| 25 |
MODEL_PATH,
|
| 26 |
device_map="auto",
|
| 27 |
-
torch_dtype=torch.float16,
|
| 28 |
low_cpu_mem_usage=True
|
| 29 |
)
|
| 30 |
|
| 31 |
-
generator = model.generate
|
| 32 |
print("✅ ChatDoctor model loaded successfully!\n")
|
| 33 |
|
| 34 |
# =============================
|
| 35 |
-
#
|
| 36 |
# =============================
|
| 37 |
class StopOnTokens(StoppingCriteria):
|
| 38 |
def __init__(self, stop_ids):
|
|
@@ -49,132 +47,89 @@ class StopOnTokens(StoppingCriteria):
|
|
| 49 |
return True
|
| 50 |
return False
|
| 51 |
|
|
|
|
| 52 |
# =============================
|
| 53 |
# Medical Keywords and Validation
|
| 54 |
# =============================
|
| 55 |
MEDICAL_KEYWORDS = [
|
| 56 |
-
|
| 57 |
-
"
|
| 58 |
-
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
|
| 62 |
-
"heart", "stomach", "head", "back", "chest", "throat", "lung", "kidney",
|
| 63 |
-
"liver", "brain", "skin", "eye", "ear", "nose", "tooth", "teeth", "joint",
|
| 64 |
-
"muscle", "bone", "neck", "shoulder", "knee", "ankle", "foot", "hand",
|
| 65 |
-
# Medical terms
|
| 66 |
-
"doctor", "hospital", "clinic", "emergency", "ambulance", "medication",
|
| 67 |
-
"medicine", "prescription", "diagnosis", "treatment", "therapy", "cure",
|
| 68 |
-
"sick", "ill", "disease", "condition", "disorder", "syndrome",
|
| 69 |
-
# Injuries
|
| 70 |
-
"injury", "wound", "cut", "bruise", "fracture", "sprain", "burn", "bleed",
|
| 71 |
-
# Vitals and tests
|
| 72 |
-
"blood", "pressure", "temperature", "pulse", "breathing", "test", "scan",
|
| 73 |
-
# Mental health
|
| 74 |
-
"stress", "anxiety", "depression", "mental", "sleep", "insomnia", "tired",
|
| 75 |
-
"fatigue", "exhausted", "mood", "panic", "worry",
|
| 76 |
-
# Lifestyle/wellness
|
| 77 |
-
"diet", "nutrition", "exercise", "weight", "vitamin", "supplement", "healthy",
|
| 78 |
-
"wellness", "fitness", "eating", "appetite", "lifestyle", "food", "fruit",
|
| 79 |
-
"vegetable", "meal", "breakfast", "lunch", "dinner", "snack", "drink",
|
| 80 |
-
"water", "hydration", "protein", "carb", "fat", "calorie", "sugar",
|
| 81 |
-
"cholesterol", "gym", "workout", "run", "walk", "yoga", "sport",
|
| 82 |
-
# Serious conditions
|
| 83 |
-
"cancer", "tumor", "surgery", "stroke", "attack", "seizure", "diabetic",
|
| 84 |
-
# Questions about health
|
| 85 |
-
"health", "medical", "feel", "feeling", "comfortable", "uncomfortable",
|
| 86 |
-
"recommendation", "recommend", "advice", "suggest", "should i", "better",
|
| 87 |
-
"improve", "prevent", "avoid", "good for", "bad for"
|
| 88 |
]
|
| 89 |
|
| 90 |
CASUAL_ONLY_PATTERNS = [
|
| 91 |
-
"hey", "hi", "hello", "sup", "
|
| 92 |
-
"
|
| 93 |
-
"how are you", "how r u", "wassup", "hiya", "greetings"
|
| 94 |
]
|
| 95 |
|
|
|
|
| 96 |
def is_medical_query(message):
|
| 97 |
-
"""Check if the message contains medical-related content"""
|
| 98 |
message_lower = message.lower()
|
| 99 |
-
|
| 100 |
-
# Check for medical keywords
|
| 101 |
for keyword in MEDICAL_KEYWORDS:
|
| 102 |
if keyword in message_lower:
|
| 103 |
return True
|
| 104 |
-
|
| 105 |
-
# Check for question words combined with longer messages (might be medical)
|
| 106 |
question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does"]
|
| 107 |
has_question = any(q in message_lower.split()[:3] for q in question_words)
|
| 108 |
-
|
| 109 |
-
# If it has a question word and is longer than 5 words, might be medical
|
| 110 |
if has_question and len(message.split()) > 5:
|
| 111 |
return True
|
| 112 |
-
|
| 113 |
return False
|
| 114 |
|
|
|
|
| 115 |
def is_only_greeting(message):
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
# Remove punctuation for checking
|
| 120 |
-
message_clean = message_lower.replace("!", "").replace("?", "").replace(".", "").strip()
|
| 121 |
-
|
| 122 |
-
# Check if it's a short greeting (3 words or less)
|
| 123 |
-
if len(message_clean.split()) <= 3:
|
| 124 |
for pattern in CASUAL_ONLY_PATTERNS:
|
| 125 |
-
if
|
| 126 |
return True
|
| 127 |
-
|
| 128 |
return False
|
| 129 |
|
|
|
|
| 130 |
# =============================
|
| 131 |
-
# Get Response
|
| 132 |
# =============================
|
| 133 |
def get_response(user_input, history_context):
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
# STRICT FILTERING: Only allow medical queries to reach the model
|
| 137 |
-
if not is_medical_query(user_input):
|
| 138 |
-
return "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and medical topics. I can help you with:\n\n• Symptoms and health concerns\n• Medical conditions and treatments\n• General health advice\n• Wellness and prevention\n\nPlease describe any health-related symptoms or medical questions you have, and I'll do my best to assist you."
|
| 139 |
-
|
| 140 |
-
human_invitation = "Patient: "
|
| 141 |
-
doctor_invitation = "ChatDoctor: "
|
| 142 |
-
|
| 143 |
-
# Enhanced system instruction
|
| 144 |
-
system_instruction = """You are ChatDoctor, a professional medical AI assistant. You ONLY discuss health, medical symptoms, treatments, and wellness topics.
|
| 145 |
-
|
| 146 |
-
If a patient greets you or asks non-medical questions, you must respond professionally: "I'm ChatDoctor, here to help with your health concerns. What medical symptoms or health questions can I assist you with today?"
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
-
# Build
|
| 153 |
history_text = [system_instruction]
|
| 154 |
for human, assistant in history_context:
|
| 155 |
if human:
|
| 156 |
-
history_text.append(
|
| 157 |
if assistant:
|
| 158 |
-
history_text.append(
|
| 159 |
-
|
| 160 |
-
# Add current user input with medical context reinforcement
|
| 161 |
-
if not is_medical_query(user_input):
|
| 162 |
-
user_input = f"{user_input} [Medical consultation context]"
|
| 163 |
-
|
| 164 |
-
history_text.append(human_invitation + user_input)
|
| 165 |
|
| 166 |
-
|
| 167 |
-
prompt = "\n".join(history_text) + "\n" + doctor_invitation
|
| 168 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 169 |
|
| 170 |
-
# Define stop words and their token IDs
|
| 171 |
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 172 |
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 173 |
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 174 |
|
| 175 |
-
# Generate model response
|
| 176 |
with torch.no_grad():
|
| 177 |
-
output_ids =
|
| 178 |
input_ids,
|
| 179 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 180 |
do_sample=True,
|
|
@@ -186,49 +141,30 @@ Now continue the medical consultation:
|
|
| 186 |
eos_token_id=tokenizer.eos_token_id
|
| 187 |
)
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
# Remove any "Patient:" that might have slipped through
|
| 194 |
-
for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
|
| 195 |
if stop_word in response:
|
| 196 |
response = response.split(stop_word)[0].strip()
|
| 197 |
break
|
| 198 |
|
| 199 |
response = response.strip()
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# Model went off-topic, force redirect
|
| 207 |
-
response = "I apologize for any confusion. I'm ChatDoctor, and I'm specifically designed to help with medical and health-related questions. Could you please tell me about any health symptoms or medical concerns you're experiencing?"
|
| 208 |
-
|
| 209 |
-
# Free memory
|
| 210 |
del input_ids, output_ids
|
| 211 |
gc.collect()
|
| 212 |
-
torch.cuda.
|
|
|
|
| 213 |
|
| 214 |
return response
|
| 215 |
|
| 216 |
-
# =============================
|
| 217 |
-
# Gradio Chat Function
|
| 218 |
-
# =============================
|
| 219 |
-
def chat_function(message, history):
|
| 220 |
-
"""Gradio chat interface function"""
|
| 221 |
-
if not message.strip():
|
| 222 |
-
return ""
|
| 223 |
-
|
| 224 |
-
try:
|
| 225 |
-
response = get_response(message, history)
|
| 226 |
-
return response
|
| 227 |
-
except Exception as e:
|
| 228 |
-
return f"Error: {str(e)}"
|
| 229 |
|
| 230 |
# =============================
|
| 231 |
-
#
|
| 232 |
# =============================
|
| 233 |
custom_css = """
|
| 234 |
#header {
|
|
@@ -239,18 +175,8 @@ custom_css = """
|
|
| 239 |
border-radius: 10px;
|
| 240 |
margin-bottom: 20px;
|
| 241 |
}
|
| 242 |
-
|
| 243 |
-
#header
|
| 244 |
-
margin: 0;
|
| 245 |
-
font-size: 2.5em;
|
| 246 |
-
}
|
| 247 |
-
|
| 248 |
-
#header p {
|
| 249 |
-
margin: 10px 0 0 0;
|
| 250 |
-
font-size: 1.1em;
|
| 251 |
-
opacity: 0.9;
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
.disclaimer {
|
| 255 |
background-color: #fff3cd;
|
| 256 |
border: 1px solid #ffc107;
|
|
@@ -259,153 +185,77 @@ custom_css = """
|
|
| 259 |
margin: 20px 0;
|
| 260 |
color: #856404;
|
| 261 |
}
|
| 262 |
-
|
| 263 |
-
.disclaimer h3 {
|
| 264 |
-
margin-top: 0;
|
| 265 |
-
color: #856404;
|
| 266 |
-
}
|
| 267 |
-
|
| 268 |
-
footer {
|
| 269 |
-
text-align: center;
|
| 270 |
-
margin-top: 30px;
|
| 271 |
-
color: #666;
|
| 272 |
-
font-size: 0.9em;
|
| 273 |
-
}
|
| 274 |
"""
|
| 275 |
|
| 276 |
-
# =============================
|
| 277 |
-
# Gradio Interface
|
| 278 |
-
# =============================
|
| 279 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 280 |
-
# Header
|
| 281 |
gr.HTML("""
|
| 282 |
<div id="header">
|
| 283 |
<h1>🩺 ChatDoctor AI Assistant</h1>
|
| 284 |
-
<p>Your AI-powered medical
|
| 285 |
</div>
|
| 286 |
""")
|
| 287 |
-
|
| 288 |
-
# Disclaimer
|
| 289 |
gr.HTML("""
|
| 290 |
<div class="disclaimer">
|
| 291 |
<h3>⚠️ Medical Disclaimer</h3>
|
| 292 |
-
<p
|
| 293 |
-
It is NOT a substitute for professional medical advice, diagnosis, or treatment
|
| 294 |
-
Always seek the advice of your physician or other qualified health provider with any questions
|
| 295 |
-
you may have regarding a medical condition. Never disregard professional medical advice or
|
| 296 |
-
delay in seeking it because of something you have read here.</p>
|
| 297 |
</div>
|
| 298 |
""")
|
| 299 |
-
|
| 300 |
-
# Chatbot Interface
|
| 301 |
chatbot = gr.Chatbot(
|
| 302 |
-
height=
|
| 303 |
-
placeholder="<div style='text-align:
|
| 304 |
show_label=False,
|
| 305 |
avatar_images=(None, "🤖"),
|
| 306 |
)
|
| 307 |
-
|
| 308 |
with gr.Row():
|
| 309 |
-
msg = gr.Textbox(
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
scale=9,
|
| 313 |
-
container=False
|
| 314 |
-
)
|
| 315 |
-
submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
|
| 316 |
-
|
| 317 |
with gr.Row():
|
| 318 |
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
|
| 319 |
retry_btn = gr.Button("🔄 Retry", scale=1)
|
| 320 |
-
|
| 321 |
-
# Examples
|
| 322 |
-
gr.Examples(
|
| 323 |
-
examples=[
|
| 324 |
-
"I have a persistent headache for 3 days. What should I do?",
|
| 325 |
-
"What are the symptoms of diabetes?",
|
| 326 |
-
"How can I improve my sleep quality?",
|
| 327 |
-
"I have a fever and sore throat. Should I be concerned?",
|
| 328 |
-
"What are some natural ways to reduce stress?",
|
| 329 |
-
],
|
| 330 |
-
inputs=msg,
|
| 331 |
-
label="💡 Example Medical Questions"
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# Settings (collapsed by default)
|
| 335 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
step=0.1,
|
| 341 |
-
label="Temperature (Creativity)",
|
| 342 |
-
info="Higher values make responses more creative but less focused"
|
| 343 |
-
)
|
| 344 |
-
max_tokens_slider = gr.Slider(
|
| 345 |
-
minimum=50,
|
| 346 |
-
maximum=500,
|
| 347 |
-
value=MAX_NEW_TOKENS,
|
| 348 |
-
step=50,
|
| 349 |
-
label="Max Response Length",
|
| 350 |
-
info="Maximum number of tokens in response"
|
| 351 |
-
)
|
| 352 |
-
top_k_slider = gr.Slider(
|
| 353 |
-
minimum=1,
|
| 354 |
-
maximum=100,
|
| 355 |
-
value=TOP_K,
|
| 356 |
-
step=1,
|
| 357 |
-
label="Top K",
|
| 358 |
-
info="Limits vocabulary selection"
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
# Footer
|
| 362 |
-
gr.HTML("""
|
| 363 |
-
<footer>
|
| 364 |
-
<p>Powered by ChatDoctor Model | Built with Gradio</p>
|
| 365 |
-
<p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
|
| 366 |
-
</footer>
|
| 367 |
-
""")
|
| 368 |
-
|
| 369 |
-
# Event handlers
|
| 370 |
def user_message(user_msg, history):
|
| 371 |
return "", history + [[user_msg, None]]
|
| 372 |
-
|
| 373 |
-
def bot_response(history, temp, max_tok,
|
| 374 |
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 375 |
-
TEMPERATURE = temp
|
| 376 |
-
MAX_NEW_TOKENS = int(max_tok)
|
| 377 |
-
TOP_K = int(top_k_val)
|
| 378 |
-
|
| 379 |
user_msg = history[-1][0]
|
| 380 |
-
bot_msg =
|
| 381 |
history[-1][1] = bot_msg
|
| 382 |
return history
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 386 |
-
bot_response, [chatbot,
|
| 387 |
)
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
|
| 391 |
)
|
| 392 |
-
|
| 393 |
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
retry_btn.click(retry_last, None, chatbot, queue=False)
|
| 399 |
|
| 400 |
# =============================
|
| 401 |
-
# Launch
|
| 402 |
# =============================
|
| 403 |
if __name__ == "__main__":
|
| 404 |
-
print("\n
|
| 405 |
demo.queue()
|
| 406 |
-
demo.launch(
|
| 407 |
-
server_name="0.0.0.0", # Accessible from network
|
| 408 |
-
server_port=7860,
|
| 409 |
-
share=False, # Set to True to create public link
|
| 410 |
-
show_error=True
|
| 411 |
-
)
|
|
|
|
| 2 |
import gc
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
| 6 |
|
| 7 |
# =============================
|
| 8 |
# Configuration
|
|
|
|
| 13 |
TOP_K = 50
|
| 14 |
REPETITION_PENALTY = 1.1
|
| 15 |
|
|
|
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 17 |
+
print(f"🚀 Loading model from {MODEL_PATH} on {device}...")
|
| 18 |
|
| 19 |
# =============================
|
| 20 |
+
# Load Model & Tokenizer
|
| 21 |
# =============================
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
MODEL_PATH,
|
| 25 |
device_map="auto",
|
| 26 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 27 |
low_cpu_mem_usage=True
|
| 28 |
)
|
| 29 |
|
|
|
|
| 30 |
print("✅ ChatDoctor model loaded successfully!\n")
|
| 31 |
|
| 32 |
# =============================
|
| 33 |
+
# Stop Criteria
|
| 34 |
# =============================
|
| 35 |
class StopOnTokens(StoppingCriteria):
|
| 36 |
def __init__(self, stop_ids):
|
|
|
|
| 47 |
return True
|
| 48 |
return False
|
| 49 |
|
| 50 |
+
|
| 51 |
# =============================
|
| 52 |
# Medical Keywords and Validation
|
| 53 |
# =============================
|
| 54 |
MEDICAL_KEYWORDS = [
|
| 55 |
+
"pain", "ache", "symptom", "hurt", "sore", "discomfort", "fever", "cough", "flu",
|
| 56 |
+
"infection", "allergy", "diabetes", "pressure", "asthma", "migraine", "vomit",
|
| 57 |
+
"stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
|
| 58 |
+
"doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
|
| 59 |
+
"illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
|
| 60 |
+
"weight", "vitamin", "fatigue", "anxiety", "depression"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
]
|
| 62 |
|
| 63 |
CASUAL_ONLY_PATTERNS = [
|
| 64 |
+
"hey", "hi", "hello", "sup", "yo", "good morning", "good evening",
|
| 65 |
+
"how are you", "wassup", "hiya"
|
|
|
|
| 66 |
]
|
| 67 |
|
| 68 |
+
|
| 69 |
def is_medical_query(message):
|
|
|
|
| 70 |
message_lower = message.lower()
|
|
|
|
|
|
|
| 71 |
for keyword in MEDICAL_KEYWORDS:
|
| 72 |
if keyword in message_lower:
|
| 73 |
return True
|
|
|
|
|
|
|
| 74 |
question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does"]
|
| 75 |
has_question = any(q in message_lower.split()[:3] for q in question_words)
|
|
|
|
|
|
|
| 76 |
if has_question and len(message.split()) > 5:
|
| 77 |
return True
|
|
|
|
| 78 |
return False
|
| 79 |
|
| 80 |
+
|
| 81 |
def is_only_greeting(message):
|
| 82 |
+
message_lower = message.lower().strip().replace("!", "").replace("?", "").replace(".", "")
|
| 83 |
+
if len(message_lower.split()) <= 3:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
for pattern in CASUAL_ONLY_PATTERNS:
|
| 85 |
+
if message_lower == pattern or message_lower.startswith(pattern):
|
| 86 |
return True
|
|
|
|
| 87 |
return False
|
| 88 |
|
| 89 |
+
|
| 90 |
# =============================
|
| 91 |
+
# Get Response
|
| 92 |
# =============================
|
| 93 |
def get_response(user_input, history_context):
|
| 94 |
+
if is_only_greeting(user_input):
|
| 95 |
+
return "👋 Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
+
if not is_medical_query(user_input):
|
| 98 |
+
return (
|
| 99 |
+
"Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
|
| 100 |
+
"I can help you with:\n"
|
| 101 |
+
"• Symptoms and medical conditions\n"
|
| 102 |
+
"• Treatment and prevention advice\n"
|
| 103 |
+
"• Fitness, diet, and mental health tips\n\n"
|
| 104 |
+
"Please describe your health concern in detail to get started."
|
| 105 |
+
)
|
| 106 |
|
| 107 |
+
human_prefix = "Patient:"
|
| 108 |
+
doctor_prefix = "ChatDoctor:"
|
| 109 |
+
system_instruction = (
|
| 110 |
+
"You are ChatDoctor, a professional medical AI assistant. "
|
| 111 |
+
"You provide accurate, concise, and empathetic responses to health-related questions only.\n\n"
|
| 112 |
+
"If the question is non-medical, politely redirect back to medical topics.\n"
|
| 113 |
+
)
|
| 114 |
|
| 115 |
+
# Build history
|
| 116 |
history_text = [system_instruction]
|
| 117 |
for human, assistant in history_context:
|
| 118 |
if human:
|
| 119 |
+
history_text.append(f"{human_prefix} {human}")
|
| 120 |
if assistant:
|
| 121 |
+
history_text.append(f"{doctor_prefix} {assistant}")
|
| 122 |
+
history_text.append(f"{human_prefix} {user_input}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
|
|
|
|
| 125 |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
|
| 126 |
|
|
|
|
| 127 |
stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
|
| 128 |
stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
|
| 129 |
stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
|
| 130 |
|
|
|
|
| 131 |
with torch.no_grad():
|
| 132 |
+
output_ids = model.generate(
|
| 133 |
input_ids,
|
| 134 |
max_new_tokens=MAX_NEW_TOKENS,
|
| 135 |
do_sample=True,
|
|
|
|
| 141 |
eos_token_id=tokenizer.eos_token_id
|
| 142 |
)
|
| 143 |
|
| 144 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
|
| 145 |
+
|
| 146 |
+
for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
|
|
|
|
|
|
|
|
|
|
| 147 |
if stop_word in response:
|
| 148 |
response = response.split(stop_word)[0].strip()
|
| 149 |
break
|
| 150 |
|
| 151 |
response = response.strip()
|
| 152 |
+
if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud"]):
|
| 153 |
+
response = (
|
| 154 |
+
"I apologize for the confusion — I'm ChatDoctor, trained to assist with medical and health-related topics only. "
|
| 155 |
+
"Please tell me about your symptoms or health concerns."
|
| 156 |
+
)
|
| 157 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
del input_ids, output_ids
|
| 159 |
gc.collect()
|
| 160 |
+
if torch.cuda.is_available():
|
| 161 |
+
torch.cuda.empty_cache()
|
| 162 |
|
| 163 |
return response
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# =============================
|
| 167 |
+
# Gradio Interface
|
| 168 |
# =============================
|
| 169 |
custom_css = """
|
| 170 |
#header {
|
|
|
|
| 175 |
border-radius: 10px;
|
| 176 |
margin-bottom: 20px;
|
| 177 |
}
|
| 178 |
+
#header h1 { margin: 0; font-size: 2.3em; }
|
| 179 |
+
#header p { margin: 5px 0 0; font-size: 1em; opacity: 0.9; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
.disclaimer {
|
| 181 |
background-color: #fff3cd;
|
| 182 |
border: 1px solid #ffc107;
|
|
|
|
| 185 |
margin: 20px 0;
|
| 186 |
color: #856404;
|
| 187 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
"""
|
| 189 |
|
|
|
|
|
|
|
|
|
|
| 190 |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
|
|
|
| 191 |
gr.HTML("""
|
| 192 |
<div id="header">
|
| 193 |
<h1>🩺 ChatDoctor AI Assistant</h1>
|
| 194 |
+
<p>Your AI-powered medical consultation partner</p>
|
| 195 |
</div>
|
| 196 |
""")
|
|
|
|
|
|
|
| 197 |
gr.HTML("""
|
| 198 |
<div class="disclaimer">
|
| 199 |
<h3>⚠️ Medical Disclaimer</h3>
|
| 200 |
+
<p>This AI assistant is for informational purposes only.
|
| 201 |
+
It is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
|
|
|
|
|
|
|
|
|
|
| 202 |
</div>
|
| 203 |
""")
|
| 204 |
+
|
|
|
|
| 205 |
chatbot = gr.Chatbot(
|
| 206 |
+
height=480,
|
| 207 |
+
placeholder="<div style='text-align:center;padding:40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>Describe your symptoms or ask a health-related question to begin.</p></div>",
|
| 208 |
show_label=False,
|
| 209 |
avatar_images=(None, "🤖"),
|
| 210 |
)
|
| 211 |
+
|
| 212 |
with gr.Row():
|
| 213 |
+
msg = gr.Textbox(placeholder="Type your medical concern here...", show_label=False, scale=9, container=False)
|
| 214 |
+
send_btn = gr.Button("Send 📤", scale=1, variant="primary")
|
| 215 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
with gr.Row():
|
| 217 |
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
|
| 218 |
retry_btn = gr.Button("🔄 Retry", scale=1)
|
| 219 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
with gr.Accordion("⚙️ Advanced Settings", open=False):
|
| 221 |
+
temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature")
|
| 222 |
+
max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
|
| 223 |
+
top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K")
|
| 224 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
def user_message(user_msg, history):
|
| 226 |
return "", history + [[user_msg, None]]
|
| 227 |
+
|
| 228 |
+
def bot_response(history, temp, max_tok, topk):
|
| 229 |
global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
|
| 230 |
+
TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
|
|
|
|
|
|
|
|
|
|
| 231 |
user_msg = history[-1][0]
|
| 232 |
+
bot_msg = get_response(user_msg, history[:-1])
|
| 233 |
history[-1][1] = bot_msg
|
| 234 |
return history
|
| 235 |
+
|
| 236 |
+
def retry_last(history, temp, max_tok, topk):
|
| 237 |
+
if not history:
|
| 238 |
+
return history
|
| 239 |
+
user_msg = history[-1][0]
|
| 240 |
+
bot_msg = get_response(user_msg, history[:-1])
|
| 241 |
+
history[-1][1] = bot_msg
|
| 242 |
+
return history
|
| 243 |
+
|
| 244 |
msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 245 |
+
bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
|
| 246 |
)
|
| 247 |
+
send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
|
| 248 |
+
bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
|
|
|
|
| 249 |
)
|
|
|
|
| 250 |
clear_btn.click(lambda: None, None, chatbot, queue=False)
|
| 251 |
+
retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot)
|
| 252 |
+
|
| 253 |
+
gr.HTML(f"<footer><center><p>🧠 Powered by LLaMA-based ChatDoctor | Device: {device.upper()}</p></center></footer>")
|
|
|
|
|
|
|
| 254 |
|
| 255 |
# =============================
|
| 256 |
+
# Launch App
|
| 257 |
# =============================
|
| 258 |
if __name__ == "__main__":
|
| 259 |
+
print("\n💡 Launching ChatDoctor Gradio Interface...")
|
| 260 |
demo.queue()
|
| 261 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|