import asyncio import base64 import os import time from io import BytesIO from google.genai import types from google.genai.types import ( LiveConnectConfig, SpeechConfig, VoiceConfig, PrebuiltVoiceConfig, Content, Part, ) import gradio as gr import numpy as np import websockets from dotenv import load_dotenv from fastrtc import ( AsyncAudioVideoStreamHandler, Stream, WebRTC, get_cloudflare_turn_credentials_async, wait_for_item, ) from google import genai from gradio.utils import get_space from PIL import Image # ------------------------------------------ import asyncio import base64 import json import os import pathlib from typing import AsyncGenerator, Literal import gradio as gr import numpy as np from dotenv import load_dotenv from fastapi import FastAPI from fastapi.responses import HTMLResponse from fastrtc import ( AsyncStreamHandler, Stream, get_cloudflare_turn_credentials_async, wait_for_item, ) from google import genai from google.genai.types import ( LiveConnectConfig, PrebuiltVoiceConfig, SpeechConfig, VoiceConfig, ) from gradio.utils import get_space from pydantic import BaseModel # ------------------------------------------------ from dotenv import load_dotenv load_dotenv() import os import io import asyncio from pydub import AudioSegment async def safe_get_ice_config_async(): """Return Cloudflare TURN credentials when available, otherwise return a STUN-only fallback. This prevents the library from raising the HF_TOKEN / CLOUDFLARE_* error when those environment variables are not set during local testing. """ # If HuggingFace token or Cloudflare TURN env vars are present, try to use the helper if os.getenv("HF_TOKEN") or (os.getenv("CLOUDFLARE_TURN_KEY_ID") and os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")): try: return await get_cloudflare_turn_credentials_async() except Exception as e: print("Warning: failed to get Cloudflare TURN credentials, falling back to STUN-only. Error:", e) # Fallback: return minimal STUN servers so WebRTC can still attempt peer connections locally return {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]} # Gemini: google-genai from google import genai # --------------------------------------------------- # VAD imports from reference code import collections import webrtcvad import time # Weaviate imports import weaviate from weaviate.classes.init import Auth from contextlib import contextmanager # helper functions GEMINI_API_KEY="AIzaSyATK7Q1xqWLa7nw1Y40mvRrB8motyQl1oo" HF_TOKEN ="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" WEAVIATE_URL="18vysvlxqza0ux821ecbg.c0.us-west3.gcp.weaviate.cloud" WEAVIATE_API_KEY="b2d4dC9sV1Y0dkZjSnlkRV9EMU04V0FyRE9HSlBPQnhlbENzQ0dWQm9pbENyRUVuWXpWc3R3YmpjK1pBPV92MjAw" DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" from openai import OpenAI openai = OpenAI( api_key=DEEPINFRA_API_KEY, base_url="https://api.deepinfra.com/v1/openai", ) @contextmanager def weaviate_client(): """ Context manager that yields a Weaviate client and guarantees client.close() on exit. """ client = weaviate.connect_to_weaviate_cloud( cluster_url=WEAVIATE_URL, auth_credentials=Auth.api_key(WEAVIATE_API_KEY), ) try: yield client finally: client.close() def encode_audio(data: np.ndarray) -> dict: """Encode Audio data to send to the server""" return { "mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8"), } def encode_audio2(data: np.ndarray) -> bytes: """Encode Audio data to send to the server""" return data.tobytes() import soundfile as sf def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): buffer = io.BytesIO() sf.write(buffer, audio_array, sample_rate, format='WAV') return buffer.getvalue() def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): """ Convert a NumPy audio array to WAV bytes. Args: audio_array (np.ndarray): Audio signal (1D or 2D). sample_rate (int): Sample rate in Hz. Returns: bytes: WAV-formatted audio data. """ buffer = io.BytesIO() sf.write(buffer, audio_array, sample_rate, format='WAV') buffer.seek(0) return buffer.read() # webrtc handler class class GeminiHandler(AsyncStreamHandler): """Handler for the Gemini API with chained latency calculation.""" def __init__( self, expected_layout: Literal["mono"] = "mono", output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"}, ) -> None: super().__init__( expected_layout, output_sample_rate, input_sample_rate=16000, ) self.input_queue: asyncio.Queue = asyncio.Queue() self.output_queue: asyncio.Queue = asyncio.Queue() self.quit: asyncio.Event = asyncio.Event() self.prompt_dict = prompt_dict # self.model = "gemini-2.5-flash-preview-tts" self.model = "gemini-2.0-flash-live-001" self.t2t_model = "gemini-2.5-flash-lite" self.s2t_model = "gemini-2.5-flash-lite" # --- VAD Initialization --- self.vad = webrtcvad.Vad(3) self.VAD_RATE = 16000 self.VAD_FRAME_MS = 20 self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0)) self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2 padding_ms = 300 self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames) self.vad_ratio = 0.9 self.vad_triggered = False self.wav_data = bytearray() self.internal_buffer = bytearray() self.end_of_speech_time: float | None = None self.first_latency_calculated: bool = False def copy(self) -> "GeminiHandler": return GeminiHandler( expected_layout="mono", output_sample_rate=self.output_sample_rate, prompt_dict=self.prompt_dict, ) def s2t(self, audio) -> str: response = self.s2t_client.models.generate_content( model=self.s2t_model, contents=[ types.Part.from_bytes(data=audio, mime_type='audio/wav'), 'Generate a transcript of the speech.' ] ) return response.text def embed_texts(self, texts: list[str], batch_size: int = 50) -> list[list[float]]: """Embed a list of texts using the configured OpenAI/DeepInfra client. Returns a list of embedding vectors (or empty lists on failure for each item). """ all_embeddings: list[list[float]] = [] for i in range(0, len(texts), batch_size): batch = texts[i : i + batch_size] try: resp = openai.embeddings.create( model="Qwen/Qwen3-Embedding-8B", input=batch, encoding_format="float" ) batch_embs = [item.embedding for item in resp.data] all_embeddings.extend(batch_embs) except Exception as e: print(f"Embedding batch error (items {i}–{i+len(batch)-1}): {e}") all_embeddings.extend([[] for _ in batch]) return all_embeddings def s2t_and_embed(self, audio) -> list[float]: """Convert speech to text, then embed the transcript.""" transcript = self.s2t(audio) # Step 1: Speech → Text if not transcript: return [] embeddings = self.embed_texts([transcript]) # Step 2: Text → Embedding return embeddings[0] if embeddings else [] def encode_query(self, query: str) -> list[float] | None: """Generate a single embedding vector for a query string.""" embs = self.embed_texts([query], batch_size=1) if embs and embs[0]: print("Query embedding (first 5 dims):", embs[0][:5]) return embs[0] print("Failed to generate query embedding.") return None def rag_autism(self, query: str, top_k: int = 3) -> dict: """ Run a RAG retrieval on the 'UserDocument' collection in Weaviate using v4 syntax. Returns up to `top_k` matching text chunks as {'answer': [texts...]} """ qe = self.encode_query(query) if not qe: return {"answer": []} try: with weaviate_client() as client: books_collection = client.collections.get("UserDocument") response = books_collection.query.near_vector( near_vector=qe, limit=top_k, return_properties=["text"] ) # Extract the text property from each object hits = [obj.properties.get("text") for obj in response.objects if "text" in obj.properties] # --- FIX: REMOVE REPEATED CONTEXT --- # Convert to a dictionary's keys to get unique items, then back to a list unique_hits = list(dict.fromkeys(hits)) if not unique_hits: return {"answer": []} return {"answer": unique_hits} except Exception as e: print("RAG Error:", e) return {"answer": []} def t2t(self, text: str) -> str: """ Sends text to the pre-initialized chat model and returns the text response. """ try: # Ensure the chat session exists before using it. if not hasattr(self, 'chat'): print("Error: Chat session (self.chat) is not initialized.") return "I'm sorry, my chat function is not ready." # Use the existing chat session to send the message. print("--> Attempting to send prompt to t2t model...") response = self.chat.send_message(text) print("--> Successfully received response from t2t model.") return response.text except Exception as e: print(f"t2t error: {e}") return "" async def start_up(self): # Flag for if we are using text-to-text in the middle of the chain or not. self.t2t_bool = False self.sys_prompt = None self.t2t_client = genai.Client(api_key=GEMINI_API_KEY) self.s2t_client = genai.Client(api_key=GEMINI_API_KEY) if self.sys_prompt is not None: chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt) else: chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.") self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config) self.t2s_client = genai.Client(api_key=GEMINI_API_KEY) voice_name = "Puck" if self.t2t_bool: sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). Always be clear, non-judgmental, and supportive.""" else: sys_instruction = self.sys_prompt if sys_instruction is not None: config = LiveConnectConfig( response_modalities=["AUDIO"], speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) ) ), system_instruction=Content(parts=[Part.from_text(text=sys_instruction)]) ) else: config = LiveConnectConfig( response_modalities=["AUDIO"], speech_config=SpeechConfig( voice_config=VoiceConfig( prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) ) ), ) async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session: async for text_from_user in self.stream(): print("--------------------------------------------") print(f"Received text from user and reading aloud: {text_from_user}") print("--------------------------------------------") if not text_from_user or not text_from_user.strip(): continue # 1) Run RAG retrieval on the user input to get contextual snippets try: rag_res = self.rag_autism(text_from_user, top_k=3) context_snippets = rag_res.get("answer", []) if isinstance(rag_res, dict) else [] # --- ADDED THIS BLOCK TO PRINT THE RAG CONTEXT --- if context_snippets: print("\n--- RAG CONTEXT RETRIEVED ---") for i, snippet in enumerate(context_snippets): print(f"Snippet {i+1}: {snippet}...") print("-----------------------------\n") # except Exception as e: print("Error running RAG:", e) context_snippets = [] # 2) Build the prompt for t2t model including retrieved context combined_context = "\n\n".join(context_snippets) if context_snippets else "" if combined_context: prompt =( "Please answer the user's question based on the following context. " "Be helpful and concise.\n\n" f"--- CONTEXT ---\n{combined_context}\n\n" f"--- USER QUESTION ---\n{text_from_user}" ) else: prompt = ( "Answer the user's question from your own knowledge as a helpful assistant " "specializing in Autism Spectrum Disorder.\n\n" f"--- USER QUESTION ---\n{text_from_user}" ) print(prompt) # 3) Send prompt to chat (t2t) to obtain reply text try: reply_text = self.t2t(prompt) print("\n--- FINAL AI RESPONSE ---") print(reply_text) print("-----------------------------") except Exception as e: print("t2t generation error:", e) reply_text = "" if not reply_text: print("No t2t reply generated, skipping t2s send.") continue # 4) Send the reply_text to the live TTS session to speak it try: text_to_speak = f"Read the following text aloud exactly as it is, without adding or changing anything: '{reply_text}'" print(f">>> MODIFIED TEXT SENT TO T2S API: '{text_to_speak}' <<<") await session.send_client_content( turns=types.Content(role='user', parts=[types.Part(text=text_to_speak)]) ) async for resp_chunk in session.receive(): if getattr(resp_chunk, "data", None): array = np.frombuffer(resp_chunk.data, dtype=np.int16) self.output_queue.put_nowait((self.output_sample_rate, array)) except Exception as e: print("Error sending to live TTS session:", e) async def stream(self) -> AsyncGenerator[bytes, None]: while not self.quit.is_set(): try: # Get the text message to be converted to speech text_to_speak = await self.input_queue.get() yield text_to_speak except (asyncio.TimeoutError, TimeoutError): pass async def receive(self, frame: tuple[int, np.ndarray]) -> None: sr, array = frame audio_bytes = array.tobytes() self.internal_buffer.extend(audio_bytes) while len(self.internal_buffer) >= self.VAD_FRAME_BYTES: vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES] self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:] is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE) if not self.vad_triggered: self.vad_ring_buffer.append((vad_frame, is_speech)) num_voiced = len([f for f, speech in self.vad_ring_buffer if speech]) if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen: print("Speech detected, starting to record...") self.vad_triggered = True for f, s in self.vad_ring_buffer: self.wav_data.extend(f) self.vad_ring_buffer.clear() else: self.wav_data.extend(vad_frame) self.vad_ring_buffer.append((vad_frame, is_speech)) num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech]) if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen: print("End of speech detected.") self.end_of_speech_time = time.monotonic() self.vad_triggered = False full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16) audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr) text_input = self.s2t(audio_input_wav) # --- ADDED THIS BLOCK TO PRINT THE S2T TRANSCRIPT --- print("\n--- FULL S2T TRANSCRIPT ---") print(f"'{text_input}'") print("---------------------------\n") # ---------------------------------------------------- if text_input and text_input.strip(): if self.t2t_bool: text_message = self.t2t(text_input) else: text_message = text_input self.input_queue.put_nowait(text_message) else: print("STT returned empty transcript, skipping.") self.vad_ring_buffer.clear() self.wav_data = bytearray() async def emit(self) -> tuple[int, np.ndarray] | None: return await wait_for_item(self.output_queue) def shutdown(self) -> None: self.quit.set() with gr.Blocks() as demo: gr.Markdown("# Gemini Chained Speech-to-Speech Demo") with gr.Row() as row2: with gr.Column(): webrtc2 = WebRTC( label="Audio Chat", modality="audio", mode="send-receive", elem_id="audio-source", rtc_configuration=safe_get_ice_config_async, icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", pulse_color="rgb(255, 255, 255)", icon_button_color="rgb(255, 255, 255)", ) webrtc2.stream( GeminiHandler(), inputs=[webrtc2], outputs=[webrtc2], time_limit=180 if get_space() else None, concurrency_limit=2 if get_space() else None, ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT",7860)), debug=True, )