A-Mahla commited on
Commit
d47b6ac
·
unverified ·
1 Parent(s): 1a862ff

FIX callback (#26)

Browse files
cua2-core/src/cua2_core/services/agent_service.py CHANGED
@@ -25,7 +25,7 @@ from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketM
25
  from e2b_desktop import Sandbox, TimeoutException
26
  from fastapi import WebSocket
27
  from PIL import Image
28
- from smolagents import ActionStep, AgentImage, AgentMaxStepsError, TaskStep
29
  from starlette.websockets import WebSocketState
30
 
31
  logger = logging.getLogger(__name__)
@@ -50,7 +50,7 @@ class AgentService:
50
  self.websocket_manager: WebSocketManager = websocket_manager
51
  self.task_websockets: dict[str, WebSocket] = {}
52
  self.sandbox_service: SandboxService = sandbox_service
53
- self.last_screenshot: dict[str, AgentImage | None] = {}
54
  self._lock = asyncio.Lock()
55
  self.max_sandboxes = int(600 / num_workers)
56
  self._archival_lock_file: IO[str] | None = None
@@ -222,10 +222,7 @@ class AgentService:
222
  step_filename = f"{message_id}-1"
223
  screenshot_bytes = agent.desktop.screenshot()
224
  image = Image.open(BytesIO(screenshot_bytes))
225
- screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
226
- image.save(screenshot_path)
227
-
228
- self.last_screenshot[message_id] = image
229
 
230
  await asyncio.to_thread(
231
  agent.run,
@@ -364,32 +361,16 @@ class AgentService:
364
  ):
365
  time.sleep(3)
366
 
367
- image = self.last_screenshot[message_id]
368
- assert image is not None
369
-
370
- for previous_memory_step in (
371
- agent.memory.steps
372
- ): # Remove previous screenshots from logs for lean processing
373
- if (
374
- isinstance(previous_memory_step, ActionStep)
375
- and previous_memory_step.step_number is not None
376
- and previous_memory_step.step_number <= memory_step.step_number - 1
377
- ):
378
- previous_memory_step.observations_images = None
379
- elif isinstance(previous_memory_step, TaskStep):
380
- previous_memory_step.task_images = None
381
 
382
- memory_step.observations_images = [image.copy()]
383
-
384
- if memory_step.observations_images:
385
- image = memory_step.observations_images[0]
386
- buffered = BytesIO()
387
- image.save(buffered, format="PNG")
388
- image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
389
- del buffered
390
- del image
391
- else:
392
- image_base64 = None
393
 
394
  step = AgentStep(
395
  traceId=message_id,
@@ -428,13 +409,26 @@ class AgentService:
428
  if self.active_tasks[message_id].traceMetadata.completed:
429
  raise AgentStopException("Task not completed")
430
 
431
- step_filename = f"{message_id}-{memory_step.step_number}"
432
  screenshot_bytes = agent.desktop.screenshot()
433
  image = Image.open(BytesIO(screenshot_bytes))
434
- screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
435
- image.save(screenshot_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  del self.last_screenshot[message_id]
437
- self.last_screenshot[message_id] = image
438
 
439
  await self._agent_runner(message_id, step_callback)
440
 
 
25
  from e2b_desktop import Sandbox, TimeoutException
26
  from fastapi import WebSocket
27
  from PIL import Image
28
+ from smolagents import ActionStep, AgentMaxStepsError, TaskStep
29
  from starlette.websockets import WebSocketState
30
 
31
  logger = logging.getLogger(__name__)
 
50
  self.websocket_manager: WebSocketManager = websocket_manager
51
  self.task_websockets: dict[str, WebSocket] = {}
52
  self.sandbox_service: SandboxService = sandbox_service
53
+ self.last_screenshot: dict[str, tuple[Image.Image, str] | None] = {}
54
  self._lock = asyncio.Lock()
55
  self.max_sandboxes = int(600 / num_workers)
56
  self._archival_lock_file: IO[str] | None = None
 
222
  step_filename = f"{message_id}-1"
223
  screenshot_bytes = agent.desktop.screenshot()
224
  image = Image.open(BytesIO(screenshot_bytes))
225
+ self.last_screenshot[message_id] = (image, step_filename)
 
 
 
226
 
227
  await asyncio.to_thread(
228
  agent.run,
 
361
  ):
362
  time.sleep(3)
363
 
364
+ image, step_filename = self.last_screenshot[message_id] # type: ignore
365
+ assert image is not None and step_filename is not None
366
+ screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
367
+ image.save(screenshot_path)
 
 
 
 
 
 
 
 
 
 
368
 
369
+ buffered = BytesIO()
370
+ image.save(buffered, format="PNG")
371
+ image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
372
+ del buffered
373
+ del image
 
 
 
 
 
 
374
 
375
  step = AgentStep(
376
  traceId=message_id,
 
409
  if self.active_tasks[message_id].traceMetadata.completed:
410
  raise AgentStopException("Task not completed")
411
 
412
+ step_filename = f"{message_id}-{memory_step.step_number + 1}"
413
  screenshot_bytes = agent.desktop.screenshot()
414
  image = Image.open(BytesIO(screenshot_bytes))
415
+
416
+ for previous_memory_step in (
417
+ agent.memory.steps
418
+ ): # Remove previous screenshots from logs for lean processing
419
+ if (
420
+ isinstance(previous_memory_step, ActionStep)
421
+ and previous_memory_step.step_number is not None
422
+ and previous_memory_step.step_number <= memory_step.step_number
423
+ ):
424
+ previous_memory_step.observations_images = None
425
+ elif isinstance(previous_memory_step, TaskStep):
426
+ previous_memory_step.task_images = None
427
+
428
+ memory_step.observations_images = [image.copy()]
429
+
430
  del self.last_screenshot[message_id]
431
+ self.last_screenshot[message_id] = (image, step_filename)
432
 
433
  await self._agent_runner(message_id, step_callback)
434