Muhammadidrees commited on
Commit
cb7958f
·
verified ·
1 Parent(s): b100d07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -570
app.py CHANGED
@@ -1,601 +1,204 @@
1
- import os
2
- import gc
3
- import re
4
- import time
5
- import torch
6
  import gradio as gr
7
- import numpy as np
8
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
9
- from transformers import pipeline
10
- from collections import defaultdict
11
- from datetime import datetime, timedelta
12
- import tempfile
13
-
14
- # =============================
15
- # Configuration
16
- # =============================
17
- MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
18
- WHISPER_MODEL = "openai/whisper-small" # Change to "openai/whisper-base" for faster, or "openai/whisper-medium" for better accuracy
19
- TTS_MODEL = "suno/bark-small" # Alternative: "facebook/mms-tts-eng" for faster TTS
20
-
21
- MAX_NEW_TOKENS = 200
22
- TEMPERATURE = 0.5
23
- TOP_K = 50
24
- REPETITION_PENALTY = 1.1
25
- MAX_HISTORY_TURNS = 5
26
-
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- print(f"🚀 Loading models on {device}...")
29
-
30
- # =============================
31
- # Rate Limiting
32
- # =============================
33
- rate_limit_store = defaultdict(list)
34
- MAX_REQUESTS_PER_MINUTE = 10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- def check_rate_limit(session_id):
37
- """Simple rate limiting to prevent abuse"""
38
- now = datetime.now()
39
- rate_limit_store[session_id] = [
40
- timestamp for timestamp in rate_limit_store[session_id]
41
- if now - timestamp < timedelta(minutes=1)
42
- ]
43
-
44
- if len(rate_limit_store[session_id]) >= MAX_REQUESTS_PER_MINUTE:
45
- return False
46
-
47
- rate_limit_store[session_id].append(now)
48
- return True
 
 
 
 
 
 
 
 
 
 
49
 
50
- # ==========================
51
- # Load Models
52
- # =============================
53
- try:
54
- # Load ChatDoctor Model
55
- print("Loading ChatDoctor model...")
56
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
57
- model = AutoModelForCausalLM.from_pretrained(
58
- MODEL_PATH,
59
- device_map="auto",
60
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
- low_cpu_mem_usage=True
62
- )
63
- print("✅ ChatDoctor model loaded!")
64
-
65
- # Load Whisper (Speech-to-Text)
66
- print("Loading Whisper ASR model...")
67
- whisper_pipe = pipeline(
68
- "automatic-speech-recognition",
69
- model=WHISPER_MODEL,
70
- device=0 if torch.cuda.is_available() else -1
71
- )
72
- print("✅ Whisper model loaded!")
73
-
74
- # Load TTS Model
75
- print("Loading TTS model...")
76
  try:
77
- tts_pipe = pipeline(
78
- "text-to-speech",
79
- model=TTS_MODEL,
80
- device=0 if torch.cuda.is_available() else -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  )
82
- print("✅ TTS model loaded!")
83
- TTS_AVAILABLE = True
84
- except Exception as e:
85
- print(f"⚠️ TTS model not available: {e}")
86
- TTS_AVAILABLE = False
87
-
88
- except Exception as e:
89
- print(f"❌ Error loading models: {e}")
90
- raise
91
-
92
- # =============================
93
- # Stop Criteria
94
- # =============================
95
- class StopOnTokens(StoppingCriteria):
96
- def __init__(self, stop_ids):
97
- self.stop_ids = stop_ids
98
 
99
- def __call__(self, input_ids, scores, **kwargs):
100
- for stop_id_seq in self.stop_ids:
101
- if len(stop_id_seq) == 1:
102
- if input_ids[0][-1] == stop_id_seq[0]:
103
- return True
104
- else:
105
- if len(input_ids[0]) >= len(stop_id_seq):
106
- if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
107
- return True
108
- return False
109
 
