import json import logging import logging.config import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, Form, UploadFile, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from src.pipeline import VoicePipeline from typing import Literal from src.config import ( DEEPGRAM_API_KEY, CARTESIA_API_KEY, CARTESIA_VOICE_ID, SAMPLE_RATE, GOOGLE_API_KEY, GOOGLE_PROJECT_ID, STT_PROVIDER, TTS_PROVIDER, WAKE_WORD_ENABLED, ) from src.stt.deepgram_rest import transcribe_audio as deepgram_transcribe from src.stt.gemini_stt import transcribe_audio as gemini_stt_transcribe from src.stt.chirp3_client import transcribe_audio as chirp3_transcribe from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize from src.tts.gemini_client import synthesize_stream as gemini_synthesize, GEMINI_SAMPLE_RATE LOG_CONFIG = { "version": 1, "disable_existing_loggers": False, "formatters": { "default": { "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s", }, }, "handlers": { "default": { "class": "logging.StreamHandler", "formatter": "default", }, }, "root": { "level": "INFO", "handlers": ["default"], }, } logging.config.dictConfig(LOG_CONFIG) logger = logging.getLogger(__name__) VERSION = "1.2.0" app = FastAPI(title="Voice Agent Service", version=VERSION) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def health() -> JSONResponse: body: dict = { "message": "Welcome to the Voice Agent Service! Please use the /health endpoint to check service status.", "version": VERSION, } return JSONResponse(status_code=200, content=body) @app.get("/health") async def health() -> JSONResponse: stt_ready = bool(DEEPGRAM_API_KEY) tts_ready = bool(CARTESIA_API_KEY and CARTESIA_VOICE_ID) gemini_ready = bool(GOOGLE_API_KEY) all_ready = stt_ready and tts_ready body: dict = { "status": "ok" if all_ready else "degraded", "version": VERSION, "stt_ready": stt_ready, "tts_ready": tts_ready, "gemini_tts_ready": gemini_ready, "gemini_stt_ready": gemini_ready, } if not all_ready: body["message"] = "One or more required configurations are missing. Check your .env file." return JSONResponse(status_code=200 if all_ready else 503, content=body) class TTSRequest(BaseModel): text: str provider: Literal["cartesia", "gemini"] = "gemini" @app.post("/stt") async def speech_to_text( audio: UploadFile = File(...), provider: str = Form(default="chirp3"), ) -> JSONResponse: data = await audio.read() if not data: raise HTTPException(status_code=400, detail="Audio file is empty.") mimetype = audio.content_type or "audio/wav" if provider == "chirp3": if not GOOGLE_PROJECT_ID: raise HTTPException(status_code=503, detail="Chirp3 STT not configured: missing GOOGLE_PROJECT_ID.") result = await chirp3_transcribe(data, mimetype=mimetype) elif provider == "gemini": if not GOOGLE_API_KEY: raise HTTPException(status_code=503, detail="Gemini STT not configured: missing GOOGLE_API_KEY.") result = await gemini_stt_transcribe(data, mimetype=mimetype) else: result = await deepgram_transcribe(data, mimetype=mimetype) return JSONResponse(content=result) @app.post("/tts") async def text_to_speech(req: TTSRequest) -> StreamingResponse: if not req.text.strip(): raise HTTPException(status_code=400, detail="text must not be empty.") logger.info("TTS request: provider=%s, text_len=%d, text=%r", req.provider, len(req.text), req.text) if req.provider == "gemini": if not GOOGLE_API_KEY: raise HTTPException(status_code=503, detail="Gemini TTS not configured.") stream = gemini_synthesize(req.text) sample_rate = GEMINI_SAMPLE_RATE else: stream = cartesia_synthesize(req.text) sample_rate = SAMPLE_RATE return StreamingResponse( stream, media_type="audio/pcm", headers={ "X-Sample-Rate": str(sample_rate), "X-Encoding": "pcm_s16le", "X-Channels": "1", }, ) @app.websocket("/ws/voice") async def voice_ws( ws: WebSocket, stt_provider: str = Query(default=STT_PROVIDER), tts_provider: str = Query(default=TTS_PROVIDER), wake_word_enabled: bool = Query(default=WAKE_WORD_ENABLED), ) -> None: await ws.accept() logger.info( "Client connected: %s (stt=%s, tts=%s)", ws.client, stt_provider, tts_provider, ) async def send_audio(chunk: bytes) -> None: try: await ws.send_bytes(chunk) except WebSocketDisconnect: pass async def send_event(event: dict) -> None: try: await ws.send_text(json.dumps(event)) except WebSocketDisconnect: pass tts_sample_rate = GEMINI_SAMPLE_RATE if tts_provider == "gemini" else SAMPLE_RATE await send_event({ "event": "tts_config", "tts_provider": tts_provider, "stt_provider": stt_provider, "sample_rate": tts_sample_rate, "encoding": "pcm_s16le", "channels": 1, }) pipeline = VoicePipeline( send_audio=send_audio, send_event=send_event, stt_provider=stt_provider, tts_provider=tts_provider, wake_word_enabled=wake_word_enabled, ) pipeline.start() try: while True: data = await ws.receive() if "bytes" in data and data["bytes"]: pipeline.feed_audio(data["bytes"]) elif "text" in data and data["text"]: try: msg = json.loads(data["text"]) action = msg.get("action") if action == "stop": break elif action == "ping": await ws.send_text(json.dumps({"event": "pong"})) elif action == "interrupt": await pipeline.interrupt() elif action == "speak": text = msg.get("text", "").strip() if text: await pipeline.speak(text) except json.JSONDecodeError: pass except WebSocketDisconnect: logger.info("Client disconnected: %s", ws.client) finally: try: await pipeline.stop_async() except Exception: logger.exception("Error during pipeline stop") if __name__ == "__main__": uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False, log_config=LOG_CONFIG)