Muhammadidrees commited on
Commit
b73171b
·
verified ·
1 Parent(s): 9b97ee3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -248
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import gc
3
  import torch
4
  import gradio as gr
5
- from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
 
7
  # =============================
8
  # Configuration
@@ -13,26 +13,24 @@ TEMPERATURE = 0.5
13
  TOP_K = 50
14
  REPETITION_PENALTY = 1.1
15
 
16
- # Detect device
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model from {MODEL_PATH} on {device}...")
19
 
20
  # =============================
21
- # Load Tokenizer and Model
22
  # =============================
23
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
- model = LlamaForCausalLM.from_pretrained(
25
  MODEL_PATH,
26
  device_map="auto",
27
- torch_dtype=torch.float16,
28
  low_cpu_mem_usage=True
29
  )
30
 
31
- generator = model.generate
32
  print("✅ ChatDoctor model loaded successfully!\n")
33
 
34
  # =============================
35
- # Stopping Criteria
36
  # =============================
37
  class StopOnTokens(StoppingCriteria):
38
  def __init__(self, stop_ids):
@@ -49,132 +47,89 @@ class StopOnTokens(StoppingCriteria):
49
  return True
50
  return False
51
 
 
52
  # =============================
53
  # Medical Keywords and Validation
54
  # =============================
55
  MEDICAL_KEYWORDS = [
56
- # Symptoms
57
- "pain", "ache", "symptom", "hurt", "sore", "discomfort", "suffering",
58
- # Common conditions
59
- "fever", "cough", "cold", "flu", "infection", "allergy", "diabetes", "pressure",
60
- "asthma", "migraine", "nausea", "vomit", "diarrhea", "constipation",
61
- # Body parts
62
- "heart", "stomach", "head", "back", "chest", "throat", "lung", "kidney",
63
- "liver", "brain", "skin", "eye", "ear", "nose", "tooth", "teeth", "joint",
64
- "muscle", "bone", "neck", "shoulder", "knee", "ankle", "foot", "hand",
65
- # Medical terms
66
- "doctor", "hospital", "clinic", "emergency", "ambulance", "medication",
67
- "medicine", "prescription", "diagnosis", "treatment", "therapy", "cure",
68
- "sick", "ill", "disease", "condition", "disorder", "syndrome",
69
- # Injuries
70
- "injury", "wound", "cut", "bruise", "fracture", "sprain", "burn", "bleed",
71
- # Vitals and tests
72
- "blood", "pressure", "temperature", "pulse", "breathing", "test", "scan",
73
- # Mental health
74
- "stress", "anxiety", "depression", "mental", "sleep", "insomnia", "tired",
75
- "fatigue", "exhausted", "mood", "panic", "worry",
76
- # Lifestyle/wellness
77
- "diet", "nutrition", "exercise", "weight", "vitamin", "supplement", "healthy",
78
- "wellness", "fitness", "eating", "appetite", "lifestyle", "food", "fruit",
79
- "vegetable", "meal", "breakfast", "lunch", "dinner", "snack", "drink",
80
- "water", "hydration", "protein", "carb", "fat", "calorie", "sugar",
81
- "cholesterol", "gym", "workout", "run", "walk", "yoga", "sport",
82
- # Serious conditions
83
- "cancer", "tumor", "surgery", "stroke", "attack", "seizure", "diabetic",
84
- # Questions about health
85
- "health", "medical", "feel", "feeling", "comfortable", "uncomfortable",
86
- "recommendation", "recommend", "advice", "suggest", "should i", "better",
87
- "improve", "prevent", "avoid", "good for", "bad for"
88
  ]
89
 
90
  CASUAL_ONLY_PATTERNS = [
91
- "hey", "hi", "hello", "sup", "what's up", "whats up", "yo",
92
- "good morning", "good evening", "good afternoon", "good night",
93
- "how are you", "how r u", "wassup", "hiya", "greetings"
94
  ]
95
 
 
96
  def is_medical_query(message):
97
- """Check if the message contains medical-related content"""
98
  message_lower = message.lower()
99
-
100
- # Check for medical keywords
101
  for keyword in MEDICAL_KEYWORDS:
102
  if keyword in message_lower:
103
  return True
104
-
105
- # Check for question words combined with longer messages (might be medical)
106
  question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does"]
107
  has_question = any(q in message_lower.split()[:3] for q in question_words)
108
-
109
- # If it has a question word and is longer than 5 words, might be medical
110
  if has_question and len(message.split()) > 5:
