Sanket-Setu / frontend /src /hooks /useWebSocket.ts
devrajsinh2012's picture
fix: model paths (.pth), landmark normalization, WS URL, GPU fallback; add ModelSelector; mobile layout improvements
c476eae
import { useEffect, useRef, useState, useCallback } from 'react';
import type { ModelMode, PredictionResponse } from '../types';
// Derive WebSocket base URL.
// Priority: VITE_WS_URL env var β†’ dev fallback (port 8000) β†’ same host (production).
function _defaultWsUrl(): string {
if (import.meta.env.VITE_WS_URL) return import.meta.env.VITE_WS_URL as string;
const proto = window.location.protocol === 'https:' ? 'wss' : 'ws';
// In Vite dev mode the frontend is served on 5173 but FastAPI runs on 8000.
if (import.meta.env.DEV) return `${proto}://localhost:8000`;
// In production the backend is co-located (HF Spaces Docker).
return `${proto}://${window.location.host}`;
}
const WS_URL = _defaultWsUrl();
const RECONNECT_BASE_MS = 1000;
const MAX_RECONNECT_MS = 30_000;
const MAX_SEND_RATE = 15; // frames/sec β€” normal
const LOW_BW_SEND_RATE = 5; // frames/sec β€” high-latency fallback
const LOW_BW_LATENCY_MS = 500; // threshold to activate low-bandwidth mode
export interface WebSocketState {
lastPrediction: PredictionResponse | null;
isConnected: boolean;
latency: number;
lowBandwidth: boolean;
sendLandmarks: (
landmarks: number[],
options?: {
sessionId?: string;
modelMode?: ModelMode;
imageB64?: string;
},
) => void;
}
/**
* WebSocket hook for sending landmark vectors and receiving predictions.
* Implements auto-reconnect with exponential back-off and send-rate throttling.
*/
export function useWebSocket(): WebSocketState {
const wsRef = useRef<WebSocket | null>(null);
const reconnectDelay = useRef(RECONNECT_BASE_MS);
const reconnectTimer = useRef<ReturnType<typeof setTimeout> | null>(null);
const lastSendTime = useRef(0);
const [lastPrediction, setLastPrediction] = useState<PredictionResponse | null>(null);
const [isConnected, setIsConnected] = useState(false);
const [latency, setLatency] = useState(0);
const [lowBandwidth, setLowBandwidth] = useState(false);
const inflightTs = useRef<number | null>(null);
const connect = useCallback(() => {
if (wsRef.current?.readyState === WebSocket.OPEN) return;
const url = `${WS_URL}/ws/landmarks`;
const ws = new WebSocket(url);
wsRef.current = ws;
ws.onopen = () => {
setIsConnected(true);
reconnectDelay.current = RECONNECT_BASE_MS;
};
ws.onmessage = (evt) => {
if (inflightTs.current !== null) {
const rtt = Date.now() - inflightTs.current;
setLatency(rtt);
setLowBandwidth(rtt > LOW_BW_LATENCY_MS);
inflightTs.current = null;
}
try {
const data: PredictionResponse = JSON.parse(evt.data);
if ('sign' in data) setLastPrediction(data);
} catch {
// ignore non-JSON messages / error frames
}
};
ws.onclose = () => {
setIsConnected(false);
wsRef.current = null;
// Exponential back-off reconnect
reconnectTimer.current = setTimeout(() => {
reconnectDelay.current = Math.min(reconnectDelay.current * 2, MAX_RECONNECT_MS);
connect();
}, reconnectDelay.current);
};
ws.onerror = (e) => {
console.warn('WebSocket error', e);
ws.close();
};
}, []);
useEffect(() => {
connect();
return () => {
if (reconnectTimer.current) clearTimeout(reconnectTimer.current);
wsRef.current?.close();
};
}, [connect]);
/** Throttled send β€” adapts to 5fps in low-bandwidth mode (latency > 500ms) */
const sendLandmarks = useCallback((
landmarks: number[],
options?: {
sessionId?: string;
modelMode?: ModelMode;
imageB64?: string;
},
) => {
const ws = wsRef.current;
if (!ws || ws.readyState !== WebSocket.OPEN) return;
const now = Date.now();
const effectiveRate = lowBandwidth ? LOW_BW_SEND_RATE : MAX_SEND_RATE;
const minInterval = 1000 / effectiveRate;
if (now - lastSendTime.current < minInterval) return;
lastSendTime.current = now;
inflightTs.current = now;
ws.send(JSON.stringify({
landmarks,
session_id: options?.sessionId ?? 'browser',
model_mode: options?.modelMode,
image_b64: options?.imageB64,
}));
}, [lowBandwidth]);
return { lastPrediction, isConnected, latency, lowBandwidth, sendLandmarks };
}