Spaces:
Running
Running
| """ | |
| WebSocket endpoint. The client sends JSON messages with type "video", "audio", | |
| or "ping". The server runs the realtime orchestrator at the cycle cadence and | |
| emits one realtime payload per cycle. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import json | |
| import time | |
| from typing import Optional | |
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect | |
| from app.pipeline.realtime import RealtimeOrchestrator, decode_jpeg_b64, decode_pcm_b64 | |
| from app.utils.logging import get_logger | |
| _log = get_logger(__name__) | |
| router = APIRouter() | |
| class SessionConnection: | |
| """One websocket plus its orchestrator. Frames buffer in, payloads stream out.""" | |
| def __init__(self, ws: WebSocket, session_id: str, orchestrator: RealtimeOrchestrator, store): | |
| self.ws = ws | |
| self.session_id = session_id | |
| self.orchestrator = orchestrator | |
| self.store = store | |
| self.latest_frame_bgr = None | |
| self.closed = False | |
| async def receive_loop(self) -> None: | |
| try: | |
| while not self.closed: | |
| msg = await self.ws.receive_text() | |
| try: | |
| data = json.loads(msg) | |
| except json.JSONDecodeError: | |
| continue | |
| msg_type = data.get("type") | |
| if msg_type == "video": | |
| frame = decode_jpeg_b64(data.get("frame", "")) | |
| if frame is not None: | |
| self.latest_frame_bgr = frame | |
| elif msg_type == "audio": | |
| pcm = decode_pcm_b64(data.get("samples", "")) | |
| if pcm is not None and pcm.size > 0: | |
| self.orchestrator.push_audio(pcm) | |
| elif msg_type == "face_metrics": | |
| self.orchestrator.update_face_metrics({ | |
| "head_yaw": float(data.get("head_yaw", 0.0)), | |
| "head_pitch": float(data.get("head_pitch", 0.0)), | |
| "head_roll": float(data.get("head_roll", 0.0)), | |
| "eye_openness": float(data.get("eye_openness", 1.0)), | |
| "smile": float(data.get("smile", 0.0)), | |
| "looking_at_camera": float(data.get("looking_at_camera", 1.0)), | |
| }) | |
| elif msg_type == "ping": | |
| await self.ws.send_text(json.dumps({"type": "pong", "ts": time.time()})) | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| _log.warning(f"ws receive loop error: {e}") | |
| finally: | |
| self.closed = True | |
| async def cycle_loop(self) -> None: | |
| try: | |
| while not self.closed: | |
| t0 = time.perf_counter() | |
| payload = await asyncio.to_thread(self.orchestrator.cycle, self.latest_frame_bgr) | |
| payload["session_id"] = self.session_id | |
| await self.store.record_frame(self.session_id, payload) | |
| await self.ws.send_text(json.dumps({"type": "realtime", **payload})) | |
| elapsed_ms = (time.perf_counter() - t0) * 1000.0 | |
| target_ms = self.orchestrator.actual_cycle_ms | |
| sleep_ms = max(0.0, target_ms - elapsed_ms) | |
| await asyncio.sleep(sleep_ms / 1000.0) | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| _log.warning(f"ws cycle loop error: {e}") | |
| finally: | |
| self.closed = True | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| state = websocket.app.state | |
| if not state.session_store.is_live(session_id): | |
| await websocket.close(code=4404, reason="Session not active") | |
| return | |
| if not all(getattr(state, k, None) for k in ("face_predictor", "voice_predictor", "fusion_predictor", "transcriber")): | |
| await websocket.close(code=4503, reason="Models not loaded") | |
| return | |
| settings = state.settings | |
| orchestrator = RealtimeOrchestrator( | |
| face_predictor=state.face_predictor, | |
| voice_predictor=state.voice_predictor, | |
| fusion_predictor=state.fusion_predictor, | |
| transcriber=state.transcriber, | |
| target_cycle_ms=settings.target_cycle_ms, | |
| max_cycle_ms=settings.adaptive_cycle_max_ms, | |
| transcription_interval_cycles=settings.transcription_interval_cycles, | |
| ) | |
| await websocket.accept() | |
| conn = SessionConnection(websocket, session_id, orchestrator, state.session_store) | |
| receive_task = asyncio.create_task(conn.receive_loop()) | |
| cycle_task = asyncio.create_task(conn.cycle_loop()) | |
| try: | |
| await asyncio.wait({receive_task, cycle_task}, return_when=asyncio.FIRST_COMPLETED) | |
| finally: | |
| conn.closed = True | |
| for task in (receive_task, cycle_task): | |
| if not task.done(): | |
| task.cancel() | |
| try: | |
| await task | |
| except (asyncio.CancelledError, Exception): | |
| pass | |