111
  return True
112
-
113
  return False
114
 
 
115
  def is_only_greeting(message):
116
- """Check if message is ONLY a casual greeting with no medical content"""
117
- message_lower = message.lower().strip()
118
-
119
- # Remove punctuation for checking
120
- message_clean = message_lower.replace("!", "").replace("?", "").replace(".", "").strip()
121
-
122
- # Check if it's a short greeting (3 words or less)
123
- if len(message_clean.split()) <= 3:
124
  for pattern in CASUAL_ONLY_PATTERNS:
125
- if message_clean == pattern or message_clean.startswith(pattern):
126
  return True
127
-
128
  return False
129
 
 
130
  # =============================
131
- # Get Response Function
132
  # =============================
133
  def get_response(user_input, history_context):
134
- """Generate response from ChatDoctor model"""
135
-
136
- # STRICT FILTERING: Only allow medical queries to reach the model
137
- if not is_medical_query(user_input):
138
- return "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and medical topics. I can help you with:\n\n• Symptoms and health concerns\n• Medical conditions and treatments\n• General health advice\n• Wellness and prevention\n\nPlease describe any health-related symptoms or medical questions you have, and I'll do my best to assist you."
139
-
140
- human_invitation = "Patient: "
141
- doctor_invitation = "ChatDoctor: "
142
-
143
- # Enhanced system instruction
144
- system_instruction = """You are ChatDoctor, a professional medical AI assistant. You ONLY discuss health, medical symptoms, treatments, and wellness topics.
145
-
146
- If a patient greets you or asks non-medical questions, you must respond professionally: "I'm ChatDoctor, here to help with your health concerns. What medical symptoms or health questions can I assist you with today?"
147
 
148
- Now continue the medical consultation:
 
 
 
 
 
 
 
 
149
 
150
- """
 
 
 
 
 
 
151
 
152
- # Build conversation from history
153
  history_text = [system_instruction]
154
  for human, assistant in history_context:
155
  if human:
156
- history_text.append(human_invitation + human)
157
  if assistant:
158
- history_text.append(doctor_invitation + assistant)
159
-
160
- # Add current user input with medical context reinforcement
161
- if not is_medical_query(user_input):
162
- user_input = f"{user_input} [Medical consultation context]"
163
-
164
- history_text.append(human_invitation + user_input)
165
 
166
- # Build conversation prompt
167
- prompt = "\n".join(history_text) + "\n" + doctor_invitation
168
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
169
 
170
- # Define stop words and their token IDs
171
  stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
172
  stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
173
  stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
174
 
175
- # Generate model response
176
  with torch.no_grad():
177
- output_ids = generator(
178
  input_ids,
179
  max_new_tokens=MAX_NEW_TOKENS,
180
  do_sample=True,
@@ -186,49 +141,30 @@ Now continue the medical consultation:
186
  eos_token_id=tokenizer.eos_token_id
187
  )
188
 
189
- # Decode and clean response
190
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
191
- response = full_output[len(prompt):].strip()
192
-
193
- # Remove any "Patient:" that might have slipped through
194
- for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
195
  if stop_word in response:
196
  response = response.split(stop_word)[0].strip()
197
  break
198
 
199
  response = response.strip()
200
-
201
- # Post-processing: Check if response seems off-topic
202
- response_lower = response.lower()
203
- chatbot_service_keywords = ["chatbot", "service", "error code", "cloud-based", "platform"]
204
-
205
- if any(keyword in response_lower for keyword in chatbot_service_keywords):
206
- # Model went off-topic, force redirect
207
- response = "I apologize for any confusion. I'm ChatDoctor, and I'm specifically designed to help with medical and health-related questions. Could you please tell me about any health symptoms or medical concerns you're experiencing?"
208
-
209
- # Free memory
210
  del input_ids, output_ids
211
  gc.collect()
212
- torch.cuda.empty_cache()
 
213
 
214
  return response
215
 
216
- # =============================
217
- # Gradio Chat Function
218
- # =============================
219
- def chat_function(message, history):
220
- """Gradio chat interface function"""
221
- if not message.strip():
222
- return ""
223
-
224
- try:
225
- response = get_response(message, history)
226
- return response
227
- except Exception as e:
228
- return f"Error: {str(e)}"
229
 
230
  # =============================
231
- # Custom CSS
232
  # =============================
