Spaces:
Build error
Build error
| import asyncio | |
| import logging | |
| import os | |
| import sys | |
| import traceback | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from google import genai | |
| from google.api_core import exceptions as google_exceptions | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| if sys.version_info < (3, 11, 0): | |
| import taskgroup, exceptiongroup | |
| asyncio.TaskGroup = taskgroup.TaskGroup | |
| asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup | |
| # Audio settings | |
| # FORMAT = pyaudio.paInt16 # Removed pyaudio dependency | |
| # CHANNELS = 1 | |
| # SEND_SAMPLE_RATE = 16000 | |
| # RECEIVE_SAMPLE_RATE = 24000 | |
| # CHUNK_SIZE = 1024 | |
| # Load configuration from environment variables | |
| GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
| MODEL = os.environ.get("GEMINI_MODEL", "models/gemini-2.0-flash-live-001") | |
| # Configure the client with the API key | |
| try: | |
| if not GOOGLE_API_KEY or GOOGLE_API_KEY == "YOUR_API_KEY_HERE": | |
| # In HF Spaces, we might set this via secrets, so we warn but don't exit immediately if it's missing during build | |
| logger.warning("GOOGLE_API_KEY environment variable not set or is a placeholder.") | |
| client = genai.Client(api_key=GOOGLE_API_KEY) | |
| except (KeyError, ValueError) as e: | |
| logger.critical(f"Error: {e}. Please set the GOOGLE_API_KEY environment variable.") | |
| # sys.exit(1) # Don't exit, let it fail at runtime if key is missing, to allow build to pass | |
| CONFIG = { | |
| "response_modalities": ["AUDIO"], | |
| "output_audio_transcription": {}, | |
| "generation_config": { | |
| "temperature": 1.0, | |
| }, | |
| } | |
| # pya = pyaudio.PyAudio() # Removed pyaudio dependency | |
| app = FastAPI() | |
| # Mount static files | |
| # We assume the frontend build will be copied to 'static' directory in the container | |
| if os.path.exists("static"): | |
| app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") | |
| async def get(): | |
| # Serve the index.html from the static directory | |
| if os.path.exists("static/index.html"): | |
| return FileResponse("static/index.html") | |
| return HTMLResponse("<h1>Frontend not found. Please build the frontend.</h1>") | |
| class AudioLoop: | |
| def __init__(self, websocket: WebSocket): | |
| self.websocket = websocket | |
| self.session = None | |
| async def run(self): | |
| try: | |
| async with client.aio.live.connect(model=MODEL, config=CONFIG) as session: | |
| self.session = session | |
| logger.info("Gemini Live API session started.") | |
| async with asyncio.TaskGroup() as tg: | |
| tg.create_task(self.receive_from_gemini()) | |
| tg.create_task(self.send_to_gemini()) | |
| except asyncio.CancelledError: | |
| logger.info("Audio loop cancelled.") | |
| except google_exceptions.GoogleAPICallError as e: | |
| logger.error(f"Google API call error in audio loop: {e}") | |
| await self.websocket.close(code=1011, reason=f"Google API Error: {e}") | |
| except Exception as e: | |
| logger.error(f"An error occurred in the audio loop: {e}") | |
| traceback.print_exc() | |
| await self.websocket.close(code=1011, reason="Internal Server Error") | |
| async def send_to_gemini(self): | |
| """Receives audio from the WebSocket and sends it to the Gemini API.""" | |
| while True: | |
| try: | |
| data = await self.websocket.receive_bytes() | |
| if self.session: | |
| await self.session.send( | |
| input={"data": data, "mime_type": "audio/pcm"} | |
| ) | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected from WebSocket.") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error receiving from websocket or sending to Gemini: {e}") | |
| break | |
| async def receive_from_gemini(self): | |
| """Receives audio and text from the Gemini API and forwards it to the WebSocket.""" | |
| while True: | |
| try: | |
| if self.session: | |
| turn = self.session.receive() | |
| async for response in turn: | |
| # Handle audio data directly from response.data | |
| if data := response.data: | |
| await self.websocket.send_bytes(data) | |
| continue | |
| # Handle text/transcript and potentially nested audio data | |
| candidate_texts = [] | |
| server_content = ( | |
| response.server_content.model_turn.parts | |
| if response.server_content | |
| and response.server_content.model_turn | |
| and response.server_content.model_turn.parts | |
| else [] | |
| ) | |
| for part in server_content: | |
| # Check for nested audio data | |
| if inline_data := getattr(part, "inline_data", None): | |
| if data := getattr(inline_data, "data", None): | |
| await self.websocket.send_bytes(data) | |
| # Check for text | |
| if part_text := getattr(part, "text", None): | |
| candidate_texts.append(part_text) | |
| server_content_obj = getattr(response, "server_content", None) | |
| if server_content_obj: | |
| if output_transcription := getattr(server_content_obj, "output_transcription", None): | |
| if trans_text := getattr(output_transcription, "text", None): | |
| candidate_texts.append(trans_text) | |
| if input_transcription := getattr(server_content_obj, "input_transcription", None): | |
| if trans_text := getattr(input_transcription, "text", None): | |
| candidate_texts.append(trans_text) | |
| if output_text := getattr(response, "output_text", None): | |
| if isinstance(output_text, (list, tuple)): | |
| candidate_texts.extend(output_text) | |
| else: | |
| candidate_texts.append(output_text) | |
| if response_text := getattr(response, "text", None): | |
| candidate_texts.append(response_text) | |
| for text_chunk in candidate_texts: | |
| if not text_chunk: | |
| continue | |
| normalized = text_chunk.replace("\r", "").replace("\n", " ") | |
| if normalized and normalized.strip(): | |
| logger.info(f"Received text: {normalized.strip()}") | |
| await self.websocket.send_text(normalized) | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected. Stopping receive loop.") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error receiving from Gemini or sending to websocket: {e}") | |
| break | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| logger.info("WebSocket connection accepted.") | |
| audio_loop = AudioLoop(websocket) | |
| try: | |
| await audio_loop.run() | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected.") | |
| except Exception as e: | |
| logger.error(f"Error in websocket endpoint: {e}") | |
| finally: | |
| logger.info("WebSocket connection closed.") | |