""" WebSocket Infrastructure for Real-Time Updates. Provides: - /ws/tracking — Live vehicle GPS positions - /ws/shipments — Shipment status change notifications - /ws/dashboard — KPI refresh stream Uses FastAPI native WebSocket support with connection management. """ import asyncio import json import logging from datetime import datetime from typing import Dict, List, Set from fastapi import APIRouter, WebSocket, WebSocketDisconnect logger = logging.getLogger("fairrelay.websocket") router = APIRouter(tags=["WebSocket"]) class ConnectionManager: """Manages WebSocket connections grouped by channel.""" def __init__(self): self._connections: Dict[str, Set[WebSocket]] = {} async def connect(self, websocket: WebSocket, channel: str): await websocket.accept() if channel not in self._connections: self._connections[channel] = set() self._connections[channel].add(websocket) logger.info(f"WS connected: {channel} (total: {len(self._connections[channel])})") def disconnect(self, websocket: WebSocket, channel: str): if channel in self._connections: self._connections[channel].discard(websocket) logger.info(f"WS disconnected: {channel} (remaining: {len(self._connections[channel])})") async def broadcast(self, channel: str, message: dict): """Broadcast message to all connections in a channel.""" if channel not in self._connections: return dead = set() for ws in self._connections[channel]: try: await ws.send_json(message) except Exception: dead.add(ws) # Clean up dead connections self._connections[channel] -= dead @property def stats(self) -> Dict[str, int]: return {ch: len(conns) for ch, conns in self._connections.items()} # Global connection manager manager = ConnectionManager() @router.websocket("/ws/tracking") async def ws_tracking(websocket: WebSocket): """ Live vehicle tracking stream. Clients receive GPS updates for active vehicles. Format: {"type": "gps_update", "vehicles": [...]} """ await manager.connect(websocket, "tracking") try: while True: # Keep connection alive — wait for client pings or push updates data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) # Client can send {"action": "subscribe", "vehicle_ids": [...]} to filter if data: try: msg = json.loads(data) if msg.get("action") == "ping": await websocket.send_json({"type": "pong", "ts": datetime.utcnow().isoformat()}) except json.JSONDecodeError: pass except (WebSocketDisconnect, asyncio.TimeoutError): pass finally: manager.disconnect(websocket, "tracking") @router.websocket("/ws/shipments") async def ws_shipments(websocket: WebSocket): """ Shipment status change notifications. Format: {"type": "status_change", "shipment_id": "...", "old_status": "...", "new_status": "..."} """ await manager.connect(websocket, "shipments") try: while True: data = await asyncio.wait_for(websocket.receive_text(), timeout=60.0) if data: try: msg = json.loads(data) if msg.get("action") == "ping": await websocket.send_json({"type": "pong"}) except json.JSONDecodeError: pass except (WebSocketDisconnect, asyncio.TimeoutError): pass finally: manager.disconnect(websocket, "shipments") @router.websocket("/ws/dashboard") async def ws_dashboard(websocket: WebSocket): """ Dashboard KPI refresh stream. Pushes updated metrics every 30 seconds. Format: {"type": "kpi_update", "metrics": {...}} """ await manager.connect(websocket, "dashboard") try: while True: data = await asyncio.wait_for(websocket.receive_text(), timeout=60.0) if data: try: msg = json.loads(data) if msg.get("action") == "ping": await websocket.send_json({"type": "pong"}) except json.JSONDecodeError: pass except (WebSocketDisconnect, asyncio.TimeoutError): pass finally: manager.disconnect(websocket, "dashboard") @router.get("/ws/status", tags=["WebSocket"]) async def ws_status(): """Get WebSocket connection stats.""" return {"connections": manager.stats, "channels": list(manager.stats.keys())} # ═══ Broadcast Utilities (call from other services) ═══ async def broadcast_gps_update(vehicles: List[Dict]): """Push GPS update to all tracking subscribers.""" await manager.broadcast("tracking", { "type": "gps_update", "timestamp": datetime.utcnow().isoformat(), "vehicles": vehicles, }) async def broadcast_shipment_status(shipment_id: str, old_status: str, new_status: str, metadata: dict = None): """Push shipment status change to subscribers.""" await manager.broadcast("shipments", { "type": "status_change", "timestamp": datetime.utcnow().isoformat(), "shipment_id": shipment_id, "old_status": old_status, "new_status": new_status, "metadata": metadata or {}, }) async def broadcast_kpi_update(metrics: Dict): """Push KPI update to dashboard subscribers.""" await manager.broadcast("dashboard", { "type": "kpi_update", "timestamp": datetime.utcnow().isoformat(), "metrics": metrics, }) async def broadcast_alert(alert_type: str, message: str, severity: str = "warning", data: dict = None): """Push alert to all dashboard subscribers.""" await manager.broadcast("dashboard", { "type": "alert", "timestamp": datetime.utcnow().isoformat(), "alert_type": alert_type, "message": message, "severity": severity, "data": data or {}, })