"""API routes for the voice agent.""" from __future__ import annotations import json import uuid from pathlib import Path import shutil import anyio from fastapi import APIRouter, File, Form, UploadFile, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse, JSONResponse from starlette.websockets import WebSocketState from ..core.errors import SpeechError, ValidationError, LLMError from ..core.logging import get_logger from ..services.pipeline import VoicePipeline from ..services.agent.agent_pipeline import AgentPipeline from ..services.stt import SpeechToTextService from ..services.vad import SileroVADStream from ..utils.audio import encode_base64 from ..core.config import get_settings router = APIRouter() MAX_FILE_SIZE_BYTES = 15 * 1024 * 1024 DEFAULT_STREAM_CONTENT_TYPE = "audio/ogg" NO_MATCH_REPLY = "Sorry, I didn't catch that. Please try again." @router.get("/health") async def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} @router.get("/ws-demo") async def ws_demo() -> FileResponse: """Serve a simple WebSocket streaming demo page.""" demo_path = Path(__file__).resolve().parent.parent / "utils" / "ws_demo.html" return FileResponse(demo_path) @router.post("/v1/agent/upload") async def agent_upload(files: list[UploadFile] = File(...)) -> JSONResponse: """Upload files for local RAG indexing.""" settings = get_settings() data_dir = Path(settings.data_dir) data_dir.mkdir(parents=True, exist_ok=True) saved: list[str] = [] for file in files: content = await file.read() if not content: continue target = data_dir / file.filename target.write_bytes(content) saved.append(file.filename) await AgentPipeline().rebuild_rag() return JSONResponse({"status": "ok", "files": saved}) @router.post("/v1/agent/reset") async def agent_reset() -> JSONResponse: """Reset local RAG + memory storage.""" settings = get_settings() data_dir = Path(settings.data_dir) store_dir = Path(settings.vector_store_dir) if data_dir.exists(): shutil.rmtree(data_dir, ignore_errors=True) if store_dir.exists(): shutil.rmtree(store_dir, ignore_errors=True) data_dir.mkdir(parents=True, exist_ok=True) store_dir.mkdir(parents=True, exist_ok=True) AgentPipeline().reset() return JSONResponse({"status": "ok"}) @router.post("/v1/voice/file") async def voice_file( file: UploadFile = File(...), prompt: str | None = Form(default=None), return_audio: bool = Form(default=True), llm_provider: str | None = Form(default=None), ) -> JSONResponse: """Process uploaded audio file and return transcript + reply.""" request_id = str(uuid.uuid4()) log = get_logger(request_id=request_id, endpoint="/v1/voice/file") if not file: raise ValidationError(code="file_missing", message="Audio file is required.") audio_bytes = await file.read() if not audio_bytes: raise ValidationError(code="file_empty", message="Audio file is empty.") if len(audio_bytes) > MAX_FILE_SIZE_BYTES: raise ValidationError(code="file_too_large", message="File exceeds 15MB limit.") pipeline = VoicePipeline() if llm_provider and llm_provider not in { "foundry_agent", "azure_openai", "local_agent", }: raise ValidationError( code="llm_provider", message="LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'.", ) if llm_provider == "local_agent": agent = AgentPipeline() result = await agent.run_audio( audio_bytes=audio_bytes, filename=file.filename, content_type=file.content_type, prompt=prompt, return_audio=return_audio, llm_provider="azure_openai", ) else: result = await pipeline.run( audio_bytes=audio_bytes, filename=file.filename, content_type=file.content_type, prompt=prompt, return_audio=return_audio, llm_provider=llm_provider, ) reply_audio_base64 = ( encode_base64(result.reply_audio) if result.reply_audio else None ) response_body = { "transcript": result.transcript, "reply_text": result.reply_text, "audio_format": "wav", "reply_audio_base64": reply_audio_base64, "timings_ms": result.timings_ms, } log.info( "voice_request_complete", file_name=file.filename, file_size=len(audio_bytes), timings_ms=result.timings_ms, return_audio=return_audio, ) return JSONResponse(response_body) @router.websocket("/ws/voice") async def voice_stream(websocket: WebSocket) -> None: """Stream audio over WebSocket, then process on 'stop'.""" await websocket.accept() request_id = str(uuid.uuid4()) log = get_logger(request_id=request_id, endpoint="/ws/voice") buffer = bytearray() content_type: str | None = DEFAULT_STREAM_CONTENT_TYPE prompt: str | None = None return_audio = True stt_session = None frames_sent: int | None = None avg_rms: float | None = None llm_provider: str | None = None vad_stream: SileroVADStream | None = None segment_processing = False session_id: str | None = None async def _finalize_segment() -> None: nonlocal stt_session, segment_processing, vad_stream if stt_session is None: raise ValidationError( code="stt_not_started", message="STT session not started." ) if not buffer: return segment_processing = True try: stt_result = await anyio.to_thread.run_sync(stt_session.finish) except SpeechError as exc: if exc.code in {"stt_empty", "stt_no_match"}: try: stt_result = await anyio.to_thread.run_sync( SpeechToTextService().transcribe, bytes(buffer), None, content_type, ) except SpeechError as exc_fallback: if exc_fallback.code in {"stt_empty", "stt_no_match"}: await websocket.send_json( { "event": "result", "transcript": "", "reply_text": NO_MATCH_REPLY, "audio_format": "wav", "reply_audio_base64": None, "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0}, } ) buffer.clear() stt_session = SpeechToTextService().start_streaming( end_silence_ms=1400, initial_silence_ms=5000 ) vad_stream = SileroVADStream() return raise else: raise await websocket.send_json( {"event": "transcript", "transcript": stt_result.transcript} ) if llm_provider == "local_agent": agent = AgentPipeline() result = await agent.run_with_transcript( transcript=stt_result.transcript, language=stt_result.language, prompt=prompt, return_audio=return_audio, llm_provider="azure_openai", session_id=session_id, ) else: pipeline = VoicePipeline() result = await pipeline.run( audio_bytes=bytes(buffer), filename=None, content_type=content_type, prompt=prompt, return_audio=return_audio, transcript_override=stt_result.transcript, language_override=stt_result.language, llm_provider=llm_provider, ) response_body = { "event": "result", "transcript": result.transcript, "reply_text": result.reply_text, "audio_format": "wav", "reply_audio_base64": None, "timings_ms": result.timings_ms, } log.info( "voice_stream_complete", bytes_received=len(buffer), timings_ms=result.timings_ms, return_audio=return_audio, content_type=content_type, frames_sent=frames_sent, avg_rms=avg_rms, ) await websocket.send_json(response_body) if result.reply_audio and return_audio: await websocket.send_bytes(result.reply_audio) buffer.clear() stt_session = SpeechToTextService().start_streaming( end_silence_ms=1200, initial_silence_ms=5000 ) vad_stream = SileroVADStream() segment_processing = False try: while True: try: message = await websocket.receive() except RuntimeError: log.info("voice_stream_disconnect") break if "bytes" in message and message["bytes"] is not None: chunk = message["bytes"] if stt_session is not None: stt_session.write(chunk) buffer.extend(chunk) if vad_stream is not None and not segment_processing: decision = vad_stream.update(chunk) if decision.speech_ended: await _finalize_segment() if len(buffer) > MAX_FILE_SIZE_BYTES: raise ValidationError( code="file_too_large", message="Stream exceeds 15MB limit." ) continue if "text" in message and message["text"] is not None: try: payload = json.loads(message["text"]) except json.JSONDecodeError as exc: raise ValidationError( code="invalid_message", message="Invalid JSON message." ) from exc event = str(payload.get("event", "")).lower() if event == "start": content_type = payload.get("content_type") or content_type prompt = payload.get("prompt") return_audio = payload.get("return_audio", True) llm_provider = payload.get("llm_provider", llm_provider) session_id = payload.get("session_id", session_id) or request_id if llm_provider and llm_provider not in { "foundry_agent", "azure_openai", "local_agent", }: raise ValidationError( code="llm_provider", message=( "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." ), ) stt_session = SpeechToTextService().start_streaming( end_silence_ms=1200, initial_silence_ms=5000 ) vad_stream = SileroVADStream() continue if event == "stop": if not buffer: raise ValidationError( code="file_empty", message="Audio stream is empty." ) if stt_session is None: raise ValidationError( code="stt_not_started", message="STT session not started.", ) prompt = payload.get("prompt", prompt) return_audio = payload.get("return_audio", return_audio) llm_provider = payload.get("llm_provider", llm_provider) session_id = payload.get("session_id", session_id) or request_id frames_sent = payload.get("frames_sent", frames_sent) avg_rms = payload.get("avg_rms", avg_rms) if llm_provider and llm_provider not in { "foundry_agent", "azure_openai", "local_agent", }: raise ValidationError( code="llm_provider", message=( "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." ), ) await _finalize_segment() break if event == "segment_end": if not buffer: continue prompt = payload.get("prompt", prompt) return_audio = payload.get("return_audio", return_audio) llm_provider = payload.get("llm_provider", llm_provider) session_id = payload.get("session_id", session_id) or request_id frames_sent = payload.get("frames_sent", frames_sent) avg_rms = payload.get("avg_rms", avg_rms) if llm_provider and llm_provider not in { "foundry_agent", "azure_openai", "local_agent", }: raise ValidationError( code="llm_provider", message=( "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." ), ) if vad_stream is not None and not vad_stream.has_speech(): await websocket.send_json( { "event": "result", "transcript": "", "reply_text": NO_MATCH_REPLY, "audio_format": "wav", "reply_audio_base64": None, "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0}, } ) buffer.clear() vad_stream.reset() continue await _finalize_segment() continue raise ValidationError( code="invalid_event", message="Event must be 'start', 'stop', or 'segment_end'.", ) except WebSocketDisconnect: log.info("voice_stream_disconnect") except ValidationError as exc: log.warning( "voice_stream_error", code=exc.code, message=exc.message, bytes_received=len(buffer), frames_sent=frames_sent, avg_rms=avg_rms, ) if websocket.application_state == WebSocketState.CONNECTED: await websocket.send_json( {"event": "error", "error": {"code": exc.code, "message": exc.message}} ) await websocket.close() except LLMError as exc: log.warning( "voice_stream_error", code=exc.code, message=exc.message, details=exc.details, bytes_received=len(buffer), frames_sent=frames_sent, avg_rms=avg_rms, ) if websocket.application_state == WebSocketState.CONNECTED: await websocket.send_json( { "event": "error", "error": {"code": exc.code, "message": exc.message}, } ) if exc.code != "llm_guardrail": await websocket.close() return except SpeechError as exc: log.warning( "voice_stream_error", code=exc.code, message=exc.message, details=exc.details, bytes_received=len(buffer), frames_sent=frames_sent, avg_rms=avg_rms, ) if websocket.application_state == WebSocketState.CONNECTED: await websocket.send_json( { "event": "error", "error": { "code": exc.code, "message": exc.message, "details": exc.details, }, } ) await websocket.close() except Exception as exc: # pragma: no cover - safety net log.error("voice_stream_unhandled", error=repr(exc)) if websocket.application_state == WebSocketState.CONNECTED: await websocket.send_json( { "event": "error", "error": {"code": "internal_error", "message": "Server error."}, } ) await websocket.close()