Muhammadidrees commited on
Commit
c0ebb5e
Β·
verified Β·
1 Parent(s): 4df8616

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +208 -71
app.py CHANGED
@@ -4,25 +4,31 @@ import re
4
  import time
5
  import torch
6
  import gradio as gr
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
 
8
  from collections import defaultdict
9
  from datetime import datetime, timedelta
 
10
 
11
  # =============================
12
  # Configuration
13
  # =============================
14
  MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
 
 
 
15
  MAX_NEW_TOKENS = 200
16
  TEMPERATURE = 0.5
17
  TOP_K = 50
18
  REPETITION_PENALTY = 1.1
19
- MAX_HISTORY_TURNS = 5 # Limit conversation history
20
 
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- print(f"πŸš€ Loading model from {MODEL_PATH} on {device}...")
23
 
24
  # =============================
25
- # Rate Limiting (Simple IP-based)
26
  # =============================
27
  rate_limit_store = defaultdict(list)
28
  MAX_REQUESTS_PER_MINUTE = 10
@@ -42,9 +48,11 @@ def check_rate_limit(session_id):
42
  return True
43
 
44
  # ==========================
45
- # Load Model & Tokenizer
46
  # =============================
47
  try:
 
 
48
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
49
  model = AutoModelForCausalLM.from_pretrained(
50
  MODEL_PATH,
@@ -52,9 +60,33 @@ try:
52
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
53
  low_cpu_mem_usage=True
54
  )
55
- print("βœ… ChatDoctor model loaded successfully!\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  except Exception as e:
57
- print(f"❌ Error loading model: {e}")
58
  raise
59
 
60
  # =============================
@@ -75,7 +107,6 @@ class StopOnTokens(StoppingCriteria):
75
  return True
76
  return False
77
 
78
-
79
  # =============================
80
  # Medical Keywords and Validation
81
  # =============================
@@ -89,7 +120,6 @@ MEDICAL_KEYWORDS = [
89
  "rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak"
90
  ]
91
 
92
- # Emergency keywords that should trigger immediate medical attention warning
93
  EMERGENCY_KEYWORDS = [
94
  "suicide", "kill myself", "end my life", "chest pain", "can't breathe",
95
  "severe bleeding", "overdose", "poisoning", "unconscious", "seizure",
@@ -103,23 +133,23 @@ CASUAL_PATTERNS = [
103
  r"^what'?s\s+up\s*[\?\!\.]*$",
104
  ]
105
 
 
 
 
 
 
 
106
 
107
  def is_emergency_query(message):
108
- """Detect if query contains emergency keywords"""
109
  message_lower = message.lower()
110
  return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS)
111
 
112
-
113
  def is_medical_query(message):
114
- """Enhanced medical query detection"""
115
  message_lower = message.lower()
116
-
117
- # Check for medical keywords
118
  for keyword in MEDICAL_KEYWORDS:
119
  if keyword in message_lower:
120
  return True
121
 
122
- # Check for question patterns with sufficient length
123
  question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"]
124
  words = message_lower.split()
125
  has_question = any(q in words[:4] for q in question_words)
@@ -129,42 +159,75 @@ def is_medical_query(message):
129
 
130
  return False
131
 
132
-
133
  def is_only_greeting(message):
134
- """Improved greeting detection using regex"""
135
  message_clean = message.lower().strip()
136
-
137
- # Remove punctuation for matching
138
  message_clean = re.sub(r'[!?.]+$', '', message_clean)
139
 
140
- # Check if it matches any casual pattern
141
  for pattern in CASUAL_PATTERNS:
142
  if re.match(pattern, message_clean):
143
  return True
144
 
145
  return False
146
 
147
-
148
- # =============================
149
- # Safety Filter
150
- # =============================
151
- DANGEROUS_PATTERNS = [
152
- r"take\s+\d+\s+(pills|tablets|capsules)",
153
- r"inject\s+(yourself|myself)",
154
- r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)",
155
- r"ignore\s+(doctor|medical|professional)",
156
- ]
157
-
158
  def contains_dangerous_advice(response):
