Muhammadidrees commited on
Commit
6068b3b
·
verified ·
1 Parent(s): 0c83b07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -160
app.py CHANGED
@@ -1,17 +1,16 @@
 
1
  import gradio as gr
2
- import torch
3
  from transformers import AutoProcessor, AutoModelForVision2Seq
4
- from PaitentVoiceToText import record_and_transcribe
5
- from DocVoice import text_to_speech # Your TTS function
6
 
7
  # -------------------
8
- # 1️⃣ Load Model & Processor
9
  # -------------------
10
  def load_model():
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  dtype = torch.float16 if device == "cuda" else torch.float32
13
 
14
- # Load directly from Hugging Face
15
  processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
16
  model = AutoModelForVision2Seq.from_pretrained(
17
  "Muhammadidrees/RaiyaChatDoc",
@@ -24,7 +23,7 @@ def load_model():
24
  processor, model, device = load_model()
25
 
26
  # -------------------
27
- # 2️⃣ Chat Logic Functions
28
  # -------------------
29
  def process_message(message, history, question_count):
30
  if not message.strip():
@@ -32,40 +31,32 @@ def process_message(message, history, question_count):
32
 
33
  history.append([message, None])
34
  question_count += 1
35
-
36
- should_analyze = (
37
- question_count >= 6 or
38
- any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
39
  )
40
 
41
- if should_analyze:
42
- system_prompt = (
43
- "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis "
44
- "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, "
45
- "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. "
46
- "Format as numbered list with diagnosis name and short explanation."
47
- )
48
- else:
49
- system_prompt = (
50
- "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question "
51
- "to gather important diagnostic information. Keep it brief - just ask the question without explanations. "
52
- "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details."
53
- )
54
-
55
  dialogue = []
56
  for user_msg, bot_msg in history[:-1]:
57
- if user_msg:
58
- dialogue.append(f"Patient: {user_msg}")
59
- if bot_msg:
60
- dialogue.append(f"Doctor: {bot_msg}")
61
  dialogue.append(f"Patient: {message}")
62
-
63
- conversation = "\n".join(dialogue)
64
- prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
65
 
 
66
  inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
67
- max_tokens = 1000 if should_analyze else 25
68
-
69
  with torch.inference_mode():
70
  outputs = model.generate(
71
  **inputs,
@@ -74,39 +65,22 @@ def process_message(message, history, question_count):
74
  temperature=0.6,
75
  top_p=0.9,
76
  repetition_penalty=1.1,
77
- pad_token_id=processor.tokenizer.eos_token_id,
78
  )
79
 
 
80
  input_length = inputs["input_ids"].shape[1]
81
- generated_tokens = outputs[:, input_length:]
82
- response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
83
-
84
  if response.lower().startswith("doctor:"):
85
  response = response[7:].strip()
86
 
 
87
  if not should_analyze:
88
- sentences = response.split('?')
89
- if len(sentences) > 1:
90
- response = sentences[0].strip() + '?'
91
- cleanup_starts = [
92
- "I need to ask",
93
- "Let me ask",
94
- "I would like to know",
95
- "Can you tell me",
96
- "It would help if",
97
- ]
98
- for phrase in cleanup_starts:
99
- if response.startswith(phrase):
100
- parts = response.split(',', 1)
101
- if len(parts) > 1:
102
- response = parts[1].strip()
103
- if not response.endswith('?'):
104
- response += '?'
105
 
106
  history[-1][1] = response
107
- if should_analyze:
108
- question_count = 0
109
-
110
  return history, history, question_count
111
 
112
  def force_analysis(history, question_count):
@@ -116,124 +90,35 @@ def clear_chat():
116
  return [], [], 0
117
 
118
  # -------------------
119
- # 3️⃣ TTS Helper
120
  # -------------------
121
- def play_assistant_audio(response_text):
122
- if response_text:
123
- text_to_speech(response_text)
124
- return None
125
-
126
- # -------------------
127
- # 4️⃣ Gradio Interface
128
- # -------------------
129
- with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
130
  question_count_state = gr.State(0)
131
- assistant_responses_state = gr.State([])
132
-
133
- gr.Markdown(
134
- """
135
- # 🩺 Chat with ChatDOC
136
- Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
137
- """
138
- )
139
 
140
- chatbot = gr.Chatbot(
141
- value=[],
142
- height=400,
143
- show_label=False,
144
- avatar_images=(
145
- r"C:\Users\JAY\Downloads\model\user_msg.png",
146
- r"C:\Users\JAY\Downloads\model\bot_msg.jpg"
147
- ),
148
- bubble_full_width=False
149
- )
150
 
151
  with gr.Row():
152
- msg = gr.Textbox(
153
- placeholder="Describe your symptoms...",
154
- scale=4,
155
- container=False,
156
- show_label=False
157
- )
158
  send_btn = gr.Button("Send", variant="primary", scale=1)
159
- mic_btn = gr.Button("🎤 Speak", variant="secondary", scale=1)
160
 
161
  with gr.Row():
162
  analysis_btn = gr.Button("Request Analysis", variant="secondary")
163
  clear_btn = gr.Button("Clear Chat", variant="stop")
164
- play_audio_btn = gr.Button("🔊 Play Assistant Response", variant="secondary")
165
 
166
- # -------------------
167
- # Update assistant responses
168
- # -------------------
169
- def update_assistant_responses(history, assistant_responses):
170
- if history and history[-1][1]:
171
- assistant_responses.append(history[-1][1])
172
- return assistant_responses
173
-
174
- # -------------------
175
- # Submit handlers
176
- # -------------------
177
- def user_submit(message, history, question_count, assistant_responses):
178
- history, updated_history, question_count = process_message(message, history, question_count)
179
- assistant_responses = update_assistant_responses(history, assistant_responses)
180
- return updated_history, updated_history, question_count, assistant_responses
181
-
182
- def mic_submit(history, question_count, assistant_responses):
183
- user_text = record_and_transcribe(duration=5)
184
- history.append([user_text, None])
185
- history, updated_history, question_count = process_message(user_text, history, question_count)
186
- assistant_responses = update_assistant_responses(history, assistant_responses)
187
- return updated_history, updated_history, question_count, assistant_responses
188
-
189
- def clear_input():
190
- return ""
191
-
192
- # -------------------
193
- # Connect buttons
194
- # -------------------
195
- send_btn.click(
196
- user_submit,
197
- inputs=[msg, chatbot, question_count_state, assistant_responses_state],
198
- outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
199
- ).then(clear_input, outputs=[msg])
200
 
201
  msg.submit(
202
- user_submit,
203
- inputs=[msg, chatbot, question_count_state, assistant_responses_state],
204
- outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
205
- ).then(clear_input, outputs=[msg])
206
 
207
- mic_btn.click(
208
- mic_submit,
209
- inputs=[chatbot, question_count_state, assistant_responses_state],
210
- outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
211
- )
212
-
213
- analysis_btn.click(
214
- force_analysis,
215
- inputs=[chatbot, question_count_state],
216
- outputs=[chatbot, question_count_state]
217
- )
218
-
219
- clear_btn.click(
220
- clear_chat,
221
- outputs=[chatbot, chatbot, question_count_state]
222
- )
223
-
224
- play_audio_btn.click(
225
- lambda assistant_responses: play_assistant_audio(assistant_responses[-1]) if assistant_responses else None,
226
- inputs=[assistant_responses_state],
227
- outputs=[]
228
- )
229
 
230
  # -------------------
231
- # 5️⃣ Launch
232
  # -------------------
233
  if __name__ == "__main__":
234
- demo.launch(
235
- server_name="127.0.0.1",
236
- server_port=7860,
237
- share=False,
238
- debug=True
239
- )
 
1
+ # app.py
2
  import gradio as gr
 
3
  from transformers import AutoProcessor, AutoModelForVision2Seq
4
+ import torch
 
5
 
6
  # -------------------
7
+ # 1️⃣ Load Model
8
  # -------------------
9
  def load_model():
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
  dtype = torch.float16 if device == "cuda" else torch.float32
12
 
13
+ # Load model and processor from Hugging Face
14
  processor = AutoProcessor.from_pretrained("Muhammadidrees/RaiyaChatDoc", trust_remote_code=True)
15
  model = AutoModelForVision2Seq.from_pretrained(
16
  "Muhammadidrees/RaiyaChatDoc",
 
23
  processor, model, device = load_model()
24
 
25
  # -------------------
26
+ # 2️⃣ Chat Logic
27
  # -------------------
28
  def process_message(message, history, question_count):
29
  if not message.strip():
 
31
 
32
  history.append([message, None])
33
  question_count += 1
34
+
35
+ # Decide if analysis is needed
36
+ should_analyze = question_count >= 6 or any(
37
+ word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"]
38
  )
39
 
40
+ # System prompt
41
+ system_prompt = (
42
+ "You are a medical doctor. "
43
+ "Provide a comprehensive analysis of potential causes for symptoms."
44
+ if should_analyze else
45
+ "You are a medical doctor conducting a patient interview. Ask ONE specific question."
46
+ )
47
+
48
+ # Build conversation context
 
 
 
 
 
49
  dialogue = []
50
  for user_msg, bot_msg in history[:-1]:
51
+ if user_msg: dialogue.append(f"Patient: {user_msg}")
52
+ if bot_msg: dialogue.append(f"Doctor: {bot_msg}")
 
 
53
  dialogue.append(f"Patient: {message}")
54
+ prompt = f"{system_prompt}\n\nConversation:\n" + "\n".join(dialogue) + "\nDoctor:"
 
 
55
 
56
+ # Prepare input
57
  inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
58
+ max_tokens = 400 if should_analyze else 25
59
+
60
  with torch.inference_mode():
61
  outputs = model.generate(
62
  **inputs,
 
65
  temperature=0.6,
66
  top_p=0.9,
67
  repetition_penalty=1.1,
68
+ pad_token_id=processor.tokenizer.eos_token_id
69
  )
70
 
71
+ # Decode response
72
  input_length = inputs["input_ids"].shape[1]
73
+ response = processor.batch_decode(outputs[:, input_length:], skip_special_tokens=True)[0].strip()
 
 
74
  if response.lower().startswith("doctor:"):
75
  response = response[7:].strip()
76
 
77
+ # Concise question formatting
78
  if not should_analyze:
79
+ response = response.split('?')[0].strip() + '?'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  history[-1][1] = response
82
+ if should_analyze: question_count = 0
83
+
 
84
  return history, history, question_count
85
 
86
  def force_analysis(history, question_count):
 
90
  return [], [], 0
91
 
92
  # -------------------
93
+ # 3️⃣ Gradio Interface
94
  # -------------------
95
+ with gr.Blocks(title="ChatDOC") as demo:
 
 
 
 
 
 
 
 
96
  question_count_state = gr.State(0)
 
 
 
 
 
 
 
 
97
 
98
+ gr.Markdown("# 🩺 Chat with ChatDOC\nDescribe your symptoms and get guidance.")
99
+ chatbot = gr.Chatbot(value=[], height=400, show_label=False)
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
+ msg = gr.Textbox(placeholder="Describe your symptoms...", scale=4, container=False, show_label=False)
 
 
 
 
 
103
  send_btn = gr.Button("Send", variant="primary", scale=1)
 
104
 
105
  with gr.Row():
106
  analysis_btn = gr.Button("Request Analysis", variant="secondary")
107
  clear_btn = gr.Button("Clear Chat", variant="stop")
 
108
 
109
+ send_event = send_btn.click(
110
+ process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
111
+ ).then(lambda: "", outputs=[msg])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  msg.submit(
114
+ process_message, inputs=[msg, chatbot, question_count_state], outputs=[chatbot, chatbot, question_count_state]
115
+ ).then(lambda: "", outputs=[msg])
 
 
116
 
117
+ analysis_btn.click(force_analysis, inputs=[chatbot, question_count_state], outputs=[chatbot, question_count_state])
118
+ clear_btn.click(clear_chat, outputs=[chatbot, chatbot, question_count_state])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # -------------------
121
+ # 4️⃣ Launch
122
  # -------------------
123
  if __name__ == "__main__":
124
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False, debug=True)