Spaces:
Sleeping
Sleeping
File size: 4,352 Bytes
cf93910 c476eae cf93910 c476eae 537f246 c476eae 537f246 cf93910 c476eae cf93910 c476eae cf93910 c476eae cf93910 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | import { useEffect, useRef, useState, useCallback } from 'react';
import type { ModelMode, PredictionResponse } from '../types';
// Derive WebSocket base URL.
// Priority: VITE_WS_URL env var β dev fallback (port 8000) β same host (production).
function _defaultWsUrl(): string {
if (import.meta.env.VITE_WS_URL) return import.meta.env.VITE_WS_URL as string;
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
// In Vite dev mode the frontend is served on 5173 but FastAPI runs on 8000.
if (import.meta.env.DEV) return `${proto}://localhost:8000`;
// In production the backend is co-located (HF Spaces Docker).
return `${proto}://${window.location.host}`;
}
const WS_URL = _defaultWsUrl();
const RECONNECT_BASE_MS = 1000;
const MAX_RECONNECT_MS = 30_000;
const MAX_SEND_RATE = 15; // frames/sec β normal
const LOW_BW_SEND_RATE = 5; // frames/sec β high-latency fallback
const LOW_BW_LATENCY_MS = 500; // threshold to activate low-bandwidth mode
export interface WebSocketState {
lastPrediction: PredictionResponse | null;
isConnected: boolean;
latency: number;
lowBandwidth: boolean;
sendLandmarks: (
landmarks: number[],
options?: {
sessionId?: string;
modelMode?: ModelMode;
imageB64?: string;
},
) => void;
}
/**
* WebSocket hook for sending landmark vectors and receiving predictions.
* Implements auto-reconnect with exponential back-off and send-rate throttling.
*/
export function useWebSocket(): WebSocketState {
const wsRef = useRef<WebSocket | null>(null);
const reconnectDelay = useRef(RECONNECT_BASE_MS);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const lastSendTime = useRef(0);
const [lastPrediction, setLastPrediction] = useState<PredictionResponse | null>(null);
const [isConnected, setIsConnected] = useState(false);
const [latency, setLatency] = useState(0);
const [lowBandwidth, setLowBandwidth] = useState(false);
const inflightTs = useRef<number | null>(null);
const connect = useCallback(() => {
if (wsRef.current?.readyState === WebSocket.OPEN) return;
const url = `${WS_URL}/ws/landmarks`;
const ws = new WebSocket(url);
wsRef.current = ws;
ws.onopen = () => {
setIsConnected(true);
reconnectDelay.current = RECONNECT_BASE_MS;
};
ws.onmessage = (evt) => {
if (inflightTs.current !== null) {
const rtt = Date.now() - inflightTs.current;
setLatency(rtt);
setLowBandwidth(rtt > LOW_BW_LATENCY_MS);
inflightTs.current = null;
}
try {
const data: PredictionResponse = JSON.parse(evt.data);
if ('sign' in data) setLastPrediction(data);
} catch {
// ignore non-JSON messages / error frames
}
};
ws.onclose = () => {
setIsConnected(false);
wsRef.current = null;
// Exponential back-off reconnect
reconnectTimer.current = setTimeout(() => {
reconnectDelay.current = Math.min(reconnectDelay.current * 2, MAX_RECONNECT_MS);
connect();
}, reconnectDelay.current);
};
ws.onerror = (e) => {
console.warn('WebSocket error', e);
ws.close();
};
}, []);
useEffect(() => {
connect();
return () => {
if (reconnectTimer.current) clearTimeout(reconnectTimer.current);
wsRef.current?.close();
};
}, [connect]);
/** Throttled send β adapts to 5fps in low-bandwidth mode (latency > 500ms) */
const sendLandmarks = useCallback((
landmarks: number[],
options?: {
sessionId?: string;
modelMode?: ModelMode;
imageB64?: string;
},
) => {
const ws = wsRef.current;
if (!ws || ws.readyState !== WebSocket.OPEN) return;
const now = Date.now();
const effectiveRate = lowBandwidth ? LOW_BW_SEND_RATE : MAX_SEND_RATE;
const minInterval = 1000 / effectiveRate;
if (now - lastSendTime.current < minInterval) return;
lastSendTime.current = now;
inflightTs.current = now;
ws.send(JSON.stringify({
landmarks,
session_id: options?.sessionId ?? 'browser',
model_mode: options?.modelMode,
image_b64: options?.imageB64,
}));
}, [lowBandwidth]);
return { lastPrediction, isConnected, latency, lowBandwidth, sendLandmarks };
}
|