159
- """Check if response contains potentially dangerous medical advice"""
160
  response_lower = response.lower()
161
-
162
  for pattern in DANGEROUS_PATTERNS:
163
  if re.search(pattern, response_lower):
164
  return True
165
-
166
  return False
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # =============================
170
  # Get Response
@@ -172,11 +235,9 @@ def contains_dangerous_advice(response):
172
  def get_response(user_input, history_context, session_id="default"):
173
  """Generate response with enhanced safety and quality checks"""
174
 
175
- # Rate limiting check
176
  if not check_rate_limit(session_id):
177
  return "⏰ You've made too many requests. Please wait a minute before trying again."
178
 
179
- # Emergency detection
180
  if is_emergency_query(user_input):
181
  return (
182
  "🚨 **EMERGENCY DETECTED** 🚨\n\n"
@@ -187,11 +248,9 @@ def get_response(user_input, history_context, session_id="default"):
187
  "This AI cannot provide emergency medical care. Please seek immediate professional help."
188
  )
189
 
190
- # Greeting detection
191
  if is_only_greeting(user_input):
192
  return "πŸ‘‹ Hello! I'm ChatDoctor β€” your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
193
 
194
- # Non-medical query handling
195
  if not is_medical_query(user_input):
196
  return (
197
  "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
@@ -202,7 +261,6 @@ def get_response(user_input, history_context, session_id="default"):
202
  "Please describe your health concern in detail to get started."
203
  )
204
 
205
- # Build prompt with limited history
206
  human_prefix = "Patient:"
207
  doctor_prefix = "ChatDoctor:"
208
  system_instruction = (
@@ -212,7 +270,6 @@ def get_response(user_input, history_context, session_id="default"):
212
  "Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n"
213
  )
214
 
215
- # Limit history to prevent token overflow
216
  limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context
217
 
218
  history_text = [system_instruction]
@@ -228,7 +285,6 @@ def get_response(user_input, history_context, session_id="default"):
228
  try:
229
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
230
 
231
- # Stop words for cleaner output
232
  stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
233
  stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
234
  stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
@@ -248,7 +304,6 @@ def get_response(user_input, history_context, session_id="default"):
248
 
249
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
250
 
251
- # Clean up response
252
  for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
253
  if stop_word in response:
254
  response = response.split(stop_word)[0].strip()
@@ -256,26 +311,22 @@ def get_response(user_input, history_context, session_id="default"):
256
 
257
  response = response.strip()
258
 
259
- # Safety filter
260
  if contains_dangerous_advice(response):
261
  response = (
262
  "I apologize, but I cannot provide that specific medical advice. "
263
  "Please consult with a qualified healthcare professional who can properly evaluate your situation."
264
  )
265
 
266
- # Filter out inappropriate content
267
  if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]):
268
  response = (
269
  "I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. "
270
  "Please tell me more about your symptoms or health concerns so I can help you better."
271
  )
272
 
273
- # Add disclaimer for serious conditions
274
  serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"]
275
  if any(condition in response.lower() for condition in serious_conditions):
276
  response += "\n\n⚠️ **Important:** Please consult a healthcare professional for proper diagnosis and treatment."
277
 
278
- # Clean up memory
279
  del input_ids, output_ids
280
  gc.collect()
281
  if torch.cuda.is_available():
@@ -287,7 +338,6 @@ def get_response(user_input, history_context, session_id="default"):
287
  print(f"Error generating response: {e}")
288
  return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later."
289
 
290
-
291
  # =============================
292
  # Gradio Interface
293
  # =============================
@@ -320,6 +370,12 @@ custom_css = """
320
  margin: 15px 0;
321
  color: #721c24;
322
  }
 
 
 
 
 
 
323
  footer {
324
  margin-top: 30px;
325
  padding: 15px;
@@ -328,13 +384,14 @@ footer {
328
  font-size: 0.9em;
329
  }
330
  """
 
