Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| import os | |
| import time | |
| from io import BytesIO | |
| from google.genai import types | |
| from google.genai.types import ( | |
| LiveConnectConfig, | |
| SpeechConfig, | |
| VoiceConfig, | |
| PrebuiltVoiceConfig, | |
| Content, | |
| Part, | |
| ) | |
| import gradio as gr | |
| import numpy as np | |
| import websockets | |
| from dotenv import load_dotenv | |
| from fastrtc import ( | |
| AsyncAudioVideoStreamHandler, | |
| Stream, | |
| WebRTC, | |
| get_cloudflare_turn_credentials_async, | |
| wait_for_item, | |
| ) | |
| from google import genai | |
| from gradio.utils import get_space | |
| from PIL import Image | |
| # ------------------------------------------ | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import pathlib | |
| from typing import AsyncGenerator, Literal | |
| import gradio as gr | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from fastrtc import ( | |
| AsyncStreamHandler, | |
| Stream, | |
| get_cloudflare_turn_credentials_async, | |
| wait_for_item, | |
| ) | |
| from google import genai | |
| from google.genai.types import ( | |
| LiveConnectConfig, | |
| PrebuiltVoiceConfig, | |
| SpeechConfig, | |
| VoiceConfig, | |
| ) | |
| from gradio.utils import get_space | |
| from pydantic import BaseModel | |
| # ------------------------------------------------ | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| import io | |
| import asyncio | |
| from pydub import AudioSegment | |
| async def safe_get_ice_config_async(): | |
| """Return Cloudflare TURN credentials when available, otherwise return a STUN-only fallback. | |
| This prevents the library from raising the HF_TOKEN / CLOUDFLARE_* error when those | |
| environment variables are not set during local testing. | |
| """ | |
| # If HuggingFace token or Cloudflare TURN env vars are present, try to use the helper | |
| if os.getenv("HF_TOKEN") or (os.getenv("CLOUDFLARE_TURN_KEY_ID") and os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")): | |
| try: | |
| return await get_cloudflare_turn_credentials_async() | |
| except Exception as e: | |
| print("Warning: failed to get Cloudflare TURN credentials, falling back to STUN-only. Error:", e) | |
| # Fallback: return minimal STUN servers so WebRTC can still attempt peer connections locally | |
| return {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} | |
| # Gemini: google-genai | |
| from google import genai | |
| # --------------------------------------------------- | |
| # VAD imports from reference code | |
| import collections | |
| import webrtcvad | |
| import time | |
| # Weaviate imports | |
| import weaviate | |
| from weaviate.classes.init import Auth | |
| from contextlib import contextmanager | |
| # helper functions | |
| GEMINI_API_KEY="AIzaSyATK7Q1xqWLa7nw1Y40mvRrB8motyQl1oo" | |
| HF_TOKEN ="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" | |
| WEAVIATE_URL="18vysvlxqza0ux821ecbg.c0.us-west3.gcp.weaviate.cloud" | |
| WEAVIATE_API_KEY="b2d4dC9sV1Y0dkZjSnlkRV9EMU04V0FyRE9HSlBPQnhlbENzQ0dWQm9pbENyRUVuWXpWc3R3YmpjK1pBPV92MjAw" | |
| DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
| DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
| from openai import OpenAI | |
| openai = OpenAI( | |
| api_key=DEEPINFRA_API_KEY, | |
| base_url="https://api.deepinfra.com/v1/openai", | |
| ) | |
| def weaviate_client(): | |
| """ | |
| Context manager that yields a Weaviate client and | |
| guarantees client.close() on exit. | |
| """ | |
| client = weaviate.connect_to_weaviate_cloud( | |
| cluster_url=WEAVIATE_URL, | |
| auth_credentials=Auth.api_key(WEAVIATE_API_KEY), | |
| ) | |
| try: | |
| yield client | |
| finally: | |
| client.close() | |
| def encode_audio(data: np.ndarray) -> dict: | |
| """Encode Audio data to send to the server""" | |
| return { | |
| "mime_type": "audio/pcm", | |
| "data": base64.b64encode(data.tobytes()).decode("UTF-8"), | |
| } | |
| def encode_audio2(data: np.ndarray) -> bytes: | |
| """Encode Audio data to send to the server""" | |
| return data.tobytes() | |
| import soundfile as sf | |
| def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| return buffer.getvalue() | |
| def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
| """ | |
| Convert a NumPy audio array to WAV bytes. | |
| Args: | |
| audio_array (np.ndarray): Audio signal (1D or 2D). | |
| sample_rate (int): Sample rate in Hz. | |
| Returns: | |
| bytes: WAV-formatted audio data. | |
| """ | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| buffer.seek(0) | |
| return buffer.read() | |
| # webrtc handler class | |
| class GeminiHandler(AsyncStreamHandler): | |
| """Handler for the Gemini API with chained latency calculation.""" | |
| def __init__( | |
| self, | |
| expected_layout: Literal["mono"] = "mono", | |
| output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"}, | |
| ) -> None: | |
| super().__init__( | |
| expected_layout, | |
| output_sample_rate, | |
| input_sample_rate=16000, | |
| ) | |
| self.input_queue: asyncio.Queue = asyncio.Queue() | |
| self.output_queue: asyncio.Queue = asyncio.Queue() | |
| self.quit: asyncio.Event = asyncio.Event() | |
| self.prompt_dict = prompt_dict | |
| # self.model = "gemini-2.5-flash-preview-tts" | |
| self.model = "gemini-2.0-flash-live-001" | |
| self.t2t_model = "gemini-2.5-flash-lite" | |
| self.s2t_model = "gemini-2.5-flash-lite" | |
| # --- VAD Initialization --- | |
| self.vad = webrtcvad.Vad(3) | |
| self.VAD_RATE = 16000 | |
| self.VAD_FRAME_MS = 20 | |
| self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0)) | |
| self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2 | |
| padding_ms = 300 | |
| self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS | |
| self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames) | |
| self.vad_ratio = 0.9 | |
| self.vad_triggered = False | |
| self.wav_data = bytearray() | |
| self.internal_buffer = bytearray() | |
| self.end_of_speech_time: float | None = None | |
| self.first_latency_calculated: bool = False | |
| def copy(self) -> "GeminiHandler": | |
| return GeminiHandler( | |
| expected_layout="mono", | |
| output_sample_rate=self.output_sample_rate, | |
| prompt_dict=self.prompt_dict, | |
| ) | |
| def s2t(self, audio) -> str: | |
| response = self.s2t_client.models.generate_content( | |
| model=self.s2t_model, | |
| contents=[ | |
| types.Part.from_bytes(data=audio, mime_type='audio/wav'), | |
| 'Generate a transcript of the speech.' | |
| ] | |
| ) | |
| return response.text | |
| def embed_texts(self, texts: list[str], batch_size: int = 50) -> list[list[float]]: | |
| """Embed a list of texts using the configured OpenAI/DeepInfra client. | |
| Returns a list of embedding vectors (or empty lists on failure for each item). | |
| """ | |
| all_embeddings: list[list[float]] = [] | |
| for i in range(0, len(texts), batch_size): | |
| batch = texts[i : i + batch_size] | |
| try: | |
| resp = openai.embeddings.create( | |
| model="Qwen/Qwen3-Embedding-8B", | |
| input=batch, | |
| encoding_format="float" | |
| ) | |
| batch_embs = [item.embedding for item in resp.data] | |
| all_embeddings.extend(batch_embs) | |
| except Exception as e: | |
| print(f"Embedding batch error (items {i}–{i+len(batch)-1}): {e}") | |
| all_embeddings.extend([[] for _ in batch]) | |
| return all_embeddings | |
| def s2t_and_embed(self, audio) -> list[float]: | |
| """Convert speech to text, then embed the transcript.""" | |
| transcript = self.s2t(audio) # Step 1: Speech → Text | |
| if not transcript: | |
| return [] | |
| embeddings = self.embed_texts([transcript]) # Step 2: Text → Embedding | |
| return embeddings[0] if embeddings else [] | |
| def encode_query(self, query: str) -> list[float] | None: | |
| """Generate a single embedding vector for a query string.""" | |
| embs = self.embed_texts([query], batch_size=1) | |
| if embs and embs[0]: | |
| print("Query embedding (first 5 dims):", embs[0][:5]) | |
| return embs[0] | |
| print("Failed to generate query embedding.") | |
| return None | |
| def rag_autism(self, query: str, top_k: int = 3) -> dict: | |
| """ | |
| Run a RAG retrieval on the 'UserDocument' collection in Weaviate using v4 syntax. | |
| Returns up to `top_k` matching text chunks as {'answer': [texts...]} | |
| """ | |
| qe = self.encode_query(query) | |
| if not qe: | |
| return {"answer": []} | |
| try: | |
| with weaviate_client() as client: | |
| books_collection = client.collections.get("UserDocument") | |
| response = books_collection.query.near_vector( | |
| near_vector=qe, | |
| limit=top_k, | |
| return_properties=["text"] | |
| ) | |
| # Extract the text property from each object | |
| hits = [obj.properties.get("text") for obj in response.objects if "text" in obj.properties] | |
| # --- FIX: REMOVE REPEATED CONTEXT --- | |
| # Convert to a dictionary's keys to get unique items, then back to a list | |
| unique_hits = list(dict.fromkeys(hits)) | |
| if not unique_hits: | |
| return {"answer": []} | |
| return {"answer": unique_hits} | |
| except Exception as e: | |
| print("RAG Error:", e) | |
| return {"answer": []} | |
| def t2t(self, text: str) -> str: | |
| """ | |
| Sends text to the pre-initialized chat model and returns the text response. | |
| """ | |
| try: | |
| # Ensure the chat session exists before using it. | |
| if not hasattr(self, 'chat'): | |
| print("Error: Chat session (self.chat) is not initialized.") | |
| return "I'm sorry, my chat function is not ready." | |
| # Use the existing chat session to send the message. | |
| print("--> Attempting to send prompt to t2t model...") | |
| response = self.chat.send_message(text) | |
| print("--> Successfully received response from t2t model.") | |
| return response.text | |
| except Exception as e: | |
| print(f"t2t error: {e}") | |
| return "" | |
| async def start_up(self): | |
| # Flag for if we are using text-to-text in the middle of the chain or not. | |
| self.t2t_bool = False | |
| self.sys_prompt = None | |
| self.t2t_client = genai.Client(api_key=GEMINI_API_KEY) | |
| self.s2t_client = genai.Client(api_key=GEMINI_API_KEY) | |
| if self.sys_prompt is not None: | |
| chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt) | |
| else: | |
| chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.") | |
| self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config) | |
| self.t2s_client = genai.Client(api_key=GEMINI_API_KEY) | |
| voice_name = "Puck" | |
| if self.t2t_bool: | |
| sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . | |
| Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). | |
| Always be clear, non-judgmental, and supportive.""" | |
| else: | |
| sys_instruction = self.sys_prompt | |
| if sys_instruction is not None: | |
| config = LiveConnectConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=SpeechConfig( | |
| voice_config=VoiceConfig( | |
| prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
| ) | |
| ), | |
| system_instruction=Content(parts=[Part.from_text(text=sys_instruction)]) | |
| ) | |
| else: | |
| config = LiveConnectConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=SpeechConfig( | |
| voice_config=VoiceConfig( | |
| prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
| ) | |
| ), | |
| ) | |
| async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session: | |
| async for text_from_user in self.stream(): | |
| print("--------------------------------------------") | |
| print(f"Received text from user and reading aloud: {text_from_user}") | |
| print("--------------------------------------------") | |
| if not text_from_user or not text_from_user.strip(): | |
| continue | |
| # 1) Run RAG retrieval on the user input to get contextual snippets | |
| try: | |
| rag_res = self.rag_autism(text_from_user, top_k=3) | |
| context_snippets = rag_res.get("answer", []) if isinstance(rag_res, dict) else [] | |
| # --- ADDED THIS BLOCK TO PRINT THE RAG CONTEXT --- | |
| if context_snippets: | |
| print("\n--- RAG CONTEXT RETRIEVED ---") | |
| for i, snippet in enumerate(context_snippets): | |
| print(f"Snippet {i+1}: {snippet}...") | |
| print("-----------------------------\n") | |
| # | |
| except Exception as e: | |
| print("Error running RAG:", e) | |
| context_snippets = [] | |
| # 2) Build the prompt for t2t model including retrieved context | |
| combined_context = "\n\n".join(context_snippets) if context_snippets else "" | |
| if combined_context: | |
| prompt =( | |
| "Please answer the user's question based on the following context. " | |
| "Be helpful and concise.\n\n" | |
| f"--- CONTEXT ---\n{combined_context}\n\n" | |
| f"--- USER QUESTION ---\n{text_from_user}" | |
| ) | |
| else: | |
| prompt = ( | |
| "Answer the user's question from your own knowledge as a helpful assistant " | |
| "specializing in Autism Spectrum Disorder.\n\n" | |
| f"--- USER QUESTION ---\n{text_from_user}" | |
| ) | |
| print(prompt) | |
| # 3) Send prompt to chat (t2t) to obtain reply text | |
| try: | |
| reply_text = self.t2t(prompt) | |
| print("\n--- FINAL AI RESPONSE ---") | |
| print(reply_text) | |
| print("-----------------------------") | |
| except Exception as e: | |
| print("t2t generation error:", e) | |
| reply_text = "" | |
| if not reply_text: | |
| print("No t2t reply generated, skipping t2s send.") | |
| continue | |
| # 4) Send the reply_text to the live TTS session to speak it | |
| try: | |
| text_to_speak = f"Read the following text aloud exactly as it is, without adding or changing anything: '{reply_text}'" | |
| print(f">>> MODIFIED TEXT SENT TO T2S API: '{text_to_speak}' <<<") | |
| await session.send_client_content( | |
| turns=types.Content(role='user', parts=[types.Part(text=text_to_speak)]) | |
| ) | |
| async for resp_chunk in session.receive(): | |
| if getattr(resp_chunk, "data", None): | |
| array = np.frombuffer(resp_chunk.data, dtype=np.int16) | |
| self.output_queue.put_nowait((self.output_sample_rate, array)) | |
| except Exception as e: | |
| print("Error sending to live TTS session:", e) | |
| async def stream(self) -> AsyncGenerator[bytes, None]: | |
| while not self.quit.is_set(): | |
| try: | |
| # Get the text message to be converted to speech | |
| text_to_speak = await self.input_queue.get() | |
| yield text_to_speak | |
| except (asyncio.TimeoutError, TimeoutError): | |
| pass | |
| async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| sr, array = frame | |
| audio_bytes = array.tobytes() | |
| self.internal_buffer.extend(audio_bytes) | |
| while len(self.internal_buffer) >= self.VAD_FRAME_BYTES: | |
| vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES] | |
| self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:] | |
| is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE) | |
| if not self.vad_triggered: | |
| self.vad_ring_buffer.append((vad_frame, is_speech)) | |
| num_voiced = len([f for f, speech in self.vad_ring_buffer if speech]) | |
| if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
| print("Speech detected, starting to record...") | |
| self.vad_triggered = True | |
| for f, s in self.vad_ring_buffer: | |
| self.wav_data.extend(f) | |
| self.vad_ring_buffer.clear() | |
| else: | |
| self.wav_data.extend(vad_frame) | |
| self.vad_ring_buffer.append((vad_frame, is_speech)) | |
| num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech]) | |
| if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
| print("End of speech detected.") | |
| self.end_of_speech_time = time.monotonic() | |
| self.vad_triggered = False | |
| full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16) | |
| audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr) | |
| text_input = self.s2t(audio_input_wav) | |
| # --- ADDED THIS BLOCK TO PRINT THE S2T TRANSCRIPT --- | |
| print("\n--- FULL S2T TRANSCRIPT ---") | |
| print(f"'{text_input}'") | |
| print("---------------------------\n") | |
| # ---------------------------------------------------- | |
| if text_input and text_input.strip(): | |
| if self.t2t_bool: | |
| text_message = self.t2t(text_input) | |
| else: | |
| text_message = text_input | |
| self.input_queue.put_nowait(text_message) | |
| else: | |
| print("STT returned empty transcript, skipping.") | |
| self.vad_ring_buffer.clear() | |
| self.wav_data = bytearray() | |
| async def emit(self) -> tuple[int, np.ndarray] | None: | |
| return await wait_for_item(self.output_queue) | |
| def shutdown(self) -> None: | |
| self.quit.set() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gemini Chained Speech-to-Speech Demo") | |
| with gr.Row() as row2: | |
| with gr.Column(): | |
| webrtc2 = WebRTC( | |
| label="Audio Chat", | |
| modality="audio", | |
| mode="send-receive", | |
| elem_id="audio-source", | |
| rtc_configuration=safe_get_ice_config_async, | |
| icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", | |
| pulse_color="rgb(255, 255, 255)", | |
| icon_button_color="rgb(255, 255, 255)", | |
| ) | |
| webrtc2.stream( | |
| GeminiHandler(), | |
| inputs=[webrtc2], | |
| outputs=[webrtc2], | |
| time_limit=180 if get_space() else None, | |
| concurrency_limit=2 if get_space() else None, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.environ.get("PORT",7860)), | |
| debug=True, | |
| ) | |