siyah1 commited on
Commit
baba48f
·
verified ·
1 Parent(s): bdf07fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +384 -55
app.py CHANGED
@@ -1,70 +1,399 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
 
 
 
20
 
21
- messages.extend(history)
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- messages.append({"role": "user", "content": message})
 
 
 
24
 
25
- response = ""
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
41
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ from threading import Event
6
+ from datetime import datetime
7
+
8
  import gradio as gr
9
+ import numpy as np
10
+ import websockets.sync.client
11
+ from dotenv import load_dotenv
12
+ from gradio_webrtc import StreamHandler, WebRTC, get_twilio_turn_credentials
13
 
14
+ load_dotenv()
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ class GeminiConfig:
18
+ def __init__(self, api_key):
19
+ self.api_key = api_key
20
+ self.host = "generativelanguage.googleapis.com"
21
+ self.model = "models/gemini-2.0-flash-exp"
22
+ self.ws_url = f"wss://{self.host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={self.api_key}"
23
 
24
+ class AudioProcessor:
25
+ @staticmethod
26
+ def encode_audio(data, sample_rate):
27
+ encoded = base64.b64encode(data.tobytes()).decode("UTF-8")
28
+ return {
29
+ "realtimeInput": {
30
+ "mediaChunks": [
31
+ {
32
+ "mimeType": f"audio/pcm;rate={sample_rate}",
33
+ "data": encoded,
34
+ }
35
+ ],
36
+ },
37
+ }
38
 
39
+ @staticmethod
40
+ def process_audio_response(data):
41
+ audio_data = base64.b64decode(data)
42
+ return np.frombuffer(audio_data, dtype=np.int16)
43
 
 
44
 
45
+ class ConversationTracker:
46
+ def __init__(self):
47
+ self.conversation_history = []
48
+ self.start_time = None
49
+ self.end_time = None
50
+ self.session_active = False
 
 
 
 
 
51
 
52
+ def start_session(self):
53
+ self.start_time = datetime.now()
54
+ self.session_active = True
55
+ self.conversation_history = []
56
 
57
+ def add_message(self, message, is_user=False):
58
+ timestamp = datetime.now()
59
+ self.conversation_history.append({
60
+ "timestamp": timestamp,
61
+ "message": message,
62
+ "speaker": "Patient" if is_user else "AI Agent",
63
+ "type": "voice"
64
+ })
65
 
66
+ def end_session(self):
67
+ self.end_time = datetime.now()
68
+ self.session_active = False
69
+
70
+ def generate_report(self):
71
+ if not self.conversation_history:
72
+ return "No conversation data available."
73
+
74
+ duration = (self.end_time - self.start_time).total_seconds() / 60 if self.end_time else 0
75
+
76
+ report = f"""
77
+ PRECONSULTATION SUMMARY REPORT
78
+ Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
79
+ Session Duration: {duration:.1f} minutes
80
+ Total Exchanges: {len(self.conversation_history)}
81
+
82
+ CONVERSATION SUMMARY:
83
+ This preconsultation session involved a voice-based interaction between the patient and an AI consultation agent. The AI gathered preliminary information to assist healthcare providers in understanding the patient's needs before their appointment.
84
+
85
+ KEY POINTS DISCUSSED:
86
  """