331
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
332
- session_state = gr.State(value=str(time.time())) # Unique session ID
333
 
334
  gr.HTML("""
335
  <div id="header">
336
  <h1>🩺 ChatDoctor AI Assistant</h1>
337
- <p>Your AI-powered medical consultation partner</p>
338
  </div>
339
  """)
340
 
@@ -356,32 +413,73 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
356
  </div>
357
  """)
358
 
359
- chatbot = gr.Chatbot(
360
- height=500,
361
- placeholder="<div style='text-align:center;padding:50px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p style='color:#6c757d;'>Describe your symptoms or ask a health-related question to begin.</p><p style='color:#dc3545;margin-top:15px;'><strong>Remember:</strong> This is not a replacement for professional medical care.</p></div>",
362
- show_label=False,
363
- avatar_images=(None, "πŸ€–"),
364
- )
365
-
366
- with gr.Row():
367
- msg = gr.Textbox(
368
- placeholder="Type your medical concern here... (e.g., 'I have a headache for 3 days')",
369
  show_label=False,
370
- scale=9,
371
- container=False,
372
- lines=1
373
  )
374
- send_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
375
 
376
- with gr.Row():
377
- clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
378
- retry_btn = gr.Button("πŸ”„ Retry", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  with gr.Accordion("βš™οΈ Advanced Settings", open=False):
381
  temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)")
382
  max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
383
  top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling")
384
 
 
 
 
385
  def user_message(user_msg, history):
386
  if not user_msg.strip():
387
  return "", history
@@ -407,6 +505,36 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
407
  history[-1][1] = bot_msg
408
  return history
409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
411
  bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
412
  )
@@ -416,10 +544,19 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
416
  clear_btn.click(lambda: None, None, chatbot, queue=False)
417
  retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot)
418
 
 
 
 
 
 
 
 
 
419
  gr.HTML(f"""
420
  <footer>
421
- <p><strong>🧠 Powered by LLaMA-based ChatDoctor</strong></p>
422
  <p>Device: {device.upper()} | Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute</p>
 
423
  <p style='font-size:0.85em;margin-top:10px;'>
424
  This AI provides general health information only. Always consult healthcare professionals for medical advice.
425
  </p>
@@ -430,11 +567,11 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
430
  # Launch App
431
  # =============================
432
  if __name__ == "__main__":
433
- print("\nπŸ’‘ Launching Enhanced ChatDoctor Gradio Interface...")
434
  print(f"πŸ“Š Configuration:")
435
- print(f" - Max History Turns: {MAX_HISTORY_TURNS}")
436
- print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute")
437
  print(f" - Device: {device.upper()}")
 
 
 
438
  demo.queue()
439
- demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
440
-
 
4
  import time
5
  import torch
6
  import gradio as gr
7
+ import numpy as np
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
9
+ from transformers import pipeline
10
  from collections import defaultdict
11
  from datetime import datetime, timedelta
12
+ import tempfile
13
 
14
  # =============================
15
  # Configuration
16
  # =============================
17
  MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
18
+ WHISPER_MODEL = "openai/whisper-small" # Change to "openai/whisper-base" for faster, or "openai/whisper-medium" for better accuracy
19
+ TTS_MODEL = "suno/bark-small" # Alternative: "facebook/mms-tts-eng" for faster TTS
20
+
21
  MAX_NEW_TOKENS = 200
22
  TEMPERATURE = 0.5
23
  TOP_K = 50
24
  REPETITION_PENALTY = 1.1
25
+ MAX_HISTORY_TURNS = 5
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"πŸš€ Loading models on {device}...")
29
 
30
  # =============================
31
+ # Rate Limiting
32
  # =============================
33
  rate_limit_store = defaultdict(list)
34
  MAX_REQUESTS_PER_MINUTE = 10
 
