vbuenosa-nttd's picture
Deploying TechHub Prototype
36ebfb6 verified
import asyncio
import logging
import os
import sys
import traceback
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from google import genai
from google.api_core import exceptions as google_exceptions
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
load_dotenv()
if sys.version_info < (3, 11, 0):
import taskgroup, exceptiongroup
asyncio.TaskGroup = taskgroup.TaskGroup
asyncio.ExceptionGroup = exceptiongroup.ExceptionGroup
# Audio settings
# FORMAT = pyaudio.paInt16 # Removed pyaudio dependency
# CHANNELS = 1
# SEND_SAMPLE_RATE = 16000
# RECEIVE_SAMPLE_RATE = 24000
# CHUNK_SIZE = 1024
# Load configuration from environment variables
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
MODEL = os.environ.get("GEMINI_MODEL", "models/gemini-2.0-flash-live-001")
# Configure the client with the API key
try:
if not GOOGLE_API_KEY or GOOGLE_API_KEY == "YOUR_API_KEY_HERE":
# In HF Spaces, we might set this via secrets, so we warn but don't exit immediately if it's missing during build
logger.warning("GOOGLE_API_KEY environment variable not set or is a placeholder.")
client = genai.Client(api_key=GOOGLE_API_KEY)
except (KeyError, ValueError) as e:
logger.critical(f"Error: {e}. Please set the GOOGLE_API_KEY environment variable.")
# sys.exit(1) # Don't exit, let it fail at runtime if key is missing, to allow build to pass
CONFIG = {
"response_modalities": ["AUDIO"],
"output_audio_transcription": {},
"generation_config": {
"temperature": 1.0,
},
}
# pya = pyaudio.PyAudio() # Removed pyaudio dependency
app = FastAPI()
# Mount static files
# We assume the frontend build will be copied to 'static' directory in the container
if os.path.exists("static"):
app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
@app.get("/")
async def get():
# Serve the index.html from the static directory
if os.path.exists("static/index.html"):
return FileResponse("static/index.html")
return HTMLResponse("<h1>Frontend not found. Please build the frontend.</h1>")
class AudioLoop:
def __init__(self, websocket: WebSocket):
self.websocket = websocket
self.session = None
async def run(self):
try:
async with client.aio.live.connect(model=MODEL, config=CONFIG) as session:
self.session = session
logger.info("Gemini Live API session started.")
async with asyncio.TaskGroup() as tg:
tg.create_task(self.receive_from_gemini())
tg.create_task(self.send_to_gemini())
except asyncio.CancelledError:
logger.info("Audio loop cancelled.")
except google_exceptions.GoogleAPICallError as e:
logger.error(f"Google API call error in audio loop: {e}")
await self.websocket.close(code=1011, reason=f"Google API Error: {e}")
except Exception as e:
logger.error(f"An error occurred in the audio loop: {e}")
traceback.print_exc()
await self.websocket.close(code=1011, reason="Internal Server Error")
async def send_to_gemini(self):
"""Receives audio from the WebSocket and sends it to the Gemini API."""
while True:
try:
data = await self.websocket.receive_bytes()
if self.session:
await self.session.send(
input={"data": data, "mime_type": "audio/pcm"}
)
except WebSocketDisconnect:
logger.info("Client disconnected from WebSocket.")
break
except Exception as e:
logger.error(f"Error receiving from websocket or sending to Gemini: {e}")
break
async def receive_from_gemini(self):
"""Receives audio and text from the Gemini API and forwards it to the WebSocket."""
while True:
try:
if self.session:
turn = self.session.receive()
async for response in turn:
# Handle audio data directly from response.data
if data := response.data:
await self.websocket.send_bytes(data)
continue
# Handle text/transcript and potentially nested audio data
candidate_texts = []
server_content = (
response.server_content.model_turn.parts
if response.server_content
and response.server_content.model_turn
and response.server_content.model_turn.parts
else []
)
for part in server_content:
# Check for nested audio data
if inline_data := getattr(part, "inline_data", None):
if data := getattr(inline_data, "data", None):
await self.websocket.send_bytes(data)
# Check for text
if part_text := getattr(part, "text", None):
candidate_texts.append(part_text)
server_content_obj = getattr(response, "server_content", None)
if server_content_obj:
if output_transcription := getattr(server_content_obj, "output_transcription", None):
if trans_text := getattr(output_transcription, "text", None):
candidate_texts.append(trans_text)
if input_transcription := getattr(server_content_obj, "input_transcription", None):
if trans_text := getattr(input_transcription, "text", None):
candidate_texts.append(trans_text)
if output_text := getattr(response, "output_text", None):
if isinstance(output_text, (list, tuple)):
candidate_texts.extend(output_text)
else:
candidate_texts.append(output_text)
if response_text := getattr(response, "text", None):
candidate_texts.append(response_text)
for text_chunk in candidate_texts:
if not text_chunk:
continue
normalized = text_chunk.replace("\r", "").replace("\n", " ")
if normalized and normalized.strip():
logger.info(f"Received text: {normalized.strip()}")
await self.websocket.send_text(normalized)
except WebSocketDisconnect:
logger.info("Client disconnected. Stopping receive loop.")
break
except Exception as e:
logger.error(f"Error receiving from Gemini or sending to websocket: {e}")
break
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
logger.info("WebSocket connection accepted.")
audio_loop = AudioLoop(websocket)
try:
await audio_loop.run()
except WebSocketDisconnect:
logger.info("Client disconnected.")
except Exception as e:
logger.error(f"Error in websocket endpoint: {e}")
finally:
logger.info("WebSocket connection closed.")