Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| import json | |
| import os | |
| from typing import Literal | |
| import gradio as gr | |
| import numpy as np | |
| from fastrtc import AsyncStreamHandler, WebRTC, wait_for_item | |
| from google import genai | |
| from google.cloud import texttospeech | |
| from google.genai.types import FunctionDeclaration, LiveConnectConfig, Tool | |
| import helpers.datastore as datastore | |
| from helpers.prompts import load_prompt | |
| from tools import FUNCTION_MAP, TOOLS | |
| with open("questions.json", "r") as f: | |
| questions_dict = json.load(f) | |
| datastore.DATA_STORE["questions"] = questions_dict | |
| SYSTEM_PROMPT = load_prompt( | |
| "src/prompts/default_prompt.jinja2", questions=questions_dict | |
| ) | |
| class TTSConfig: | |
| def __init__(self): | |
| self.client = texttospeech.TextToSpeechClient() | |
| self.voice = texttospeech.VoiceSelectionParams( | |
| name="en-US-Chirp3-HD-Charon", language_code="en-US" | |
| ) | |
| self.audio_config = texttospeech.AudioConfig( | |
| audio_encoding=texttospeech.AudioEncoding.LINEAR16 | |
| ) | |
| class AsyncGeminiHandler(AsyncStreamHandler): | |
| """Simple Async Gemini Handler""" | |
| def __init__( | |
| self, | |
| expected_layout: Literal["mono"] = "mono", | |
| output_sample_rate: int = 24000, | |
| output_frame_size: int = 480, | |
| ) -> None: | |
| super().__init__( | |
| expected_layout, | |
| output_sample_rate, | |
| output_frame_size, | |
| input_sample_rate=16000, | |
| ) | |
| self.input_queue: asyncio.Queue = asyncio.Queue() | |
| self.output_queue: asyncio.Queue = asyncio.Queue() | |
| self.text_queue: asyncio.Queue = asyncio.Queue() | |
| self.quit: asyncio.Event = asyncio.Event() | |
| self.chunk_size = 1024 | |
| self.tts_config: TTSConfig | None = TTSConfig() | |
| self.text_buffer = "" | |
| def copy(self) -> "AsyncGeminiHandler": | |
| return AsyncGeminiHandler( | |
| expected_layout="mono", | |
| output_sample_rate=self.output_sample_rate, | |
| output_frame_size=self.output_frame_size, | |
| ) | |
| def _encode_audio(self, data: np.ndarray) -> str: | |
| """Encode Audio data to send to the server""" | |
| return base64.b64encode(data.tobytes()).decode("UTF-8") | |
| async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| """Receives and processes audio frames asynchronously.""" | |
| _, array = frame | |
| array = array.squeeze() | |
| audio_message = self._encode_audio(array) | |
| self.input_queue.put_nowait(audio_message) | |
| async def emit(self) -> tuple[int, np.ndarray] | None: | |
| """Asynchronously emits items from the output queue.""" | |
| return await wait_for_item(self.output_queue) | |
| async def start_up(self) -> None: | |
| """Initialize and start the voice agent application. | |
| This asynchronous method sets up the Gemini API client, configures the live connection, | |
| and starts three concurrent tasks for receiving, processing and sending information. | |
| Returns: | |
| None | |
| Raises: | |
| ValueError: If GEMINI_API_KEY is not provided when required. | |
| """ | |
| if not os.getenv("GOOGLE_GENAI_USE_VERTEXAI") == "True": | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("API Key is required") | |
| client = genai.Client( | |
| api_key=api_key, | |
| http_options={"api_version": "v1alpha"}, | |
| ) | |
| else: | |
| client = genai.Client(http_options={"api_version": "v1beta1"}) | |
| config = LiveConnectConfig( | |
| system_instruction={ | |
| "parts": [{"text": SYSTEM_PROMPT}], | |
| "role": "user", | |
| }, | |
| tools=[ | |
| Tool( | |
| function_declarations=[ | |
| FunctionDeclaration(**tool) for tool in TOOLS | |
| ] | |
| ) | |
| ], | |
| response_modalities=["AUDIO"], | |
| ) | |
| async with ( | |
| client.aio.live.connect( | |
| model="gemini-2.0-flash-exp", config=config | |
| ) as session, # setup the live connection session (websocket) | |
| asyncio.TaskGroup() as tg, # create a task group to run multiple tasks concurrently | |
| ): | |
| self.session = session | |
| # these tasks will run concurrently and continuously | |
| [ | |
| tg.create_task(self.process()), | |
| tg.create_task(self.send_realtime()), | |
| tg.create_task(self.tts()), | |
| ] | |
| async def process(self) -> None: | |
| """Process responses from the session in a continuous loop. | |
| This asynchronous method handles different types of responses from the session: | |
| - Audio data: Processes and queues audio data with the specified sample rate | |
| - Text data: Accumulates received text in a buffer | |
| - Tool calls: Executes registered functions and sends their responses back | |
| - Server content: Handles turn completion and stores conversation history | |
| The method runs indefinitely until interrupted, handling any exceptions that occur | |
| during processing by logging them and continuing after a brief delay. | |
| Returns: | |
| None | |
| Raises: | |
| Exception: Any exceptions during processing are caught and logged | |
| """ | |
| while True: | |
| try: | |
| turn = self.session.receive() | |
| async for response in turn: | |
| if data := response.data: | |
| # audio data | |
| array = np.frombuffer(data, dtype=np.int16) | |
| self.output_queue.put_nowait((self.output_sample_rate, array)) | |
| continue | |
| if text := response.text: | |
| # text data | |
| print(f"Received text: {text}") | |
| self.text_buffer += text | |
| if response.tool_call is not None: | |
| # function calling | |
| for tool in response.tool_call.function_calls: | |
| try: | |
| tool_response = FUNCTION_MAP[tool.name](**tool.args) | |
| print(f"Calling tool: {tool.name}") | |
| print(f"Tool response: {tool_response}") | |
| await self.session.send( | |
| input=tool_response, end_of_turn=True | |
| ) | |
| await asyncio.sleep(0.1) | |
| except Exception as e: | |
| print(f"Error in tool call: {e}") | |
| await asyncio.sleep(0.1) | |
| if sc := response.server_content: | |
| # check if bot's turn is complete | |
| if sc.turn_complete and self.text_buffer: | |
| self.text_queue.put_nowait(self.text_buffer) | |
| FUNCTION_MAP["store_input"]( | |
| role="bot", input=self.text_buffer | |
| ) | |
| self.text_buffer = "" | |
| except Exception as e: | |
| print(f"Error in processing: {e}") | |
| await asyncio.sleep(0.1) | |
| async def send_realtime(self) -> None: | |
| """Send real-time audio data to model. | |
| This method continuously reads audio data from an input queue and sends it to a model | |
| session in real-time. It runs in an infinite loop until interrupted. | |
| The audio data is sent with mime type 'audio/pcm'. If an error occurs during sending, | |
| it will be printed and the method will sleep briefly before retrying. | |
| Returns: | |
| None | |
| Raises: | |
| Exception: Any exceptions during queue access or session sending will be caught and logged. | |
| """ | |
| while True: | |
| try: | |
| data = await self.input_queue.get() | |
| msg = {"data": data, "mime_type": "audio/pcm"} | |
| await self.session.send(input=msg) | |
| except Exception as e: | |
| print(f"Error in real-time sending: {e}") | |
| await asyncio.sleep(0.1) | |
| async def tts(self) -> None: | |
| while True: | |
| try: | |
| text = await self.text_queue.get() | |
| # Get response in a single request | |
| if text: | |
| response = self.tts_config.client.synthesize_speech( | |
| input=texttospeech.SynthesisInput(text=text), | |
| voice=self.tts_config.voice, | |
| audio_config=self.tts_config.audio_config, | |
| ) | |
| array = np.frombuffer(response.audio_content, dtype=np.int16) | |
| self.output_queue.put_nowait((self.output_sample_rate, array)) | |
| except Exception as e: | |
| print(f"Error in TTS: {e}") | |
| await asyncio.sleep(0.1) | |
| def shutdown(self) -> None: | |
| self.quit.set() | |
| # Main Gradio Interface | |
| def registry(*args, **kwargs): | |
| """Sets up and returns the Gradio interface.""" | |
| interface = gr.Blocks() | |
| with interface: | |
| with gr.Tabs(): | |
| with gr.TabItem("Voice Chat"): | |
| gr.HTML( | |
| """ | |
| <div style='text-align: left'> | |
| <h1>ML6 Voice Demo</h1> | |
| </div> | |
| """ | |
| ) | |
| gemini_handler = AsyncGeminiHandler() | |
| with gr.Row(): | |
| audio = WebRTC( | |
| label="Voice Chat", | |
| modality="audio", | |
| mode="send-receive", | |
| ) | |
| # Add display components for questions and answers | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.JSON( | |
| label="Questions", | |
| value=datastore.DATA_STORE["questions"], | |
| ) | |
| with gr.Column(): | |
| gr.JSON( | |
| label="Answers", | |
| value=lambda: datastore.DATA_STORE["answers"], | |
| every=1, | |
| ) | |
| audio.stream( | |
| gemini_handler, | |
| inputs=[audio], | |
| outputs=[audio], | |
| time_limit=600, | |
| concurrency_limit=10, | |
| ) | |
| return interface | |
| # Launch the Gradio interface | |
| gr.load( | |
| name="demo", | |
| src=registry, | |
| ).launch() | |