48
  return True
49
 
50
  # ==========================
51
+ # Load Models
52
  # =============================
53
  try:
54
+ # Load ChatDoctor Model
55
+ print("Loading ChatDoctor model...")
56
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
57
  model = AutoModelForCausalLM.from_pretrained(
58
  MODEL_PATH,
 
60
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
61
  low_cpu_mem_usage=True
62
  )
63
+ print("βœ… ChatDoctor model loaded!")
64
+
65
+ # Load Whisper (Speech-to-Text)
66
+ print("Loading Whisper ASR model...")
67
+ whisper_pipe = pipeline(
68
+ "automatic-speech-recognition",
69
+ model=WHISPER_MODEL,
70
+ device=0 if torch.cuda.is_available() else -1
71
+ )
72
+ print("βœ… Whisper model loaded!")
73
+
74
+ # Load TTS Model
75
+ print("Loading TTS model...")
76
+ try:
77
+ tts_pipe = pipeline(
78
+ "text-to-speech",
79
+ model=TTS_MODEL,
80
+ device=0 if torch.cuda.is_available() else -1
81
+ )
82
+ print("βœ… TTS model loaded!")
83
+ TTS_AVAILABLE = True
84
+ except Exception as e:
85
+ print(f"⚠️ TTS model not available: {e}")
86
+ TTS_AVAILABLE = False
87
+
88
  except Exception as e:
89
+ print(f"❌ Error loading models: {e}")
90
  raise
91
 
92
  # =============================
 
107
  return True
108
  return False
109
 
 
110
  # =============================
111
  # Medical Keywords and Validation
112
  # =============================
 
120
  "rash", "swelling", "injury", "bruise", "cold", "sneeze", "tired", "weak"
121
  ]
122
 
 
123
  EMERGENCY_KEYWORDS = [
124
  "suicide", "kill myself", "end my life", "chest pain", "can't breathe",
125
  "severe bleeding", "overdose", "poisoning", "unconscious", "seizure",
 
133
  r"^what'?s\s+up\s*[\?\!\.]*$",
134
  ]
135
 
136
+ DANGEROUS_PATTERNS = [
137
+ r"take\s+\d+\s+(pills|tablets|capsules)",
138
+ r"inject\s+(yourself|myself)",
139
+ r"(don't|do not)\s+go\s+to\s+(hospital|doctor|emergency)",
140
+ r"ignore\s+(doctor|medical|professional)",
141
+ ]
142
 
143
  def is_emergency_query(message):
 
144
  message_lower = message.lower()
145
  return any(keyword in message_lower for keyword in EMERGENCY_KEYWORDS)
146
 
 
147
  def is_medical_query(message):
 
148
  message_lower = message.lower()
 
 
149
  for keyword in MEDICAL_KEYWORDS:
150
  if keyword in message_lower:
151
  return True
152
 
 
153
  question_words = ["what", "how", "why", "when", "where", "can", "should", "is", "are", "do", "does", "could", "would"]
154
  words = message_lower.split()
155
  has_question = any(q in words[:4] for q in question_words)
 
159
 
160
  return False
161
 
 
162
  def is_only_greeting(message):
 
163
  message_clean = message.lower().strip()
 
 
164
  message_clean = re.sub(r'[!?.]+$', '', message_clean)
165
 
 
166
  for pattern in CASUAL_PATTERNS:
167
  if re.match(pattern, message_clean):
168
  return True
169
 
170
  return False
171
 
 
 
 
 
 
 
 
 
 
 
 
172
  def contains_dangerous_advice(response):
 
173
  response_lower = response.lower()
 
174
  for pattern in DANGEROUS_PATTERNS:
175
  if re.search(pattern, response_lower):
176
  return True
 
177
  return False
178
 