110
- # =============================
111
- # Medical Keywords and Validation
112
- # =============================
113
- MEDICAL_KEYWORDS = [
114
- "pain", "ache", "symptom", "hurt", "sore", "discomfort", "fever", "cough", "flu",
115
- "infection", "allergy", "diabetes", "pressure", "asthma", "migraine", "vomit",
116
- "stomach", "head", "chest", "throat", "heart", "lung", "liver", "kidney", "brain",
117
- "doctor", "hospital", "medicine", "treatment", "therapy", "surgery", "disease",
118
- "illness", "blood", "test", "scan", "health", "diet", "nutrition", "stress", "sleep",
119
- "weight", "vitamin", "fatigue", "anxiety", "depression", "nausea", "dizziness",
120
- "rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak"
121
- ]
122
-
123
- EMERGENCY_KEYWORDS = [
124
- "suicide", "kill myself", "end my life", "chest pain", "can't breathe",
125
- "severe bleeding", "overdose", "poisoning", "unconscious", "seizure",
126
- "stroke", "heart attack", "choking"
127
- ]
128
-
129
- CASUAL_PATTERNS = [
130
- r"^(hey|hi|hello|sup|yo|wassup|hiya)\s*[\?\!\.]*$",
131
- r"^good\s+(morning|evening|afternoon|night)\s*[\?\!\.]*$",
132
- r"^how\s+are\s+you\s*[\?\!\.]*$",
133
- r"^what'?s\s+up\s*[\?\!\.]*$",
134
- ]
135
-
136
- DANGEROUS_PATTERNS = [
137
- r"take\s+\d+\s+(pills|tablets|capsules)",
138
- r"inject\s+(yourself|myself)",
139
- r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)",
140
- r"ignore\s+(doctor|medical|professional)",
141
- ]
142
-
143
- def is_emergency_query(message):
144
- message_lower = message.lower()
145
- return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS)
146
-
147
- def is_medical_query(message):
148
- message_lower = message.lower()
149
- for keyword in MEDICAL_KEYWORDS:
150
- if keyword in message_lower:
151
- return True
152
-
153
- question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"]
154
- words = message_lower.split()
155
- has_question = any(q in words[:4] for q in question_words)
156
-
157
- if has_question and len(words) > 5:
158
- return True
159
-
160
- return False
161
-
162
- def is_only_greeting(message):
163
- message_clean = message.lower().strip()
164
- message_clean = re.sub(r'[!?.]+$', '', message_clean)
165
-
166
- for pattern in CASUAL_PATTERNS:
167
- if re.match(pattern, message_clean):
168
- return True
169
-
170
- return False
171
-
172
- def contains_dangerous_advice(response):
173
- response_lower = response.lower()
174
- for pattern in DANGEROUS_PATTERNS:
175
- if re.search(pattern, response_lower):
176
- return True
177
- return False
178
-
179
- # =============================
180
- # Speech Processing Functions
181
- # =============================
182
- def transcribe_audio(audio):
183
- """Convert speech to text using Whisper"""
184
- if audio is None:
185
- return ""
186
-
187
- try:
188
- # Handle different audio input formats
189
- if isinstance(audio, tuple):
190
- sample_rate, audio_data = audio
191
- else:
192
- audio_data = audio
193
-
194
- # Ensure audio is in the right format
195
- if isinstance(audio_data, np.ndarray):
196
- if audio_data.dtype != np.float32:
197
- audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
198
-
199
- # Transcribe
200
- result = whisper_pipe(audio_data)
201
- transcription = result["text"].strip()
202
-
203
- return transcription
204
-
205
- except Exception as e:
206
- print(f"Error in transcription: {e}")
207
- return ""
208
 
209
- def text_to_speech(text):
210
- """Convert text to speech"""
211
- if not TTS_AVAILABLE or not text:
212
- return None
213
-
214
- try:
215
- # Limit text length for TTS (to avoid timeout)
216
- if len(text) > 500:
217
- text = text[:500] + "..."
218
-
219
- # Generate speech
220
- speech = tts_pipe(text)
221
-
222
- # Extract audio data
223
- audio_data = speech["audio"]
224
- sampling_rate = speech["sampling_rate"]
225
-
226
- return (sampling_rate, audio_data)
227
-
228
  except Exception as e:
229
- print(f"Error in TTS: {e}")
230
- return None
231
-
232
- # =============================
233
- # Get Response
234
- # =============================
235
- def get_response(user_input, history_context, session_id="default"):
236
- """Generate response with enhanced safety and quality checks"""
237
-
238
- if not check_rate_limit(session_id):
239
- return "⏰ You've made too many requests. Please wait a minute before trying again."
240
-
241
- if is_emergency_query(user_input):
242
- return (
243
- "🚨 **EMERGENCY DETECTED** 🚨\n\n"
244
- "If you are experiencing a medical emergency, please:\n"
245
- "• Call emergency services immediately (911 in US, 999 in UK, 112 in EU)\n"
246
- "• Go to the nearest emergency room\n"
247
- "• Contact your local emergency hotline\n\n"
248
- "This AI cannot provide emergency medical care. Please seek immediate professional help."
249
- )
250
-
251
- if is_only_greeting(user_input):
252
- return "👋 Hello! I'm ChatDoctor — your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
253
-
254
- if not is_medical_query(user_input):
255
- return (
256
- "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
257
- "I can help you with:\n"
258
- "• Symptoms and medical conditions\n"
259
- "• Treatment and prevention advice\n"
260
- "• Fitness, diet, and mental health tips\n\n"
261
- "Please describe your health concern in detail to get started."
262
- )
263
-
264
- human_prefix = "Patient:"
265
- doctor_prefix = "ChatDoctor:"
266
- system_instruction = (
267
- "You are ChatDoctor, a professional medical AI assistant. "
268
- "You provide accurate, concise, and empathetic responses to health-related questions only.\n"
269
- "Always recommend consulting a healthcare professional for serious conditions.\n"
270
- "Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n"
271
- )
272
-
273
- limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context
274
-
275
- history_text = [system_instruction]
276
- for human, assistant in limited_history:
277
- if human:
278
- history_text.append(f"{human_prefix} {human}")
279
- if assistant:
280
- history_text.append(f"{doctor_prefix} {assistant}")
281
- history_text.append(f"{human_prefix} {user_input}")
282
-
283
- prompt = "\n".join(history_text) + f"\n{doctor_prefix} "
284
-
285
- try:
286
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
287
-
288
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
289
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
290
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
291
-
292
- with torch.no_grad():
293
- output_ids = model.generate(
294
- input_ids,
295
- max_new_tokens=MAX_NEW_TOKENS,
296
- do_sample=True,
297
- temperature=TEMPERATURE,
298
- top_k=TOP_K,
299
- repetition_penalty=REPETITION_PENALTY,
300
- stopping_criteria=stopping_criteria,
301
- pad_token_id=tokenizer.eos_token_id,
302
- eos_token_id=tokenizer.eos_token_id
303
- )
304
-
305
- response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
306
-
307
- for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
308
- if stop_word in response:
309
- response = response.split(stop_word)[0].strip()
310
- break
311
-
312
- response = response.strip()
313
-
314
- if contains_dangerous_advice(response):
315
- response = (
316
- "I apologize, but I cannot provide that specific medical advice. "
317
- "Please consult with a qualified healthcare professional who can properly evaluate your situation."
318
- )
319
-
320
- if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]):
321
- response = (
322
- "I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. "
323
- "Please tell me more about your symptoms or health concerns so I can help you better."
324
- )
325
-
326
- serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"]
327
- if any(condition in response.lower() for condition in serious_conditions):
328
- response += "\n\n⚠️ **Important:** Please consult a healthcare professional for proper diagnosis and treatment."
329
 
330
- del input_ids, output_ids
331
- gc.collect()
332
- if torch.cuda.is_available():
333
- torch.cuda.empty_cache()
334
 
335
- return response
336
-
337
- except Exception as e:
338
- print(f"Error generating response: {e}")
339
- return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later."
340
 
