Amir Mahla commited on
Commit
c5fe3f3
·
1 Parent(s): 8f218a6

FIX sandbox limit reached to often

Browse files
cua2-core/src/cua2_core/app.py CHANGED
@@ -21,13 +21,14 @@ async def lifespan(app: FastAPI):
21
  if not os.getenv("HF_TOKEN"):
22
  raise ValueError("HF_TOKEN is not set")
23
 
24
- num_workers = int(os.getenv("NUM_WORKERS", "1"))
 
25
 
26
  websocket_manager = WebSocketManager()
27
 
28
- sandbox_service = SandboxService()
29
 
30
- agent_service = AgentService(websocket_manager, sandbox_service, num_workers)
31
 
32
  # Start periodic cleanup of stuck sandboxes
33
  sandbox_service.start_periodic_cleanup()
 
21
  if not os.getenv("HF_TOKEN"):
22
  raise ValueError("HF_TOKEN is not set")
23
 
24
+ num_workers = int(os.getenv("NUM_WORKERS", "12"))
25
+ max_sandboxes = int(600 / num_workers)
26
 
27
  websocket_manager = WebSocketManager()
28
 
29
+ sandbox_service = SandboxService(max_sandboxes=max_sandboxes)
30
 
31
+ agent_service = AgentService(websocket_manager, sandbox_service, max_sandboxes)
32
 
33
  # Start periodic cleanup of stuck sandboxes
34
  sandbox_service.start_periodic_cleanup()
cua2-core/src/cua2_core/routes/websocket.py CHANGED
@@ -21,9 +21,10 @@ async def websocket_endpoint(websocket: WebSocket):
21
  await websocket_manager.connect(websocket)
22
 
23
  try:
24
- welcome_message = HeartbeatEvent(
25
- uuid=await agent_service.create_id_and_sandbox(websocket)
26
- )
 
27
  await websocket_manager.send_message(welcome_message, websocket)
28
 
29
  # Keep the connection alive and wait for messages
 
21
  await websocket_manager.connect(websocket)
22
 
23
  try:
24
+ # Create ID and acquire sandbox - this adds uuid to task_websockets
25
+ # If this fails, the finally block will clean up via cleanup_tasks_for_websocket
26
+ uuid = await agent_service.create_id_and_sandbox(websocket)
27
+ welcome_message = HeartbeatEvent(uuid=uuid)
28
  await websocket_manager.send_message(welcome_message, websocket)
29
 
30
  # Keep the connection alive and wait for messages
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -45,7 +45,7 @@ class AgentService:
45
  self,
46
  websocket_manager: WebSocketManager,
47
  sandbox_service: SandboxService,
48
- num_workers: int,
49
  ):
50
  self.active_tasks: dict[str, ActiveTask] = {}
51
  self.websocket_manager: WebSocketManager = websocket_manager
@@ -53,7 +53,7 @@ class AgentService:
53
  self.sandbox_service: SandboxService = sandbox_service
54
  self.last_screenshot: dict[str, tuple[Image.Image, str] | None] = {}
55
  self._lock = asyncio.Lock()
56
- self.max_sandboxes = int(600 / num_workers)
57
  self._archival_lock_file: IO[str] | None = None
58
 
59
  # Initialize archival service in dedicated process
@@ -296,141 +296,161 @@ class AgentService:
296
  # Update archival service after task removal
297
  self._update_archival_active_tasks()
298
 
299
- # Release sandbox back to the pool
300
- if sandbox:
 
301
  await self.sandbox_service.release_sandbox(message_id)
 
 
 
 
302
 
303
  async def _agent_processing(
304
  self,
305
  message_id: str,
306
  ):
307
  """Process the user task with the appropriate agent"""
 
 
 
308
 
309
- # Set up log file for this task
310
- active_task = self.active_tasks[message_id]
311
-
312
- # Ensure the directory exists
313
- os.makedirs(active_task.trace_path, exist_ok=True)
314
 
315
- # Capture the event loop reference in the async context
316
- # This will be used in the callback to safely schedule coroutines from the worker thread
317
- loop = asyncio.get_running_loop()
318
 
319
- def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
320
- assert memory_step.step_number is not None
321
 
322
- if memory_step.step_number > agent.max_steps:
323
- raise AgentStopException("Max steps reached")
324
 
325
- if self.active_tasks[message_id].traceMetadata.completed:
326
- raise AgentStopException("Task not completed")
327
 
