Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| import os | |
| import time | |
| from io import BytesIO | |
| from google.genai import types | |
| from google.genai.types import ( | |
| LiveConnectConfig, | |
| SpeechConfig, | |
| VoiceConfig, | |
| PrebuiltVoiceConfig, | |
| Content, | |
| Part, | |
| ) | |
| import gradio as gr | |
| import numpy as np | |
| import websockets | |
| from dotenv import load_dotenv | |
| from fastrtc import ( | |
| AsyncAudioVideoStreamHandler, | |
| Stream, | |
| WebRTC, | |
| get_cloudflare_turn_credentials_async, | |
| wait_for_item, | |
| ) | |
| from google import genai | |
| from gradio.utils import get_space | |
| from PIL import Image | |
| # ------------------------------------------ | |
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| import pathlib | |
| from typing import AsyncGenerator, Literal | |
| import gradio as gr | |
| import numpy as np | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI | |
| from fastapi.responses import HTMLResponse | |
| from fastrtc import ( | |
| AsyncStreamHandler, | |
| Stream, | |
| get_cloudflare_turn_credentials_async, | |
| wait_for_item, | |
| ) | |
| from google import genai | |
| from google.genai.types import ( | |
| LiveConnectConfig, | |
| PrebuiltVoiceConfig, | |
| SpeechConfig, | |
| VoiceConfig, | |
| ) | |
| from gradio.utils import get_space | |
| from pydantic import BaseModel | |
| # ------------------------------------------------ | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| import io | |
| import asyncio | |
| from pydub import AudioSegment | |
| # Gemini: google-genai | |
| from google import genai | |
| # --------------------------------------------------- | |
| # VAD imports from reference code | |
| import collections | |
| import webrtcvad | |
| import time | |
| # helper functions | |
| GEMINI_API_KEY="AIzaSyCUCivstFpC9pq_jMHMYdlPrmh9Bx97dFo" | |
| TAVILY_API_KEY="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" | |
| OPENAI_API_KEY="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" | |
| QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzUxMDUxNzg4fQ.I9J-K7OM0BtcNKgj2d4uVM8QYAHYfFCVAyP4rlZkK2E" | |
| QDRANT_URL="https://6a3aade6-e8ad-4a6c-a579-21f5af90b7e8.us-east4-0.gcp.cloud.qdrant.io" | |
| OPENAI_API_KEY="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" | |
| WEAVIATE_URL="yorcqe2sqswhcaivxvt9a.c0.us-west3.gcp.weaviate.cloud" | |
| WEAVIATE_API_KEY="d2d0VGdZQTBmdTFlOWdDZl9tT2h3WDVWd1NpT1dQWHdGK0xjR1hYeWxicUxHVnFRazRUSjY2VlRUVlkwPV92MjAw" | |
| DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4" | |
| DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai" | |
| def encode_audio(data: np.ndarray) -> dict: | |
| """Encode Audio data to send to the server""" | |
| return { | |
| "mime_type": "audio/pcm", | |
| "data": base64.b64encode(data.tobytes()).decode("UTF-8"), | |
| } | |
| def encode_audio2(data: np.ndarray) -> bytes: | |
| """Encode Audio data to send to the server""" | |
| return data.tobytes() | |
| import soundfile as sf | |
| def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| return buffer.getvalue() | |
| def numpy_array_to_wav_bytes(audio_array, sample_rate=16000): | |
| """ | |
| Convert a NumPy audio array to WAV bytes. | |
| Args: | |
| audio_array (np.ndarray): Audio signal (1D or 2D). | |
| sample_rate (int): Sample rate in Hz. | |
| Returns: | |
| bytes: WAV-formatted audio data. | |
| """ | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_array, sample_rate, format='WAV') | |
| buffer.seek(0) | |
| return buffer.read() | |
| # webrtc handler class | |
| class GeminiHandler(AsyncStreamHandler): | |
| """Handler for the Gemini API with chained latency calculation.""" | |
| def __init__( | |
| self, | |
| expected_layout: Literal["mono"] = "mono", | |
| output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"}, | |
| ) -> None: | |
| super().__init__( | |
| expected_layout, | |
| output_sample_rate, | |
| input_sample_rate=16000, | |
| ) | |
| self.input_queue: asyncio.Queue = asyncio.Queue() | |
| self.output_queue: asyncio.Queue = asyncio.Queue() | |
| self.quit: asyncio.Event = asyncio.Event() | |
| self.prompt_dict = prompt_dict | |
| # self.model = "gemini-2.5-flash-preview-tts" | |
| self.model = "gemini-2.0-flash-live-001" | |
| self.t2t_model = "gemini-2.0-flash" | |
| self.s2t_model = "gemini-2.0-flash" | |
| # --- VAD Initialization --- | |
| self.vad = webrtcvad.Vad(3) | |
| self.VAD_RATE = 16000 | |
| self.VAD_FRAME_MS = 20 | |
| self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0)) | |
| self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2 | |
| padding_ms = 300 | |
| self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS | |
| self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames) | |
| self.vad_ratio = 0.9 | |
| self.vad_triggered = False | |
| self.wav_data = bytearray() | |
| self.internal_buffer = bytearray() | |
| self.end_of_speech_time: float | None = None | |
| self.first_latency_calculated: bool = False | |
| def copy(self) -> "GeminiHandler": | |
| return GeminiHandler( | |
| expected_layout="mono", | |
| output_sample_rate=self.output_sample_rate, | |
| prompt_dict=self.prompt_dict, | |
| ) | |
| def t2t(self, text: str) -> str: | |
| print(f"Sending text to Gemini: {text}") | |
| response = self.chat.send_message(text) | |
| print(f"Received response from Gemini: {response.text}") | |
| return response.text | |
| def s2t(self, audio) -> str: | |
| response = self.s2t_client.models.generate_content( | |
| model=self.s2t_model, | |
| contents=[ | |
| types.Part.from_bytes(data=audio, mime_type='audio/wav'), | |
| 'Generate a transcript of the speech.' | |
| ] | |
| ) | |
| return response.text | |
| async def start_up(self): | |
| # Flag for if we are using text-to-text in the middle of the chain or not. | |
| self.t2t_bool = False | |
| self.sys_prompt = None | |
| self.t2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY")) | |
| self.s2t_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))#, http_options={"api_version": "v1alpha"}) | |
| if self.sys_prompt is not None: | |
| chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt) | |
| else: | |
| chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.") | |
| self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config) | |
| self.t2s_client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) | |
| voice_name = "Puck" | |
| if self.t2t_bool: | |
| sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . | |
| Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). | |
| Always be clear, non-judgmental, and supportive.""" | |
| else: | |
| sys_instruction = self.sys_prompt | |
| if sys_instruction is not None: | |
| config = LiveConnectConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=SpeechConfig( | |
| voice_config=VoiceConfig( | |
| prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
| ) | |
| ), | |
| system_instruction=Content(parts=[Part.from_text(text=sys_instruction)]) | |
| ) | |
| else: | |
| config = LiveConnectConfig( | |
| response_modalities=["AUDIO"], | |
| speech_config=SpeechConfig( | |
| voice_config=VoiceConfig( | |
| prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name) | |
| ) | |
| ), | |
| ) | |
| async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session: | |
| async for text_from_user in self.stream(): | |
| print("--------------------------------------------") | |
| print(f"Received text from user and reading aloud: {text_from_user}") | |
| print("--------------------------------------------") | |
| if text_from_user and text_from_user.strip(): | |
| if self.t2t_bool: | |
| prompt = f""" | |
| You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism . | |
| Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD). | |
| Always be clear, non-judgmental, and supportive. | |
| {text_from_user} | |
| """ | |
| else: | |
| prompt = text_from_user | |
| await session.send_client_content( | |
| turns=types.Content( | |
| role='user', parts=[types.Part(text=prompt)])) | |
| async for resp_chunk in session.receive(): | |
| if resp_chunk.data: | |
| array = np.frombuffer(resp_chunk.data, dtype=np.int16) | |
| self.output_queue.put_nowait((self.output_sample_rate, array)) | |
| async def stream(self) -> AsyncGenerator[bytes, None]: | |
| while not self.quit.is_set(): | |
| try: | |
| # Get the text message to be converted to speech | |
| text_to_speak = await self.input_queue.get() | |
| yield text_to_speak | |
| except (asyncio.TimeoutError, TimeoutError): | |
| pass | |
| async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| sr, array = frame | |
| audio_bytes = array.tobytes() | |
| self.internal_buffer.extend(audio_bytes) | |
| while len(self.internal_buffer) >= self.VAD_FRAME_BYTES: | |
| vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES] | |
| self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:] | |
| is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE) | |
| if not self.vad_triggered: | |
| self.vad_ring_buffer.append((vad_frame, is_speech)) | |
| num_voiced = len([f for f, speech in self.vad_ring_buffer if speech]) | |
| if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
| print("Speech detected, starting to record...") | |
| self.vad_triggered = True | |
| for f, s in self.vad_ring_buffer: | |
| self.wav_data.extend(f) | |
| self.vad_ring_buffer.clear() | |
| else: | |
| self.wav_data.extend(vad_frame) | |
| self.vad_ring_buffer.append((vad_frame, is_speech)) | |
| num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech]) | |
| if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen: | |
| print("End of speech detected.") | |
| self.end_of_speech_time = time.monotonic() | |
| self.vad_triggered = False | |
| full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16) | |
| audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr) | |
| text_input = self.s2t(audio_input_wav) | |
| if text_input and text_input.strip(): | |
| if self.t2t_bool: | |
| text_message = self.t2t(text_input) | |
| else: | |
| text_message = text_input | |
| self.input_queue.put_nowait(text_message) | |
| else: | |
| print("STT returned empty transcript, skipping.") | |
| self.vad_ring_buffer.clear() | |
| self.wav_data = bytearray() | |
| async def emit(self) -> tuple[int, np.ndarray] | None: | |
| return await wait_for_item(self.output_queue) | |
| def shutdown(self) -> None: | |
| self.quit.set() | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Gemini Chained Speech-to-Speech Demo") | |
| # for audio modality | |
| # with gr.Row(visible=(modality_selector.value == "audio")) as row2: | |
| with gr.Row() as row2: | |
| with gr.Column(): # Optional, can be removed if not needed | |
| webrtc2 = WebRTC( | |
| label="Audio Chat", | |
| modality="audio", | |
| mode="send-receive", | |
| elem_id="audio-source", | |
| rtc_configuration=get_cloudflare_turn_credentials_async, | |
| icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", | |
| pulse_color="rgb(255, 255, 255)", | |
| icon_button_color="rgb(255, 255, 255)", | |
| ) | |
| # Corrected inputs and outputs for webrtc2.stream to use webrtc2 | |
| webrtc2.stream( | |
| GeminiHandler(), | |
| inputs=[webrtc2], # Was webrtc | |
| outputs=[webrtc2],# Was webrtc | |
| time_limit=180 if get_space() else None, | |
| concurrency_limit=2 if get_space() else None, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7860) |