233
  custom_css = """
234
  #header {
@@ -239,18 +175,8 @@ custom_css = """
239
  border-radius: 10px;
240
  margin-bottom: 20px;
241
  }
242
-
243
- #header h1 {
244
- margin: 0;
245
- font-size: 2.5em;
246
- }
247
-
248
- #header p {
249
- margin: 10px 0 0 0;
250
- font-size: 1.1em;
251
- opacity: 0.9;
252
- }
253
-
254
  .disclaimer {
255
  background-color: #fff3cd;
256
  border: 1px solid #ffc107;
@@ -259,153 +185,77 @@ custom_css = """
259
  margin: 20px 0;
260
  color: #856404;
261
  }
262
-
263
- .disclaimer h3 {
264
- margin-top: 0;
265
- color: #856404;
266
- }
267
-
268
- footer {
269
- text-align: center;
270
- margin-top: 30px;
271
- color: #666;
272
- font-size: 0.9em;
273
- }
274
  """
275
 
276
- # =============================
277
- # Gradio Interface
278
- # =============================
279
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
280
- # Header
281
  gr.HTML("""
282
  <div id="header">
283
  <h1>🩺 ChatDoctor AI Assistant</h1>
284
- <p>Your AI-powered medical conversation partner</p>
285
  </div>
286
  """)
287
-
288
- # Disclaimer
289
  gr.HTML("""
290
  <div class="disclaimer">
291
  <h3>⚠️ Medical Disclaimer</h3>
292
- <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
293
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.
294
- Always seek the advice of your physician or other qualified health provider with any questions
295
- you may have regarding a medical condition. Never disregard professional medical advice or
296
- delay in seeking it because of something you have read here.</p>
297
  </div>
298
  """)
299
-
300
- # Chatbot Interface
301
  chatbot = gr.Chatbot(
302
- height=500,
303
- placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Please describe your symptoms or health question.</p></div>",
304
  show_label=False,
305
  avatar_images=(None, "🤖"),
306
  )
307
-
308
  with gr.Row():
309
- msg = gr.Textbox(
310
- placeholder="Describe your health symptoms or medical concern here...",
311
- show_label=False,
312
- scale=9,
313
- container=False
314
- )
315
- submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
316
-
317
  with gr.Row():
318
  clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
319
  retry_btn = gr.Button("🔄 Retry", scale=1)
320
-
321
- # Examples
322
- gr.Examples(
323
- examples=[
324
- "I have a persistent headache for 3 days. What should I do?",
325
- "What are the symptoms of diabetes?",
326
- "How can I improve my sleep quality?",
327
- "I have a fever and sore throat. Should I be concerned?",
328
- "What are some natural ways to reduce stress?",
329
- ],
330
- inputs=msg,
331
- label="💡 Example Medical Questions"
332
- )
333
-
334
- # Settings (collapsed by default)
335
  with gr.Accordion("⚙️ Advanced Settings", open=False):
336
- temperature_slider = gr.Slider(
337
- minimum=0.1,
338
- maximum=1.0,
339
- value=TEMPERATURE,
340
- step=0.1,
341
- label="Temperature (Creativity)",
342
- info="Higher values make responses more creative but less focused"
343
- )
344
- max_tokens_slider = gr.Slider(
345
- minimum=50,
346
- maximum=500,
347
- value=MAX_NEW_TOKENS,
348
- step=50,
349
- label="Max Response Length",
350
- info="Maximum number of tokens in response"
351
- )
352
- top_k_slider = gr.Slider(
353
- minimum=1,
354
- maximum=100,
355
- value=TOP_K,
356
- step=1,
357
- label="Top K",
358
- info="Limits vocabulary selection"
359
- )
360
-
361
- # Footer
362
- gr.HTML("""
363
- <footer>
364
- <p>Powered by ChatDoctor Model | Built with Gradio</p>
365
- <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
366
- </footer>
367
- """)
368
-
369
- # Event handlers
370
  def user_message(user_msg, history):
371
  return "", history + [[user_msg, None]]
372
-
373
- def bot_response(history, temp, max_tok, top_k_val):
374
  global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
375
- TEMPERATURE = temp
376
- MAX_NEW_TOKENS = int(max_tok)
377
- TOP_K = int(top_k_val)
378
-
379
  user_msg = history[-1][0]
380
- bot_msg = chat_function(user_msg, history[:-1])
381
  history[-1][1] = bot_msg
382
  return history
383
-
384
- # Connect events
 
 
 
 
 
 
 
385
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
386
- bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
387
  )