87
+
88
+ # Extract key information from conversation
89
+ user_messages = [msg["message"] for msg in self.conversation_history if msg["speaker"] == "Patient"]
90
+ if user_messages:
91
+ report += "- Patient concerns and symptoms mentioned during the session\n"
92
+ report += "- Medical history and current health status discussed\n"
93
+ report += "- Expectations and questions for the upcoming consultation\n\n"
94
+
95
+ report += "RECOMMENDATIONS FOR HEALTHCARE PROVIDER:\n"
96
+ report += "- Review the patient's expressed concerns\n"
97
+ report += "- Consider the preliminary information gathered\n"
98
+ report += "- Address any specific questions or anxieties mentioned\n"
99
+ report += "- Follow up on symptoms or conditions discussed\n\n"
100
+
101
+ report += "NOTE: This is an AI-generated summary for informational purposes only. "
102
+ report += "Professional medical judgment should always take precedence.\n"
103
+
104
+ return report
105
+
106
+
107
+ class GeminiHandler(StreamHandler):
108
+ def __init__(
109
+ self, expected_layout="mono", output_sample_rate=24000, output_frame_size=480
110
+ ) -> None:
111
+ super().__init__(
112
+ expected_layout,
113
+ output_sample_rate,
114
+ output_frame_size,
115
+ input_sample_rate=24000,
116
+ )
117
+ self.config = None
118
+ self.ws = None
119
+ self.all_output_data = None
120
+ self.audio_processor = AudioProcessor()
121
+ self.args_set = Event()
122
+ self.conversation_tracker = ConversationTracker()
123
+ self.system_prompt_sent = False
124
+
125
+ def copy(self):
126
+ handler = GeminiHandler(
127
+ expected_layout=self.expected_layout,
128
+ output_sample_rate=self.output_sample_rate,
129
+ output_frame_size=self.output_frame_size,
130
+ )
131
+ handler.conversation_tracker = self.conversation_tracker
132
+ return handler
133
+
134
+ def _initialize_websocket(self):
135
+ assert self.config, "Config not set"
136
+ try:
137
+ self.ws = websockets.sync.client.connect(self.config.ws_url, timeout=30)
138
+ initial_request = {
139
+ "setup": {
140
+ "model": self.config.model,
141
+ "systemInstruction": {
142
+ "parts": [
143
+ {
144
+ "text": """You are a friendly and professional AI preconsultation agent designed to help patients prepare for their medical appointments. Your role is to:
145
+
146
+ 1. Warmly greet patients and explain your purpose
147
+ 2. Gather preliminary information about their health concerns
148
+ 3. Ask relevant questions about symptoms, medical history, and current medications
149
+ 4. Address any anxieties or questions they have about their upcoming appointment
150
+ 5. Provide reassurance and basic health education when appropriate
151
+ 6. Keep the conversation focused and efficient (aim for 5-10 minutes)
152
+
153
+ Guidelines:
154
+ - Be empathetic and professional
155
+ - Ask one question at a time
156
+ - Listen actively and acknowledge concerns
157
+ - Don't provide medical diagnoses or treatment advice
158
+ - Encourage patients to discuss all concerns with their healthcare provider
159
+ - Keep responses concise but warm
160
+ - When the patient indicates they're ready to end or have covered their main concerns, offer to summarize and conclude
161
+
162
+ Start by introducing yourself and asking how you can help them prepare for their appointment."""
163
+ }
164
+ ]
165
+ }
166
+ }
167
+ }
168
+ self.ws.send(json.dumps(initial_request))
169
+ setup_response = json.loads(self.ws.recv())
170
+ print(f"Setup response: {setup_response}")
171
+ self.conversation_tracker.start_session()
172
+ except websockets.exceptions.WebSocketException as e:
173
+ print(f"WebSocket connection failed: {str(e)}")
174
+ self.ws = None
175
+ except Exception as e:
176
+ print(f"Setup failed: {str(e)}")
177
+ self.ws = None
178
+
179
+ async def fetch_args(self):
180
+ if self.channel:
181
+ self.channel.send("tick")
182
+
183
+ def set_args(self, args):
184
+ super().set_args(args)
185
+ self.args_set.set()
186
+
187
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
188
+ if not self.channel:
189
+ return
190
+ if not self.config:
191
+ # Get API key from environment variable
192
+ api_key = os.getenv('GEMINI_API_KEY')
193
+ if not api_key:
194
+ print("Error: GEMINI_API_KEY environment variable not set")
195
+ return
196
+ self.config = GeminiConfig(api_key)
197
+
198
+ try:
199
+ if not self.ws:
200
+ self._initialize_websocket()
201
+
202
+ _, array = frame
203
+ array = array.squeeze()
204
+ audio_message = self.audio_processor.encode_audio(
205
+ array, self.output_sample_rate
206
+ )
207
+ self.ws.send(json.dumps(audio_message))
208
+ except Exception as e:
209
+ print(f"Error in receive: {str(e)}")
210
+ if self.ws:
211
+ self.ws.close()
212
+ self.ws = None
213
+
214
+ def _process_server_content(self, content):
215
+ # Track AI responses
216
+ for part in content.get("parts", []):
217
+ if "text" in part:
218
+ self.conversation_tracker.add_message(part["text"], is_user=False)
219
+
220
+ data = part.get("inlineData", {}).get("data", "")
221
+ if data:
222
+ audio_array = self.audio_processor.process_audio_response(data)
223
+ if self.all_output_data is None:
224
+ self.all_output_data = audio_array
225
+ else:
226
+ self.all_output_data = np.concatenate(
227
+ (self.all_output_data, audio_array)
228
+ )
229
+
230
+ while self.all_output_data.shape[-1] >= self.output_frame_size:
231
+ yield (
232
+ self.output_sample_rate,
233
+ self.all_output_data[: self.output_frame_size].reshape(1, -1),
234
+ )
235
+ self.all_output_data = self.all_output_data[
236
+ self.output_frame_size :
237
+ ]
238
+
239
+ def generator(self):
240
+ while True:
241
+ if not self.ws or not self.config:
242
+ print("WebSocket not connected")
243
+ yield None
244
+ continue
245
+
246
+ try:
247
+ message = self.ws.recv(timeout=5)
248
+ msg = json.loads(message)
249
+
250
+ if "serverContent" in msg:
251
+ content = msg["serverContent"].get("modelTurn", {})
252
+ yield from self._process_server_content(content)
253
+ except TimeoutError:
254
+ print("Timeout waiting for server response")
255
+ yield None
256
+ except Exception as e:
257
+ print(f"Error in generator: {str(e)}")
258
+ yield None
259
+
260
+ def emit(self) -> tuple[int, np.ndarray] | None:
261
+ if not self.ws:
262
+ return None
263
+ if not hasattr(self, "_generator"):
264
+ self._generator = self.generator()
265
+ try:
266
+ return next(self._generator)
267
+ except StopIteration:
268
+ self.reset()
269
+ return None
270
+
271
+ def reset(self) -> None:
272
+ if hasattr(self, "_generator"):
273
+ delattr(self, "_generator")
274
+ self.all_output_data = None
275
+
276
+ def shutdown(self) -> None:
277
+ if self.ws:
278
+ self.ws.close()
279
+ if self.conversation_tracker.session_active:
280
+ self.conversation_tracker.end_session()
281
+
282
+ def check_connection(self):
283
+ try:
284
+ if not self.ws or self.ws.closed:
285
+ self._initialize_websocket()
286
+ return True
287
+ except Exception as e:
288
+ print(f"Connection check failed: {str(e)}")
289
+ return False
290
+
291
+
292
+ class PreconsultationApp:
293
+ def __init__(self):
294
+ self.handler = None
295
+ self.demo = self._create_interface()
296
+
297
+ def _create_interface(self):
298
+ with gr.Blocks(title="AI Preconsultation Agent") as demo:
299
+ gr.HTML("""
300
+ <div style='text-align: center; margin-bottom: 20px'>
301
+ <h1>🩺 AI Preconsultation Agent</h1>
302
+ <p>Prepare for your medical appointment with our AI assistant</p>
303
+ <p style='color: #666; font-size: 14px'>
304
+ This AI agent will help gather preliminary information before your consultation
305
+ </p>
306
+ </div>
307
+ """)
308
+
309
+ with gr.Row():
310
+ with gr.Column(scale=2):
311
+ webrtc = WebRTC(
312
+ label="Voice Consultation",
313
+ modality="audio",
314
+ mode="send-receive",
315
+ rtc_configuration=get_twilio_turn_credentials(),
316
+ )
317
+
318
+ with gr.Column(scale=1):
319
+ gr.HTML("""
320
+ <div style='background-color: #f0f9ff; padding: 15px; border-radius: 8px; margin-bottom: 15px'>
321
+ <h3 style='margin-top: 0'>How it works:</h3>
322
+ <ol style='margin-bottom: 0'>
323
+ <li>Click "Start" to begin the voice consultation</li>
324
+ <li>Speak naturally with the AI agent</li>
325
+ <li>Share your health concerns and questions</li>
326
+ <li>End the session when ready</li>
327
+ <li>Get a summary report for your healthcare provider</li>
328
+ </ol>
329
+ </div>
330
+ """)
331
+
332
+ end_session_btn = gr.Button(
333
+ "End Session & Generate Report",
334
+ variant="primary",
335
+ size="lg"
336
+ )
337
+
338
+ with gr.Row():
339
+ report_output = gr.Textbox(
340
+ label="Consultation Summary Report",
341
+ placeholder="Your consultation report will appear here after ending the session...",
342
+ lines=15,
343
+ max_lines=20,
344
+ visible=False
345
+ )
346
+
347
+ # Set up the WebRTC stream
348
+ self.handler = GeminiHandler()
349
+ webrtc.stream(
350
+ self.handler,
351
+ inputs=[webrtc],
352
+ outputs=[webrtc],
353
+ time_limit=600, # 10 minutes max
354
+ concurrency_limit=1,
355
+ )
356
+
357
+ # Handle end session
358
+ def end_session():
359
+ if self.handler and self.handler.conversation_tracker.session_active:
360
+ self.handler.conversation_tracker.end_session()
361
+ report = self.handler.conversation_tracker.generate_report()
362
+ return gr.update(value=report, visible=True)
363
+ return gr.update(value="No active session to end.", visible=True)
364
+
365
+ end_session_btn.click(
366
+ end_session,
367
+ outputs=[report_output]
368
+ )
369
+
370
+ gr.HTML("""
371
+ <div style='text-align: center; margin-top: 20px; padding: 15px; background-color: #fef3c7; border-radius: 8px'>
372
+ <p style='margin: 0; color: #92400e'>
373
+ <strong>Important:</strong> This AI agent is for preliminary consultation only.
374
+ Always consult with qualified healthcare professionals for medical advice.
375
+ </p>
376
+ </div>
377
+ """)
378
+
379
+ return demo
380
+
381
+ def launch(self):
382
+ # Check if API key is set
383
+ if not os.getenv('GEMINI_API_KEY'):
384
+ print("Error: Please set the GEMINI_API_KEY environment variable")
385
+ print("You can get a Gemini API key from: https://ai.google.dev/gemini-api/docs/api-key")
386
+ return
387
+
388
+ self.demo.launch(
389
+ server_name="0.0.0.0",
390
+ server_port=int(os.environ.get("PORT", 7860)),
391
+ ssl_verify=False,
392
+ ssl_keyfile=None,
393
+ ssl_certfile=None,
394
+ )
395
 
396
 
397
  if __name__ == "__main__":
398
+ app = PreconsultationApp()
399
+ app.launch()