fairrelay-production / brain /app /api /websocket.py
lordvisorad's picture
Upload brain/app/api/websocket.py with huggingface_hub
ae8ec4c verified
"""
WebSocket Infrastructure for Real-Time Updates.
Provides:
- /ws/tracking — Live vehicle GPS positions
- /ws/shipments — Shipment status change notifications
- /ws/dashboard — KPI refresh stream
Uses FastAPI native WebSocket support with connection management.
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Dict, List, Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
logger = logging.getLogger("fairrelay.websocket")
router = APIRouter(tags=["WebSocket"])
class ConnectionManager:
"""Manages WebSocket connections grouped by channel."""
def __init__(self):
self._connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, channel: str):
await websocket.accept()
if channel not in self._connections:
self._connections[channel] = set()
self._connections[channel].add(websocket)
logger.info(f"WS connected: {channel} (total: {len(self._connections[channel])})")
def disconnect(self, websocket: WebSocket, channel: str):
if channel in self._connections:
self._connections[channel].discard(websocket)
logger.info(f"WS disconnected: {channel} (remaining: {len(self._connections[channel])})")
async def broadcast(self, channel: str, message: dict):
"""Broadcast message to all connections in a channel."""
if channel not in self._connections:
return
dead = set()
for ws in self._connections[channel]:
try:
await ws.send_json(message)
except Exception:
dead.add(ws)
# Clean up dead connections
self._connections[channel] -= dead
@property
def stats(self) -> Dict[str, int]:
return {ch: len(conns) for ch, conns in self._connections.items()}
# Global connection manager
manager = ConnectionManager()
@router.websocket("/ws/tracking")
async def ws_tracking(websocket: WebSocket):
"""
Live vehicle tracking stream.
Clients receive GPS updates for active vehicles.
Format: {"type": "gps_update", "vehicles": [...]}
"""
await manager.connect(websocket, "tracking")
try:
while True:
# Keep connection alive — wait for client pings or push updates
data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
# Client can send {"action": "subscribe", "vehicle_ids": [...]} to filter
if data:
try:
msg = json.loads(data)
if msg.get("action") == "ping":
await websocket.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()})
except json.JSONDecodeError:
pass
except (WebSocketDisconnect, asyncio.TimeoutError):
pass
finally:
manager.disconnect(websocket, "tracking")
@router.websocket("/ws/shipments")
async def ws_shipments(websocket: WebSocket):
"""
Shipment status change notifications.
Format: {"type": "status_change", "shipment_id": "...", "old_status": "...", "new_status": "..."}
"""
await manager.connect(websocket, "shipments")
try:
while True:
data = await asyncio.wait_for(websocket.receive_text(), timeout=60.0)
if data:
try:
msg = json.loads(data)
if msg.get("action") == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except (WebSocketDisconnect, asyncio.TimeoutError):
pass
finally:
manager.disconnect(websocket, "shipments")
@router.websocket("/ws/dashboard")
async def ws_dashboard(websocket: WebSocket):
"""
Dashboard KPI refresh stream.
Pushes updated metrics every 30 seconds.
Format: {"type": "kpi_update", "metrics": {...}}
"""
await manager.connect(websocket, "dashboard")
try:
while True:
data = await asyncio.wait_for(websocket.receive_text(), timeout=60.0)
if data:
try:
msg = json.loads(data)
if msg.get("action") == "ping":
await websocket.send_json({"type": "pong"})
except json.JSONDecodeError:
pass
except (WebSocketDisconnect, asyncio.TimeoutError):
pass
finally:
manager.disconnect(websocket, "dashboard")
@router.get("/ws/status", tags=["WebSocket"])
async def ws_status():
"""Get WebSocket connection stats."""
return {"connections": manager.stats, "channels": list(manager.stats.keys())}
# ═══ Broadcast Utilities (call from other services) ═══
async def broadcast_gps_update(vehicles: List[Dict]):
"""Push GPS update to all tracking subscribers."""
await manager.broadcast("tracking", {
"type": "gps_update",
"timestamp": datetime.utcnow().isoformat(),
"vehicles": vehicles,
})
async def broadcast_shipment_status(shipment_id: str, old_status: str, new_status: str, metadata: dict = None):
"""Push shipment status change to subscribers."""
await manager.broadcast("shipments", {
"type": "status_change",
"timestamp": datetime.utcnow().isoformat(),
"shipment_id": shipment_id,
"old_status": old_status,
"new_status": new_status,
"metadata": metadata or {},
})
async def broadcast_kpi_update(metrics: Dict):
"""Push KPI update to dashboard subscribers."""
await manager.broadcast("dashboard", {
"type": "kpi_update",
"timestamp": datetime.utcnow().isoformat(),
"metrics": metrics,
})
async def broadcast_alert(alert_type: str, message: str, severity: str = "warning", data: dict = None):
"""Push alert to all dashboard subscribers."""
await manager.broadcast("dashboard", {
"type": "alert",
"timestamp": datetime.utcnow().isoformat(),
"alert_type": alert_type,
"message": message,
"severity": severity,
"data": data or {},
})