File size: 5,022 Bytes
8ca89f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65f5325
 
 
 
 
 
 
 
 
8ca89f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
WebSocket endpoint. The client sends JSON messages with type "video", "audio",
or "ping". The server runs the realtime orchestrator at the cycle cadence and
emits one realtime payload per cycle.
"""
from __future__ import annotations

import asyncio
import json
import time
from typing import Optional

from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from app.pipeline.realtime import RealtimeOrchestrator, decode_jpeg_b64, decode_pcm_b64
from app.utils.logging import get_logger


_log = get_logger(__name__)

router = APIRouter()


class SessionConnection:
    """One websocket plus its orchestrator. Frames buffer in, payloads stream out."""

    def __init__(self, ws: WebSocket, session_id: str, orchestrator: RealtimeOrchestrator, store):
        self.ws = ws
        self.session_id = session_id
        self.orchestrator = orchestrator
        self.store = store
        self.latest_frame_bgr = None
        self.closed = False

    async def receive_loop(self) -> None:
        try:
            while not self.closed:
                msg = await self.ws.receive_text()
                try:
                    data = json.loads(msg)
                except json.JSONDecodeError:
                    continue

                msg_type = data.get("type")
                if msg_type == "video":
                    frame = decode_jpeg_b64(data.get("frame", ""))
                    if frame is not None:
                        self.latest_frame_bgr = frame
                elif msg_type == "audio":
                    pcm = decode_pcm_b64(data.get("samples", ""))
                    if pcm is not None and pcm.size > 0:
                        self.orchestrator.push_audio(pcm)
                elif msg_type == "face_metrics":
                    self.orchestrator.update_face_metrics({
                        "head_yaw": float(data.get("head_yaw", 0.0)),
                        "head_pitch": float(data.get("head_pitch", 0.0)),
                        "head_roll": float(data.get("head_roll", 0.0)),
                        "eye_openness": float(data.get("eye_openness", 1.0)),
                        "smile": float(data.get("smile", 0.0)),
                        "looking_at_camera": float(data.get("looking_at_camera", 1.0)),
                    })
                elif msg_type == "ping":
                    await self.ws.send_text(json.dumps({"type": "pong", "ts": time.time()}))
        except WebSocketDisconnect:
            pass
        except Exception as e:
            _log.warning(f"ws receive loop error: {e}")
        finally:
            self.closed = True

    async def cycle_loop(self) -> None:
        try:
            while not self.closed:
                t0 = time.perf_counter()
                payload = await asyncio.to_thread(self.orchestrator.cycle, self.latest_frame_bgr)
                payload["session_id"] = self.session_id
                await self.store.record_frame(self.session_id, payload)
                await self.ws.send_text(json.dumps({"type": "realtime", **payload}))
                elapsed_ms = (time.perf_counter() - t0) * 1000.0
                target_ms = self.orchestrator.actual_cycle_ms
                sleep_ms = max(0.0, target_ms - elapsed_ms)
                await asyncio.sleep(sleep_ms / 1000.0)
        except WebSocketDisconnect:
            pass
        except Exception as e:
            _log.warning(f"ws cycle loop error: {e}")
        finally:
            self.closed = True


@router.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: str):
    state = websocket.app.state
    if not state.session_store.is_live(session_id):
        await websocket.close(code=4404, reason="Session not active")
        return

    if not all(getattr(state, k, None) for k in ("face_predictor", "voice_predictor", "fusion_predictor", "transcriber")):
        await websocket.close(code=4503, reason="Models not loaded")
        return

    settings = state.settings
    orchestrator = RealtimeOrchestrator(
        face_predictor=state.face_predictor,
        voice_predictor=state.voice_predictor,
        fusion_predictor=state.fusion_predictor,
        transcriber=state.transcriber,
        target_cycle_ms=settings.target_cycle_ms,
        max_cycle_ms=settings.adaptive_cycle_max_ms,
        transcription_interval_cycles=settings.transcription_interval_cycles,
    )

    await websocket.accept()
    conn = SessionConnection(websocket, session_id, orchestrator, state.session_store)

    receive_task = asyncio.create_task(conn.receive_loop())
    cycle_task = asyncio.create_task(conn.cycle_loop())
    try:
        await asyncio.wait({receive_task, cycle_task}, return_when=asyncio.FIRST_COMPLETED)
    finally:
        conn.closed = True
        for task in (receive_task, cycle_task):
            if not task.done():
                task.cancel()
                try:
                    await task
                except (asyncio.CancelledError, Exception):
                    pass