Pulastya B commited on
Commit
711046d
Β·
1 Parent(s): 1470b93

Replace polling with Server-Sent Events (SSE) for real-time progress updates

Browse files
FRRONTEEEND/components/ChatInterface.tsx CHANGED
@@ -51,7 +51,7 @@ export const ChatInterface: React.FC<{ onBack: () => void }> = ({ onBack }) => {
51
  const [showAssets, setShowAssets] = useState(false);
52
  const fileInputRef = useRef<HTMLInputElement>(null);
53
  const scrollRef = useRef<HTMLDivElement>(null);
54
- const progressIntervalRef = useRef<NodeJS.Timeout | null>(null);
55
 
56
  const activeSession = sessions.find(s => s.id === activeSessionId) || sessions[0];
57
 
@@ -61,48 +61,82 @@ export const ChatInterface: React.FC<{ onBack: () => void }> = ({ onBack }) => {
61
  }
62
  }, [activeSession.messages, isTyping]);
63
 
64
- // Poll for progress ONLY when isTyping is true
65
  useEffect(() => {
66
  if (!isTyping) {
67
- if (progressIntervalRef.current) {
68
- clearInterval(progressIntervalRef.current);
69
- progressIntervalRef.current = null;
 
70
  }
71
  setCurrentStep('');
72
  return;
73
  }
74
 
 
 
75
  const sessionKey = activeSessionId || 'default';
 
76
 
77
- const pollProgress = async () => {
 
 
 
 
 
 
 
 
 
 
78
  try {
79
- const API_URL = window.location.origin;
80
- const progressResponse = await fetch(`${API_URL}/api/progress/${sessionKey}`);
81
- if (progressResponse.ok) {
82
- const progressData = await progressResponse.json();
83
- const steps = progressData.steps || [];
84
-
85
- if (steps.length > 0) {
86
- const latestStep = steps[steps.length - 1];
87
- // Format tool name nicely
88
- const toolName = latestStep.tool
89
- .replace(/_/g, ' ')
90
- .replace(/\b\w/g, (l: string) => l.toUpperCase());
91
- setCurrentStep(`Executing: ${toolName}`);
92
- }
93
- }
94
  } catch (err) {
95
- console.error('Progress polling error:', err);
96
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  };
98
 
99
- // Start polling every 1 second when workflow is active
100
- progressIntervalRef.current = setInterval(pollProgress, 1000);
101
 
 
102
  return () => {
103
- if (progressIntervalRef.current) {
104
- clearInterval(progressIntervalRef.current);
105
- progressIntervalRef.current = null;
106
  }
107
  };
108
  }, [isTyping, activeSessionId]);
 
51
  const [showAssets, setShowAssets] = useState(false);
52
  const fileInputRef = useRef<HTMLInputElement>(null);
53
  const scrollRef = useRef<HTMLDivElement>(null);
54
+ const eventSourceRef = useRef<EventSource | null>(null);
55
 
56
  const activeSession = sessions.find(s => s.id === activeSessionId) || sessions[0];
57
 
 
61
  }
62
  }, [activeSession.messages, isTyping]);
63
 
64
+ // Connect to SSE when workflow starts, disconnect when it completes
65
  useEffect(() => {
66
  if (!isTyping) {
67
+ // Close SSE connection when workflow completes
68
+ if (eventSourceRef.current) {
69
+ eventSourceRef.current.close();
70
+ eventSourceRef.current = null;
71
  }
72
  setCurrentStep('');
73
  return;
74
  }
75
 
76
+ // Connect to SSE stream
77
+ const API_URL = window.location.origin;
78
  const sessionKey = activeSessionId || 'default';
79
+ const eventSource = new EventSource(`${API_URL}/api/progress/stream/${sessionKey}`);
80
 
81
+ eventSource.onopen = () => {
82
+ console.log('βœ… SSE connection established');
83
+ };
84
+
85
+ // Handle connection event
86
+ eventSource.addEventListener('connected', (e) => {
87
+ console.log('Connected to progress stream:', e.data);
88
+ });
89
+
90
+ // Handle tool start events
91
+ eventSource.addEventListener('tool_start', (e) => {
92
  try {
93
+ const data = JSON.parse(e.data);
94
+ setCurrentStep(data.message || `πŸ”§ Executing: ${data.tool}`);
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  } catch (err) {
96
+ console.error('Error parsing tool_start event:', err);
97
  }
98
+ });
99
+
100
+ // Handle tool complete events
101
+ eventSource.addEventListener('tool_complete', (e) => {
102
+ try {
103
+ const data = JSON.parse(e.data);
104
+ setCurrentStep(data.message || `βœ“ Completed: ${data.tool}`);
105
+ } catch (err) {
106
+ console.error('Error parsing tool_complete event:', err);
107
+ }
108
+ });
109
+
110
+ // Handle tool error events
111
+ eventSource.addEventListener('tool_error', (e) => {
112
+ try {
113
+ const data = JSON.parse(e.data);
114
+ setCurrentStep(data.message || `❌ Failed: ${data.tool}`);
115
+ } catch (err) {
116
+ console.error('Error parsing tool_error event:', err);
117
+ }
118
+ });
119
+
120
+ // Handle analysis completion
121
+ eventSource.addEventListener('analysis_complete', (e) => {
122
+ console.log('βœ… Analysis completed');
123
+ setIsTyping(false); // This will trigger cleanup
124
+ });
125
+
126
+ // Handle errors
127
+ eventSource.onerror = (err) => {
128
+ console.error('SSE error:', err);
129
+ eventSource.close();
130
+ eventSourceRef.current = null;
131
  };
132
 
133
+ eventSourceRef.current = eventSource;
 
134
 
135
+ // Cleanup on unmount or when isTyping changes to false
136
  return () => {
137
+ if (eventSourceRef.current) {
138
+ eventSourceRef.current.close();
139
+ eventSourceRef.current = null;
140
  }
141
  };
142
  }, [isTyping, activeSessionId]);
