NeerajCodz commited on
Commit
4afa792
·
1 Parent(s): 05f6bf1

feat: add WebSocket support for real-time scraper progress updates

Browse files
backend/app/api/routes/episode.py CHANGED
@@ -136,15 +136,36 @@ async def step_episode(request: StepRequest) -> StepResponse:
136
  logger.info(f"Step in episode {request.episode_id}: {request.action.action_type}")
137
 
138
  env = get_environment(request.episode_id)
 
 
 
 
139
 
140
  try:
141
  observation, reward, reward_breakdown, terminated, truncated, info = await env.step(
142
  request.action
143
  )
 
 
 
 
 
 
 
 
 
 
144
 
145
  # Clean up if episode is done
146
  if terminated or truncated:
147
  logger.info(f"Episode {request.episode_id} completed")
 
 
 
 
 
 
 
148
 
149
  return StepResponse(
150
  observation=observation,
@@ -156,6 +177,11 @@ async def step_episode(request: StepRequest) -> StepResponse:
156
  )
157
  except Exception as e:
158
  logger.error(f"Step failed: {e}")
 
 
 
 
 
159
  raise HTTPException(
160
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
161
  detail=f"Step execution failed: {str(e)}",
 
136
  logger.info(f"Step in episode {request.episode_id}: {request.action.action_type}")
137
 
138
  env = get_environment(request.episode_id)
139
+
140
+ # Get WebSocket manager for real-time updates
141
+ from app.api.routes.websocket import get_connection_manager
142
+ ws_manager = get_connection_manager()
143
 
144
  try:
145
  observation, reward, reward_breakdown, terminated, truncated, info = await env.step(
146
  request.action
147
  )
148
+
149
+ # Send real-time progress update via WebSocket
150
+ await ws_manager.send_progress_update(
151
+ episode_id=request.episode_id,
152
+ step=observation.step_number,
153
+ action_type=request.action.action_type.value,
154
+ reward=reward,
155
+ progress=observation.extraction_progress,
156
+ message=f"Executed {request.action.action_type.value}",
157
+ )
158
 
159
  # Clean up if episode is done
160
  if terminated or truncated:
161
  logger.info(f"Episode {request.episode_id} completed")
162
+ state = env.get_state()
163
+ await ws_manager.send_completion(
164
+ episode_id=request.episode_id,
165
+ success=terminated and not truncated,
166
+ total_reward=state.get("total_reward", 0.0),
167
+ extracted_data=state.get("extracted_data", {}),
168
+ )
169
 
170
  return StepResponse(
171
  observation=observation,
 
177
  )
178
  except Exception as e:
179
  logger.error(f"Step failed: {e}")
180
+ await ws_manager.send_error(
181
+ episode_id=request.episode_id,
182
+ error=str(e),
183
+ details={"action_type": request.action.action_type.value},
184
+ )
185
  raise HTTPException(
186
  status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
187
  detail=f"Step execution failed: {str(e)}",
backend/app/api/routes/websocket.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket support for real-time scraper updates."""
2
+
3
+ import asyncio
4
+ import json
5
+ import logging
6
+ from typing import Any
7
+
8
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
9
+ from fastapi.websockets import WebSocketState
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ router = APIRouter(prefix="/ws", tags=["WebSocket"])
14
+
15
+ # Store active WebSocket connections by episode_id
16
+ _active_connections: dict[str, list[WebSocket]] = {}
17
+
18
+
19
+ class ConnectionManager:
20
+ """Manage WebSocket connections for real-time updates."""
21
+
22
+ def __init__(self):
23
+ self.active_connections: dict[str, list[WebSocket]] = {}
24
+
25
+ async def connect(self, websocket: WebSocket, episode_id: str):
26
+ """Connect a new WebSocket client."""
27
+ await websocket.accept()
28
+ if episode_id not in self.active_connections:
29
+ self.active_connections[episode_id] = []
30
+ self.active_connections[episode_id].append(websocket)
31
+ logger.info(f"WebSocket connected for episode {episode_id}")
32
+
33
+ def disconnect(self, websocket: WebSocket, episode_id: str):
34
+ """Disconnect a WebSocket client."""
35
+ if episode_id in self.active_connections:
36
+ if websocket in self.active_connections[episode_id]:
37
+ self.active_connections[episode_id].remove(websocket)
38
+ if not self.active_connections[episode_id]:
39
+ del self.active_connections[episode_id]
40
+ logger.info(f"WebSocket disconnected for episode {episode_id}")
41
+
42
+ async def send_personal_message(self, message: dict[str, Any], websocket: WebSocket):
43
+ """Send a message to a specific client."""
44
+ try:
45
+ if websocket.client_state == WebSocketState.CONNECTED:
46
+ await websocket.send_json(message)
47
+ except Exception as e:
48
+ logger.error(f"Error sending personal message: {e}")
49
+
50
+ async def broadcast(self, message: dict[str, Any], episode_id: str):
51
+ """Broadcast a message to all clients watching an episode."""
52
+ if episode_id not in self.active_connections:
53
+ return
54
+
55
+ disconnected = []
56
+ for connection in self.active_connections[episode_id]:
57
+ try:
58
+ if connection.client_state == WebSocketState.CONNECTED:
59
+ await connection.send_json(message)
60
+ else:
61
+ disconnected.append(connection)
62
+ except Exception as e:
63
+ logger.error(f"Error broadcasting to client: {e}")
64
+ disconnected.append(connection)
65
+
66
+ # Clean up disconnected clients
67
+ for conn in disconnected:
68
+ self.disconnect(conn, episode_id)
69
+
70
+ async def send_progress_update(
71
+ self,
72
+ episode_id: str,
73
+ step: int,
74
+ action_type: str,
75
+ reward: float,
76
+ progress: float,
77
+ message: str | None = None,
78
+ ):
79
+ """Send a progress update for an episode."""
80
+ update = {
81
+ "type": "progress",
82
+ "episode_id": episode_id,
83
+ "step": step,
84
+ "action_type": action_type,
85
+ "reward": reward,
86
+ "progress": progress,
87
+ "message": message,
88
+ "timestamp": asyncio.get_event_loop().time(),
89
+ }
90
+ await self.broadcast(update, episode_id)
91
+
92
+ async def send_error(self, episode_id: str, error: str, details: dict[str, Any] | None = None):
93
+ """Send an error message."""
94
+ message = {
95
+ "type": "error",
96
+ "episode_id": episode_id,
97
+ "error": error,
98
+ "details": details or {},
99
+ "timestamp": asyncio.get_event_loop().time(),
100
+ }
101
+ await self.broadcast(message, episode_id)
102
+
103
+ async def send_completion(
104
+ self,
105
+ episode_id: str,
106
+ success: bool,
107
+ total_reward: float,
108
+ extracted_data: dict[str, Any],
109
+ ):
110
+ """Send a completion notification."""
111
+ message = {
112
+ "type": "completion",
113
+ "episode_id": episode_id,
114
+ "success": success,
115
+ "total_reward": total_reward,
116
+ "extracted_data": extracted_data,
117
+ "timestamp": asyncio.get_event_loop().time(),
118
+ }
119
+ await self.broadcast(message, episode_id)
120
+
121
+
122
+ # Global connection manager
123
+ manager = ConnectionManager()
124
+
125
+
126
+ @router.websocket("/episode/{episode_id}")
127
+ async def websocket_episode(websocket: WebSocket, episode_id: str):
128
+ """
129
+ WebSocket endpoint for receiving real-time updates about an episode.
130
+
131
+ Clients can connect to this endpoint to receive updates about:
132
+ - Action execution progress
133
+ - Reward changes
134
+ - Extraction progress
135
+ - Errors
136
+ - Episode completion
137
+
138
+ Args:
139
+ websocket: WebSocket connection
140
+ episode_id: ID of the episode to watch
141
+ """
142
+ await manager.connect(websocket, episode_id)
143
+
144
+ try:
145
+ # Send initial connection confirmation
146
+ await manager.send_personal_message(
147
+ {
148
+ "type": "connected",
149
+ "episode_id": episode_id,
150
+ "message": f"Connected to episode {episode_id}",
151
+ },
152
+ websocket,
153
+ )
154
+
155
+ # Keep connection alive and handle incoming messages
156
+ while True:
157
+ try:
158
+ # Receive messages from client (e.g., subscription updates)
159
+ data = await asyncio.wait_for(
160
+ websocket.receive_text(),
161
+ timeout=30.0, # 30 second timeout
162
+ )
163
+
164
+ try:
165
+ message = json.loads(data)
166
+
167
+ # Handle ping/pong for keep-alive
168
+ if message.get("type") == "ping":
169
+ await manager.send_personal_message(
170
+ {"type": "pong", "timestamp": asyncio.get_event_loop().time()},
171
+ websocket,
172
+ )
173
+
174
+ except json.JSONDecodeError:
175
+ logger.warning(f"Invalid JSON received: {data}")
176
+
177
+ except asyncio.TimeoutError:
178
+ # Send a ping to check if client is still connected
179
+ try:
180
+ await manager.send_personal_message(
181
+ {"type": "ping", "timestamp": asyncio.get_event_loop().time()},
182
+ websocket,
183
+ )
184
+ except Exception:
185
+ # Client disconnected
186
+ break
187
+
188
+ except WebSocketDisconnect:
189
+ logger.info(f"Client disconnected from episode {episode_id}")
190
+ except Exception as e:
191
+ logger.error(f"WebSocket error for episode {episode_id}: {e}")
192
+ finally:
193
+ manager.disconnect(websocket, episode_id)
194
+
195
+
196
+ def get_connection_manager() -> ConnectionManager:
197
+ """Get the global connection manager instance."""
198
+ return manager
backend/app/main.py CHANGED
@@ -137,6 +137,10 @@ def create_app() -> FastAPI:
137
  # Import and include providers router
138
  from app.api.routes import providers
139
  app.include_router(providers.router, prefix=api_prefix, tags=["Providers"])
 
 
 
 
140
 
141
  # Serve static files (frontend build)
142
  static_dir = Path(__file__).parent.parent / "static"
 
137
  # Import and include providers router
138
  from app.api.routes import providers
139
  app.include_router(providers.router, prefix=api_prefix, tags=["Providers"])
140
+
141
+ # Import and include WebSocket router
142
+ from app.api.routes import websocket
143
+ app.include_router(websocket.router, tags=["WebSocket"])
144
 
145
  # Serve static files (frontend build)
146
  static_dir = Path(__file__).parent.parent / "static"