341
- # =============================
342
- # Gradio Interface
343
- # =============================
344
  custom_css = """
345
- #header {
346
- text-align: center;
347
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
348
- color: white;
349
- padding: 25px;
350
- border-radius: 12px;
351
- margin-bottom: 20px;
352
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
353
  }
354
- #header h1 { margin: 0; font-size: 2.5em; font-weight: 700; }
355
- #header p { margin: 5px 0 0; font-size: 1.1em; opacity: 0.95; }
356
- .disclaimer {
357
- background-color: #fff3cd;
358
- border-left: 4px solid #ffc107;
359
- border-radius: 8px;
360
- padding: 18px;
361
- margin: 20px 0;
362
- color: #856404;
363
  }
364
- .disclaimer h3 { margin-top: 0; color: #d39e00; }
365
- .emergency-warning {
366
- background-color: #f8d7da;
367
- border-left: 4px solid #dc3545;
368
  border-radius: 8px;
369
  padding: 15px;
370
- margin: 15px 0;
371
- color: #721c24;
372
- }
373
- .voice-section {
374
- background: linear-gradient(135deg, #e0c3fc 0%, #8ec5fc 100%);
375
- border-radius: 10px;
376
- padding: 20px;
377
- margin: 15px 0;
378
- }
379
- footer {
380
- margin-top: 30px;
381
- padding: 15px;
382
- text-align: center;
383
- color: #6c757d;
384
- font-size: 0.9em;
385
  }
386
  """
387
 
388
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
389
- session_state = gr.State(value=str(time.time()))
390
-
391
- gr.HTML("""
392
- <div id="header">
393
- <h1>🩺 ChatDoctor AI Assistant</h1>
394
- <p>🎤 Voice-Enabled Medical Consultation Partner</p>
395
- </div>
396
- """)
397
-
398
- gr.HTML("""
399
- <div class="disclaimer">
400
- <h3>⚠️ Medical Disclaimer</h3>
401
- <p><strong>This AI assistant is for informational purposes only.</strong>
402
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.
403
- Always seek the advice of your physician or qualified health provider with any questions
404
- you may have regarding a medical condition.</p>
405
- </div>
406
- """)
407
-
408
- gr.HTML("""
409
- <div class="emergency-warning">
410
- <h4>🚨 In Case of Emergency</h4>
411
- <p>If you are experiencing a medical emergency, call emergency services immediately
412
- (911 in US, 999 in UK, 112 in EU) or go to the nearest emergency room.</p>
413
- </div>
414
- """)
415
-
416
- with gr.Tab("💬 Text Chat"):
417
- chatbot = gr.Chatbot(
418
- height=500,
419
- placeholder="<div style='text-align:center;padding:50px;'><h3>👋 Welcome to ChatDoctor!</h3><p style='color:#6c757d;'>Describe your symptoms or ask a health-related question to begin.</p></div>",
420
- show_label=False,
421
- avatar_images=(None, "🤖"),
422
- )
423
 
424
- with gr.Row():
425
- msg = gr.Textbox(
426
- placeholder="Type your medical concern here...",
427
- show_label=False,
428
- scale=9,
429
- container=False,
430
- lines=1
431
- )
432
- send_btn = gr.Button("Send 📤", scale=1, variant="primary")
 
 
433
 
434
- with gr.Row():
435
- clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
436
- retry_btn = gr.Button("🔄 Retry", scale=1)
 
 
 
 
 
437
 
438
- with gr.Tab("🎤 Voice Chat"):
439
- gr.HTML('<div class="voice-section"><h3>🎙️ Voice Interaction</h3><p>Record your medical question and get voice responses!</p></div>')
440
-
441
- voice_chatbot = gr.Chatbot(
442
- height=400,
443
- placeholder="<div style='text-align:center;padding:40px;'><h3>🎤 Voice Chat Mode</h3><p>Click the microphone to record your question</p></div>",
444
  show_label=False,
445
- avatar_images=(None, "🤖"),
446
- )
447
-
448
- with gr.Row():
449
- audio_input = gr.Audio(
450
- sources=["microphone"],
451
- type="numpy",
452
- label="🎤 Record Your Question",
453
- scale=8
454
- )
455
- voice_send_btn = gr.Button("Send Voice 🎙️", scale=2, variant="primary")
456
-
457
- audio_output = gr.Audio(
458
- label="🔊 Voice Response",
459
- autoplay=True,
460
- visible=TTS_AVAILABLE
461
- )
462
-
463
- transcribed_text = gr.Textbox(
464
- label="📝 Transcribed Text",
465
- interactive=False,
466
- visible=True
467
  )