179
+ # =============================
180
+ # Speech Processing Functions
181
+ # =============================
182
+ def transcribe_audio(audio):
183
+ """Convert speech to text using Whisper"""
184
+ if audio is None:
185
+ return ""
186
+
187
+ try:
188
+ # Handle different audio input formats
189
+ if isinstance(audio, tuple):
190
+ sample_rate, audio_data = audio
191
+ else:
192
+ audio_data = audio
193
+
194
+ # Ensure audio is in the right format
195
+ if isinstance(audio_data, np.ndarray):
196
+ if audio_data.dtype != np.float32:
197
+ audio_data = audio_data.astype(np.float32) / np.iinfo(audio_data.dtype).max
198
+
199
+ # Transcribe
200
+ result = whisper_pipe(audio_data)
201
+ transcription = result["text"].strip()
202
+
203
+ return transcription
204
+
205
+ except Exception as e:
206
+ print(f"Error in transcription: {e}")
207
+ return ""
208
+
209
+ def text_to_speech(text):
210
+ """Convert text to speech"""
211
+ if not TTS_AVAILABLE or not text:
212
+ return None
213
+
214
+ try:
215
+ # Limit text length for TTS (to avoid timeout)
216
+ if len(text) > 500:
217
+ text = text[:500] + "..."
218
+
219
+ # Generate speech
220
+ speech = tts_pipe(text)
221
+
222
+ # Extract audio data
223
+ audio_data = speech["audio"]
224
+ sampling_rate = speech["sampling_rate"]
225
+
226
+ return (sampling_rate, audio_data)
227
+
228
+ except Exception as e:
229
+ print(f"Error in TTS: {e}")
230
+ return None
231
 
232
  # =============================
233
  # Get Response
 
235
  def get_response(user_input, history_context, session_id="default"):
236
  """Generate response with enhanced safety and quality checks"""
237
 
 
238
  if not check_rate_limit(session_id):
239
  return "⏰ You've made too many requests. Please wait a minute before trying again."
240
 
 
241
  if is_emergency_query(user_input):
242
  return (
243
  "🚨 **EMERGENCY DETECTED** 🚨\n\n"
 
248
  "This AI cannot provide emergency medical care. Please seek immediate professional help."
249
  )
250
 
 
251
  if is_only_greeting(user_input):
252
  return "πŸ‘‹ Hello! I'm ChatDoctor β€” your AI medical assistant. Please tell me about any health symptoms or medical concerns you'd like to discuss."
253
 
 
254
  if not is_medical_query(user_input):
255
  return (
256
  "Hello! I'm ChatDoctor, an AI medical assistant specialized in health and wellness.\n\n"
 
261
  "Please describe your health concern in detail to get started."
262
  )
263
 
 
264
  human_prefix = "Patient:"
265
  doctor_prefix = "ChatDoctor:"
266
  system_instruction = (
 
270
  "Never provide dosage instructions or tell patients to avoid seeking professional help.\n\n"
271
  )
272
 
 
273
  limited_history = history_context[-MAX_HISTORY_TURNS:] if len(history_context) > MAX_HISTORY_TURNS else history_context
274
 
275
  history_text = [system_instruction]
 
285
  try:
286
  input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
287
 
 
288
  stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
289
  stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
290
  stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
 
304
 
305
  response = tokenizer.decode(output_ids[0], skip_special_tokens=True)[len(prompt):].strip()
306
 
 
307
  for stop_word in ["Patient:", "Patient :", "\nPatient", "Patient"]:
308
  if stop_word in response:
309
  response = response.split(stop_word)[0].strip()
 
311
 
312
  response = response.strip()
313
 
 
314
  if contains_dangerous_advice(response):
315
  response = (
316
  "I apologize, but I cannot provide that specific medical advice. "
317
  "Please consult with a qualified healthcare professional who can properly evaluate your situation."
318
  )
319
 
 
320
  if any(x in response.lower() for x in ["chatbot", "api key", "error", "cloud", "sorry, i don't have"]):
