Muhammadidrees commited on
Commit
72f0197
·
verified ·
1 Parent(s): 79f2975

Create DocBrain.py

Browse files
Files changed (1) hide show
  1. DocBrain.py +240 -0
DocBrain.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq
3
+ import torch
4
+ from PaitentVoiceToText import record_and_transcribe # Your STT function
5
+ from DocVoice import text_to_speech # Your TTS function
6
+
7
+ # -------------------
8
+ # 1️⃣ Load Model & Processor
9
+ # -------------------
10
+ def load_model():
11
+ local_dir = r"C:\Users\JAY\Downloads\model\CHATDOCMODEL"
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ dtype = torch.float16 if device == "cuda" else torch.float32
14
+
15
+ processor = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True)
16
+ model = AutoModelForVision2Seq.from_pretrained(
17
+ local_dir,
18
+ dtype=dtype,
19
+ device_map=None
20
+ )
21
+ model.to(device)
22
+ return processor, model, device
23
+
24
+ processor, model, device = load_model()
25
+
26
+ # -------------------
27
+ # 2️⃣ Chat Logic Functions
28
+ # -------------------
29
+ def process_message(message, history, question_count):
30
+ if not message.strip():
31
+ return history, history, question_count
32
+
33
+ history.append([message, None])
34
+ question_count += 1
35
+
36
+ should_analyze = (
37
+ question_count >= 6 or
38
+ any(word in message.lower() for word in ["analysis", "diagnose", "what do you think", "causes"])
39
+ )
40
+
41
+ if should_analyze:
42
+ system_prompt = (
43
+ "You are a medical doctor. Based on the patient's responses, provide a comprehensive analysis "
44
+ "of potential causes for their symptoms. Start with 'Based on the information provided by the patient, "
45
+ "potential causes of [symptoms] could include:' and list 3-4 possible diagnoses with brief explanations. "
46
+ "Format as numbered list with diagnosis name and short explanation."
47
+ )
48
+ else:
49
+ system_prompt = (
50
+ "You are a medical doctor conducting a patient interview. Ask ONE specific, direct medical question "
51
+ "to gather important diagnostic information. Keep it brief - just ask the question without explanations. "
52
+ "Focus on key areas like: age, medical history, medications, lifestyle, family history, or symptom details."
53
+ )
54
+
55
+ dialogue = []
56
+ for user_msg, bot_msg in history[:-1]:
57
+ if user_msg:
58
+ dialogue.append(f"Patient: {user_msg}")
59
+ if bot_msg:
60
+ dialogue.append(f"Doctor: {bot_msg}")
61
+ dialogue.append(f"Patient: {message}")
62
+
63
+ conversation = "\n".join(dialogue)
64
+ prompt = f"{system_prompt}\n\nConversation:\n{conversation}\nDoctor:"
65
+
66
+ inputs = processor(text=prompt, images=None, return_tensors="pt").to(device)
67
+ max_tokens = 1000 if should_analyze else 25
68
+
69
+ with torch.inference_mode():
70
+ outputs = model.generate(
71
+ **inputs,
72
+ max_new_tokens=max_tokens,
73
+ do_sample=True,
74
+ temperature=0.6,
75
+ top_p=0.9,
76
+ repetition_penalty=1.1,
77
+ pad_token_id=processor.tokenizer.eos_token_id,
78
+ )
79
+
80
+ input_length = inputs["input_ids"].shape[1]
81
+ generated_tokens = outputs[:, input_length:]
82
+ response = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0].strip()
83
+
84
+ if response.lower().startswith("doctor:"):
85
+ response = response[7:].strip()
86
+
87
+ if not should_analyze:
88
+ sentences = response.split('?')
89
+ if len(sentences) > 1:
90
+ response = sentences[0].strip() + '?'
91
+ cleanup_starts = [
92
+ "I need to ask",
93
+ "Let me ask",
94
+ "I would like to know",
95
+ "Can you tell me",
96
+ "It would help if",
97
+ ]
98
+ for phrase in cleanup_starts:
99
+ if response.startswith(phrase):
100
+ parts = response.split(',', 1)
101
+ if len(parts) > 1:
102
+ response = parts[1].strip()
103
+ if not response.endswith('?'):
104
+ response += '?'
105
+
106
+ history[-1][1] = response
107
+ if should_analyze:
108
+ question_count = 0
109
+
110
+ return history, history, question_count
111
+
112
+ def force_analysis(history, question_count):
113
+ return history, 10
114
+
115
+ def clear_chat():
116
+ return [], [], 0
117
+
118
+ # -------------------
119
+ # 3️⃣ TTS Helper
120
+ # -------------------
121
+ def play_assistant_audio(response_text):
122
+ if response_text:
123
+ text_to_speech(response_text)
124
+ return None
125
+
126
+ # -------------------
127
+ # 4️⃣ Gradio Interface
128
+ # -------------------
129
+ with gr.Blocks(title="ChatDOC", theme=gr.themes.Soft()) as demo:
130
+ question_count_state = gr.State(0)
131
+ assistant_responses_state = gr.State([])
132
+
133
+ gr.Markdown(
134
+ """
135
+ # 🩺 Chat with ChatDOC
136
+ Welcome! I'm your AI medical assistant. Please describe your symptoms and I'll ask relevant questions to help understand your condition better.
137
+ """
138
+ )
139
+
140
+ chatbot = gr.Chatbot(
141
+ value=[],
142
+ height=400,
143
+ show_label=False,
144
+ avatar_images=(
145
+ r"C:\Users\JAY\Downloads\model\user_msg.png",
146
+ r"C:\Users\JAY\Downloads\model\bot_msg.jpg"
147
+ ),
148
+ bubble_full_width=False
149
+ )
150
+
151
+ with gr.Row():
152
+ msg = gr.Textbox(
153
+ placeholder="Describe your symptoms...",
154
+ scale=4,
155
+ container=False,
156
+ show_label=False
157
+ )
158
+ send_btn = gr.Button("Send", variant="primary", scale=1)
159
+ mic_btn = gr.Button("🎤 Speak", variant="secondary", scale=1)
160
+
161
+ with gr.Row():
162
+ analysis_btn = gr.Button("Request Analysis", variant="secondary")
163
+ clear_btn = gr.Button("Clear Chat", variant="stop")
164
+ play_audio_btn = gr.Button("🔊 Play Assistant Response", variant="secondary")
165
+
166
+ # -------------------
167
+ # Update assistant responses
168
+ # -------------------
169
+ def update_assistant_responses(history, assistant_responses):
170
+ if history and history[-1][1]:
171
+ assistant_responses.append(history[-1][1])
172
+ return assistant_responses
173
+
174
+ # -------------------
175
+ # Submit handlers
176
+ # -------------------
177
+ def user_submit(message, history, question_count, assistant_responses):
178
+ history, updated_history, question_count = process_message(message, history, question_count)
179
+ assistant_responses = update_assistant_responses(history, assistant_responses)
180
+ return updated_history, updated_history, question_count, assistant_responses
181
+
182
+ def mic_submit(history, question_count, assistant_responses):
183
+ user_text = record_and_transcribe(duration=5)
184
+ # Show user message immediately
185
+ history.append([user_text, None])
186
+ history, updated_history, question_count = process_message(user_text, history, question_count)
187
+ assistant_responses = update_assistant_responses(history, assistant_responses)
188
+ return updated_history, updated_history, question_count, assistant_responses
189
+
190
+ def clear_input():
191
+ return ""
192
+
193
+ # -------------------
194
+ # Connect buttons
195
+ # -------------------
196
+ send_btn.click(
197
+ user_submit,
198
+ inputs=[msg, chatbot, question_count_state, assistant_responses_state],
199
+ outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
200
+ ).then(clear_input, outputs=[msg])
201
+
202
+ msg.submit(
203
+ user_submit,
204
+ inputs=[msg, chatbot, question_count_state, assistant_responses_state],
205
+ outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
206
+ ).then(clear_input, outputs=[msg])
207
+
208
+ mic_btn.click(
209
+ mic_submit,
210
+ inputs=[chatbot, question_count_state, assistant_responses_state],
211
+ outputs=[chatbot, chatbot, question_count_state, assistant_responses_state]
212
+ )
213
+
214
+ analysis_btn.click(
215
+ force_analysis,
216
+ inputs=[chatbot, question_count_state],
217
+ outputs=[chatbot, question_count_state]
218
+ )
219
+
220
+ clear_btn.click(
221
+ clear_chat,
222
+ outputs=[chatbot, chatbot, question_count_state]
223
+ )
224
+
225
+ play_audio_btn.click(
226
+ lambda assistant_responses: play_assistant_audio(assistant_responses[-1]) if assistant_responses else None,
227
+ inputs=[assistant_responses_state],
228
+ outputs=[]
229
+ )
230
+
231
+ # -------------------
232
+ # 5️⃣ Launch
233
+ # -------------------
234
+ if __name__ == "__main__":
235
+ demo.launch(
236
+ server_name="127.0.0.1",
237
+ server_port=7860,
238
+ share=False,
239
+ debug=True
240
+ )