388
-
389
- submit_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
390
- bot_response, [chatbot, temperature_slider, max_tokens_slider, top_k_slider], chatbot
391
  )
392
-
393
  clear_btn.click(lambda: None, None, chatbot, queue=False)
394
-
395
- def retry_last():
396
- return None
397
-
398
- retry_btn.click(retry_last, None, chatbot, queue=False)
399
 
400
  # =============================
401
- # Launch Interface
402
  # =============================
403
  if __name__ == "__main__":
404
- print("\n🚀 Launching ChatDoctor Gradio Interface...")
405
  demo.queue()
406
- demo.launch(
407
- server_name="0.0.0.0", # Accessible from network
408
- server_port=7860,
409
- share=False, # Set to True to create public link
410
- show_error=True
411
- )
 
2
  import gc
3
  import torch
4
  import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
6
 
7
  # =============================
8
  # Configuration
 
13
  TOP_K = 50
14
  REPETITION_PENALTY = 1.1
15
 
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ print(f"🚀 Loading model from {MODEL_PATH} on {device}...")
18
 
19
  # =============================
20
+ # Load Model & Tokenizer
21
  # =============================
22
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
  MODEL_PATH,
25
  device_map="auto",
26
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
  low_cpu_mem_usage=True
28
  )
29
 
 
30
  print("✅ ChatDoctor model loaded successfully!\n")
31
 
32
  # =============================
33
+ # Stop Criteria
34
  # =============================
35
  class StopOnTokens(StoppingCriteria):
36
  def __init__(self, stop_ids):
 
47
  return True
48
  return False
49
 
50
+
51
  # =============================
52
  # Medical Keywords and Validation
53
  # =============================
54
  MEDICAL_KEYWORDS = [
55
+ "pain", "ache", "symptom", "hurt", "sore", "discomfort", "fever", "cough", "flu",
56
+ "infection", "allergy", "diabetes", "pressure", "asthma", "migraine", "vomit",
57
+ "stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
58
+ "doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
59
+ "illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
60
+ "weight", "vitamin", "fatigue", "anxiety", "depression"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ]
62
 
63
  CASUAL_ONLY_PATTERNS = [
64
+ "hey", "hi", "hello", "sup", "yo", "good morning", "good evening",
65
+ "how are you", "wassup", "hiya"
 
66
  ]
67
 
68
+
69
  def is_medical_query(message):
 
70
  message_lower = message.lower()
 
 
71
  for keyword in MEDICAL_KEYWORDS:
72
  if keyword in message_lower:
73
  return True
 
 
74
  question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does"]
75
  has_question = any(q in message_lower.split()[:3] for q in question_words)
 
 
76
  if has_question and len(message.split()) > 5:
77
  return True
 
78
  return False
79
 
80
+
81
  def is_only_greeting(message):
82
+ message_lower = message.lower().strip().replace("!", "").replace("?", "").replace(".", "")
83
+ if len(message_lower.split()) <= 3:
 
 
 
 
 
 
84
  for pattern in CASUAL_ONLY_PATTERNS:
85
+ if message_lower == pattern or message_lower.startswith(pattern):
86
  return True
 
87
  return False
88
 
89
+
90
  # =============================
91
+ # Get Response
92
  # =============================
93
  def get_response(user_input, history_context):
94
+ if is_only_greeting(user_input):
95
+ return "👋 Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ if not is_medical_query(user_input):
98
+ return (
99
+ "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
100
+ "I can help you with:\n"
101
+ "• Symptoms and medical conditions\n"
102
+ "• Treatment and prevention advice\n"
103
+ "• Fitness, diet, and mental health tips\n\n"
104
+ "Please describe your health concern in detail to get started."
105
+ )
106
 
107
+ human_prefix = "Patient:"
108
+ doctor_prefix = "ChatDoctor:"
109
+ system_instruction = (
110
+ "You are ChatDoctor, a professional medical AI assistant. "
111
+ "You provide accurate, concise, and empathetic responses to health-related questions only.\n\n"
112
+ "If the question is non-medical, politely redirect back to medical topics.\n"
113
+ )
114
 
115
+ # Build history
116
  history_text = [system_instruction]
117
  for human, assistant in history_context:
118
  if human:
119
+ history_text.append(f"{human_prefix} {human}")
120
  if assistant:
