Pulastya B commited on
Commit
e93629e
·
1 Parent(s): 3c36cfe

Fix JSON serialization error for numpy types in SSE stream

Browse files
Files changed (1) hide show
  1. src/api/app.py +22 -3
src/api/app.py CHANGED
@@ -23,6 +23,7 @@ from fastapi.middleware.cors import CORSMiddleware
23
  from pydantic import BaseModel
24
  import asyncio
25
  import json
 
26
 
27
  # Import from parent package
28
  from src.orchestrator import DataScienceCopilot
@@ -32,6 +33,24 @@ from src.progress_manager import progress_manager
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger(__name__)
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # Initialize FastAPI
36
  app = FastAPI(
37
  title="Data Science Agent API",
@@ -207,13 +226,13 @@ async def stream_progress(session_id: str):
207
  'session_id': session_id
208
  }
209
  print(f"[SSE] SENDING connection event to client")
210
- yield f"data: {json.dumps(connection_event)}\n\n"
211
 
212
  # Send any existing history first (for reconnections)
213
  history = progress_manager.get_history(session_id)
214
  print(f"[SSE] Sending {len(history[-10:])} history events")
215
  for event in history[-10:]: # Send last 10 events
216
- yield f"data: {json.dumps(event)}\n\n"
217
 
218
  print(f"[SSE] Starting event stream loop for session {session_id}")
219
 
@@ -222,7 +241,7 @@ async def stream_progress(session_id: str):
222
  if not queue.empty():
223
  event = queue.get_nowait()
224
  print(f"[SSE] GOT event from queue: {event.get('type')}")
225
- yield f"data: {json.dumps(event)}\n\n"
226
 
227
  # Check if analysis is complete
228
  if event.get('type') == 'analysis_complete':
 
23
  from pydantic import BaseModel
24
  import asyncio
25
  import json
26
+ import numpy as np
27
 
28
  # Import from parent package
29
  from src.orchestrator import DataScienceCopilot
 
33
  logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger(__name__)
35
 
36
+ # JSON serializer that handles numpy types
37
+ def safe_json_dumps(obj):
38
+ """Convert object to JSON string, handling numpy types."""
39
+ def convert(o):
40
+ if isinstance(o, (np.integer, np.int64, np.int32)):
41
+ return int(o)
42
+ elif isinstance(o, (np.floating, np.float64, np.float32)):
43
+ return float(o)
44
+ elif isinstance(o, np.ndarray):
45
+ return o.tolist()
46
+ elif isinstance(o, dict):
47
+ return {k: convert(v) for k, v in o.items()}
48
+ elif isinstance(o, (list, tuple)):
49
+ return [convert(item) for item in o]
50
+ return o
51
+
52
+ return json.dumps(convert(obj))
53
+
54
  # Initialize FastAPI
55
  app = FastAPI(
56
  title="Data Science Agent API",
 
226
  'session_id': session_id
227
  }
228
  print(f"[SSE] SENDING connection event to client")
229
+ yield f"data: {safe_json_dumps(connection_event)}\n\n"
230
 
231
  # Send any existing history first (for reconnections)
232
  history = progress_manager.get_history(session_id)
233
  print(f"[SSE] Sending {len(history[-10:])} history events")
234
  for event in history[-10:]: # Send last 10 events
235
+ yield f"data: {safe_json_dumps(event)}\n\n"
236
 
237
  print(f"[SSE] Starting event stream loop for session {session_id}")
238
 
 
241
  if not queue.empty():
242
  event = queue.get_nowait()
243
  print(f"[SSE] GOT event from queue: {event.get('type')}")
244
+ yield f"data: {safe_json_dumps(event)}\n\n"
245
 
246
  # Check if analysis is complete
247
  if event.get('type') == 'analysis_complete':