Spaces:
Running
Running
A-Mahla
commited on
FIX callback (#26)
Browse files
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -25,7 +25,7 @@ from cua2_core.websocket.websocket_manager import WebSocketException, WebSocketM
|
|
| 25 |
from e2b_desktop import Sandbox, TimeoutException
|
| 26 |
from fastapi import WebSocket
|
| 27 |
from PIL import Image
|
| 28 |
-
from smolagents import ActionStep,
|
| 29 |
from starlette.websockets import WebSocketState
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
|
@@ -50,7 +50,7 @@ class AgentService:
|
|
| 50 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 51 |
self.task_websockets: dict[str, WebSocket] = {}
|
| 52 |
self.sandbox_service: SandboxService = sandbox_service
|
| 53 |
-
self.last_screenshot: dict[str,
|
| 54 |
self._lock = asyncio.Lock()
|
| 55 |
self.max_sandboxes = int(600 / num_workers)
|
| 56 |
self._archival_lock_file: IO[str] | None = None
|
|
@@ -222,10 +222,7 @@ class AgentService:
|
|
| 222 |
step_filename = f"{message_id}-1"
|
| 223 |
screenshot_bytes = agent.desktop.screenshot()
|
| 224 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 225 |
-
|
| 226 |
-
image.save(screenshot_path)
|
| 227 |
-
|
| 228 |
-
self.last_screenshot[message_id] = image
|
| 229 |
|
| 230 |
await asyncio.to_thread(
|
| 231 |
agent.run,
|
|
@@ -364,32 +361,16 @@ class AgentService:
|
|
| 364 |
):
|
| 365 |
time.sleep(3)
|
| 366 |
|
| 367 |
-
image = self.last_screenshot[message_id]
|
| 368 |
-
assert image is not None
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
agent.memory.steps
|
| 372 |
-
): # Remove previous screenshots from logs for lean processing
|
| 373 |
-
if (
|
| 374 |
-
isinstance(previous_memory_step, ActionStep)
|
| 375 |
-
and previous_memory_step.step_number is not None
|
| 376 |
-
and previous_memory_step.step_number <= memory_step.step_number - 1
|
| 377 |
-
):
|
| 378 |
-
previous_memory_step.observations_images = None
|
| 379 |
-
elif isinstance(previous_memory_step, TaskStep):
|
| 380 |
-
previous_memory_step.task_images = None
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
image.save(buffered, format="PNG")
|
| 388 |
-
image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
|
| 389 |
-
del buffered
|
| 390 |
-
del image
|
| 391 |
-
else:
|
| 392 |
-
image_base64 = None
|
| 393 |
|
| 394 |
step = AgentStep(
|
| 395 |
traceId=message_id,
|
|
@@ -428,13 +409,26 @@ class AgentService:
|
|
| 428 |
if self.active_tasks[message_id].traceMetadata.completed:
|
| 429 |
raise AgentStopException("Task not completed")
|
| 430 |
|
| 431 |
-
step_filename = f"{message_id}-{memory_step.step_number}"
|
| 432 |
screenshot_bytes = agent.desktop.screenshot()
|
| 433 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
del self.last_screenshot[message_id]
|
| 437 |
-
self.last_screenshot[message_id] = image
|
| 438 |
|
| 439 |
await self._agent_runner(message_id, step_callback)
|
| 440 |
|
|
|
|
| 25 |
from e2b_desktop import Sandbox, TimeoutException
|
| 26 |
from fastapi import WebSocket
|
| 27 |
from PIL import Image
|
| 28 |
+
from smolagents import ActionStep, AgentMaxStepsError, TaskStep
|
| 29 |
from starlette.websockets import WebSocketState
|
| 30 |
|
| 31 |
logger = logging.getLogger(__name__)
|
|
|
|
| 50 |
self.websocket_manager: WebSocketManager = websocket_manager
|
| 51 |
self.task_websockets: dict[str, WebSocket] = {}
|
| 52 |
self.sandbox_service: SandboxService = sandbox_service
|
| 53 |
+
self.last_screenshot: dict[str, tuple[Image.Image, str] | None] = {}
|
| 54 |
self._lock = asyncio.Lock()
|
| 55 |
self.max_sandboxes = int(600 / num_workers)
|
| 56 |
self._archival_lock_file: IO[str] | None = None
|
|
|
|
| 222 |
step_filename = f"{message_id}-1"
|
| 223 |
screenshot_bytes = agent.desktop.screenshot()
|
| 224 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 225 |
+
self.last_screenshot[message_id] = (image, step_filename)
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
await asyncio.to_thread(
|
| 228 |
agent.run,
|
|
|
|
| 361 |
):
|
| 362 |
time.sleep(3)
|
| 363 |
|
| 364 |
+
image, step_filename = self.last_screenshot[message_id] # type: ignore
|
| 365 |
+
assert image is not None and step_filename is not None
|
| 366 |
+
screenshot_path = os.path.join(agent.data_dir, f"{step_filename}.png")
|
| 367 |
+
image.save(screenshot_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
buffered = BytesIO()
|
| 370 |
+
image.save(buffered, format="PNG")
|
| 371 |
+
image_base64 = f"data:image/png;base64,{base64.b64encode(buffered.getvalue()).decode('utf-8')}"
|
| 372 |
+
del buffered
|
| 373 |
+
del image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
step = AgentStep(
|
| 376 |
traceId=message_id,
|
|
|
|
| 409 |
if self.active_tasks[message_id].traceMetadata.completed:
|
| 410 |
raise AgentStopException("Task not completed")
|
| 411 |
|
| 412 |
+
step_filename = f"{message_id}-{memory_step.step_number + 1}"
|
| 413 |
screenshot_bytes = agent.desktop.screenshot()
|
| 414 |
image = Image.open(BytesIO(screenshot_bytes))
|
| 415 |
+
|
| 416 |
+
for previous_memory_step in (
|
| 417 |
+
agent.memory.steps
|
| 418 |
+
): # Remove previous screenshots from logs for lean processing
|
| 419 |
+
if (
|
| 420 |
+
isinstance(previous_memory_step, ActionStep)
|
| 421 |
+
and previous_memory_step.step_number is not None
|
| 422 |
+
and previous_memory_step.step_number <= memory_step.step_number
|
| 423 |
+
):
|
| 424 |
+
previous_memory_step.observations_images = None
|
| 425 |
+
elif isinstance(previous_memory_step, TaskStep):
|
| 426 |
+
previous_memory_step.task_images = None
|
| 427 |
+
|
| 428 |
+
memory_step.observations_images = [image.copy()]
|
| 429 |
+
|
| 430 |
del self.last_screenshot[message_id]
|
| 431 |
+
self.last_screenshot[message_id] = (image, step_filename)
|
| 432 |
|
| 433 |
await self._agent_runner(message_id, step_callback)
|
| 434 |
|