121
+ history_text.append(f"{doctor_prefix} {assistant}")
122
+ history_text.append(f"{human_prefix} {user_input}")
 
 
 
 
 
123
 
124
+ prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
 
125
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
126
 
 
127
  stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
128
  stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
129
  stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
130
 
 
131
  with torch.no_grad():
132
+ output_ids = model.generate(
133
  input_ids,
134
  max_new_tokens=MAX_NEW_TOKENS,
135
  do_sample=True,
 
141
  eos_token_id=tokenizer.eos_token_id
142
  )
143
 
144
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
145
+
146
+ for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
 
 
 
147
  if stop_word in response:
148
  response = response.split(stop_word)[0].strip()
149
  break
150
 
151
  response = response.strip()
152
+ if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud"]):
153
+ response = (
154
+ "I apologize for the confusion — I'm ChatDoctor, trained to assist with medical and health-related topics only. "
155
+ "Please tell me about your symptoms or health concerns."
156
+ )
157
+
 
 
 
 
158
  del input_ids, output_ids
159
  gc.collect()
160
+ if torch.cuda.is_available():
161
+ torch.cuda.empty_cache()
162
 
163
  return response
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  # =============================
167
+ # Gradio Interface
168
  # =============================
169
  custom_css = """
170
  #header {
 
175
  border-radius: 10px;
176
  margin-bottom: 20px;
177
  }
178
+ #header h1 { margin: 0; font-size: 2.3em; }
179
+ #header p { margin: 5px 0 0; font-size: 1em; opacity: 0.9; }
 
 
 
 
 
 
 
 
 
 
180
  .disclaimer {
181
  background-color: #fff3cd;
182
  border: 1px solid #ffc107;
 
185
  margin: 20px 0;
186
  color: #856404;
187
  }
 
 
 
 
 
 
 
 
 
 
 
 
188
  """
189
 
 
 
 
190
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
 
191
  gr.HTML("""
192
  <div id="header">
193
  <h1>🩺 ChatDoctor AI Assistant</h1>
194
+ <p>Your AI-powered medical consultation partner</p>
195
  </div>
196
  """)
 
 
197
  gr.HTML("""
198
  <div class="disclaimer">
199
  <h3>⚠️ Medical Disclaimer</h3>
200
+ <p>This AI assistant is for informational purposes only.
201
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.</p>
 
 
 
202
  </div>
203
  """)
204
+
 
205
  chatbot = gr.Chatbot(
206
+ height=480,
207
+ placeholder="<div style='text-align:center;padding:40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>Describe your symptoms or ask a health-related question to begin.</p></div>",
208
  show_label=False,
209
  avatar_images=(None, "🤖"),
210
  )
211
+
212
  with gr.Row():
213
+ msg = gr.Textbox(placeholder="Type your medical concern here...", show_label=False, scale=9, container=False)
214
+ send_btn = gr.Button("Send 📤", scale=1, variant="primary")
215
+
 
 
 
 
 
216
  with gr.Row():
217
  clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
218
  retry_btn = gr.Button("🔄 Retry", scale=1)
219
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  with gr.Accordion("⚙️ Advanced Settings", open=False):
221
+ temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature")
222
+ max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
223
+ top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K")
224
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def user_message(user_msg, history):
226
  return "", history + [[user_msg, None]]
227
+
228
+ def bot_response(history, temp, max_tok, topk):
229
  global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
230
+ TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
 
 
 
231
  user_msg = history[-1][0]
232
+ bot_msg = get_response(user_msg, history[:-1])
233
  history[-1][1] = bot_msg
234
  return history
235
+
236
+ def retry_last(history, temp, max_tok, topk):
237
+ if not history:
238
+ return history
239
+ user_msg = history[-1][0]
240
+ bot_msg = get_response(user_msg, history[:-1])
241
+ history[-1][1] = bot_msg
242
+ return history
243
+
244
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
245
+ bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
246
  )
247
+ send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
248
+ bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot
 
249
  )
 
250
  clear_btn.click(lambda: None, None, chatbot, queue=False)
251
+ retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider], chatbot)
252
+
253
+ gr.HTML(f"<footer><center><p>🧠 Powered by LLaMA-based ChatDoctor | Device: {device.upper()}</p></center></footer>")
 
 
254
 
255
  # =============================
256
+ # Launch App
257
  # =============================
258
  if __name__ == "__main__":
259
+ print("\n💡 Launching ChatDoctor Gradio Interface...")
260
  demo.queue()
261
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)