Spaces:
Running
Running
| """API routes for the voice agent.""" | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from pathlib import Path | |
| import shutil | |
| import anyio | |
| from fastapi import APIRouter, File, Form, UploadFile, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from starlette.websockets import WebSocketState | |
| from ..core.errors import SpeechError, ValidationError, LLMError | |
| from ..core.logging import get_logger | |
| from ..services.pipeline import VoicePipeline | |
| from ..services.agent.agent_pipeline import AgentPipeline | |
| from ..services.stt import SpeechToTextService | |
| from ..services.vad import SileroVADStream | |
| from ..utils.audio import encode_base64 | |
| from ..core.config import get_settings | |
| router = APIRouter() | |
| MAX_FILE_SIZE_BYTES = 15 * 1024 * 1024 | |
| DEFAULT_STREAM_CONTENT_TYPE = "audio/ogg" | |
| NO_MATCH_REPLY = "Sorry, I didn't catch that. Please try again." | |
| async def health() -> dict[str, str]: | |
| """Health check endpoint.""" | |
| return {"status": "ok"} | |
| async def ws_demo() -> FileResponse: | |
| """Serve a simple WebSocket streaming demo page.""" | |
| demo_path = Path(__file__).resolve().parent.parent / "utils" / "ws_demo.html" | |
| return FileResponse(demo_path) | |
| async def agent_upload(files: list[UploadFile] = File(...)) -> JSONResponse: | |
| """Upload files for local RAG indexing.""" | |
| settings = get_settings() | |
| data_dir = Path(settings.data_dir) | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| saved: list[str] = [] | |
| for file in files: | |
| content = await file.read() | |
| if not content: | |
| continue | |
| target = data_dir / file.filename | |
| target.write_bytes(content) | |
| saved.append(file.filename) | |
| await AgentPipeline().rebuild_rag() | |
| return JSONResponse({"status": "ok", "files": saved}) | |
| async def agent_reset() -> JSONResponse: | |
| """Reset local RAG + memory storage.""" | |
| settings = get_settings() | |
| data_dir = Path(settings.data_dir) | |
| store_dir = Path(settings.vector_store_dir) | |
| if data_dir.exists(): | |
| shutil.rmtree(data_dir, ignore_errors=True) | |
| if store_dir.exists(): | |
| shutil.rmtree(store_dir, ignore_errors=True) | |
| data_dir.mkdir(parents=True, exist_ok=True) | |
| store_dir.mkdir(parents=True, exist_ok=True) | |
| AgentPipeline().reset() | |
| return JSONResponse({"status": "ok"}) | |
| async def voice_file( | |
| file: UploadFile = File(...), | |
| prompt: str | None = Form(default=None), | |
| return_audio: bool = Form(default=True), | |
| llm_provider: str | None = Form(default=None), | |
| ) -> JSONResponse: | |
| """Process uploaded audio file and return transcript + reply.""" | |
| request_id = str(uuid.uuid4()) | |
| log = get_logger(request_id=request_id, endpoint="/v1/voice/file") | |
| if not file: | |
| raise ValidationError(code="file_missing", message="Audio file is required.") | |
| audio_bytes = await file.read() | |
| if not audio_bytes: | |
| raise ValidationError(code="file_empty", message="Audio file is empty.") | |
| if len(audio_bytes) > MAX_FILE_SIZE_BYTES: | |
| raise ValidationError(code="file_too_large", message="File exceeds 15MB limit.") | |
| pipeline = VoicePipeline() | |
| if llm_provider and llm_provider not in { | |
| "foundry_agent", | |
| "azure_openai", | |
| "local_agent", | |
| }: | |
| raise ValidationError( | |
| code="llm_provider", | |
| message="LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'.", | |
| ) | |
| if llm_provider == "local_agent": | |
| agent = AgentPipeline() | |
| result = await agent.run_audio( | |
| audio_bytes=audio_bytes, | |
| filename=file.filename, | |
| content_type=file.content_type, | |
| prompt=prompt, | |
| return_audio=return_audio, | |
| llm_provider="azure_openai", | |
| ) | |
| else: | |
| result = await pipeline.run( | |
| audio_bytes=audio_bytes, | |
| filename=file.filename, | |
| content_type=file.content_type, | |
| prompt=prompt, | |
| return_audio=return_audio, | |
| llm_provider=llm_provider, | |
| ) | |
| reply_audio_base64 = ( | |
| encode_base64(result.reply_audio) if result.reply_audio else None | |
| ) | |
| response_body = { | |
| "transcript": result.transcript, | |
| "reply_text": result.reply_text, | |
| "audio_format": "wav", | |
| "reply_audio_base64": reply_audio_base64, | |
| "timings_ms": result.timings_ms, | |
| } | |
| log.info( | |
| "voice_request_complete", | |
| file_name=file.filename, | |
| file_size=len(audio_bytes), | |
| timings_ms=result.timings_ms, | |
| return_audio=return_audio, | |
| ) | |
| return JSONResponse(response_body) | |
| async def voice_stream(websocket: WebSocket) -> None: | |
| """Stream audio over WebSocket, then process on 'stop'.""" | |
| await websocket.accept() | |
| request_id = str(uuid.uuid4()) | |
| log = get_logger(request_id=request_id, endpoint="/ws/voice") | |
| buffer = bytearray() | |
| content_type: str | None = DEFAULT_STREAM_CONTENT_TYPE | |
| prompt: str | None = None | |
| return_audio = True | |
| stt_session = None | |
| frames_sent: int | None = None | |
| avg_rms: float | None = None | |
| llm_provider: str | None = None | |
| vad_stream: SileroVADStream | None = None | |
| segment_processing = False | |
| session_id: str | None = None | |
| async def _finalize_segment() -> None: | |
| nonlocal stt_session, segment_processing, vad_stream | |
| if stt_session is None: | |
| raise ValidationError( | |
| code="stt_not_started", message="STT session not started." | |
| ) | |
| if not buffer: | |
| return | |
| segment_processing = True | |
| try: | |
| stt_result = await anyio.to_thread.run_sync(stt_session.finish) | |
| except SpeechError as exc: | |
| if exc.code in {"stt_empty", "stt_no_match"}: | |
| try: | |
| stt_result = await anyio.to_thread.run_sync( | |
| SpeechToTextService().transcribe, | |
| bytes(buffer), | |
| None, | |
| content_type, | |
| ) | |
| except SpeechError as exc_fallback: | |
| if exc_fallback.code in {"stt_empty", "stt_no_match"}: | |
| await websocket.send_json( | |
| { | |
| "event": "result", | |
| "transcript": "", | |
| "reply_text": NO_MATCH_REPLY, | |
| "audio_format": "wav", | |
| "reply_audio_base64": None, | |
| "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0}, | |
| } | |
| ) | |
| buffer.clear() | |
| stt_session = SpeechToTextService().start_streaming( | |
| end_silence_ms=1400, initial_silence_ms=5000 | |
| ) | |
| vad_stream = SileroVADStream() | |
| return | |
| raise | |
| else: | |
| raise | |
| await websocket.send_json( | |
| {"event": "transcript", "transcript": stt_result.transcript} | |
| ) | |
| if llm_provider == "local_agent": | |
| agent = AgentPipeline() | |
| result = await agent.run_with_transcript( | |
| transcript=stt_result.transcript, | |
| language=stt_result.language, | |
| prompt=prompt, | |
| return_audio=return_audio, | |
| llm_provider="azure_openai", | |
| session_id=session_id, | |
| ) | |
| else: | |
| pipeline = VoicePipeline() | |
| result = await pipeline.run( | |
| audio_bytes=bytes(buffer), | |
| filename=None, | |
| content_type=content_type, | |
| prompt=prompt, | |
| return_audio=return_audio, | |
| transcript_override=stt_result.transcript, | |
| language_override=stt_result.language, | |
| llm_provider=llm_provider, | |
| ) | |
| response_body = { | |
| "event": "result", | |
| "transcript": result.transcript, | |
| "reply_text": result.reply_text, | |
| "audio_format": "wav", | |
| "reply_audio_base64": None, | |
| "timings_ms": result.timings_ms, | |
| } | |
| log.info( | |
| "voice_stream_complete", | |
| bytes_received=len(buffer), | |
| timings_ms=result.timings_ms, | |
| return_audio=return_audio, | |
| content_type=content_type, | |
| frames_sent=frames_sent, | |
| avg_rms=avg_rms, | |
| ) | |
| await websocket.send_json(response_body) | |
| if result.reply_audio and return_audio: | |
| await websocket.send_bytes(result.reply_audio) | |
| buffer.clear() | |
| stt_session = SpeechToTextService().start_streaming( | |
| end_silence_ms=1200, initial_silence_ms=5000 | |
| ) | |
| vad_stream = SileroVADStream() | |
| segment_processing = False | |
| try: | |
| while True: | |
| try: | |
| message = await websocket.receive() | |
| except RuntimeError: | |
| log.info("voice_stream_disconnect") | |
| break | |
| if "bytes" in message and message["bytes"] is not None: | |
| chunk = message["bytes"] | |
| if stt_session is not None: | |
| stt_session.write(chunk) | |
| buffer.extend(chunk) | |
| if vad_stream is not None and not segment_processing: | |
| decision = vad_stream.update(chunk) | |
| if decision.speech_ended: | |
| await _finalize_segment() | |
| if len(buffer) > MAX_FILE_SIZE_BYTES: | |
| raise ValidationError( | |
| code="file_too_large", message="Stream exceeds 15MB limit." | |
| ) | |
| continue | |
| if "text" in message and message["text"] is not None: | |
| try: | |
| payload = json.loads(message["text"]) | |
| except json.JSONDecodeError as exc: | |
| raise ValidationError( | |
| code="invalid_message", message="Invalid JSON message." | |
| ) from exc | |
| event = str(payload.get("event", "")).lower() | |
| if event == "start": | |
| content_type = payload.get("content_type") or content_type | |
| prompt = payload.get("prompt") | |
| return_audio = payload.get("return_audio", True) | |
| llm_provider = payload.get("llm_provider", llm_provider) | |
| session_id = payload.get("session_id", session_id) or request_id | |
| if llm_provider and llm_provider not in { | |
| "foundry_agent", | |
| "azure_openai", | |
| "local_agent", | |
| }: | |
| raise ValidationError( | |
| code="llm_provider", | |
| message=( | |
| "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." | |
| ), | |
| ) | |
| stt_session = SpeechToTextService().start_streaming( | |
| end_silence_ms=1200, initial_silence_ms=5000 | |
| ) | |
| vad_stream = SileroVADStream() | |
| continue | |
| if event == "stop": | |
| if not buffer: | |
| raise ValidationError( | |
| code="file_empty", message="Audio stream is empty." | |
| ) | |
| if stt_session is None: | |
| raise ValidationError( | |
| code="stt_not_started", | |
| message="STT session not started.", | |
| ) | |
| prompt = payload.get("prompt", prompt) | |
| return_audio = payload.get("return_audio", return_audio) | |
| llm_provider = payload.get("llm_provider", llm_provider) | |
| session_id = payload.get("session_id", session_id) or request_id | |
| frames_sent = payload.get("frames_sent", frames_sent) | |
| avg_rms = payload.get("avg_rms", avg_rms) | |
| if llm_provider and llm_provider not in { | |
| "foundry_agent", | |
| "azure_openai", | |
| "local_agent", | |
| }: | |
| raise ValidationError( | |
| code="llm_provider", | |
| message=( | |
| "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." | |
| ), | |
| ) | |
| await _finalize_segment() | |
| break | |
| if event == "segment_end": | |
| if not buffer: | |
| continue | |
| prompt = payload.get("prompt", prompt) | |
| return_audio = payload.get("return_audio", return_audio) | |
| llm_provider = payload.get("llm_provider", llm_provider) | |
| session_id = payload.get("session_id", session_id) or request_id | |
| frames_sent = payload.get("frames_sent", frames_sent) | |
| avg_rms = payload.get("avg_rms", avg_rms) | |
| if llm_provider and llm_provider not in { | |
| "foundry_agent", | |
| "azure_openai", | |
| "local_agent", | |
| }: | |
| raise ValidationError( | |
| code="llm_provider", | |
| message=( | |
| "LLM provider must be 'foundry_agent', 'azure_openai', or 'local_agent'." | |
| ), | |
| ) | |
| if vad_stream is not None and not vad_stream.has_speech(): | |
| await websocket.send_json( | |
| { | |
| "event": "result", | |
| "transcript": "", | |
| "reply_text": NO_MATCH_REPLY, | |
| "audio_format": "wav", | |
| "reply_audio_base64": None, | |
| "timings_ms": {"stt": 0, "llm": 0, "tts": 0, "total": 0}, | |
| } | |
| ) | |
| buffer.clear() | |
| vad_stream.reset() | |
| continue | |
| await _finalize_segment() | |
| continue | |
| raise ValidationError( | |
| code="invalid_event", | |
| message="Event must be 'start', 'stop', or 'segment_end'.", | |
| ) | |
| except WebSocketDisconnect: | |
| log.info("voice_stream_disconnect") | |
| except ValidationError as exc: | |
| log.warning( | |
| "voice_stream_error", | |
| code=exc.code, | |
| message=exc.message, | |
| bytes_received=len(buffer), | |
| frames_sent=frames_sent, | |
| avg_rms=avg_rms, | |
| ) | |
| if websocket.application_state == WebSocketState.CONNECTED: | |
| await websocket.send_json( | |
| {"event": "error", "error": {"code": exc.code, "message": exc.message}} | |
| ) | |
| await websocket.close() | |
| except LLMError as exc: | |
| log.warning( | |
| "voice_stream_error", | |
| code=exc.code, | |
| message=exc.message, | |
| details=exc.details, | |
| bytes_received=len(buffer), | |
| frames_sent=frames_sent, | |
| avg_rms=avg_rms, | |
| ) | |
| if websocket.application_state == WebSocketState.CONNECTED: | |
| await websocket.send_json( | |
| { | |
| "event": "error", | |
| "error": {"code": exc.code, "message": exc.message}, | |
| } | |
| ) | |
| if exc.code != "llm_guardrail": | |
| await websocket.close() | |
| return | |
| except SpeechError as exc: | |
| log.warning( | |
| "voice_stream_error", | |
| code=exc.code, | |
| message=exc.message, | |
| details=exc.details, | |
| bytes_received=len(buffer), | |
| frames_sent=frames_sent, | |
| avg_rms=avg_rms, | |
| ) | |
| if websocket.application_state == WebSocketState.CONNECTED: | |
| await websocket.send_json( | |
| { | |
| "event": "error", | |
| "error": { | |
| "code": exc.code, | |
| "message": exc.message, | |
| "details": exc.details, | |
| }, | |
| } | |
| ) | |
| await websocket.close() | |
| except Exception as exc: # pragma: no cover - safety net | |
| log.error("voice_stream_unhandled", error=repr(exc)) | |
| if websocket.application_state == WebSocketState.CONNECTED: | |
| await websocket.send_json( | |
| { | |
| "event": "error", | |
| "error": {"code": "internal_error", "message": "Server error."}, | |
| } | |
| ) | |
| await websocket.close() | |