328
- model_output = (
329
- memory_step.model_output_message.content
330
- if memory_step.model_output_message
331
- else None
332
- )
333
- if isinstance(memory_step.error, AgentMaxStepsError):
334
- model_output = memory_step.action_output
335
-
336
- thought = (
337
- model_output.split("```")[0].replace("\nAction:\n", "")
338
- if model_output
339
- and (
340
- memory_step.error is None
341
- or isinstance(memory_step.error, AgentMaxStepsError)
342
  )
343
- else None
344
- )
345
-
346
- if model_output is not None:
347
- action_sequence = model_output.split("```")[1]
348
- else:
349
- action_sequence = (
350
- """The task failed due to an error""" # TODO: To Handle in front
 
 
 
351
  )
352
 
353
- agent_actions = (
354
- AgentAction.from_function_calls(parse_function_call(action_sequence))
355
- if action_sequence
356
- else None
357
- )
358
-
359
- # if not (
360
- # agent_actions is not None
361
- # and not any(action.function_name == "wait" for action in agent_actions)
362
- # ):
363
- time.sleep(3)
364
-
365
- image, step_filename = self.last_screenshot[message_id] # type: ignore
366
- assert image is not None and step_filename is not None
367
- screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
368
- image.save(screenshot_path)
369
-
370
- buffered = BytesIO()
371
- image.save(buffered, format="PNG")
372
- image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
373
- del buffered
374
- del image
375
-
376
- if memory_step.token_usage is not None:
377
- step = AgentStep(
378
- traceId=message_id,
379
- stepId=str(memory_step.step_number),
380
- image=image_base64,
381
- thought=thought,
382
- actions=agent_actions,
383
- error=memory_step.error.message if memory_step.error else None,
384
- duration=memory_step.timing.duration,
385
- inputTokensUsed=memory_step.token_usage.input_tokens,
386
- outputTokensUsed=memory_step.token_usage.output_tokens,
387
- step_evaluation="neutral",
388
- )
389
 
390
- self.active_tasks[message_id].update_trace_metadata(
391
- step_input_tokens_used=memory_step.token_usage.input_tokens,
392
- step_output_tokens_used=memory_step.token_usage.output_tokens,
393
- step_duration=memory_step.timing.duration,
394
- step_numberOfSteps=1,
 
395
  )
396
 
397
- self.active_tasks[message_id].update_step(step)
398
-
399
- websocket = self.task_websockets.get(message_id)
400
- if websocket and websocket.client_state == WebSocketState.CONNECTED:
401
- future = asyncio.run_coroutine_threadsafe(
402
- self.websocket_manager.send_agent_progress(
403
- step=step,
404
- metadata=self.active_tasks[message_id].traceMetadata,
405
- websocket=websocket,
406
- ),
407
- loop,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  )
409
- future.result()
410
-
411
- if self.active_tasks[message_id].traceMetadata.completed:
412
- raise AgentStopException("Task not completed")
413
 
414
- step_filename = f"{message_id}-{memory_step.step_number + 1}"
415
- screenshot_bytes = agent.desktop.screenshot()
416
- original_image = Image.open(BytesIO(screenshot_bytes))
417
- image = compress_image_to_max_size(original_image, max_size_kb=500)
418
- del original_image
419
-
420
- for previous_memory_step in (
421
- agent.memory.steps
422
- ): # Remove previous screenshots from logs for lean processing
423
- if isinstance(previous_memory_step, ActionStep):
424
- previous_memory_step.observations_images = None
425
- elif isinstance(previous_memory_step, TaskStep):
426
- previous_memory_step.task_images = None
427
-
428
- memory_step.observations_images = [image.copy()]
429
-
430
- del self.last_screenshot[message_id]
431
- self.last_screenshot[message_id] = (image, step_filename)
432
 
