Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import logging.config | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, Form, UploadFile, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from src.pipeline import VoicePipeline | |
| from typing import Literal | |
| from src.config import ( | |
| DEEPGRAM_API_KEY, CARTESIA_API_KEY, CARTESIA_VOICE_ID, | |
| SAMPLE_RATE, GOOGLE_API_KEY, GOOGLE_PROJECT_ID, STT_PROVIDER, TTS_PROVIDER, WAKE_WORD_ENABLED, | |
| ) | |
| from src.stt.deepgram_rest import transcribe_audio as deepgram_transcribe | |
| from src.stt.gemini_stt import transcribe_audio as gemini_stt_transcribe | |
| from src.stt.chirp3_client import transcribe_audio as chirp3_transcribe | |
| from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize | |
| from src.tts.gemini_client import synthesize_stream as gemini_synthesize, GEMINI_SAMPLE_RATE | |
| LOG_CONFIG = { | |
| "version": 1, | |
| "disable_existing_loggers": False, | |
| "formatters": { | |
| "default": { | |
| "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| }, | |
| }, | |
| "handlers": { | |
| "default": { | |
| "class": "logging.StreamHandler", | |
| "formatter": "default", | |
| }, | |
| }, | |
| "root": { | |
| "level": "INFO", | |
| "handlers": ["default"], | |
| }, | |
| } | |
| logging.config.dictConfig(LOG_CONFIG) | |
| logger = logging.getLogger(__name__) | |
| VERSION = "1.2.0" | |
| app = FastAPI(title="Voice Agent Service", version=VERSION) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health() -> JSONResponse: | |
| body: dict = { | |
| "message": "Welcome to the Voice Agent Service! Please use the /health endpoint to check service status.", | |
| "version": VERSION, | |
| } | |
| return JSONResponse(status_code=200, content=body) | |
| async def health() -> JSONResponse: | |
| stt_ready = bool(DEEPGRAM_API_KEY) | |
| tts_ready = bool(CARTESIA_API_KEY and CARTESIA_VOICE_ID) | |
| gemini_ready = bool(GOOGLE_API_KEY) | |
| all_ready = stt_ready and tts_ready | |
| body: dict = { | |
| "status": "ok" if all_ready else "degraded", | |
| "version": VERSION, | |
| "stt_ready": stt_ready, | |
| "tts_ready": tts_ready, | |
| "gemini_tts_ready": gemini_ready, | |
| "gemini_stt_ready": gemini_ready, | |
| } | |
| if not all_ready: | |
| body["message"] = "One or more required configurations are missing. Check your .env file." | |
| return JSONResponse(status_code=200 if all_ready else 503, content=body) | |
| class TTSRequest(BaseModel): | |
| text: str | |
| provider: Literal["cartesia", "gemini"] = "gemini" | |
| async def speech_to_text( | |
| audio: UploadFile = File(...), | |
| provider: str = Form(default="chirp3"), | |
| ) -> JSONResponse: | |
| data = await audio.read() | |
| if not data: | |
| raise HTTPException(status_code=400, detail="Audio file is empty.") | |
| mimetype = audio.content_type or "audio/wav" | |
| if provider == "chirp3": | |
| if not GOOGLE_PROJECT_ID: | |
| raise HTTPException(status_code=503, detail="Chirp3 STT not configured: missing GOOGLE_PROJECT_ID.") | |
| result = await chirp3_transcribe(data, mimetype=mimetype) | |
| elif provider == "gemini": | |
| if not GOOGLE_API_KEY: | |
| raise HTTPException(status_code=503, detail="Gemini STT not configured: missing GOOGLE_API_KEY.") | |
| result = await gemini_stt_transcribe(data, mimetype=mimetype) | |
| else: | |
| result = await deepgram_transcribe(data, mimetype=mimetype) | |
| return JSONResponse(content=result) | |
| async def text_to_speech(req: TTSRequest) -> StreamingResponse: | |
| if not req.text.strip(): | |
| raise HTTPException(status_code=400, detail="text must not be empty.") | |
| logger.info("TTS request: provider=%s, text_len=%d, text=%r", req.provider, len(req.text), req.text) | |
| if req.provider == "gemini": | |
| if not GOOGLE_API_KEY: | |
| raise HTTPException(status_code=503, detail="Gemini TTS not configured.") | |
| stream = gemini_synthesize(req.text) | |
| sample_rate = GEMINI_SAMPLE_RATE | |
| else: | |
| stream = cartesia_synthesize(req.text) | |
| sample_rate = SAMPLE_RATE | |
| return StreamingResponse( | |
| stream, | |
| media_type="audio/pcm", | |
| headers={ | |
| "X-Sample-Rate": str(sample_rate), | |
| "X-Encoding": "pcm_s16le", | |
| "X-Channels": "1", | |
| }, | |
| ) | |
| async def voice_ws( | |
| ws: WebSocket, | |
| stt_provider: str = Query(default=STT_PROVIDER), | |
| tts_provider: str = Query(default=TTS_PROVIDER), | |
| wake_word_enabled: bool = Query(default=WAKE_WORD_ENABLED), | |
| ) -> None: | |
| await ws.accept() | |
| logger.info( | |
| "Client connected: %s (stt=%s, tts=%s)", | |
| ws.client, stt_provider, tts_provider, | |
| ) | |
| async def send_audio(chunk: bytes) -> None: | |
| try: | |
| await ws.send_bytes(chunk) | |
| except WebSocketDisconnect: | |
| pass | |
| async def send_event(event: dict) -> None: | |
| try: | |
| await ws.send_text(json.dumps(event)) | |
| except WebSocketDisconnect: | |
| pass | |
| tts_sample_rate = GEMINI_SAMPLE_RATE if tts_provider == "gemini" else SAMPLE_RATE | |
| await send_event({ | |
| "event": "tts_config", | |
| "tts_provider": tts_provider, | |
| "stt_provider": stt_provider, | |
| "sample_rate": tts_sample_rate, | |
| "encoding": "pcm_s16le", | |
| "channels": 1, | |
| }) | |
| pipeline = VoicePipeline( | |
| send_audio=send_audio, | |
| send_event=send_event, | |
| stt_provider=stt_provider, | |
| tts_provider=tts_provider, | |
| wake_word_enabled=wake_word_enabled, | |
| ) | |
| pipeline.start() | |
| try: | |
| while True: | |
| data = await ws.receive() | |
| if "bytes" in data and data["bytes"]: | |
| pipeline.feed_audio(data["bytes"]) | |
| elif "text" in data and data["text"]: | |
| try: | |
| msg = json.loads(data["text"]) | |
| action = msg.get("action") | |
| if action == "stop": | |
| break | |
| elif action == "ping": | |
| await ws.send_text(json.dumps({"event": "pong"})) | |
| elif action == "interrupt": | |
| await pipeline.interrupt() | |
| elif action == "speak": | |
| text = msg.get("text", "").strip() | |
| if text: | |
| await pipeline.speak(text) | |
| except json.JSONDecodeError: | |
| pass | |
| except WebSocketDisconnect: | |
| logger.info("Client disconnected: %s", ws.client) | |
| finally: | |
| try: | |
| await pipeline.stop_async() | |
| except Exception: | |
| logger.exception("Error during pipeline stop") | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False, log_config=LOG_CONFIG) | |