468
-
469
- with gr.Row():
470
- voice_clear_btn = gr.Button("🗑️ Clear Voice Chat", scale=1)
471
-
472
- if not TTS_AVAILABLE:
473
- gr.Warning("⚠️ TTS model not available. Voice responses disabled. Text responses will still work.")
474
-
475
- with gr.Accordion("⚙️ Advanced Settings", open=False):
476
- temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)")
477
- max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
478
- top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling")
479
-
480
- # =============================
481
- # Text Chat Functions
482
- # =============================
483
- def user_message(user_msg, history):
484
- if not user_msg.strip():
485
- return "", history
486
- return "", history + [[user_msg, None]]
487
-
488
- def bot_response(history, temp, max_tok, topk, session_id):
489
- if not history or history[-1][1] is not None:
490
- return history
491
-
492
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
493
- TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
494
-
495
- user_msg = history[-1][0]
496
- bot_msg = get_response(user_msg, history[:-1], session_id)
497
- history[-1][1] = bot_msg
498
- return history
499
-
500
- def retry_last(history, temp, max_tok, topk, session_id):
501
- if not history:
502
- return history
503
- user_msg = history[-1][0]
504
- bot_msg = get_response(user_msg, history[:-1], session_id)
505
- history[-1][1] = bot_msg
506
- return history
507
-
508
- # =============================
509
- # Voice Chat Functions
510
- # =============================
511
- def text_to_speech(text):
512
- # Convert text to speech using Bark
513
- from transformers import AutoProcessor, BarkModel
514
- import numpy as np
515
-
516
- processor = AutoProcessor.from_pretrained("suno/bark-small")
517
- model = BarkModel.from_pretrained("suno/bark-small")
518
-
519
- inputs = processor(text, voice_preset="v2/en_speaker_6", return_tensors="pt")
520
- speech = model.generate(**inputs)
521
-
522
- # ✅ Extract and normalize audio data
523
- audio_data = speech["audio"]
524
- sampling_rate = speech["sampling_rate"]
525
-
526
- # 🔊 Normalize & clip Bark audio output to avoid struct.error
527
- if isinstance(audio_data, np.ndarray):
528
- audio_data = np.clip(audio_data, -1.0, 1.0).astype(np.float32)
529
- else:
530
- audio_data = np.array(audio_data, dtype=np.float32)
531
- audio_data = np.clip(audio_data, -1.0, 1.0)
532
-
533
- return (sampling_rate, audio_data)
534
-
535
- def process_voice_input(audio, history, temp, max_tok, topk, session_id):
536
- """Process voice input: transcribe, get response, convert to speech"""
537
- if audio is None:
538
- return history, "", None
539
-
540
- # Transcribe audio to text
541
- transcribed = transcribe_audio(audio)
542
-
543
- if not transcribed:
544
- return history, "⚠️ Could not transcribe audio. Please try again.", None
545
-
546
- # Add to history
547
- history = history + [[transcribed, None]]
548
-
549
- # Get bot response
550
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
551
- TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
552
-
553
- bot_msg = get_response(transcribed, history[:-1], session_id)
554
- history[-1][1] = bot_msg
555
-
556
- # Convert response to speech
557
- audio_response = text_to_speech(bot_msg) if TTS_AVAILABLE else None
558
-
559
- return history, transcribed, audio_response
560
 
561
- # Text Chat Events
562
- msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
563
- bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
564
  )
565
- send_btn.click(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
566
- bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
567
  )
568
- clear_btn.click(lambda: None, None, chatbot, queue=False)
569
- retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot)
570
-
571
- # Voice Chat Events
572
- voice_send_btn.click(
573
- process_voice_input,
574
- [audio_input, voice_chatbot, temp_slider, max_tok_slider, top_k_slider, session_state],
575
- [voice_chatbot, transcribed_text, audio_output]
576
  )
577
- voice_clear_btn.click(lambda: (None, "", None), None, [voice_chatbot, transcribed_text, audio_output], queue=False)
578
 