321
  response = (
322
  "I apologize for the confusion. I'm ChatDoctor, trained to assist with medical and health-related topics. "
323
  "Please tell me more about your symptoms or health concerns so I can help you better."
324
  )
325
 
 
326
  serious_conditions = ["cancer", "tumor", "heart disease", "stroke", "diabetes complications"]
327
  if any(condition in response.lower() for condition in serious_conditions):
328
  response += "\n\n⚠️ **Important:** Please consult a healthcare professional for proper diagnosis and treatment."
329
 
 
330
  del input_ids, output_ids
331
  gc.collect()
332
  if torch.cuda.is_available():
 
338
  print(f"Error generating response: {e}")
339
  return "I apologize, but I encountered an error processing your request. Please try rephrasing your question or try again later."
340
 
 
341
  # =============================
342
  # Gradio Interface
343
  # =============================
 
370
  margin: 15px 0;
371
  color: #721c24;
372
  }
373
+ .voice-section {
374
+ background: linear-gradient(135deg, #e0c3fc 0%, #8ec5fc 100%);
375
+ border-radius: 10px;
376
+ padding: 20px;
377
+ margin: 15px 0;
378
+ }
379
  footer {
380
  margin-top: 30px;
381
  padding: 15px;
 
384
  font-size: 0.9em;
385
  }
386
  """
387
+
388
  with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
389
+ session_state = gr.State(value=str(time.time()))
390
 
391
  gr.HTML("""
392
  <div id="header">
393
  <h1>🩺 ChatDoctor AI Assistant</h1>
394
+ <p>🎀 Voice-Enabled Medical Consultation Partner</p>
395
  </div>
396
  """)
397
 
 
413
  </div>
414
  """)
415
 
416
+ with gr.Tab("πŸ’¬ Text Chat"):
417
+ chatbot = gr.Chatbot(
418
+ height=500,
419
+ placeholder="<div style='text-align:center;padding:50px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p style='color:#6c757d;'>Describe your symptoms or ask a health-related question to begin.</p></div>",
 
 
 
 
 
 
420
  show_label=False,
421
+ avatar_images=(None, "πŸ€–"),
 
 
422
  )
 
423
 
424
+ with gr.Row():
425
+ msg = gr.Textbox(
426
+ placeholder="Type your medical concern here...",
427
+ show_label=False,
428
+ scale=9,
429
+ container=False,
430
+ lines=1
431
+ )
432
+ send_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
433
+
434
+ with gr.Row():
435
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
436
+ retry_btn = gr.Button("πŸ”„ Retry", scale=1)
437
+
438
+ with gr.Tab("🎀 Voice Chat"):
439
+ gr.HTML('<div class="voice-section"><h3>πŸŽ™οΈ Voice Interaction</h3><p>Record your medical question and get voice responses!</p></div>')
440
+
441
+ voice_chatbot = gr.Chatbot(
442
+ height=400,
443
+ placeholder="<div style='text-align:center;padding:40px;'><h3>🎀 Voice Chat Mode</h3><p>Click the microphone to record your question</p></div>",
444
+ show_label=False,
445
+ avatar_images=(None, "πŸ€–"),
446
+ )
447
+
448
+ with gr.Row():
449
+ audio_input = gr.Audio(
450
+ sources=["microphone"],
451
+ type="numpy",
452
+ label="🎀 Record Your Question",
453
+ scale=8
454
+ )
455
+ voice_send_btn = gr.Button("Send Voice πŸŽ™οΈ", scale=2, variant="primary")
456
+
457
+ audio_output = gr.Audio(
458
+ label="πŸ”Š Voice Response",
459
+ autoplay=True,
460
+ visible=TTS_AVAILABLE
461
+ )
462
+
463
+ transcribed_text = gr.Textbox(
464
+ label="πŸ“ Transcribed Text",
465
+ interactive=False,
466
+ visible=True
467
+ )
468
+
469
+ with gr.Row():
470
+ voice_clear_btn = gr.Button("πŸ—‘οΈ Clear Voice Chat", scale=1)
471
+
472
+ if not TTS_AVAILABLE:
473
+ gr.Warning("⚠️ TTS model not available. Voice responses disabled. Text responses will still work.")
474
 
