modelforge-backend / backend /services /socket_manager.py
ModelForge CI
deploy: 2026-06-19 19:24 UTC
6761f70
Raw
History Blame Contribute Delete
2.64 kB
"""
WebSocket Connection Manager for real-time training metrics.
"""
import sys
from pathlib import Path
from typing import Dict, Set
from fastapi import WebSocket
from collections import defaultdict
sys.path.append(str(Path(__file__).parent.parent))
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
self.message_history = defaultdict(list)
self._state: Dict[str, dict] = {}
async def connect(self, client_id: str, websocket: WebSocket):
await websocket.accept()
if client_id not in self.active_connections:
self.active_connections[client_id] = set()
self.active_connections[client_id].add(websocket)
if self.message_history[client_id]:
for msg in self.message_history[client_id]:
await websocket.send_json(msg)
def disconnect(self, client_id: str, websocket: WebSocket | None = None):
if client_id not in self.active_connections:
return
if websocket is not None:
self.active_connections[client_id].discard(websocket)
else:
self.active_connections[client_id].clear()
if not self.active_connections[client_id]:
del self.active_connections[client_id]
def get_state(self, client_id: str) -> dict | None:
return self._state.get(client_id)
async def broadcast_json(self, client_id: str, data: dict):
self._state[client_id] = data
self.message_history[client_id].append(data)
if len(self.message_history[client_id]) > 100:
self.message_history[client_id] = self.message_history[client_id][-100:]
if client_id in self.active_connections:
disconnected = set()
for connection in self.active_connections[client_id]:
try:
await connection.send_json(data)
except Exception:
disconnected.add(connection)
for dead in disconnected:
self.disconnect(client_id, dead)
async def send_status(self, client_id: str, status: str, message: str = ""):
await self.broadcast_json(client_id, {"type": "status", "status": status, "message": message})
async def send_completion(self, client_id: str, success: bool, model_path: str = ""):
await self.broadcast_json(client_id, {
"type": "completion",
"success": success,
"model_path": model_path,
"message": "Training completed successfully!" if success else "Training failed"
})
manager = ConnectionManager()