579
- gr.HTML(f"""
580
- <footer>
581
- <p><strong>🧠 Powered by LLaMA + Whisper + TTS</strong></p>
582
- <p>Device: {device.upper()} | Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute</p>
583
- <p>🎤 Voice: Whisper ASR | 🔊 TTS: {"Enabled" if TTS_AVAILABLE else "Disabled"}</p>
584
- <p style='font-size:0.85em;margin-top:10px;'>
585
- This AI provides general health information only. Always consult healthcare professionals for medical advice.
586
- </p>
587
- </footer>
588
- """)
589
-
590
- # =============================
591
- # Launch App
592
- # =============================
593
  if __name__ == "__main__":
594
- print("\n💡 Launching ChatDoctor with Voice Support...")
595
- print(f"📊 Configuration:")
596
- print(f" - Device: {device.upper()}")
597
- print(f" - Whisper Model: {WHISPER_MODEL}")
598
- print(f" - TTS Available: {TTS_AVAILABLE}")
599
- print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute")
600
- demo.queue()
601
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
1
  import gradio as gr
2
+ from groq import GroqClient
3
+
4
+ # ==============================
5
+ # Initialize Groq client
6
+ # ==============================
7
+ client = GroqClient(api_key="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")
8
+
9
+ # ==============================
10
+ # System Prompt for Doctor
11
+ # ==============================
12
+ SYSTEM_PROMPT = """
13
+ You are Dr. HealBot, a calm, knowledgeable, and empathetic doctor talking to a patient.
14
+
15
+ GOAL:
16
+ Have a natural conversation — ask 3-4 short medical questions to understand the patient's condition,
17
+ then start giving practical advice including:
18
+ - possible over-the-counter medicines (generic name only)
19
+ - simple lifestyle or habit changes
20
+ - nutrition or exercise guidance
21
+ - when to see a real doctor
22
+
23
+ TONE & STYLE:
24
+ - Speak like a real doctor, short and direct sentences (1-2 lines max).
25
+ - Be warm but professional.
26
+ - Use plain language — no medical jargon unless necessary.
27
+ - No bullet points or lists — just natural speech.
28
+ - Only one question per response, until enough info is gathered.
29
+ - After about 4 patient answers, switch to giving advice.
30
+
31
+ CONVERSATION FLOW EXAMPLE:
32
+ Doctor: How can I help you?
33
+ Patient: I’ve had a cough for 2 weeks.
34
+ Doctor: Is it dry or with phlegm?
35
+ Patient: With phlegm.
36
+ Doctor: Do you have fever or chest pain?
37
+ Patient: Mild fever.
38
+ Doctor: Do you smoke or have allergies?
39
+ Patient: I smoke.
40
+ Doctor: Sounds like a mild chest infection. You can try paracetamol for fever and warm fluids.
41
+ Cut down on smoking and rest. If symptoms persist beyond 5 days, see a doctor.
42
+
43
+ ALWAYS END with a gentle reminder:
44
+ "Please consult a qualified doctor if it doesn’t improve or if symptoms worsen."
45
+ """
46
 