475
  with gr.Accordion("βš™οΈ Advanced Settings", open=False):
476
  temp_slider = gr.Slider(0.1, 1.0, TEMPERATURE, 0.1, label="Temperature (Lower = More Focused)")
477
  max_tok_slider = gr.Slider(50, 500, MAX_NEW_TOKENS, 50, label="Max Tokens")
478
  top_k_slider = gr.Slider(1, 100, TOP_K, 1, label="Top-K Sampling")
479
 
480
+ # =============================
481
+ # Text Chat Functions
482
+ # =============================
483
  def user_message(user_msg, history):
484
  if not user_msg.strip():
485
  return "", history
 
505
  history[-1][1] = bot_msg
506
  return history
507
 
508
+ # =============================
509
+ # Voice Chat Functions
510
+ # =============================
511
+ def process_voice_input(audio, history, temp, max_tok, topk, session_id):
512
+ """Process voice input: transcribe, get response, convert to speech"""
513
+ if audio is None:
514
+ return history, "", None
515
+
516
+ # Transcribe audio to text
517
+ transcribed = transcribe_audio(audio)
518
+
519
+ if not transcribed:
520
+ return history, "⚠️ Could not transcribe audio. Please try again.", None
521
+
522
+ # Add to history
523
+ history = history + [[transcribed, None]]
524
+
525
+ # Get bot response
526
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
527
+ TEMPERATURE, MAX_NEW_TOKENS, TOP_K = temp, int(max_tok), int(topk)
528
+
529
+ bot_msg = get_response(transcribed, history[:-1], session_id)
530
+ history[-1][1] = bot_msg
531
+
532
+ # Convert response to speech
533
+ audio_response = text_to_speech(bot_msg) if TTS_AVAILABLE else None
534
+
535
+ return history, transcribed, audio_response
536
+
537
+ # Text Chat Events
538
  msg.submit(user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
539
  bot_response, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot
540
  )
 
544
  clear_btn.click(lambda: None, None, chatbot, queue=False)
545
  retry_btn.click(retry_last, [chatbot, temp_slider, max_tok_slider, top_k_slider, session_state], chatbot)
546
 
547
+ # Voice Chat Events
548
+ voice_send_btn.click(
549
+ process_voice_input,
550
+ [audio_input, voice_chatbot, temp_slider, max_tok_slider, top_k_slider, session_state],
551
+ [voice_chatbot, transcribed_text, audio_output]
552
+ )
553
+ voice_clear_btn.click(lambda: (None, "", None), None, [voice_chatbot, transcribed_text, audio_output], queue=False)
554
+
555
  gr.HTML(f"""
556
  <footer>
557
+ <p><strong>🧠 Powered by LLaMA + Whisper + TTS</strong></p>
558
  <p>Device: {device.upper()} | Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute</p>
559
+ <p>🎀 Voice: Whisper ASR | πŸ”Š TTS: {"Enabled" if TTS_AVAILABLE else "Disabled"}</p>
560
  <p style='font-size:0.85em;margin-top:10px;'>
561
  This AI provides general health information only. Always consult healthcare professionals for medical advice.
562
  </p>
 
567
  # Launch App
568
  # =============================
569
  if __name__ == "__main__":
570
+ print("\nπŸ’‘ Launching ChatDoctor with Voice Support...")
571
  print(f"πŸ“Š Configuration:")
 
 
572
  print(f" - Device: {device.upper()}")
573
+ print(f" - Whisper Model: {WHISPER_MODEL}")
574
+ print(f" - TTS Available: {TTS_AVAILABLE}")
575
+ print(f" - Rate Limit: {MAX_REQUESTS_PER_MINUTE} requests/minute")
576
  demo.queue()
577
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)