src/api/app.py CHANGED
@@ -17,10 +17,12 @@ from dotenv import load_dotenv
17
  load_dotenv()
18
 
19
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
20
- from fastapi.responses import JSONResponse, FileResponse
21
  from fastapi.staticfiles import StaticFiles
22
  from fastapi.middleware.cors import CORSMiddleware
23
  from pydantic import BaseModel
 
 
24
 
25
  # Import from parent package
26
  from src.orchestrator import DataScienceCopilot
@@ -49,9 +51,81 @@ app.add_middleware(
49
  # Agent itself is stateless - no conversation memory between requests
50
  agent: Optional[DataScienceCopilot] = None
51
 
52
- # Global progress tracking (in-memory for simplicity)
53
  progress_store: Dict[str, List[Dict[str, Any]]] = {}
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # Mount static files for React frontend
56
  frontend_path = Path(__file__).parent.parent.parent / "FRRONTEEEND" / "dist"
57
  if frontend_path.exists():
@@ -95,13 +169,75 @@ async def root():
95
 
96
  @app.get("/api/progress/{session_id}")
97
  async def get_progress(session_id: str):
98
- """Get progress updates for a specific session."""
99
  return {
100
  "session_id": session_id,
101
- "steps": progress_store.get(session_id, [])
 
102
  }
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  @app.get("/health")
106
  async def health_check():
107
  """
@@ -172,12 +308,31 @@ async def run_analysis(
172
  progress_store[session_key] = []
173
 
174
  def progress_callback(tool_name: str, status: str):
175
- """Callback to track progress"""
 
176
  progress_store[session_key].append({
177
  "tool": tool_name,
178
  "status": status,
179
  "timestamp": time.time()
180
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # Set progress callback on existing agent
183
  agent.progress_callback = progress_callback
@@ -194,6 +349,20 @@ async def run_analysis(
194
 
195
  logger.info(f"Follow-up analysis completed: {result.get('status')}")
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  # Make result JSON serializable
198
  def make_json_serializable(obj):
199
  if isinstance(obj, dict):
@@ -267,12 +436,31 @@ async def run_analysis(
267
  progress_store[session_key] = []
268
 
269
  def progress_callback(tool_name: str, status: str):
270
- """Callback to track progress"""
 
271
  progress_store[session_key].append({
272
  "tool": tool_name,
273
  "status": status,
274
  "timestamp": time.time()
275
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  # Set progress callback on existing agent
278
  agent.progress_callback = progress_callback
@@ -289,6 +477,20 @@ async def run_analysis(
289
 
290
  logger.info(f"Analysis completed: {result.get('status')}")
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  # Filter out non-JSON-serializable objects (like matplotlib/plotly Figures)
293
  def make_json_serializable(obj):
294
  """Recursively convert objects to JSON-serializable format."""
 
17
  load_dotenv()
18
 
19
  from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
20
+ from fastapi.responses import JSONResponse, FileResponse, StreamingResponse
21
  from fastapi.staticfiles import StaticFiles
22
  from fastapi.middleware.cors import CORSMiddleware
23
  from pydantic import BaseModel
24
+ import asyncio
25
+ import json
26
 
27
  # Import from parent package
28
  from src.orchestrator import DataScienceCopilot
 
51
  # Agent itself is stateless - no conversation memory between requests
52
  agent: Optional[DataScienceCopilot] = None
53
 
54
+ # Global progress tracking with SSE support
55
  progress_store: Dict[str, List[Dict[str, Any]]] = {}
56
 
57
+ # SSE event queues for real-time streaming
58
+ class ProgressEventManager:
59
+ """Manages SSE connections and progress events for real-time updates."""
60
+
61
+ def __init__(self):
62
+ self.active_streams: Dict[str, List[asyncio.Queue]] = {}
63
+ self.session_status: Dict[str, Dict[str, Any]] = {}
64
+
65
+ def create_stream(self, session_id: str) -> asyncio.Queue:
66
+ """Create a new SSE stream for a session."""
67
+ if session_id not in self.active_streams:
68
+ self.active_streams[session_id] = []
69
+
70
+ queue = asyncio.Queue()
71
+ self.active_streams[session_id].append(queue)
72
+ return queue
73
+
74
+ def remove_stream(self, session_id: str, queue: asyncio.Queue):
75
+ """Remove an SSE stream when client disconnects."""
76
+ if session_id in self.active_streams:
77
+ try:
78
+ self.active_streams[session_id].remove(queue)
79
+ if not self.active_streams[session_id]:
80
+ del self.active_streams[session_id]
81
+ except (ValueError, KeyError):
82
+ pass
83
+
84
+ async def send_event(self, session_id: str, event_type: str, data: Dict[str, Any]):
85
+ """Send an event to all connected clients for a session."""
86
+ if session_id not in self.active_streams:
87
+ return
88
+
89
+ # Store current status
90
+ self.session_status[session_id] = {
91
+ "type": event_type,
92
+ "data": data,
93
+ "timestamp": time.time()
94
+ }
95
+
96
+ # Send to all connected streams
97
+ dead_queues = []
98
+ for queue in self.active_streams[session_id]:
99
+ try:
100
+ await asyncio.wait_for(queue.put((event_type, data)), timeout=1.0)
101
+ except (asyncio.TimeoutError, Exception):
102
+ dead_queues.append(queue)
103
+
104
+ # Clean up dead queues
105
+ for queue in dead_queues:
106
+ self.remove_stream(session_id, queue)
107
+
108
+ def get_current_status(self, session_id: str) -> Optional[Dict[str, Any]]:
109
+ """Get the current status for a session."""
110
+ return self.session_status.get(session_id)
111
+
112
+ def clear_session(self, session_id: str):
113
+ """Clear all data for a session."""
114
+ if session_id in self.active_streams:
115
+ # Close all queues
116
+ for queue in self.active_streams[session_id]:
117
+ try:
118
+ queue.put_nowait(("complete", {}))
119
+ except:
120
+ pass
121
+ del self.active_streams[session_id]
122
+
123
+ if session_id in self.session_status:
124
+ del self.session_status[session_id]
125
+
126
+ # Global event manager
127
+ event_manager = ProgressEventManager()
128
+
129
  # Mount static files for React frontend
130
  frontend_path = Path(__file__).parent.parent.parent / "FRRONTEEEND" / "dist"
131
  if frontend_path.exists():
 
169
 
170
  @app.get("/api/progress/{session_id}")
171
  async def get_progress(session_id: str):
172
+ """Get progress updates for a specific session (legacy polling endpoint)."""
173
  return {
174
  "session_id": session_id,
175
+ "steps": progress_store.get(session_id, []),
176
+ "current": event_manager.get_current_status(session_id)
177
  }
178
 
179
 
180
+ @app.get("/api/progress/stream/{session_id}")
181
+ async def stream_progress(session_id: str):
182
+ """Stream real-time progress updates using Server-Sent Events (SSE).
183
+
184
+ This replaces the polling mechanism with a persistent connection that
185
+ receives events as they happen during workflow execution.
186
+
187
+ Events:
188
+ - tool_start: When a tool begins execution
189
+ - tool_complete: When a tool finishes successfully
190
+ - tool_error: When a tool fails
191
+ - analysis_complete: When the entire workflow finishes
192
+ - status_update: General status messages
193
+ """
194
+ async def event_generator():
195
+ queue = event_manager.create_stream(session_id)
196
+
197
+ try:
198
+ # Send initial connection event
199
+ yield f"event: connected\ndata: {{\"session_id\": \"{session_id}\"}}\n\n"
200
+
201
+ # Send current status if exists
202
+ current = event_manager.get_current_status(session_id)
203
+ if current:
204
+ yield f"event: {current['type']}\ndata: {json.dumps(current['data'])}\n\n"
205
+
206
+ # Stream events as they arrive
207
+ while True:
208
+ try:
209
+ # Wait for next event with timeout
210
+ event_type, data = await asyncio.wait_for(queue.get(), timeout=30.0)
211
+
212
+ # Check for completion signal
213
+ if event_type == "complete":
214
+ yield f"event: analysis_complete\ndata: {{\"status\": \"completed\"}}\n\n"
215
+ break
216
+
217
+ # Send the event
218
+ yield f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
219
+
220
+ except asyncio.TimeoutError:
221
+ # Send keepalive ping
222
+ yield f": keepalive\n\n"
223
+ continue
224
+
225
+ except asyncio.CancelledError:
226
+ logger.info(f"SSE stream cancelled for session {session_id}")
227
+ finally:
228
+ event_manager.remove_stream(session_id, queue)
229
+
230
+ return StreamingResponse(
231
+ event_generator(),
232
+ media_type="text/event-stream",
233
+ headers={
234
+ "Cache-Control": "no-cache",
235
+ "Connection": "keep-alive",
236
+ "X-Accel-Buffering": "no" # Disable nginx buffering
237
+ }
238
+ )
239
+
240
+
241
  @app.get("/health")
242
  async def health_check():
243
  """
 
308
  progress_store[session_key] = []
309
 
310
  def progress_callback(tool_name: str, status: str):
311
+ """Callback to track progress and send SSE events"""
312
+ # Store in legacy progress store
313
  progress_store[session_key].append({
314
  "tool": tool_name,
315
  "status": status,
316
  "timestamp": time.time()
317
  })
318
+
319
+ # Send SSE event asynchronously
320
+ event_type = "tool_start" if status == "running" else "tool_complete" if status == "completed" else "tool_error"
321
+ event_data = {
322
+ "tool": tool_name,
323
+ "status": status,
324
+ "message": f"πŸ”§ Executing: {tool_name.replace('_', ' ').title()}" if status == "running" else
325
+ f"βœ“ Completed: {tool_name.replace('_', ' ').title()}" if status == "completed" else
326
+ f"❌ Failed: {tool_name.replace('_', ' ').title()}",
327
+ "timestamp": time.time()
328
+ }
329
+
330
+ # Schedule the async event send
331
+ try:
332
+ asyncio.create_task(event_manager.send_event(session_key, event_type, event_data))
333
+ except RuntimeError:
334
+ # If no event loop, we're in sync context - that's ok, legacy polling still works
335
+ pass
336
 
337
  # Set progress callback on existing agent
338
  agent.progress_callback = progress_callback
 
349
 
350
  logger.info(f"Follow-up analysis completed: {result.get('status')}")
351
 
352
+ # Send completion event via SSE
353
+ try:
354
+ asyncio.create_task(event_manager.send_event(
355
+ session_key,
356
+ "analysis_complete",
357
+ {
358
+ "status": result.get("status"),
359
+ "message": "βœ… Analysis completed successfully!",
360
+ "timestamp": time.time()
361
+ }
362
+ ))
363
+ except RuntimeError:
364
+ pass
365
+
366
  # Make result JSON serializable
367
  def make_json_serializable(obj):
368
  if isinstance(obj, dict):
 
436
  progress_store[session_key] = []
437
 
438
  def progress_callback(tool_name: str, status: str):
439
+ """Callback to track progress and send SSE events"""
440
+ # Store in legacy progress store
441
  progress_store[session_key].append({
442
  "tool": tool_name,
443
  "status": status,
444
  "timestamp": time.time()
445
  })
446
+
447
+ # Send SSE event asynchronously
448
+ event_type = "tool_start" if status == "running" else "tool_complete" if status == "completed" else "tool_error"
449
+ event_data = {
450
+ "tool": tool_name,
451
+ "status": status,
452
+ "message": f"πŸ”§ Executing: {tool_name.replace('_', ' ').title()}" if status == "running" else
453
+ f"βœ“ Completed: {tool_name.replace('_', ' ').title()}" if status == "completed" else
454
+ f"❌ Failed: {tool_name.replace('_', ' ').title()}",
455
+ "timestamp": time.time()
456
+ }
457
+
458
+ # Schedule the async event send
459
+ try:
460
+ asyncio.create_task(event_manager.send_event(session_key, event_type, event_data))
461
+ except RuntimeError:
462
+ # If no event loop, we're in sync context - that's ok, legacy polling still works
463
+ pass
464
 
465
  # Set progress callback on existing agent
466
  agent.progress_callback = progress_callback
 
477
 
478
  logger.info(f"Analysis completed: {result.get('status')}")
479
 
480
+ # Send completion event via SSE
481
+ try:
482
+ asyncio.create_task(event_manager.send_event(
483
+ session_key,
484
+ "analysis_complete",
485
+ {
486
+ "status": result.get("status"),
487
+ "message": "βœ… Analysis completed successfully!",
488
+ "timestamp": time.time()
489
+ }
490
+ ))
491
+ except RuntimeError:
492
+ pass
493
+
494
  # Filter out non-JSON-serializable objects (like matplotlib/plotly Figures)
495
  def make_json_serializable(obj):
496
  """Recursively convert objects to JSON-serializable format."""