Muhammadidrees commited on
Commit
e666052
·
verified ·
1 Parent(s): e579965

Delete withvoiceandfrontend .py

Browse files
Files changed (1) hide show
  1. withvoiceandfrontend .py +0 -459
withvoiceandfrontend .py DELETED
@@ -1,459 +0,0 @@
1
- import os
2
- import gc
3
- import torch
4
- import gradio as gr
5
- from transformers import LlamaTokenizer, LlamaForCausalLM, StoppingCriteria, StoppingCriteriaList
6
-
7
- # =============================
8
- # Configuration
9
- # =============================
10
- MODEL_PATH = r"Muhammadidrees/JayConverstionalModel"
11
- MAX_NEW_TOKENS = 200
12
- TEMPERATURE = 0.5
13
- TOP_K = 50
14
- REPETITION_PENALTY = 1.1
15
-
16
- # Detect device
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Loading model from {MODEL_PATH} on {device}...")
19
-
20
- # =============================
21
- # Load Tokenizer and Model
22
- # =============================
23
- tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
24
- model = LlamaForCausalLM.from_pretrained(
25
- MODEL_PATH,
26
- device_map="auto",
27
- torch_dtype=torch.float16,
28
- low_cpu_mem_usage=True
29
- )
30
-
31
- generator = model.generate
32
- print("✅ ChatDoctor model loaded successfully!\n")
33
-
34
- # =============================
35
- # Stopping Criteria
36
- # =============================
37
- class StopOnTokens(StoppingCriteria):
38
- def __init__(self, stop_ids):
39
- self.stop_ids = stop_ids
40
-
41
- def __call__(self, input_ids, scores, **kwargs):
42
- for stop_id_seq in self.stop_ids:
43
- if len(stop_id_seq) == 1:
44
- if input_ids[0][-1] == stop_id_seq[0]:
45
- return True
46
- else:
47
- if len(input_ids[0]) >= len(stop_id_seq):
48
- if input_ids[0][-len(stop_id_seq):].tolist() == stop_id_seq:
49
- return True
50
- return False
51
-
52
- # =============================
53
- # Get Response Function
54
- # =============================
55
- def get_response(user_input, history_context):
56
- """Generate response from ChatDoctor model"""
57
- human_invitation = "Patient: "
58
- doctor_invitation = "ChatDoctor: "
59
-
60
- # Build conversation from history
61
- history_text = []
62
- for human, assistant in history_context:
63
- if human:
64
- history_text.append(human_invitation + human)
65
- if assistant:
66
- history_text.append(doctor_invitation + assistant)
67
-
68
- # Add current user input
69
- history_text.append(human_invitation + user_input)
70
-
71
- # Build conversation prompt
72
- prompt = "\n".join(history_text) + "\n" + doctor_invitation
73
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
74
-
75
- # Define stop words and their token IDs
76
- stop_words = ["Patient:", "\nPatient:", "Patient :", "\n\nPatient"]
77
- stop_ids = [tokenizer.encode(word, add_special_tokens=False) for word in stop_words]
78
- stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_ids)])
79
-
80
- # Generate model response
81
- with torch.no_grad():
82
- output_ids = generator(
83
- input_ids,
84
- max_new_tokens=MAX_NEW_TOKENS,
85
- do_sample=True,
86
- temperature=TEMPERATURE,
87
- top_k=TOP_K,
88
- repetition_penalty=REPETITION_PENALTY,
89
- stopping_criteria=stopping_criteria,
90
- pad_token_id=tokenizer.eos_token_id,
91
- eos_token_id=tokenizer.eos_token_id
92
- )
93
-
94
- # Decode and clean response
95
- full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
96
- response = full_output[len(prompt):].strip()
97
-
98
- # Remove any "Patient:" that might have slipped through
99
- for stop_word in ["Patient:", "Patient :", "\nPatient:", "\nPatient", "Patient"]:
100
- if stop_word in response:
101
- response = response.split(stop_word)[0].strip()
102
- break
103
-
104
- response = response.strip()
105
-
106
- # Free memory
107
- del input_ids, output_ids
108
- gc.collect()
109
- torch.cuda.empty_cache()
110
-
111
- return response
112
-
113
- # =============================
114
- # Gradio Chat Function
115
- # =============================
116
- def chat_function(message, history):
117
- """Gradio chat interface function"""
118
- if not message.strip():
119
- return ""
120
-
121
- try:
122
- response = get_response(message, history)
123
- return response
124
- except Exception as e:
125
- return f"Error: {str(e)}"
126
-
127
- # =============================
128
- # Text-to-Speech Function
129
- # =============================
130
- def text_to_speech(text):
131
- """Convert text response to speech"""
132
- try:
133
- from gtts import gTTS
134
- import tempfile
135
-
136
- if not text or text.startswith("Error:"):
137
- return None
138
-
139
- # Create speech
140
- tts = gTTS(text=text, lang='en', slow=False)
141
-
142
- # Save to temporary file
143
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp3')
144
- tts.save(temp_file.name)
145
-
146
- return temp_file.name
147
- except Exception as e:
148
- print(f"TTS Error: {e}")
149
- return None
150
-
151
- # =============================
152
- # Custom CSS
153
- # =============================
154
- custom_css = """
155
- #header {
156
- text-align: center;
157
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
158
- color: white;
159
- padding: 20px;
160
- border-radius: 10px;
161
- margin-bottom: 20px;
162
- }
163
-
164
- #header h1 {
165
- margin: 0;
166
- font-size: 2.5em;
167
- }
168
-
169
- #header p {
170
- margin: 10px 0 0 0;
171
- font-size: 1.1em;
172
- opacity: 0.9;
173
- }
174
-
175
- .disclaimer {
176
- background-color: #fff3cd;
177
- border: 1px solid #ffc107;
178
- border-radius: 8px;
179
- padding: 15px;
180
- margin: 20px 0;
181
- color: #856404;
182
- }
183
-
184
- .disclaimer h3 {
185
- margin-top: 0;
186
- color: #856404;
187
- }
188
-
189
- .voice-section {
190
- background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
191
- padding: 20px;
192
- border-radius: 10px;
193
- margin: 20px 0;
194
- }
195
-
196
- footer {
197
- text-align: center;
198
- margin-top: 30px;
199
- color: #666;
200
- font-size: 0.9em;
201
- }
202
- """
203
-
204
- # =============================
205
- # Gradio Interface
206
- # =============================
207
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
208
- # Header
209
- gr.HTML("""
210
- <div id="header">
211
- <h1>🩺 ChatDoctor AI Assistant</h1>
212
- <p>Your AI-powered medical conversation partner with Voice Support</p>
213
- </div>
214
- """)
215
-
216
- # Disclaimer
217
- gr.HTML("""
218
- <div class="disclaimer">
219
- <h3>⚠️ Medical Disclaimer</h3>
220
- <p><strong>Important:</strong> This AI assistant is for informational and educational purposes only.
221
- It is NOT a substitute for professional medical advice, diagnosis, or treatment.
222
- Always seek the advice of your physician or other qualified health provider with any questions
223
- you may have regarding a medical condition. Never disregard professional medical advice or
224
- delay in seeking it because of something you have read here.</p>
225
- </div>
226
- """)
227
-
228
- with gr.Row():
229
- with gr.Column(scale=7):
230
- # Chatbot Interface
231
- chatbot = gr.Chatbot(
232
- height=500,
233
- placeholder="<div style='text-align: center; padding: 40px;'><h3>👋 Welcome to ChatDoctor!</h3><p>I'm here to discuss your health concerns. Type or speak your question!</p></div>",
234
- show_label=False,
235
- avatar_images=(None, "🤖"),
236
- )
237
-
238
- with gr.Row():
239
- msg = gr.Textbox(
240
- placeholder="Type your message here... (e.g., 'I have a headache')",
241
- show_label=False,
242
- scale=9,
243
- container=False
244
- )
245
- submit_btn = gr.Button("Send 📤", scale=1, variant="primary")
246
-
247
- with gr.Row():
248
- clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
249
- retry_btn = gr.Button("🔄 Retry", scale=1)
250
-
251
- with gr.Column(scale=3):
252
- # Voice Input Section
253
- gr.HTML("<div class='voice-section'><h3 style='color: white; text-align: center; margin-top: 0;'>🎤 Voice Features</h3></div>")
254
-
255
- audio_input = gr.Audio(
256
- sources=["microphone"],
257
- type="filepath",
258
- label="🎙️ Speak Your Question",
259
- show_download_button=False
260
- )
261
-
262
- transcribed_text = gr.Textbox(
263
- label="📝 Transcribed Text",
264
- placeholder="Your speech will appear here...",
265
- interactive=False,
266
- lines=3
267
- )
268
-
269
- send_voice_btn = gr.Button("Send Voice Message 🔊", variant="primary")
270
-
271
- gr.Markdown("---")
272
-
273
- # Voice Output
274
- tts_enabled = gr.Checkbox(
275
- label="🔊 Enable Text-to-Speech for responses",
276
- value=True,
277
- info="Hear the doctor's response"
278
- )
279
-
280
- audio_output = gr.Audio(
281
- label="🔈 AI Response Audio",
282
- autoplay=False,
283
- visible=True
284
- )
285
-
286
- # Examples
287
- gr.Examples(
288
- examples=[
289
- "I have a persistent headache for 3 days. What should I do?",
290
- "What are the symptoms of diabetes?",
291
- "How can I improve my sleep quality?",
292
- "I have a fever and sore throat. Should I be concerned?",
293
- "What are some natural ways to reduce stress?",
294
- ],
295
- inputs=msg,
296
- label="💡 Example Questions"
297
- )
298
-
299
- # Settings (collapsed by default)
300
- with gr.Accordion("⚙️ Advanced Settings", open=False):
301
- temperature_slider = gr.Slider(
302
- minimum=0.1,
303
- maximum=1.0,
304
- value=TEMPERATURE,
305
- step=0.1,
306
- label="Temperature (Creativity)",
307
- info="Higher values make responses more creative but less focused"
308
- )
309
- max_tokens_slider = gr.Slider(
310
- minimum=50,
311
- maximum=500,
312
- value=MAX_NEW_TOKENS,
313
- step=50,
314
- label="Max Response Length",
315
- info="Maximum number of tokens in response"
316
- )
317
- top_k_slider = gr.Slider(
318
- minimum=1,
319
- maximum=100,
320
- value=TOP_K,
321
- step=1,
322
- label="Top K",
323
- info="Limits vocabulary selection"
324
- )
325
-
326
- # Footer
327
- gr.HTML("""
328
- <footer>
329
- <p>Powered by ChatDoctor Model | Built with Gradio | Voice-Enabled 🎤</p>
330
- <p>Device: """ + device.upper() + """ | Model: LLaMA-based Medical AI</p>
331
- </footer>
332
- """)
333
-
334
- # =============================
335
- # Event Handlers
336
- # =============================
337
-
338
- def user_message(user_msg, history):
339
- return "", history + [[user_msg, None]], None
340
-
341
- def bot_response(history, temp, max_tok, top_k_val, tts_enabled_val):
342
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
343
- TEMPERATURE = temp
344
- MAX_NEW_TOKENS = int(max_tok)
345
- TOP_K = int(top_k_val)
346
-
347
- user_msg = history[-1][0]
348
- bot_msg = chat_function(user_msg, history[:-1])
349
- history[-1][1] = bot_msg
350
-
351
- # Generate audio if TTS is enabled
352
- audio_file = None
353
- if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
354
- audio_file = text_to_speech(bot_msg)
355
-
356
- return history, audio_file
357
-
358
- def transcribe_audio(audio_file):
359
- """Transcribe audio to text using Whisper"""
360
- if audio_file is None:
361
- return ""
362
-
363
- try:
364
- import whisper
365
- model = whisper.load_model("base")
366
- result = model.transcribe(audio_file)
367
- return result["text"]
368
- except ImportError:
369
- return "Error: Please install whisper: pip install openai-whisper"
370
- except Exception as e:
371
- return f"Transcription error: {str(e)}"
372
-
373
- def process_voice_input(audio_file, history, temp, max_tok, top_k_val, tts_enabled_val):
374
- """Process voice input: transcribe -> send -> get response"""
375
- if audio_file is None:
376
- return history, "", None, None
377
-
378
- # Transcribe
379
- transcribed = transcribe_audio(audio_file)
380
-
381
- if transcribed.startswith("Error:"):
382
- return history, transcribed, None, None
383
-
384
- # Add to chat
385
- history = history + [[transcribed, None]]
386
-
387
- # Get response
388
- global TEMPERATURE, MAX_NEW_TOKENS, TOP_K
389
- TEMPERATURE = temp
390
- MAX_NEW_TOKENS = int(max_tok)
391
- TOP_K = int(top_k_val)
392
-
393
- bot_msg = chat_function(transcribed, history[:-1])
394
- history[-1][1] = bot_msg
395
-
396
- # Generate audio if TTS is enabled
397
- audio_file = None
398
- if tts_enabled_val and bot_msg and not bot_msg.startswith("Error:"):
399
- audio_file = text_to_speech(bot_msg)
400
-
401
- return history, transcribed, None, audio_file
402
-
403
- # Text input events
404
- msg.submit(
405
- user_message,
406
- [msg, chatbot],
407
- [msg, chatbot, audio_output],
408
- queue=False
409
- ).then(
410
- bot_response,
411
- [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
412
- [chatbot, audio_output]
413
- )
414
-
415
- submit_btn.click(
416
- user_message,
417
- [msg, chatbot],
418
- [msg, chatbot, audio_output],
419
- queue=False
420
- ).then(
421
- bot_response,
422
- [chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
423
- [chatbot, audio_output]
424
- )
425
-
426
- # Voice input events
427
- audio_input.change(
428
- transcribe_audio,
429
- [audio_input],
430
- [transcribed_text]
431
- )
432
-
433
- send_voice_btn.click(
434
- process_voice_input,
435
- [audio_input, chatbot, temperature_slider, max_tokens_slider, top_k_slider, tts_enabled],
436
- [chatbot, transcribed_text, audio_input, audio_output]
437
- )
438
-
439
- # Clear and retry
440
- clear_btn.click(lambda: (None, None, None), None, [chatbot, audio_output, transcribed_text], queue=False)
441
-
442
- retry_btn.click(lambda: None, None, chatbot, queue=False)
443
-
444
- # =============================
445
- # Launch Interface
446
- # =============================
447
- if __name__ == "__main__":
448
- print("\n🚀 Launching ChatDoctor Gradio Interface with Voice Support...")
449
- print("\n📦 Required packages:")
450
- print(" pip install gradio gTTS openai-whisper")
451
- print("\nNote: Whisper will download models on first use (~100MB for base model)\n")
452
-
453
- demo.queue()
454
- demo.launch(
455
- server_name="0.0.0.0",
456
- server_port=7860,
457
- share=False,
458
- show_error=True
459
- )