afouda commited on
Commit
4fada6b
·
verified ·
1 Parent(s): c71961b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +513 -0
app.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ import time
5
+ from io import BytesIO
6
+ from google.genai import types
7
+ from google.genai.types import (
8
+ LiveConnectConfig,
9
+ SpeechConfig,
10
+ VoiceConfig,
11
+ PrebuiltVoiceConfig,
12
+ Content,
13
+ Part,
14
+ )
15
+ import gradio as gr
16
+ import numpy as np
17
+ import websockets
18
+ from dotenv import load_dotenv
19
+ from fastrtc import (
20
+ AsyncAudioVideoStreamHandler,
21
+ Stream,
22
+ WebRTC,
23
+ get_cloudflare_turn_credentials_async,
24
+ wait_for_item,
25
+ )
26
+ from google import genai
27
+ from gradio.utils import get_space
28
+ from PIL import Image
29
+
30
+ # ------------------------------------------
31
+ import asyncio
32
+ import base64
33
+ import json
34
+ import os
35
+ import pathlib
36
+ from typing import AsyncGenerator, Literal
37
+
38
+ import gradio as gr
39
+ import numpy as np
40
+ from dotenv import load_dotenv
41
+ from fastapi import FastAPI
42
+ from fastapi.responses import HTMLResponse
43
+ from fastrtc import (
44
+ AsyncStreamHandler,
45
+ Stream,
46
+ get_cloudflare_turn_credentials_async,
47
+ wait_for_item,
48
+ )
49
+ from google import genai
50
+ from google.genai.types import (
51
+ LiveConnectConfig,
52
+ PrebuiltVoiceConfig,
53
+ SpeechConfig,
54
+ VoiceConfig,
55
+ )
56
+ from gradio.utils import get_space
57
+ from pydantic import BaseModel
58
+ # ------------------------------------------------
59
+ from dotenv import load_dotenv
60
+ load_dotenv()
61
+ import os
62
+ import io
63
+ import asyncio
64
+ from pydub import AudioSegment
65
+
66
+
67
+ async def safe_get_ice_config_async():
68
+ """Return Cloudflare TURN credentials when available, otherwise return a STUN-only fallback.
69
+
70
+ This prevents the library from raising the HF_TOKEN / CLOUDFLARE_* error when those
71
+ environment variables are not set during local testing.
72
+ """
73
+ # If HuggingFace token or Cloudflare TURN env vars are present, try to use the helper
74
+ if os.getenv("HF_TOKEN") or (os.getenv("CLOUDFLARE_TURN_KEY_ID") and os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")):
75
+ try:
76
+ return await get_cloudflare_turn_credentials_async()
77
+ except Exception as e:
78
+ print("Warning: failed to get Cloudflare TURN credentials, falling back to STUN-only. Error:", e)
79
+
80
+ # Fallback: return minimal STUN servers so WebRTC can still attempt peer connections locally
81
+ return {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
82
+
83
+ # Gemini: google-genai
84
+ from google import genai
85
+ # ---------------------------------------------------
86
+ # VAD imports from reference code
87
+ import collections
88
+ import webrtcvad
89
+ import time
90
+ # Weaviate imports
91
+ import weaviate
92
+ from weaviate.classes.init import Auth
93
+ from contextlib import contextmanager
94
+ # helper functions
95
+ GEMINI_API_KEY="AIzaSyATK7Q1xqWLa7nw1Y40mvRrB8motyQl1oo"
96
+ HF_TOKEN ="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"
97
+
98
+ WEAVIATE_URL="18vysvlxqza0ux821ecbg.c0.us-west3.gcp.weaviate.cloud"
99
+
100
+ WEAVIATE_API_KEY="b2d4dC9sV1Y0dkZjSnlkRV9EMU04V0FyRE9HSlBPQnhlbENzQ0dWQm9pbENyRUVuWXpWc3R3YmpjK1pBPV92MjAw"
101
+
102
+ DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"
103
+
104
+ DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"
105
+
106
+ from openai import OpenAI
107
+ openai = OpenAI(
108
+ api_key=DEEPINFRA_API_KEY,
109
+ base_url="https://api.deepinfra.com/v1/openai",
110
+ )
111
+ @contextmanager
112
+ def weaviate_client():
113
+ """
114
+ Context manager that yields a Weaviate client and
115
+ guarantees client.close() on exit.
116
+ """
117
+ client = weaviate.connect_to_weaviate_cloud(
118
+ cluster_url=WEAVIATE_URL,
119
+ auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
120
+ )
121
+ try:
122
+ yield client
123
+ finally:
124
+ client.close()
125
+
126
+ def encode_audio(data: np.ndarray) -> dict:
127
+ """Encode Audio data to send to the server"""
128
+ return {
129
+ "mime_type": "audio/pcm",
130
+ "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
131
+ }
132
+ def encode_audio2(data: np.ndarray) -> bytes:
133
+ """Encode Audio data to send to the server"""
134
+ return data.tobytes()
135
+
136
+ import soundfile as sf
137
+
138
+ def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
139
+ buffer = io.BytesIO()
140
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
141
+ return buffer.getvalue()
142
+
143
+
144
+ def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
145
+ """
146
+ Convert a NumPy audio array to WAV bytes.
147
+
148
+ Args:
149
+ audio_array (np.ndarray): Audio signal (1D or 2D).
150
+ sample_rate (int): Sample rate in Hz.
151
+
152
+ Returns:
153
+ bytes: WAV-formatted audio data.
154
+ """
155
+ buffer = io.BytesIO()
156
+ sf.write(buffer, audio_array, sample_rate, format='WAV')
157
+ buffer.seek(0)
158
+ return buffer.read()
159
+ # webrtc handler class
160
+ class GeminiHandler(AsyncStreamHandler):
161
+ """Handler for the Gemini API with chained latency calculation."""
162
+
163
+ def __init__(
164
+ self,
165
+ expected_layout: Literal["mono"] = "mono",
166
+ output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"},
167
+ ) -> None:
168
+ super().__init__(
169
+ expected_layout,
170
+ output_sample_rate,
171
+ input_sample_rate=16000,
172
+ )
173
+ self.input_queue: asyncio.Queue = asyncio.Queue()
174
+ self.output_queue: asyncio.Queue = asyncio.Queue()
175
+ self.quit: asyncio.Event = asyncio.Event()
176
+ self.prompt_dict = prompt_dict
177
+ # self.model = "gemini-2.5-flash-preview-tts"
178
+ self.model = "gemini-2.0-flash-live-001"
179
+ self.t2t_model = "gemini-2.5-flash-lite"
180
+ self.s2t_model = "gemini-2.5-flash-lite"
181
+
182
+ # --- VAD Initialization ---
183
+ self.vad = webrtcvad.Vad(3)
184
+ self.VAD_RATE = 16000
185
+ self.VAD_FRAME_MS = 20
186
+ self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0))
187
+ self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2
188
+ padding_ms = 300
189
+ self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS
190
+ self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames)
191
+ self.vad_ratio = 0.9
192
+ self.vad_triggered = False
193
+ self.wav_data = bytearray()
194
+ self.internal_buffer = bytearray()
195
+
196
+ self.end_of_speech_time: float | None = None
197
+ self.first_latency_calculated: bool = False
198
+
199
+ def copy(self) -> "GeminiHandler":
200
+ return GeminiHandler(
201
+ expected_layout="mono",
202
+ output_sample_rate=self.output_sample_rate,
203
+ prompt_dict=self.prompt_dict,
204
+ )
205
+
206
+
207
+ def s2t(self, audio) -> str:
208
+ response = self.s2t_client.models.generate_content(
209
+ model=self.s2t_model,
210
+ contents=[
211
+ types.Part.from_bytes(data=audio, mime_type='audio/wav'),
212
+ 'Generate a transcript of the speech.'
213
+ ]
214
+ )
215
+ return response.text
216
+ def embed_texts(self, texts: list[str], batch_size: int = 50) -> list[list[float]]:
217
+ """Embed a list of texts using the configured OpenAI/DeepInfra client.
218
+
219
+ Returns a list of embedding vectors (or empty lists on failure for each item).
220
+ """
221
+ all_embeddings: list[list[float]] = []
222
+ for i in range(0, len(texts), batch_size):
223
+ batch = texts[i : i + batch_size]
224
+ try:
225
+ resp = openai.embeddings.create(
226
+ model="Qwen/Qwen3-Embedding-8B",
227
+ input=batch,
228
+ encoding_format="float"
229
+ )
230
+ batch_embs = [item.embedding for item in resp.data]
231
+ all_embeddings.extend(batch_embs)
232
+ except Exception as e:
233
+ print(f"Embedding batch error (items {i}–{i+len(batch)-1}): {e}")
234
+ all_embeddings.extend([[] for _ in batch])
235
+ return all_embeddings
236
+
237
+
238
+ def s2t_and_embed(self, audio) -> list[float]:
239
+ """Convert speech to text, then embed the transcript."""
240
+ transcript = self.s2t(audio) # Step 1: Speech → Text
241
+ if not transcript:
242
+ return []
243
+ embeddings = self.embed_texts([transcript]) # Step 2: Text → Embedding
244
+ return embeddings[0] if embeddings else []
245
+
246
+ def encode_query(self, query: str) -> list[float] | None:
247
+ """Generate a single embedding vector for a query string."""
248
+ embs = self.embed_texts([query], batch_size=1)
249
+ if embs and embs[0]:
250
+ print("Query embedding (first 5 dims):", embs[0][:5])
251
+ return embs[0]
252
+ print("Failed to generate query embedding.")
253
+ return None
254
+
255
+ def rag_autism(self, query: str, top_k: int = 3) -> dict:
256
+ """
257
+ Run a RAG retrieval on the 'UserDocument' collection in Weaviate using v4 syntax.
258
+ Returns up to `top_k` matching text chunks as {'answer': [texts...]}
259
+ """
260
+ qe = self.encode_query(query)
261
+ if not qe:
262
+ return {"answer": []}
263
+
264
+ try:
265
+ with weaviate_client() as client:
266
+ books_collection = client.collections.get("UserDocument")
267
+ response = books_collection.query.near_vector(
268
+ near_vector=qe,
269
+ limit=top_k,
270
+ return_properties=["text"]
271
+ )
272
+
273
+ # Extract the text property from each object
274
+ hits = [obj.properties.get("text") for obj in response.objects if "text" in obj.properties]
275
+
276
+ # --- FIX: REMOVE REPEATED CONTEXT ---
277
+ # Convert to a dictionary's keys to get unique items, then back to a list
278
+ unique_hits = list(dict.fromkeys(hits))
279
+
280
+ if not unique_hits:
281
+ return {"answer": []}
282
+ return {"answer": unique_hits}
283
+
284
+ except Exception as e:
285
+ print("RAG Error:", e)
286
+ return {"answer": []}
287
+ def t2t(self, text: str) -> str:
288
+ """
289
+ Sends text to the pre-initialized chat model and returns the text response.
290
+ """
291
+ try:
292
+ # Ensure the chat session exists before using it.
293
+ if not hasattr(self, 'chat'):
294
+ print("Error: Chat session (self.chat) is not initialized.")
295
+ return "I'm sorry, my chat function is not ready."
296
+
297
+ # Use the existing chat session to send the message.
298
+ print("--> Attempting to send prompt to t2t model...")
299
+ response = self.chat.send_message(text)
300
+ print("--> Successfully received response from t2t model.")
301
+ return response.text
302
+ except Exception as e:
303
+ print(f"t2t error: {e}")
304
+ return ""
305
+
306
+ async def start_up(self):
307
+ # Flag for if we are using text-to-text in the middle of the chain or not.
308
+ self.t2t_bool = False
309
+ self.sys_prompt = None
310
+
311
+ self.t2t_client = genai.Client(api_key=GEMINI_API_KEY)
312
+ self.s2t_client = genai.Client(api_key=GEMINI_API_KEY)
313
+
314
+ if self.sys_prompt is not None:
315
+ chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt)
316
+ else:
317
+ chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.")
318
+ self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config)
319
+
320
+ self.t2s_client = genai.Client(api_key=GEMINI_API_KEY)
321
+
322
+ voice_name = "Puck"
323
+ if self.t2t_bool:
324
+ sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
325
+ Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
326
+ Always be clear, non-judgmental, and supportive."""
327
+ else:
328
+ sys_instruction = self.sys_prompt
329
+
330
+ if sys_instruction is not None:
331
+ config = LiveConnectConfig(
332
+ response_modalities=["AUDIO"],
333
+ speech_config=SpeechConfig(
334
+ voice_config=VoiceConfig(
335
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
336
+ )
337
+ ),
338
+ system_instruction=Content(parts=[Part.from_text(text=sys_instruction)])
339
+ )
340
+ else:
341
+ config = LiveConnectConfig(
342
+ response_modalities=["AUDIO"],
343
+ speech_config=SpeechConfig(
344
+ voice_config=VoiceConfig(
345
+ prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
346
+ )
347
+ ),
348
+ )
349
+
350
+ async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session:
351
+ async for text_from_user in self.stream():
352
+ print("--------------------------------------------")
353
+ print(f"Received text from user and reading aloud: {text_from_user}")
354
+ print("--------------------------------------------")
355
+ if not text_from_user or not text_from_user.strip():
356
+ continue
357
+
358
+ # 1) Run RAG retrieval on the user input to get contextual snippets
359
+ try:
360
+ rag_res = self.rag_autism(text_from_user, top_k=3)
361
+ context_snippets = rag_res.get("answer", []) if isinstance(rag_res, dict) else []
362
+
363
+ # --- ADDED THIS BLOCK TO PRINT THE RAG CONTEXT ---
364
+ if context_snippets:
365
+ print("\n--- RAG CONTEXT RETRIEVED ---")
366
+ for i, snippet in enumerate(context_snippets):
367
+ print(f"Snippet {i+1}: {snippet}...")
368
+ print("-----------------------------\n")
369
+ #
370
+
371
+ except Exception as e:
372
+ print("Error running RAG:", e)
373
+ context_snippets = []
374
+
375
+ # 2) Build the prompt for t2t model including retrieved context
376
+ combined_context = "\n\n".join(context_snippets) if context_snippets else ""
377
+ if combined_context:
378
+ prompt =(
379
+ "Please answer the user's question based on the following context. "
380
+ "Be helpful and concise.\n\n"
381
+ f"--- CONTEXT ---\n{combined_context}\n\n"
382
+ f"--- USER QUESTION ---\n{text_from_user}"
383
+ )
384
+ else:
385
+ prompt = (
386
+ "Answer the user's question from your own knowledge as a helpful assistant "
387
+ "specializing in Autism Spectrum Disorder.\n\n"
388
+ f"--- USER QUESTION ---\n{text_from_user}"
389
+ )
390
+ print(prompt)
391
+
392
+ # 3) Send prompt to chat (t2t) to obtain reply text
393
+ try:
394
+ reply_text = self.t2t(prompt)
395
+ print("\n--- FINAL AI RESPONSE ---")
396
+ print(reply_text)
397
+ print("-----------------------------")
398
+ except Exception as e:
399
+ print("t2t generation error:", e)
400
+ reply_text = ""
401
+
402
+ if not reply_text:
403
+ print("No t2t reply generated, skipping t2s send.")
404
+ continue
405
+
406
+ # 4) Send the reply_text to the live TTS session to speak it
407
+ try:
408
+ text_to_speak = f"Read the following text aloud exactly as it is, without adding or changing anything: '{reply_text}'"
409
+
410
+ print(f">>> MODIFIED TEXT SENT TO T2S API: '{text_to_speak}' <<<")
411
+ await session.send_client_content(
412
+ turns=types.Content(role='user', parts=[types.Part(text=text_to_speak)])
413
+ )
414
+ async for resp_chunk in session.receive():
415
+ if getattr(resp_chunk, "data", None):
416
+ array = np.frombuffer(resp_chunk.data, dtype=np.int16)
417
+ self.output_queue.put_nowait((self.output_sample_rate, array))
418
+ except Exception as e:
419
+ print("Error sending to live TTS session:", e)
420
+
421
+
422
+ async def stream(self) -> AsyncGenerator[bytes, None]:
423
+ while not self.quit.is_set():
424
+ try:
425
+ # Get the text message to be converted to speech
426
+ text_to_speak = await self.input_queue.get()
427
+ yield text_to_speak
428
+ except (asyncio.TimeoutError, TimeoutError):
429
+ pass
430
+
431
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
432
+ sr, array = frame
433
+ audio_bytes = array.tobytes()
434
+ self.internal_buffer.extend(audio_bytes)
435
+
436
+ while len(self.internal_buffer) >= self.VAD_FRAME_BYTES:
437
+ vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES]
438
+ self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:]
439
+ is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE)
440
+
441
+ if not self.vad_triggered:
442
+ self.vad_ring_buffer.append((vad_frame, is_speech))
443
+ num_voiced = len([f for f, speech in self.vad_ring_buffer if speech])
444
+ if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
445
+ print("Speech detected, starting to record...")
446
+ self.vad_triggered = True
447
+ for f, s in self.vad_ring_buffer:
448
+ self.wav_data.extend(f)
449
+ self.vad_ring_buffer.clear()
450
+ else:
451
+ self.wav_data.extend(vad_frame)
452
+ self.vad_ring_buffer.append((vad_frame, is_speech))
453
+ num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech])
454
+ if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
455
+ print("End of speech detected.")
456
+
457
+ self.end_of_speech_time = time.monotonic()
458
+
459
+ self.vad_triggered = False
460
+ full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16)
461
+ audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
462
+
463
+ text_input = self.s2t(audio_input_wav)
464
+
465
+ # --- ADDED THIS BLOCK TO PRINT THE S2T TRANSCRIPT ---
466
+ print("\n--- FULL S2T TRANSCRIPT ---")
467
+ print(f"'{text_input}'")
468
+ print("---------------------------\n")
469
+ # ----------------------------------------------------
470
+
471
+ if text_input and text_input.strip():
472
+ if self.t2t_bool:
473
+ text_message = self.t2t(text_input)
474
+ else:
475
+ text_message = text_input
476
+ self.input_queue.put_nowait(text_message)
477
+ else:
478
+ print("STT returned empty transcript, skipping.")
479
+
480
+ self.vad_ring_buffer.clear()
481
+ self.wav_data = bytearray()
482
+
483
+ async def emit(self) -> tuple[int, np.ndarray] | None:
484
+ return await wait_for_item(self.output_queue)
485
+
486
+ def shutdown(self) -> None:
487
+ self.quit.set()
488
+
489
+ with gr.Blocks() as demo:
490
+ gr.Markdown("# Gemini Chained Speech-to-Speech Demo")
491
+
492
+ with gr.Row() as row2:
493
+ with gr.Column():
494
+ webrtc2 = WebRTC(
495
+ label="Audio Chat",
496
+ modality="audio",
497
+ mode="send-receive",
498
+ elem_id="audio-source",
499
+ rtc_configuration=safe_get_ice_config_async,
500
+ icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
501
+ pulse_color="rgb(255, 255, 255)",
502
+ icon_button_color="rgb(255, 255, 255)",
503
+ )
504
+ webrtc2.stream(
505
+ GeminiHandler(),
506
+ inputs=[webrtc2],
507
+ outputs=[webrtc2],
508
+ time_limit=180 if get_space() else None,
509
+ concurrency_limit=2 if get_space() else None,
510
+ )
511
+
512
+ if __name__ == "__main__":
513
+ demo.launch(server_port=9090, debug=True)