433
- await self._agent_runner(message_id, step_callback)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
  def update_trace_step(
436
  self,
@@ -568,6 +588,9 @@ class AgentService:
568
  """
569
  Clean up all tasks associated with a disconnected websocket.
570
  This will stop the tasks and release their sandboxes.
 
 
 
571
  """
572
  tasks_to_cleanup = []
573
 
@@ -579,11 +602,13 @@ class AgentService:
579
  logger.info(
580
  f"Marking task {message_id} for cleanup due to websocket disconnect"
581
  )
 
 
582
 
583
  # Cleanup each task
584
  for message_id in tasks_to_cleanup:
585
  try:
586
- # Mark task as completed to stop the agent
587
  if message_id in self.active_tasks:
588
  self.active_tasks[message_id].update_trace_metadata(
589
  completed=True,
@@ -592,7 +617,9 @@ class AgentService:
592
  f"Stopped task {message_id} due to websocket disconnect"
593
  )
594
 
595
- # Release the sandbox immediately
 
 
596
  await self.sandbox_service.release_sandbox(message_id)
597
  logger.info(
598
  f"Released sandbox for task {message_id} due to websocket disconnect"
 
45
  self,
46
  websocket_manager: WebSocketManager,
47
  sandbox_service: SandboxService,
48
+ max_sandboxes: int,
49
  ):
50
  self.active_tasks: dict[str, ActiveTask] = {}
51
  self.websocket_manager: WebSocketManager = websocket_manager
 
53
  self.sandbox_service: SandboxService = sandbox_service
54
  self.last_screenshot: dict[str, tuple[Image.Image, str] | None] = {}
55
  self._lock = asyncio.Lock()
56
+ self.max_sandboxes = max_sandboxes
57
  self._archival_lock_file: IO[str] | None = None
58
 
59
  # Initialize archival service in dedicated process
 
296
  # Update archival service after task removal
297
  self._update_archival_active_tasks()
298
 
299
+ # Always release sandbox back to the pool, even if it's still in "creating" state
300
+ # This handles cases where acquire_sandbox was called but sandbox never became ready
301
+ try:
302
  await self.sandbox_service.release_sandbox(message_id)
303
+ except Exception as e:
304
+ logger.error(
305
+ f"Error releasing sandbox for {message_id}: {e}", exc_info=True
306
+ )
307
 
308
  async def _agent_processing(
309
  self,
310
  message_id: str,
311
  ):
312
  """Process the user task with the appropriate agent"""
313
+ try:
314
+ # Set up log file for this task
315
+ active_task = self.active_tasks[message_id]
316
 
317
+ # Ensure the directory exists
318
+ os.makedirs(active_task.trace_path, exist_ok=True)
 
 
 
319
 
320
+ # Capture the event loop reference in the async context
321
+ # This will be used in the callback to safely schedule coroutines from the worker thread
322
+ loop = asyncio.get_running_loop()
323
 
324
+ def step_callback(memory_step: ActionStep, agent: E2BVisionAgent):
325
+ assert memory_step.step_number is not None
326
 
327
+ if memory_step.step_number > agent.max_steps:
328
+ raise AgentStopException("Max steps reached")
329
 
330
+ if self.active_tasks[message_id].traceMetadata.completed:
331
+ raise AgentStopException("Task not completed")
332
 
333
+ model_output = (
334
+ memory_step.model_output_message.content
335
+ if memory_step.model_output_message
336
+ else None
 
 
 
 
 
 
 
 
 
 
337
  )
338
+ if isinstance(memory_step.error, AgentMaxStepsError):
339
+ model_output = memory_step.action_output
340
+
341
+ thought = (
342
+ model_output.split("```")[0].replace("\nAction:\n", "")
343
+ if model_output
344
+ and (
345
+ memory_step.error is None
346
+ or isinstance(memory_step.error, AgentMaxStepsError)
347
+ )
348
+ else None
349
  )
350
 
351
+ if model_output is not None:
352
+ action_sequence = model_output.split("```")[1]
353
+ else:
354
+ action_sequence = """The task failed due to an error""" # TODO: To Handle in front
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ agent_actions = (
357
+ AgentAction.from_function_calls(
358
+ parse_function_call(action_sequence)
359
+ )
360
+ if action_sequence
361
+ else None
362
  )
363
 
364
+ # if not (
365
+ # agent_actions is not None
366
+ # and not any(action.function_name == "wait" for action in agent_actions)
367
+ # ):
368
+ time.sleep(3)
369
+
370
+ image, step_filename = self.last_screenshot[message_id] # type: ignore
371
+ assert image is not None and step_filename is not None
372
+ screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
373
+ image.save(screenshot_path)
374
+
375
+ buffered = BytesIO()
376
+ image.save(buffered, format="PNG")
377
+ image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
378
+ del buffered
379
+ del image
380
+
381
+ if memory_step.token_usage is not None:
382
+ step = AgentStep(
383
+ traceId=message_id,
384
+ stepId=str(memory_step.step_number),
385
+ image=image_base64,
386
+ thought=thought,
387
+ actions=agent_actions,
388
+ error=memory_step.error.message if memory_step.error else None,
389
+ duration=memory_step.timing.duration,
390
+ inputTokensUsed=memory_step.token_usage.input_tokens,
391
+ outputTokensUsed=memory_step.token_usage.output_tokens,
392
+ step_evaluation="neutral",
393
  )
 
 
 
 
394
 
395
+ self.active_tasks[message_id].update_trace_metadata(
396
+ step_input_tokens_used=memory_step.token_usage.input_tokens,
397
+ step_output_tokens_used=memory_step.token_usage.output_tokens,
398
+ step_duration=memory_step.timing.duration,
399
+ step_numberOfSteps=1,
400
+ )
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
+ self.active_tasks[message_id].update_step(step)
403
+
404
+ websocket = self.task_websockets.get(message_id)
405
+ if websocket and websocket.client_state == WebSocketState.CONNECTED:
406
+ future = asyncio.run_coroutine_threadsafe(
407
+ self.websocket_manager.send_agent_progress(
408
+ step=step,
409
+ metadata=self.active_tasks[message_id].traceMetadata,
410
+ websocket=websocket,
411
+ ),
412
+ loop,
413
+ )
414
+ future.result()
415
+
416
+ if self.active_tasks[message_id].traceMetadata.completed:
417
+ raise AgentStopException("Task not completed")
418
+
419
+ step_filename = f"{message_id}-{memory_step.step_number + 1}"
420
+ screenshot_bytes = agent.desktop.screenshot()
421
+ original_image = Image.open(BytesIO(screenshot_bytes))
422
+ image = compress_image_to_max_size(original_image, max_size_kb=500)
423
+ del original_image
424
+
425
+ for previous_memory_step in (
426
+ agent.memory.steps
427
+ ): # Remove previous screenshots from logs for lean processing
428
+ if isinstance(previous_memory_step, ActionStep):
429
+ previous_memory_step.observations_images = None
430
+ elif isinstance(previous_memory_step, TaskStep):
431
+ previous_memory_step.task_images = None
432
+
433
+ memory_step.observations_images = [image.copy()]
434
+
435
+ del self.last_screenshot[message_id]
436
+ self.last_screenshot[message_id] = (image, step_filename)
437
+
438
+ await self._agent_runner(message_id, step_callback)
439
+ except Exception as e:
440
+ # If _agent_processing fails before _agent_runner is called,
441
+ # we still need to release the sandbox that was acquired in create_id_and_sandbox
442
+ logger.error(
443
+ f"Error in _agent_processing for {message_id}: {e}", exc_info=True
444
+ )
445
+ try:
446
+ await self.sandbox_service.release_sandbox(message_id)
447
+ except Exception as release_error:
448
+ logger.error(
449
+ f"Error releasing sandbox in _agent_processing cleanup for {message_id}: {release_error}",
450
+ exc_info=True,
451
+ )
452
+ # Re-raise to ensure error is logged
453
+ raise
454
 
455
  def update_trace_step(
456
  self,
 
588
  """
589
  Clean up all tasks associated with a disconnected websocket.
590
  This will stop the tasks and release their sandboxes.
591
+
592
+ Note: This also cleans up sandboxes that were acquired in create_id_and_sandbox
593
+ but never had a task created (e.g., if websocket disconnects before process_user_task).
594
  """
595
  tasks_to_cleanup = []
596
 
 
602
  logger.info(
603
  f"Marking task {message_id} for cleanup due to websocket disconnect"
604
  )
605
+ # Remove from task_websockets immediately to prevent double cleanup
606
+ del self.task_websockets[message_id]
607
 
608
  # Cleanup each task
609
  for message_id in tasks_to_cleanup:
610
  try:
611
+ # Mark task as completed to stop the agent (if task exists)
612
  if message_id in self.active_tasks:
613
  self.active_tasks[message_id].update_trace_metadata(
614
  completed=True,
 
617
  f"Stopped task {message_id} due to websocket disconnect"
618
  )
619
 
620
+ # Always release the sandbox, even if no task was created
621
+ # This handles the case where create_id_and_sandbox succeeded but
622
+ # process_user_task was never called
623
  await self.sandbox_service.release_sandbox(message_id)
624
  logger.info(
625
  f"Released sandbox for task {message_id} due to websocket disconnect"
cua2-core/src/cua2_core/services/agent_utils/instruction_utils/pregenerated_instructions.json CHANGED
@@ -321,4 +321,4 @@
321
  "Navigate to weather.com and enter 'New York” in the search bar to retrieve the current weather conditions for that location.",
322
  "Create a new directory named 'my_folder' in the current working directory.",
323
  "Open LibreOffice Writer, create a new blank document, type the following text: 'Automation Task Completed Successfully.', then save the document as 'Automation_Report.odt” on your desktop."
324
- ]
 
321
  "Navigate to weather.com and enter 'New York” in the search bar to retrieve the current weather conditions for that location.",
322
  "Create a new directory named 'my_folder' in the current working directory.",
323
  "Open LibreOffice Writer, create a new blank document, type the following text: 'Automation Task Completed Successfully.', then save the document as 'Automation_Report.odt” on your desktop."
324
+ ]