47
+ # ==============================
48
+ # Initial greeting
49
+ # ==============================
50
+ INITIAL_MESSAGE = "How can I help you today?"
51
+
52
+ # ==============================
53
+ # Chat logic
54
+ # ==============================
55
+ def chat_with_doctor(message, history):
56
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
57
+
58
+ # Build chat history
59
+ for chat in history:
60
+ if isinstance(chat, dict):
61
+ messages.append(chat)
62
+ elif isinstance(chat, (list, tuple)) and len(chat) == 2:
63
+ if chat[0]:
64
+ messages.append({"role": "user", "content": chat[0]})
65
+ if chat[1]:
66
+ messages.append({"role": "assistant", "content": chat[1]})
67
+
68
+ # Add current patient message
69
+ messages.append({"role": "user", "content": message})
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  try:
72
+ # Count how many patient turns have occurred
73
+ patient_turns = sum(1 for chat in history if isinstance(chat, (list, tuple)) and chat[0])
74
+
75
+ # After 4 patient turns, guide the model to provide recommendations
76
+ if patient_turns >= 4:
77
+ messages.append({
78
+ "role": "system",
79
+ "content": (
80
+ "Now begin giving specific recommendations based on the patient's symptoms. "
81
+ "Include possible generic medicines (like paracetamol, ibuprofen, etc.), "
82
+ "lifestyle and nutrition tips, and when to seek medical attention. "
83
+ "Keep it short and empathetic, like a real doctor speaking naturally."
84
+ )
85
+ })
86
+
87
+ # Generate the response using Groq LLM
88
+ chat_completion = client.chat.completions.create(
89
+ messages=messages,
90
+ model="llama-3.3-70b-versatile",
91
+ temperature=0.6,
92
+ max_tokens=120,
93
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ response = chat_completion.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
96
 
97
+ # Append to history
98
+ history.append([message, response])
99
+ return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
+ error_msg = f"⚠️ Error: {str(e)}. Please check your API connection and try again."
103
+ history.append([message, error_msg])
104
+ return history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
 
 
 
 
106
 
107
+ def reset_conversation():
108
+ """Reset the chat to start fresh"""
109
+ return []
 
 
110
 
111
+ # ==============================
112
+ # Custom CSS
113
+ # ==============================
114
  custom_css = """
115
+ #chatbot {
116
+ height: 600px;
 
 
 
 
 
 
117
  }
118
+ .gradio-container {
119
+ font-family: 'Arial', sans-serif;
 
 
 
 
 
 
 
120
  }
121
+ #warning {
122
+ background-color: #fff3cd;
123
+ border: 1px solid #ffc107;
 
124
  border-radius: 8px;
125
  padding: 15px;
126
+ margin: 10px 0;
127
+ color: #856404;
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  }
129
  """
130
 
131
+ # ==============================
132
+ # Gradio Interface
133
+ # ==============================
134
+ with gr.Blocks(css=custom_css, title="AI Medical Consultant") as demo:
135
+ gr.Markdown(
136
+ """
137
+ # 🏥 AI Medical Consultant
138
+ ### Realistic Doctor-Patient Conversation • Medicine • Lifestyle • Nutrition
139
+ """
140
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ gr.HTML(
143
+ """
144
+ <div id="warning">
145
+ <strong>⚠️ Medical Disclaimer:</strong><br>
146
+ This AI provides general health information only. It is <b>NOT</b> a substitute for
147
+ professional medical advice, diagnosis, or treatment.<br>
148
+ Always consult qualified healthcare providers for medical concerns.<br>
149
+ For emergencies, call your local emergency number immediately.
150
+ </div>
151
+ """
152
+ )
153
 
154
+ chatbot = gr.Chatbot(
155
+ value=[[None, INITIAL_MESSAGE]],
156
+ elem_id="chatbot",
157
+ height=600,
158
+ show_label=False,
159
+ avatar_images=(None, "https://api.dicebear.com/7.x/bottts/svg?seed=doctor"),
160
+ type="tuples"
161
+ )
162
 
163
+ with gr.Row():
164
+ msg = gr.Textbox(
165
+ placeholder="Describe your symptoms or ask a question...",
 
 
 
166
  show_label=False,
167
+ scale=9,
168
+ lines=2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  )
170
+ submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
171
+
172
+ with gr.Row():
173
+ clear_btn = gr.Button("🔄 Start New Consultation", variant="secondary")
174
+
175
+ gr.Markdown(
176
+ """
177
+ ### 💡 Tips for Best Results:
178
+ - Be specific about your symptoms (location, severity, duration)
179
+ - Mention any relevant medical history or medications
180
+ - Ask follow-up questions freely
181
+ """
182
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ # Event Handlers
185
+ msg.submit(chat_with_doctor, [msg, chatbot], [chatbot]).then(
186
+ lambda: gr.update(value=""), None, [msg]
187
  )
188
+ submit_btn.click(chat_with_doctor, [msg, chatbot], [chatbot]).then(
189
+ lambda: gr.update(value=""), None, [msg]
190
  )
191
+ clear_btn.click(reset_conversation, None, [chatbot]).then(
192
+ lambda: [[None, INITIAL_MESSAGE]], None, [chatbot]
 
 
 
 
 
 
193
  )
 
194
 
195
+ # ==============================
196
+ # Launch app
197
+ # ==============================
 
 
 
 
 
 
 
 
 
 
 
198
  if __name__ == "__main__":
199
+ demo.launch(
200
+ share=True,
201
+ show_error=True,
202
+ server_name="0.0.0.0",
203
+ server_port=7860
204
+ )