interview-mirror / app /api /websocket.py
AliAbouelazm's picture
harsher scoring + interactive controls
65f5325
"""
WebSocket endpoint. The client sends JSON messages with type "video", "audio",
or "ping". The server runs the realtime orchestrator at the cycle cadence and
emits one realtime payload per cycle.
"""
from __future__ import annotations
import asyncio
import json
import time
from typing import Optional
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from app.pipeline.realtime import RealtimeOrchestrator, decode_jpeg_b64, decode_pcm_b64
from app.utils.logging import get_logger
_log = get_logger(__name__)
router = APIRouter()
class SessionConnection:
"""One websocket plus its orchestrator. Frames buffer in, payloads stream out."""
def __init__(self, ws: WebSocket, session_id: str, orchestrator: RealtimeOrchestrator, store):
self.ws = ws
self.session_id = session_id
self.orchestrator = orchestrator
self.store = store
self.latest_frame_bgr = None
self.closed = False
async def receive_loop(self) -> None:
try:
while not self.closed:
msg = await self.ws.receive_text()
try:
data = json.loads(msg)
except json.JSONDecodeError:
continue
msg_type = data.get("type")
if msg_type == "video":
frame = decode_jpeg_b64(data.get("frame", ""))
if frame is not None:
self.latest_frame_bgr = frame
elif msg_type == "audio":
pcm = decode_pcm_b64(data.get("samples", ""))
if pcm is not None and pcm.size > 0:
self.orchestrator.push_audio(pcm)
elif msg_type == "face_metrics":
self.orchestrator.update_face_metrics({
"head_yaw": float(data.get("head_yaw", 0.0)),
"head_pitch": float(data.get("head_pitch", 0.0)),
"head_roll": float(data.get("head_roll", 0.0)),
"eye_openness": float(data.get("eye_openness", 1.0)),
"smile": float(data.get("smile", 0.0)),
"looking_at_camera": float(data.get("looking_at_camera", 1.0)),
})
elif msg_type == "ping":
await self.ws.send_text(json.dumps({"type": "pong", "ts": time.time()}))
except WebSocketDisconnect:
pass
except Exception as e:
_log.warning(f"ws receive loop error: {e}")
finally:
self.closed = True
async def cycle_loop(self) -> None:
try:
while not self.closed:
t0 = time.perf_counter()
payload = await asyncio.to_thread(self.orchestrator.cycle, self.latest_frame_bgr)
payload["session_id"] = self.session_id
await self.store.record_frame(self.session_id, payload)
await self.ws.send_text(json.dumps({"type": "realtime", **payload}))
elapsed_ms = (time.perf_counter() - t0) * 1000.0
target_ms = self.orchestrator.actual_cycle_ms
sleep_ms = max(0.0, target_ms - elapsed_ms)
await asyncio.sleep(sleep_ms / 1000.0)
except WebSocketDisconnect:
pass
except Exception as e:
_log.warning(f"ws cycle loop error: {e}")
finally:
self.closed = True
@router.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
state = websocket.app.state
if not state.session_store.is_live(session_id):
await websocket.close(code=4404, reason="Session not active")
return
if not all(getattr(state, k, None) for k in ("face_predictor", "voice_predictor", "fusion_predictor", "transcriber")):
await websocket.close(code=4503, reason="Models not loaded")
return
settings = state.settings
orchestrator = RealtimeOrchestrator(
face_predictor=state.face_predictor,
voice_predictor=state.voice_predictor,
fusion_predictor=state.fusion_predictor,
transcriber=state.transcriber,
target_cycle_ms=settings.target_cycle_ms,
max_cycle_ms=settings.adaptive_cycle_max_ms,
transcription_interval_cycles=settings.transcription_interval_cycles,
)
await websocket.accept()
conn = SessionConnection(websocket, session_id, orchestrator, state.session_store)
receive_task = asyncio.create_task(conn.receive_loop())
cycle_task = asyncio.create_task(conn.cycle_loop())
try:
await asyncio.wait({receive_task, cycle_task}, return_when=asyncio.FIRST_COMPLETED)
finally:
conn.closed = True
for task in (receive_task, cycle_task):
if not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass