Muhammadidrees commited on
Commit
cb75158
Β·
verified Β·
1 Parent(s): f5f43e6

Rename frontend_VOic.py to app.py

Browse files
Files changed (1) hide show
  1. frontend_VOic.py β†’ app.py +462 -458
frontend_VOic.py β†’ app.py RENAMED
@@ -1,459 +1,463 @@
1
- import os
2
- import gc
3
- import torch
4
- import gradio as gr
5
- from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
-
7
- # =============================
8
- # Configuration
9
- # =============================
10
- MODEL_PATH = r"C:\Users\JAY\Downloads\Chatdoc\ChatDoctor\pretrained"
11
- MAX_NEW_TOKENS = 200
12
- TEMPERATURE = 0.5
13
- TOP_K = 50
14
- REPETITION_PENALTY = 1.1
15
-
16
- # Detect device
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model from {MODEL_PATH} on {device}...")
19
-
20
- # =============================
21
- # Load Tokenizer and Model
22
- # =============================
23
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
- model = LlamaForCausalLM.from_pretrained(
25
- MODEL_PATH,
26
- device_map="auto",
27
- torch_dtype=torch.float16,
28
- low_cpu_mem_usage=True
29
- )
30
-
31
- generator = model.generate
32
- print("βœ… ChatDoctor model loaded successfully!\n")
33
-
34
- # =============================
35
- # Stopping Criteria
36
- # =============================
37
- class StopOnTokens(StoppingCriteria):
38
- def __init__(self, stop_ids):
39
- self.stop_ids = stop_ids
40
-
41
- def __call__(self, input_ids, scores, **kwargs):
42
- for stop_id_seq in self.stop_ids:
43
- if len(stop_id_seq) == 1:
44
- if input_ids[0][-1] == stop_id_seq[0]:
45
- return True
46
- else:
47
- if len(input_ids[0]) >= len(stop_id_seq):
48
- if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
- return True
50
- return False
51
-
52
- # =============================
53
- # Get Response Function
54
- # =============================
55
- def get_response(user_input, history_context):
56
- """Generate response from ChatDoctor model"""
57
- human_invitation = "Patient: "
58
- doctor_invitation = "ChatDoctor: "
59
-
60
- # Build conversation from history
61
- history_text = []
62
- for human, assistant in history_context:
63
- if human:
64
- history_text.append(human_invitation + human)
65
- if assistant:
66
- history_text.append(doctor_invitation + assistant)
67
-
68
- # Add current user input
69
- history_text.append(human_invitation + user_input)
70
-
71
- # Build conversation prompt
72
- prompt = "\n".join(history_text) + "\n" + doctor_invitation
73
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
74
-
75
- # Define stop words and their token IDs
76
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
77
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
78
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
79
-
80
- # Generate model response
81
- with torch.no_grad():
82
- output_ids = generator(
83
- input_ids,
84
- max_new_tokens=MAX_NEW_TOKENS,
85
- do_sample=True,
86
- temperature=TEMPERATURE,
87
- top_k=TOP_K,
88
- repetition_penalty=REPETITION_PENALTY,
89
- stopping_criteria=stopping_criteria,
90
- pad_token_id=tokenizer.eos_token_id,
91
- eos_token_id=tokenizer.eos_token_id
92
- )
93
-
94
- # Decode and clean response
95
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
96
- response = full_output[len(prompt):].strip()
97
-
98
- # Remove any "Patient:" that might have slipped through
99
- for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
100
- if stop_word in response:
101
- response = response.split(stop_word)[0].strip()
102
- break
103
-
104
- response = response.strip()
105
-
106
- # Free memory
107
- del input_ids, output_ids
108
- gc.collect()
109
- torch.cuda.empty_cache()
110
-
111
- return response
112
-
113
- # =============================
114
- # Gradio Chat Function
115
- # =============================
116
- def chat_function(message, history):
117
- """Gradio chat interface function"""
118
- if not message.strip():
119
- return ""
120
-
121
- try:
122
- response = get_response(message, history)
123
- return response
124
- except Exception as e:
125
- return f"Error: {str(e)}"
126
-
127
- # =============================
128
- # Text-to-Speech Function
129
- # =============================
130
- def text_to_speech(text):
131
- """Convert text response to speech"""
132
- try:
133
- from gtts import gTTS
134
- import tempfile
135
-
136
- if not text or text.startswith("Error:"):
137
- return None
138
-
139
- # Create speech
140
- tts = gTTS(text=text, lang='en', slow=False)
141
-
142
- # Save to temporary file
143
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
144
- tts.save(temp_file.name)
145
-
146
- return temp_file.name
147
- except Exception as e:
148
- print(f"TTS Error: {e}")
149
- return None
150
-
151
- # =============================
152
- # Custom CSS
153
- # =============================
154
- custom_css = """
155
- #header {
156
- text-align: center;
157
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
158
- color: white;
159
- padding: 20px;
160
- border-radius: 10px;
161
- margin-bottom: 20px;
162
- }
163
-
164
- #header h1 {
165
- margin: 0;
166
- font-size: 2.5em;
167
- }
168
-
169
- #header p {
170
- margin: 10px 0 0 0;
171
- font-size: 1.1em;
172
- opacity: 0.9;
173
- }
174
-
175
- .disclaimer {
176
- background-color: #fff3cd;
177
- border: 1px solid #ffc107;
178
- border-radius: 8px;
179
- padding: 15px;
180
- margin: 20px 0;
181
- color: #856404;
182
- }
183
-
184
- .disclaimer h3 {
185
- margin-top: 0;
186
- color: #856404;
187
- }
188
-
189
- .voice-section {
190
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
191
- padding: 20px;
192
- border-radius: 10px;
193
- margin: 20px 0;
194
- }
195
-
196
- footer {
197
- text-align: center;
198
- margin-top: 30px;
199
- color: #666;
200
- font-size: 0.9em;
201
- }
202
- """
203
-
204
- # =============================
205
- # Gradio Interface
206
- # =============================
207
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
208
- # Header
209
- gr.HTML("""
210
- <div id="header">
211
- <h1>🩺 ChatDoctor AI Assistant</h1>
212
- <p>Your AI-powered medical conversation partner with Voice Support</p>
213
- </div>
214
- """)
215
-
216
- # Disclaimer
217
- gr.HTML("""
218
- <div class="disclaimer">
219
- <h3>⚠️ Medical Disclaimer</h3>
220
- <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
221
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.
222
- Always seek the advice of your physician or other qualified health provider with any questions
223
- you may have regarding a medical condition. Never disregard professional medical advice or
224
- delay in seeking it because of something you have read here.</p>
225
- </div>
226
- """)
227
-
228
- with gr.Row():
229
- with gr.Column(scale=7):
230
- # Chatbot Interface
231
- chatbot = gr.Chatbot(
232
- height=500,
233
- placeholder="<div style='text-align: center; padding: 40px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Type or speak your question!</p></div>",
234
- show_label=False,
235
- avatar_images=(None, "πŸ€–"),
236
- )
237
-
238
- with gr.Row():
239
- msg = gr.Textbox(
240
- placeholder="Type your message here... (e.g., 'I have a headache')",
241
- show_label=False,
242
- scale=9,
243
- container=False
244
- )
245
- submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
246
-
247
- with gr.Row():
248
- clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
249
- retry_btn = gr.Button("πŸ”„ Retry", scale=1)
250
-
251
- with gr.Column(scale=3):
252
- # Voice Input Section
253
- gr.HTML("<div class='voice-section'><h3 style='color: white; text-align: center; margin-top: 0;'>🎀 Voice Features</h3></div>")
254
-
255
- audio_input = gr.Audio(
256
- sources=["microphone"],
257
- type="filepath",
258
- label="πŸŽ™οΈ Speak Your Question",
259
- show_download_button=False
260
- )
261
-
262
- transcribed_text = gr.Textbox(
263
- label="πŸ“ Transcribed Text",
264
- placeholder="Your speech will appear here...",
265
- interactive=False,
266
- lines=3
267
- )
268
-
269
- send_voice_btn = gr.Button("Send Voice Message πŸ”Š", variant="primary")
270
-
271
- gr.Markdown("---")
272
-
273
- # Voice Output
274
- tts_enabled = gr.Checkbox(
275
- label="πŸ”Š Enable Text-to-Speech for responses",
276
- value=True,
277
- info="Hear the doctor's response"
278
- )
279
-
280
- audio_output = gr.Audio(
281
- label="πŸ”ˆ AI Response Audio",
282
- autoplay=False,
283
- visible=True
284
- )
285
-
286
- # Examples
287
- gr.Examples(
288
- examples=[
289
- "I have a persistent headache for 3 days. What should I do?",
290
- "What are the symptoms of diabetes?",
291
- "How can I improve my sleep quality?",
292
- "I have a fever and sore throat. Should I be concerned?",
293
- "What are some natural ways to reduce stress?",
294
- ],
295
- inputs=msg,
296
- label="πŸ’‘ Example Questions"
297
- )
298
-
299
- # Settings (collapsed by default)
300
- with gr.Accordion("βš™οΈ Advanced Settings", open=False):
301
- temperature_slider = gr.Slider(
302
- minimum=0.1,
303
- maximum=1.0,
304
- value=TEMPERATURE,
305
- step=0.1,
306
- label="Temperature (Creativity)",
307
- info="Higher values make responses more creative but less focused"
308
- )
309
- max_tokens_slider = gr.Slider(
310
- minimum=50,
311
- maximum=500,
312
- value=MAX_NEW_TOKENS,
313
- step=50,
314
- label="Max Response Length",
315
- info="Maximum number of tokens in response"
316
- )
317
- top_k_slider = gr.Slider(
318
- minimum=1,
319
- maximum=100,
320
- value=TOP_K,
321
- step=1,
322
- label="Top K",
323
- info="Limits vocabulary selection"
324
- )
325
-
326
- # Footer
327
- gr.HTML("""
328
- <footer>
329
- <p>Powered by ChatDoctor Model | Built with Gradio | Voice-Enabled 🎀</p>
330
- <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
331
- </footer>
332
- """)
333
-
334
- # =============================
335
- # Event Handlers
336
- # =============================
337
-
338
- def user_message(user_msg, history):
339
- return "", history + [[user_msg, None]], None
340
-
341
- def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val):
342
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
343
- TEMPERATURE = temp
344
- MAX_NEW_TOKENS = int(max_tok)
345
- TOP_K = int(top_k_val)
346
-
347
- user_msg = history[-1][0]
348
- bot_msg = chat_function(user_msg, history[:-1])
349
- history[-1][1] = bot_msg
350
-
351
- # Generate audio if TTS is enabled
352
- audio_file = None
353
- if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
354
- audio_file = text_to_speech(bot_msg)
355
-
356
- return history, audio_file
357
-
358
- def transcribe_audio(audio_file):
359
- """Transcribe audio to text using Whisper"""
360
- if audio_file is None:
361
- return ""
362
-
363
- try:
364
- import whisper
365
- model = whisper.load_model("base")
366
- result = model.transcribe(audio_file)
367
- return result["text"]
368
- except ImportError:
369
- return "Error: Please install whisper: pip install openai-whisper"
370
- except Exception as e:
371
- return f"Transcription error: {str(e)}"
372
-
373
- def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val):
374
- """Process voice input: transcribe -> send -> get response"""
375
- if audio_file is None:
376
- return history, "", None, None
377
-
378
- # Transcribe
379
- transcribed = transcribe_audio(audio_file)
380
-
381
- if transcribed.startswith("Error:"):
382
- return history, transcribed, None, None
383
-
384
- # Add to chat
385
- history = history + [[transcribed, None]]
386
-
387
- # Get response
388
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
389
- TEMPERATURE = temp
390
- MAX_NEW_TOKENS = int(max_tok)
391
- TOP_K = int(top_k_val)
392
-
393
- bot_msg = chat_function(transcribed, history[:-1])
394
- history[-1][1] = bot_msg
395
-
396
- # Generate audio if TTS is enabled
397
- audio_file = None
398
- if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
399
- audio_file = text_to_speech(bot_msg)
400
-
401
- return history, transcribed, None, audio_file
402
-
403
- # Text input events
404
- msg.submit(
405
- user_message,
406
- [msg, chatbot],
407
- [msg, chatbot, audio_output],
408
- queue=False
409
- ).then(
410
- bot_response,
411
- [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
412
- [chatbot, audio_output]
413
- )
414
-
415
- submit_btn.click(
416
- user_message,
417
- [msg, chatbot],
418
- [msg, chatbot, audio_output],
419
- queue=False
420
- ).then(
421
- bot_response,
422
- [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
423
- [chatbot, audio_output]
424
- )
425
-
426
- # Voice input events
427
- audio_input.change(
428
- transcribe_audio,
429
- [audio_input],
430
- [transcribed_text]
431
- )
432
-
433
- send_voice_btn.click(
434
- process_voice_input,
435
- [audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
436
- [chatbot, transcribed_text, audio_input, audio_output]
437
- )
438
-
439
- # Clear and retry
440
- clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False)
441
-
442
- retry_btn.click(lambda: None, None, chatbot, queue=False)
443
-
444
- # =============================
445
- # Launch Interface
446
- # =============================
447
- if __name__ == "__main__":
448
- print("\nπŸš€ Launching ChatDoctor Gradio Interface with Voice Support...")
449
- print("\nπŸ“¦ Required packages:")
450
- print(" pip install gradio gTTS openai-whisper")
451
- print("\nNote: Whisper will download models on first use (~100MB for base model)\n")
452
-
453
- demo.queue()
454
- demo.launch(
455
- server_name="0.0.0.0",
456
- server_port=7860,
457
- share=False,
458
- show_error=True
 
 
 
 
459
  )
 
1
+ import os
2
+ import gc
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
+ from huggingface_hub import login
7
+ import os
8
+
9
+ # Login using the token stored in repository secrets
10
+ login(token=os.getenv("HUGGINGFACE_TOKEN"))
11
+ # =============================
12
+ # Configuration
13
+ # =============================
14
+ MODEL_PATH = r"zl111/ChatDoctor\result"
15
+ MAX_NEW_TOKENS = 200
16
+ TEMPERATURE = 0.5
17
+ TOP_K = 50
18
+ REPETITION_PENALTY = 1.1
19
+
20
+ # Detect device
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f"Loading model from {MODEL_PATH} on {device}...")
23
+
24
+ # =============================
25
+ # Load Tokenizer and Model
26
+ # =============================
27
+ tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
28
+ model = LlamaForCausalLM.from_pretrained(
29
+ MODEL_PATH,
30
+ device_map="auto",
31
+ torch_dtype=torch.float16,
32
+ low_cpu_mem_usage=True
33
+ )
34
+
35
+ generator = model.generate
36
+ print("βœ… ChatDoctor model loaded successfully!\n")
37
+
38
+ # =============================
39
+ # Stopping Criteria
40
+ # =============================
41
+ class StopOnTokens(StoppingCriteria):
42
+ def __init__(self, stop_ids):
43
+ self.stop_ids = stop_ids
44
+
45
+ def __call__(self, input_ids, scores, **kwargs):
46
+ for stop_id_seq in self.stop_ids:
47
+ if len(stop_id_seq) == 1:
48
+ if input_ids[0][-1] == stop_id_seq[0]:
49
+ return True
50
+ else:
51
+ if len(input_ids[0]) >= len(stop_id_seq):
52
+ if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
53
+ return True
54
+ return False
55
+
56
+ # =============================
57
+ # Get Response Function
58
+ # =============================
59
+ def get_response(user_input, history_context):
60
+ """Generate response from ChatDoctor model"""
61
+ human_invitation = "Patient: "
62
+ doctor_invitation = "ChatDoctor: "
63
+
64
+ # Build conversation from history
65
+ history_text = []
66
+ for human, assistant in history_context:
67
+ if human:
68
+ history_text.append(human_invitation + human)
69
+ if assistant:
70
+ history_text.append(doctor_invitation + assistant)
71
+
72
+ # Add current user input
73
+ history_text.append(human_invitation + user_input)
74
+
75
+ # Build conversation prompt
76
+ prompt = "\n".join(history_text) + "\n" + doctor_invitation
77
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
78
+
79
+ # Define stop words and their token IDs
80
+ stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
81
+ stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
82
+ stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
83
+
84
+ # Generate model response
85
+ with torch.no_grad():
86
+ output_ids = generator(
87
+ input_ids,
88
+ max_new_tokens=MAX_NEW_TOKENS,
89
+ do_sample=True,
90
+ temperature=TEMPERATURE,
91
+ top_k=TOP_K,
92
+ repetition_penalty=REPETITION_PENALTY,
93
+ stopping_criteria=stopping_criteria,
94
+ pad_token_id=tokenizer.eos_token_id,
95
+ eos_token_id=tokenizer.eos_token_id
96
+ )
97
+
98
+ # Decode and clean response
99
+ full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
100
+ response = full_output[len(prompt):].strip()
101
+
102
+ # Remove any "Patient:" that might have slipped through
103
+ for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
104
+ if stop_word in response:
105
+ response = response.split(stop_word)[0].strip()
106
+ break
107
+
108
+ response = response.strip()
109
+
110
+ # Free memory
111
+ del input_ids, output_ids
112
+ gc.collect()
113
+ torch.cuda.empty_cache()
114
+
115
+ return response
116
+
117
+ # =============================
118
+ # Gradio Chat Function
119
+ # =============================
120
+ def chat_function(message, history):
121
+ """Gradio chat interface function"""
122
+ if not message.strip():
123
+ return ""
124
+
125
+ try:
126
+ response = get_response(message, history)
127
+ return response
128
+ except Exception as e:
129
+ return f"Error: {str(e)}"
130
+
131
+ # =============================
132
+ # Text-to-Speech Function
133
+ # =============================
134
+ def text_to_speech(text):
135
+ """Convert text response to speech"""
136
+ try:
137
+ from gtts import gTTS
138
+ import tempfile
139
+
140
+ if not text or text.startswith("Error:"):
141
+ return None
142
+
143
+ # Create speech
144
+ tts = gTTS(text=text, lang='en', slow=False)
145
+
146
+ # Save to temporary file
147
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
148
+ tts.save(temp_file.name)
149
+
150
+ return temp_file.name
151
+ except Exception as e:
152
+ print(f"TTS Error: {e}")
153
+ return None
154
+
155
+ # =============================
156
+ # Custom CSS
157
+ # =============================
158
+ custom_css = """
159
+ #header {
160
+ text-align: center;
161
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
162
+ color: white;
163
+ padding: 20px;
164
+ border-radius: 10px;
165
+ margin-bottom: 20px;
166
+ }
167
+
168
+ #header h1 {
169
+ margin: 0;
170
+ font-size: 2.5em;
171
+ }
172
+
173
+ #header p {
174
+ margin: 10px 0 0 0;
175
+ font-size: 1.1em;
176
+ opacity: 0.9;
177
+ }
178
+
179
+ .disclaimer {
180
+ background-color: #fff3cd;
181
+ border: 1px solid #ffc107;
182
+ border-radius: 8px;
183
+ padding: 15px;
184
+ margin: 20px 0;
185
+ color: #856404;
186
+ }
187
+
188
+ .disclaimer h3 {
189
+ margin-top: 0;
190
+ color: #856404;
191
+ }
192
+
193
+ .voice-section {
194
+ background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
195
+ padding: 20px;
196
+ border-radius: 10px;
197
+ margin: 20px 0;
198
+ }
199
+
200
+ footer {
201
+ text-align: center;
202
+ margin-top: 30px;
203
+ color: #666;
204
+ font-size: 0.9em;
205
+ }
206
+ """
207
+
208
+ # =============================
209
+ # Gradio Interface
210
+ # =============================
211
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
212
+ # Header
213
+ gr.HTML("""
214
+ <div id="header">
215
+ <h1>🩺 ChatDoctor AI Assistant</h1>
216
+ <p>Your AI-powered medical conversation partner with Voice Support</p>
217
+ </div>
218
+ """)
219
+
220
+ # Disclaimer
221
+ gr.HTML("""
222
+ <div class="disclaimer">
223
+ <h3>⚠️ Medical Disclaimer</h3>
224
+ <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
225
+ It is NOT a substitute for professional medical advice, diagnosis, or treatment.
226
+ Always seek the advice of your physician or other qualified health provider with any questions
227
+ you may have regarding a medical condition. Never disregard professional medical advice or
228
+ delay in seeking it because of something you have read here.</p>
229
+ </div>
230
+ """)
231
+
232
+ with gr.Row():
233
+ with gr.Column(scale=7):
234
+ # Chatbot Interface
235
+ chatbot = gr.Chatbot(
236
+ height=500,
237
+ placeholder="<div style='text-align: center; padding: 40px;'><h3>πŸ‘‹ Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Type or speak your question!</p></div>",
238
+ show_label=False,
239
+ avatar_images=(None, "πŸ€–"),
240
+ )
241
+
242
+ with gr.Row():
243
+ msg = gr.Textbox(
244
+ placeholder="Type your message here... (e.g., 'I have a headache')",
245
+ show_label=False,
246
+ scale=9,
247
+ container=False
248
+ )
249
+ submit_btn = gr.Button("Send πŸ“€", scale=1, variant="primary")
250
+
251
+ with gr.Row():
252
+ clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", scale=1)
253
+ retry_btn = gr.Button("πŸ”„ Retry", scale=1)
254
+
255
+ with gr.Column(scale=3):
256
+ # Voice Input Section
257
+ gr.HTML("<div class='voice-section'><h3 style='color: white; text-align: center; margin-top: 0;'>🎀 Voice Features</h3></div>")
258
+
259
+ audio_input = gr.Audio(
260
+ sources=["microphone"],
261
+ type="filepath",
262
+ label="πŸŽ™οΈ Speak Your Question",
263
+ show_download_button=False
264
+ )
265
+
266
+ transcribed_text = gr.Textbox(
267
+ label="πŸ“ Transcribed Text",
268
+ placeholder="Your speech will appear here...",
269
+ interactive=False,
270
+ lines=3
271
+ )
272
+
273
+ send_voice_btn = gr.Button("Send Voice Message πŸ”Š", variant="primary")
274
+
275
+ gr.Markdown("---")
276
+
277
+ # Voice Output
278
+ tts_enabled = gr.Checkbox(
279
+ label="πŸ”Š Enable Text-to-Speech for responses",
280
+ value=True,
281
+ info="Hear the doctor's response"
282
+ )
283
+
284
+ audio_output = gr.Audio(
285
+ label="πŸ”ˆ AI Response Audio",
286
+ autoplay=False,
287
+ visible=True
288
+ )
289
+
290
+ # Examples
291
+ gr.Examples(
292
+ examples=[
293
+ "I have a persistent headache for 3 days. What should I do?",
294
+ "What are the symptoms of diabetes?",
295
+ "How can I improve my sleep quality?",
296
+ "I have a fever and sore throat. Should I be concerned?",
297
+ "What are some natural ways to reduce stress?",
298
+ ],
299
+ inputs=msg,
300
+ label="πŸ’‘ Example Questions"
301
+ )
302
+
303
+ # Settings (collapsed by default)
304
+ with gr.Accordion("βš™οΈ Advanced Settings", open=False):
305
+ temperature_slider = gr.Slider(
306
+ minimum=0.1,
307
+ maximum=1.0,
308
+ value=TEMPERATURE,
309
+ step=0.1,
310
+ label="Temperature (Creativity)",
311
+ info="Higher values make responses more creative but less focused"
312
+ )
313
+ max_tokens_slider = gr.Slider(
314
+ minimum=50,
315
+ maximum=500,
316
+ value=MAX_NEW_TOKENS,
317
+ step=50,
318
+ label="Max Response Length",
319
+ info="Maximum number of tokens in response"
320
+ )
321
+ top_k_slider = gr.Slider(
322
+ minimum=1,
323
+ maximum=100,
324
+ value=TOP_K,
325
+ step=1,
326
+ label="Top K",
327
+ info="Limits vocabulary selection"
328
+ )
329
+
330
+ # Footer
331
+ gr.HTML("""
332
+ <footer>
333
+ <p>Powered by ChatDoctor Model | Built with Gradio | Voice-Enabled 🎀</p>
334
+ <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
335
+ </footer>
336
+ """)
337
+
338
+ # =============================
339
+ # Event Handlers
340
+ # =============================
341
+
342
+ def user_message(user_msg, history):
343
+ return "", history + [[user_msg, None]], None
344
+
345
+ def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val):
346
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
347
+ TEMPERATURE = temp
348
+ MAX_NEW_TOKENS = int(max_tok)
349
+ TOP_K = int(top_k_val)
350
+
351
+ user_msg = history[-1][0]
352
+ bot_msg = chat_function(user_msg, history[:-1])
353
+ history[-1][1] = bot_msg
354
+
355
+ # Generate audio if TTS is enabled
356
+ audio_file = None
357
+ if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
358
+ audio_file = text_to_speech(bot_msg)
359
+
360
+ return history, audio_file
361
+
362
+ def transcribe_audio(audio_file):
363
+ """Transcribe audio to text using Whisper"""
364
+ if audio_file is None:
365
+ return ""
366
+
367
+ try:
368
+ import whisper
369
+ model = whisper.load_model("base")
370
+ result = model.transcribe(audio_file)
371
+ return result["text"]
372
+ except ImportError:
373
+ return "Error: Please install whisper: pip install openai-whisper"
374
+ except Exception as e:
375
+ return f"Transcription error: {str(e)}"
376
+
377
+ def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val):
378
+ """Process voice input: transcribe -> send -> get response"""
379
+ if audio_file is None:
380
+ return history, "", None, None
381
+
382
+ # Transcribe
383
+ transcribed = transcribe_audio(audio_file)
384
+
385
+ if transcribed.startswith("Error:"):
386
+ return history, transcribed, None, None
387
+
388
+ # Add to chat
389
+ history = history + [[transcribed, None]]
390
+
391
+ # Get response
392
+ global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
393
+ TEMPERATURE = temp
394
+ MAX_NEW_TOKENS = int(max_tok)
395
+ TOP_K = int(top_k_val)
396
+
397
+ bot_msg = chat_function(transcribed, history[:-1])
398
+ history[-1][1] = bot_msg
399
+
400
+ # Generate audio if TTS is enabled
401
+ audio_file = None
402
+ if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
403
+ audio_file = text_to_speech(bot_msg)
404
+
405
+ return history, transcribed, None, audio_file
406
+
407
+ # Text input events
408
+ msg.submit(
409
+ user_message,
410
+ [msg, chatbot],
411
+ [msg, chatbot, audio_output],
412
+ queue=False
413
+ ).then(
414
+ bot_response,
415
+ [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
416
+ [chatbot, audio_output]
417
+ )
418
+
419
+ submit_btn.click(
420
+ user_message,
421
+ [msg, chatbot],
422
+ [msg, chatbot, audio_output],
423
+ queue=False
424
+ ).then(
425
+ bot_response,
426
+ [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
427
+ [chatbot, audio_output]
428
+ )
429
+
430
+ # Voice input events
431
+ audio_input.change(
432
+ transcribe_audio,
433
+ [audio_input],
434
+ [transcribed_text]
435
+ )
436
+
437
+ send_voice_btn.click(
438
+ process_voice_input,
439
+ [audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
440
+ [chatbot, transcribed_text, audio_input, audio_output]
441
+ )
442
+
443
+ # Clear and retry
444
+ clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False)
445
+
446
+ retry_btn.click(lambda: None, None, chatbot, queue=False)
447
+
448
+ # =============================
449
+ # Launch Interface
450
+ # =============================
451
+ if __name__ == "__main__":
452
+ print("\nπŸš€ Launching ChatDoctor Gradio Interface with Voice Support...")
453
+ print("\nπŸ“¦ Required packages:")
454
+ print(" pip install gradio gTTS openai-whisper")
455
+ print("\nNote: Whisper will download models on first use (~100MB for base model)\n")
456
+
457
+ demo.queue()
458
+ demo.launch(
459
+ server_name="0.0.0.0",
460
+ server_port=7860,
461
+ share=False,
462
+ show_error=True
463
  )