File size: 4,352 Bytes
cf93910
c476eae
cf93910
c476eae
 
537f246
 
 
c476eae
 
 
537f246
 
 
cf93910
 
 
 
 
 
 
 
 
 
 
c476eae
 
 
 
 
 
 
 
cf93910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c476eae
 
 
 
 
 
 
 
cf93910
 
 
 
 
 
 
 
 
 
c476eae
 
 
 
 
 
cf93910
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import { useEffect, useRef, useState, useCallback } from 'react';
import type { ModelMode, PredictionResponse } from '../types';

// Derive WebSocket base URL.
// Priority: VITE_WS_URL env var β†’ dev fallback (port 8000) β†’ same host (production).
function _defaultWsUrl(): string {
  if (import.meta.env.VITE_WS_URL) return import.meta.env.VITE_WS_URL as string;
  const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
  // In Vite dev mode the frontend is served on 5173 but FastAPI runs on 8000.
  if (import.meta.env.DEV) return `${proto}://localhost:8000`;
  // In production the backend is co-located (HF Spaces Docker).
  return `${proto}://${window.location.host}`;
}
const WS_URL = _defaultWsUrl();
const RECONNECT_BASE_MS    = 1000;
const MAX_RECONNECT_MS     = 30_000;
const MAX_SEND_RATE        = 15; // frames/sec β€” normal
const LOW_BW_SEND_RATE     = 5;  // frames/sec β€” high-latency fallback
const LOW_BW_LATENCY_MS    = 500; // threshold to activate low-bandwidth mode

export interface WebSocketState {
  lastPrediction: PredictionResponse | null;
  isConnected: boolean;
  latency: number;
  lowBandwidth: boolean;
  sendLandmarks: (
    landmarks: number[],
    options?: {
      sessionId?: string;
      modelMode?: ModelMode;
      imageB64?: string;
    },
  ) => void;
}

/**
 * WebSocket hook for sending landmark vectors and receiving predictions.
 * Implements auto-reconnect with exponential back-off and send-rate throttling.
 */
export function useWebSocket(): WebSocketState {
  const wsRef          = useRef<WebSocket | null>(null);
  const reconnectDelay = useRef(RECONNECT_BASE_MS);
  const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
  const lastSendTime   = useRef(0);

  const [lastPrediction, setLastPrediction] = useState<PredictionResponse | null>(null);
  const [isConnected,    setIsConnected]    = useState(false);
  const [latency,        setLatency]        = useState(0);
  const [lowBandwidth,   setLowBandwidth]   = useState(false);

  const inflightTs = useRef<number | null>(null);

  const connect = useCallback(() => {
    if (wsRef.current?.readyState === WebSocket.OPEN) return;

    const url = `${WS_URL}/ws/landmarks`;
    const ws  = new WebSocket(url);
    wsRef.current = ws;

    ws.onopen = () => {
      setIsConnected(true);
      reconnectDelay.current = RECONNECT_BASE_MS;
    };

    ws.onmessage = (evt) => {
      if (inflightTs.current !== null) {
        const rtt = Date.now() - inflightTs.current;
        setLatency(rtt);
        setLowBandwidth(rtt > LOW_BW_LATENCY_MS);
        inflightTs.current = null;
      }
      try {
        const data: PredictionResponse = JSON.parse(evt.data);
        if ('sign' in data) setLastPrediction(data);
      } catch {
        // ignore non-JSON messages / error frames
      }
    };

    ws.onclose = () => {
      setIsConnected(false);
      wsRef.current = null;
      // Exponential back-off reconnect
      reconnectTimer.current = setTimeout(() => {
        reconnectDelay.current = Math.min(reconnectDelay.current * 2, MAX_RECONNECT_MS);
        connect();
      }, reconnectDelay.current);
    };

    ws.onerror = (e) => {
      console.warn('WebSocket error', e);
      ws.close();
    };
  }, []);

  useEffect(() => {
    connect();
    return () => {
      if (reconnectTimer.current) clearTimeout(reconnectTimer.current);
      wsRef.current?.close();
    };
  }, [connect]);

  /** Throttled send β€” adapts to 5fps in low-bandwidth mode (latency > 500ms) */
  const sendLandmarks = useCallback((
    landmarks: number[],
    options?: {
      sessionId?: string;
      modelMode?: ModelMode;
      imageB64?: string;
    },
  ) => {
    const ws = wsRef.current;
    if (!ws || ws.readyState !== WebSocket.OPEN) return;

    const now = Date.now();
    const effectiveRate = lowBandwidth ? LOW_BW_SEND_RATE : MAX_SEND_RATE;
    const minInterval = 1000 / effectiveRate;
    if (now - lastSendTime.current < minInterval) return;
    lastSendTime.current = now;

    inflightTs.current = now;
    ws.send(JSON.stringify({
      landmarks,
      session_id: options?.sessionId ?? 'browser',
      model_mode: options?.modelMode,
      image_b64: options?.imageB64,
    }));
  }, [lowBandwidth]);

  return { lastPrediction, isConnected, latency, lowBandwidth, sendLandmarks };
}