Spaces:
Sleeping
Sleeping
File size: 6,921 Bytes
4afa792 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 | """WebSocket support for real-time scraper updates."""
import asyncio
import json
import logging
from typing import Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["WebSocket"])
# Store active WebSocket connections by episode_id
_active_connections: dict[str, list[WebSocket]] = {}
class ConnectionManager:
"""Manage WebSocket connections for real-time updates."""
def __init__(self):
self.active_connections: dict[str, list[WebSocket]] = {}
async def connect(self, websocket: WebSocket, episode_id: str):
"""Connect a new WebSocket client."""
await websocket.accept()
if episode_id not in self.active_connections:
self.active_connections[episode_id] = []
self.active_connections[episode_id].append(websocket)
logger.info(f"WebSocket connected for episode {episode_id}")
def disconnect(self, websocket: WebSocket, episode_id: str):
"""Disconnect a WebSocket client."""
if episode_id in self.active_connections:
if websocket in self.active_connections[episode_id]:
self.active_connections[episode_id].remove(websocket)
if not self.active_connections[episode_id]:
del self.active_connections[episode_id]
logger.info(f"WebSocket disconnected for episode {episode_id}")
async def send_personal_message(self, message: dict[str, Any], websocket: WebSocket):
"""Send a message to a specific client."""
try:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending personal message: {e}")
async def broadcast(self, message: dict[str, Any], episode_id: str):
"""Broadcast a message to all clients watching an episode."""
if episode_id not in self.active_connections:
return
disconnected = []
for connection in self.active_connections[episode_id]:
try:
if connection.client_state == WebSocketState.CONNECTED:
await connection.send_json(message)
else:
disconnected.append(connection)
except Exception as e:
logger.error(f"Error broadcasting to client: {e}")
disconnected.append(connection)
# Clean up disconnected clients
for conn in disconnected:
self.disconnect(conn, episode_id)
async def send_progress_update(
self,
episode_id: str,
step: int,
action_type: str,
reward: float,
progress: float,
message: str | None = None,
):
"""Send a progress update for an episode."""
update = {
"type": "progress",
"episode_id": episode_id,
"step": step,
"action_type": action_type,
"reward": reward,
"progress": progress,
"message": message,
"timestamp": asyncio.get_event_loop().time(),
}
await self.broadcast(update, episode_id)
async def send_error(self, episode_id: str, error: str, details: dict[str, Any] | None = None):
"""Send an error message."""
message = {
"type": "error",
"episode_id": episode_id,
"error": error,
"details": details or {},
"timestamp": asyncio.get_event_loop().time(),
}
await self.broadcast(message, episode_id)
async def send_completion(
self,
episode_id: str,
success: bool,
total_reward: float,
extracted_data: dict[str, Any],
):
"""Send a completion notification."""
message = {
"type": "completion",
"episode_id": episode_id,
"success": success,
"total_reward": total_reward,
"extracted_data": extracted_data,
"timestamp": asyncio.get_event_loop().time(),
}
await self.broadcast(message, episode_id)
# Global connection manager
manager = ConnectionManager()
@router.websocket("/episode/{episode_id}")
async def websocket_episode(websocket: WebSocket, episode_id: str):
"""
WebSocket endpoint for receiving real-time updates about an episode.
Clients can connect to this endpoint to receive updates about:
- Action execution progress
- Reward changes
- Extraction progress
- Errors
- Episode completion
Args:
websocket: WebSocket connection
episode_id: ID of the episode to watch
"""
await manager.connect(websocket, episode_id)
try:
# Send initial connection confirmation
await manager.send_personal_message(
{
"type": "connected",
"episode_id": episode_id,
"message": f"Connected to episode {episode_id}",
},
websocket,
)
# Keep connection alive and handle incoming messages
while True:
try:
# Receive messages from client (e.g., subscription updates)
data = await asyncio.wait_for(
websocket.receive_text(),
timeout=30.0, # 30 second timeout
)
try:
message = json.loads(data)
# Handle ping/pong for keep-alive
if message.get("type") == "ping":
await manager.send_personal_message(
{"type": "pong", "timestamp": asyncio.get_event_loop().time()},
websocket,
)
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {data}")
except asyncio.TimeoutError:
# Send a ping to check if client is still connected
try:
await manager.send_personal_message(
{"type": "ping", "timestamp": asyncio.get_event_loop().time()},
websocket,
)
except Exception:
# Client disconnected
break
except WebSocketDisconnect:
logger.info(f"Client disconnected from episode {episode_id}")
except Exception as e:
logger.error(f"WebSocket error for episode {episode_id}: {e}")
finally:
manager.disconnect(websocket, episode_id)
def get_connection_manager() -> ConnectionManager:
"""Get the global connection manager instance."""
return manager
|