Spaces:
Running
Running
Amir Mahla
commited on
Commit
·
975f40e
1
Parent(s):
3cf734e
FIX race condition
Browse files
cua2-core/src/cua2_core/models/models.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
-
import threading
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import Annotated, Literal, Optional
|
| 6 |
from uuid import uuid4
|
|
@@ -269,51 +269,58 @@ class ActiveTask(BaseModel):
|
|
| 269 |
timestamp: datetime = datetime.now()
|
| 270 |
steps: list[AgentStep] = []
|
| 271 |
traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
|
| 272 |
-
_file_lock:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
@property
|
| 275 |
def trace_path(self):
|
| 276 |
"""Trace path"""
|
| 277 |
return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
|
| 278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
@model_validator(mode="after")
|
| 280 |
def store_model(self):
|
| 281 |
-
"""Validate model ID"""
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
os.makedirs(self.trace_path, exist_ok=True)
|
| 285 |
-
with open(f"{self.trace_path}/tasks.json", "w") as f:
|
| 286 |
-
json.dump(
|
| 287 |
-
self.model_dump(
|
| 288 |
-
mode="json",
|
| 289 |
-
exclude={"_file_locks"},
|
| 290 |
-
context={"actions_as_json": True, "image_as_path": True},
|
| 291 |
-
),
|
| 292 |
-
f,
|
| 293 |
-
indent=2,
|
| 294 |
-
)
|
| 295 |
return self
|
| 296 |
|
| 297 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
"""Update step"""
|
| 299 |
-
with self.
|
| 300 |
if int(step.stepId) <= len(self.steps):
|
| 301 |
self.steps[int(step.stepId) - 1] = step
|
| 302 |
else:
|
| 303 |
self.steps.append(step)
|
| 304 |
self.traceMetadata.numberOfSteps = len(self.steps)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
exclude={"_file_locks"},
|
| 310 |
-
context={"actions_as_json": True, "image_as_path": True},
|
| 311 |
-
),
|
| 312 |
-
f,
|
| 313 |
-
indent=2,
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
def update_trace_metadata(
|
| 317 |
self,
|
| 318 |
step_input_tokens_used: int | None = None,
|
| 319 |
step_output_tokens_used: int | None = None,
|
|
@@ -327,7 +334,7 @@ class ActiveTask(BaseModel):
|
|
| 327 |
user_evaluation: Literal["success", "failed", "not_evaluated"] | None = None,
|
| 328 |
):
|
| 329 |
"""Update trace metadata"""
|
| 330 |
-
with self.
|
| 331 |
if step_input_tokens_used is not None:
|
| 332 |
self.traceMetadata.inputTokensUsed += step_input_tokens_used
|
| 333 |
if step_output_tokens_used is not None:
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
import json
|
| 3 |
import os
|
|
|
|
| 4 |
from datetime import datetime
|
| 5 |
from typing import Annotated, Literal, Optional
|
| 6 |
from uuid import uuid4
|
|
|
|
| 269 |
timestamp: datetime = datetime.now()
|
| 270 |
steps: list[AgentStep] = []
|
| 271 |
traceMetadata: AgentTraceMetadata = AgentTraceMetadata()
|
| 272 |
+
_file_lock: asyncio.Lock | None = PrivateAttr(default=None)
|
| 273 |
+
|
| 274 |
+
def _get_lock(self) -> asyncio.Lock:
|
| 275 |
+
"""Get or create the async lock (lazy initialization)"""
|
| 276 |
+
if self._file_lock is None:
|
| 277 |
+
self._file_lock = asyncio.Lock()
|
| 278 |
+
return self._file_lock
|
| 279 |
|
| 280 |
@property
|
| 281 |
def trace_path(self):
|
| 282 |
"""Trace path"""
|
| 283 |
return f"data/trace-{self.message_id}-{self.model_id.replace('/', '-')}"
|
| 284 |
|
| 285 |
+
def _write_to_file_sync(self):
|
| 286 |
+
"""Synchronous file write helper (used in async context via to_thread)"""
|
| 287 |
+
self.traceMetadata.traceId = self.message_id
|
| 288 |
+
os.makedirs(self.trace_path, exist_ok=True)
|
| 289 |
+
with open(f"{self.trace_path}/tasks.json", "w") as f:
|
| 290 |
+
json.dump(
|
| 291 |
+
self.model_dump(
|
| 292 |
+
mode="json",
|
| 293 |
+
exclude={"_file_lock", "_lock_initialized"},
|
| 294 |
+
context={"actions_as_json": True, "image_as_path": True},
|
| 295 |
+
),
|
| 296 |
+
f,
|
| 297 |
+
indent=2,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
@model_validator(mode="after")
|
| 301 |
def store_model(self):
|
| 302 |
+
"""Validate model ID - creates directory, but file write is deferred to async method"""
|
| 303 |
+
self.traceMetadata.traceId = self.message_id
|
| 304 |
+
os.makedirs(self.trace_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
return self
|
| 306 |
|
| 307 |
+
async def save_to_file(self):
|
| 308 |
+
"""Async method to save task data to file"""
|
| 309 |
+
async with self._get_lock():
|
| 310 |
+
await asyncio.to_thread(self._write_to_file_sync)
|
| 311 |
+
|
| 312 |
+
async def update_step(self, step: AgentStep):
|
| 313 |
"""Update step"""
|
| 314 |
+
async with self._get_lock():
|
| 315 |
if int(step.stepId) <= len(self.steps):
|
| 316 |
self.steps[int(step.stepId) - 1] = step
|
| 317 |
else:
|
| 318 |
self.steps.append(step)
|
| 319 |
self.traceMetadata.numberOfSteps = len(self.steps)
|
| 320 |
+
# Use to_thread for file I/O to avoid blocking
|
| 321 |
+
await asyncio.to_thread(self._write_to_file_sync)
|
| 322 |
+
|
| 323 |
+
async def update_trace_metadata(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
self,
|
| 325 |
step_input_tokens_used: int | None = None,
|
| 326 |
step_output_tokens_used: int | None = None,
|
|
|
|
| 334 |
user_evaluation: Literal["success", "failed", "not_evaluated"] | None = None,
|
| 335 |
):
|
| 336 |
"""Update trace metadata"""
|
| 337 |
+
async with self._get_lock():
|
| 338 |
if step_input_tokens_used is not None:
|
| 339 |
self.traceMetadata.inputTokensUsed += step_input_tokens_used
|
| 340 |
if step_output_tokens_used is not None:
|
cua2-core/src/cua2_core/routes/routes.py
CHANGED
|
@@ -74,7 +74,7 @@ async def update_trace_step(
|
|
| 74 |
):
|
| 75 |
"""Update a specific step in a trace (e.g., update step evaluation)"""
|
| 76 |
try:
|
| 77 |
-
agent_service.update_trace_step(
|
| 78 |
trace_id=trace_id,
|
| 79 |
step_id=step_id,
|
| 80 |
step_evaluation=request.step_evaluation,
|
|
@@ -99,7 +99,7 @@ async def update_trace_evaluation(
|
|
| 99 |
):
|
| 100 |
"""Update the user evaluation for a trace (overall task feedback)"""
|
| 101 |
try:
|
| 102 |
-
agent_service.update_trace_evaluation(
|
| 103 |
trace_id=trace_id,
|
| 104 |
user_evaluation=request.user_evaluation,
|
| 105 |
)
|
|
|
|
| 74 |
):
|
| 75 |
"""Update a specific step in a trace (e.g., update step evaluation)"""
|
| 76 |
try:
|
| 77 |
+
await agent_service.update_trace_step(
|
| 78 |
trace_id=trace_id,
|
| 79 |
step_id=step_id,
|
| 80 |
step_evaluation=request.step_evaluation,
|
|
|
|
| 99 |
):
|
| 100 |
"""Update the user evaluation for a trace (overall task feedback)"""
|
| 101 |
try:
|
| 102 |
+
await agent_service.update_trace_evaluation(
|
| 103 |
trace_id=trace_id,
|
| 104 |
user_evaluation=request.user_evaluation,
|
| 105 |
)
|
cua2-core/src/cua2_core/services/agent_service.py
CHANGED
|
@@ -104,9 +104,13 @@ class AgentService:
|
|
| 104 |
"""
|
| 105 |
Update the archival service with current active task IDs.
|
| 106 |
Should be called whenever tasks are added or removed.
|
|
|
|
|
|
|
| 107 |
"""
|
| 108 |
if self.archival_service.is_alive():
|
| 109 |
-
|
|
|
|
|
|
|
| 110 |
|
| 111 |
async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
|
| 112 |
"""Create a new ID and sandbox"""
|
|
@@ -174,8 +178,8 @@ class AgentService:
|
|
| 174 |
self.active_tasks[trace_id] = active_task
|
| 175 |
self.last_screenshot[trace_id] = None
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 181 |
|
|
@@ -351,13 +355,13 @@ class AgentService:
|
|
| 351 |
|
| 352 |
novnc_active = False
|
| 353 |
|
| 354 |
-
self.active_tasks[message_id].update_trace_metadata(
|
| 355 |
final_state=final_state,
|
| 356 |
completed=True,
|
| 357 |
)
|
| 358 |
|
| 359 |
if message_id in self.active_tasks:
|
| 360 |
-
self.active_tasks[message_id].
|
| 361 |
|
| 362 |
# Clean up
|
| 363 |
async with self._lock:
|
|
@@ -370,8 +374,8 @@ class AgentService:
|
|
| 370 |
if message_id in self.last_screenshot:
|
| 371 |
del self.last_screenshot[message_id]
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
|
| 376 |
# Always release sandbox back to the pool, even if it's still in "creating" state
|
| 377 |
# This handles cases where acquire_sandbox was called but sandbox never became ready
|
|
@@ -469,14 +473,23 @@ class AgentService:
|
|
| 469 |
step_evaluation="neutral",
|
| 470 |
)
|
| 471 |
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
)
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
| 481 |
websocket = self.task_websockets.get(message_id)
|
| 482 |
if websocket and websocket.client_state == WebSocketState.CONNECTED:
|
|
@@ -529,7 +542,7 @@ class AgentService:
|
|
| 529 |
# Re-raise to ensure error is logged
|
| 530 |
raise
|
| 531 |
|
| 532 |
-
def update_trace_step(
|
| 533 |
self,
|
| 534 |
trace_id: str,
|
| 535 |
step_id: str,
|
|
@@ -559,7 +572,8 @@ class AgentService:
|
|
| 559 |
step_index = int(step_id) - 1
|
| 560 |
if 0 <= step_index < len(active_task.steps):
|
| 561 |
active_task.steps[step_index].step_evaluation = step_evaluation
|
| 562 |
-
active_task.update_step(active_task.steps[step_index])
|
|
|
|
| 563 |
else:
|
| 564 |
raise ValueError(f"Step {step_id} not found in trace")
|
| 565 |
except (ValueError, TypeError) as e:
|
|
@@ -602,7 +616,7 @@ class AgentService:
|
|
| 602 |
except (ValueError, KeyError, TypeError) as e:
|
| 603 |
raise ValueError(f"Error processing step update: {e}")
|
| 604 |
|
| 605 |
-
def update_trace_evaluation(
|
| 606 |
self,
|
| 607 |
trace_id: str,
|
| 608 |
user_evaluation: Literal["success", "failed", "not_evaluated"],
|
|
@@ -622,7 +636,7 @@ class AgentService:
|
|
| 622 |
|
| 623 |
if active_task:
|
| 624 |
# Task is still active
|
| 625 |
-
active_task.update_trace_metadata(user_evaluation=user_evaluation)
|
| 626 |
else:
|
| 627 |
# Task is not active, try to load from file
|
| 628 |
data_dir = "data"
|
|
@@ -657,7 +671,7 @@ class AgentService:
|
|
| 657 |
async def stop_task(self, trace_id: str):
|
| 658 |
"""Stop a task"""
|
| 659 |
if trace_id in self.active_tasks:
|
| 660 |
-
self.active_tasks[trace_id].update_trace_metadata(
|
| 661 |
completed=True,
|
| 662 |
)
|
| 663 |
|
|
@@ -687,7 +701,7 @@ class AgentService:
|
|
| 687 |
try:
|
| 688 |
# Mark task as completed to stop the agent (if task exists)
|
| 689 |
if message_id in self.active_tasks:
|
| 690 |
-
self.active_tasks[message_id].update_trace_metadata(
|
| 691 |
completed=True,
|
| 692 |
)
|
| 693 |
logger.info(
|
|
|
|
| 104 |
"""
|
| 105 |
Update the archival service with current active task IDs.
|
| 106 |
Should be called whenever tasks are added or removed.
|
| 107 |
+
Note: This should be called while holding self._lock to ensure consistent snapshot.
|
| 108 |
+
The archival service update itself is fast and non-blocking.
|
| 109 |
"""
|
| 110 |
if self.archival_service.is_alive():
|
| 111 |
+
# Create a snapshot of active task IDs (should be called with lock held)
|
| 112 |
+
active_task_ids = set(self.active_tasks.keys())
|
| 113 |
+
self.archival_service.update_active_tasks(active_task_ids)
|
| 114 |
|
| 115 |
async def create_id_and_sandbox(self, websocket: WebSocket) -> str:
|
| 116 |
"""Create a new ID and sandbox"""
|
|
|
|
| 178 |
self.active_tasks[trace_id] = active_task
|
| 179 |
self.last_screenshot[trace_id] = None
|
| 180 |
|
| 181 |
+
# Update archival service with new active task (while holding lock)
|
| 182 |
+
self._update_archival_active_tasks()
|
| 183 |
|
| 184 |
asyncio.create_task(self._agent_processing(trace_id))
|
| 185 |
|
|
|
|
| 355 |
|
| 356 |
novnc_active = False
|
| 357 |
|
| 358 |
+
await self.active_tasks[message_id].update_trace_metadata(
|
| 359 |
final_state=final_state,
|
| 360 |
completed=True,
|
| 361 |
)
|
| 362 |
|
| 363 |
if message_id in self.active_tasks:
|
| 364 |
+
await self.active_tasks[message_id].save_to_file()
|
| 365 |
|
| 366 |
# Clean up
|
| 367 |
async with self._lock:
|
|
|
|
| 374 |
if message_id in self.last_screenshot:
|
| 375 |
del self.last_screenshot[message_id]
|
| 376 |
|
| 377 |
+
# Update archival service after task removal (while holding lock)
|
| 378 |
+
self._update_archival_active_tasks()
|
| 379 |
|
| 380 |
# Always release sandbox back to the pool, even if it's still in "creating" state
|
| 381 |
# This handles cases where acquire_sandbox was called but sandbox never became ready
|
|
|
|
| 473 |
step_evaluation="neutral",
|
| 474 |
)
|
| 475 |
|
| 476 |
+
# Schedule async operations in the event loop (callback runs in worker thread)
|
| 477 |
+
future1 = asyncio.run_coroutine_threadsafe(
|
| 478 |
+
self.active_tasks[message_id].update_trace_metadata(
|
| 479 |
+
step_input_tokens_used=memory_step.token_usage.input_tokens,
|
| 480 |
+
step_output_tokens_used=memory_step.token_usage.output_tokens,
|
| 481 |
+
step_duration=memory_step.timing.duration,
|
| 482 |
+
step_numberOfSteps=1,
|
| 483 |
+
),
|
| 484 |
+
loop,
|
| 485 |
)
|
| 486 |
+
future2 = asyncio.run_coroutine_threadsafe(
|
| 487 |
+
self.active_tasks[message_id].update_step(step),
|
| 488 |
+
loop,
|
| 489 |
+
)
|
| 490 |
+
# Wait for both to complete
|
| 491 |
+
future1.result()
|
| 492 |
+
future2.result()
|
| 493 |
|
| 494 |
websocket = self.task_websockets.get(message_id)
|
| 495 |
if websocket and websocket.client_state == WebSocketState.CONNECTED:
|
|
|
|
| 542 |
# Re-raise to ensure error is logged
|
| 543 |
raise
|
| 544 |
|
| 545 |
+
async def update_trace_step(
|
| 546 |
self,
|
| 547 |
trace_id: str,
|
| 548 |
step_id: str,
|
|
|
|
| 572 |
step_index = int(step_id) - 1
|
| 573 |
if 0 <= step_index < len(active_task.steps):
|
| 574 |
active_task.steps[step_index].step_evaluation = step_evaluation
|
| 575 |
+
await active_task.update_step(active_task.steps[step_index])
|
| 576 |
+
return active_task.steps[step_index]
|
| 577 |
else:
|
| 578 |
raise ValueError(f"Step {step_id} not found in trace")
|
| 579 |
except (ValueError, TypeError) as e:
|
|
|
|
| 616 |
except (ValueError, KeyError, TypeError) as e:
|
| 617 |
raise ValueError(f"Error processing step update: {e}")
|
| 618 |
|
| 619 |
+
async def update_trace_evaluation(
|
| 620 |
self,
|
| 621 |
trace_id: str,
|
| 622 |
user_evaluation: Literal["success", "failed", "not_evaluated"],
|
|
|
|
| 636 |
|
| 637 |
if active_task:
|
| 638 |
# Task is still active
|
| 639 |
+
await active_task.update_trace_metadata(user_evaluation=user_evaluation)
|
| 640 |
else:
|
| 641 |
# Task is not active, try to load from file
|
| 642 |
data_dir = "data"
|
|
|
|
| 671 |
async def stop_task(self, trace_id: str):
|
| 672 |
"""Stop a task"""
|
| 673 |
if trace_id in self.active_tasks:
|
| 674 |
+
await self.active_tasks[trace_id].update_trace_metadata(
|
| 675 |
completed=True,
|
| 676 |
)
|
| 677 |
|
|
|
|
| 701 |
try:
|
| 702 |
# Mark task as completed to stop the agent (if task exists)
|
| 703 |
if message_id in self.active_tasks:
|
| 704 |
+
await self.active_tasks[message_id].update_trace_metadata(
|
| 705 |
completed=True,
|
| 706 |
)
|
| 707 |
logger.info(
|
cua2-core/src/cua2_core/services/sandbox_service.py
CHANGED
|
@@ -158,10 +158,14 @@ class SandboxService:
|
|
| 158 |
asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
|
| 159 |
return
|
| 160 |
|
| 161 |
-
# Check capacity before adding
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
| 163 |
print(
|
| 164 |
-
f"Pool at capacity
|
|
|
|
| 165 |
)
|
| 166 |
asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
|
| 167 |
return
|
|
|
|
| 158 |
asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
|
| 159 |
return
|
| 160 |
|
| 161 |
+
# Check total capacity before adding (sandboxes + other pending creations)
|
| 162 |
+
# Note: We already removed this session_hash from pending, so we check
|
| 163 |
+
# if adding it to sandboxes would exceed capacity
|
| 164 |
+
total_count = len(self.sandboxes) + len(self.pending)
|
| 165 |
+
if total_count >= self.max_sandboxes:
|
| 166 |
print(
|
| 167 |
+
f"Pool at capacity ({total_count}/{self.max_sandboxes}), "
|
| 168 |
+
f"killing newly created sandbox for {session_hash}"
|
| 169 |
)
|
| 170 |
asyncio.create_task(self._kill_sandbox_safe(desktop, session_hash))
|
| 171 |
return
|
cua2-core/tests/test_routes.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from unittest.mock import Mock
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
|
|
@@ -15,7 +15,9 @@ def mock_agent_service():
|
|
| 15 |
"""Fixture to create a mocked AgentService"""
|
| 16 |
service = Mock(spec=AgentService)
|
| 17 |
service.active_tasks = {}
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
return service
|
| 20 |
|
| 21 |
|
|
@@ -112,8 +114,8 @@ class TestUpdateTraceStep:
|
|
| 112 |
step_id = "1"
|
| 113 |
request_data = {"step_evaluation": "like"}
|
| 114 |
|
| 115 |
-
# Mock the service method to succeed
|
| 116 |
-
|
| 117 |
|
| 118 |
response = client.patch(
|
| 119 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
|
@@ -136,8 +138,6 @@ class TestUpdateTraceStep:
|
|
| 136 |
step_id = "2"
|
| 137 |
request_data = {"step_evaluation": "dislike"}
|
| 138 |
|
| 139 |
-
mock_agent_service.update_trace_step.return_value = None
|
| 140 |
-
|
| 141 |
response = client.patch(
|
| 142 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 143 |
)
|
|
@@ -154,8 +154,6 @@ class TestUpdateTraceStep:
|
|
| 154 |
step_id = "3"
|
| 155 |
request_data = {"step_evaluation": "neutral"}
|
| 156 |
|
| 157 |
-
mock_agent_service.update_trace_step.return_value = None
|
| 158 |
-
|
| 159 |
response = client.patch(
|
| 160 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 161 |
)
|
|
@@ -186,8 +184,8 @@ class TestUpdateTraceStep:
|
|
| 186 |
request_data = {"step_evaluation": "like"}
|
| 187 |
|
| 188 |
# Mock the service to raise ValueError
|
| 189 |
-
mock_agent_service.update_trace_step
|
| 190 |
-
"Invalid step_id format"
|
| 191 |
)
|
| 192 |
|
| 193 |
response = client.patch(
|
|
@@ -204,8 +202,8 @@ class TestUpdateTraceStep:
|
|
| 204 |
request_data = {"step_evaluation": "like"}
|
| 205 |
|
| 206 |
# Mock the service to raise FileNotFoundError
|
| 207 |
-
mock_agent_service.update_trace_step
|
| 208 |
-
"Trace not found"
|
| 209 |
)
|
| 210 |
|
| 211 |
response = client.patch(
|
|
@@ -222,8 +220,8 @@ class TestUpdateTraceStep:
|
|
| 222 |
request_data = {"step_evaluation": "like"}
|
| 223 |
|
| 224 |
# Mock the service to raise ValueError for step not found
|
| 225 |
-
mock_agent_service.update_trace_step
|
| 226 |
-
"Step 999 not found in trace"
|
| 227 |
)
|
| 228 |
|
| 229 |
response = client.patch(
|
|
@@ -251,8 +249,6 @@ class TestUpdateTraceStep:
|
|
| 251 |
step_id = "1"
|
| 252 |
request_data = {"step_evaluation": "like"}
|
| 253 |
|
| 254 |
-
mock_agent_service.update_trace_step.return_value = None
|
| 255 |
-
|
| 256 |
response = client.patch(
|
| 257 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 258 |
)
|
|
@@ -269,8 +265,6 @@ class TestUpdateTraceStep:
|
|
| 269 |
step_id = "1"
|
| 270 |
request_data = {"step_evaluation": "like"}
|
| 271 |
|
| 272 |
-
mock_agent_service.update_trace_step.return_value = None
|
| 273 |
-
|
| 274 |
response = client.patch(
|
| 275 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 276 |
)
|
|
@@ -294,8 +288,7 @@ class TestRoutesIntegration:
|
|
| 294 |
|
| 295 |
def test_update_step_endpoint_available(self, client, mock_agent_service):
|
| 296 |
"""Test that update step endpoint is available"""
|
| 297 |
-
|
| 298 |
-
|
| 299 |
response = client.patch(
|
| 300 |
"/traces/test/steps/1", json={"step_evaluation": "like"}
|
| 301 |
)
|
|
|
|
| 1 |
+
from unittest.mock import AsyncMock, Mock
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
from cua2_core.models.models import AvailableModelsResponse, UpdateStepResponse
|
|
|
|
| 15 |
"""Fixture to create a mocked AgentService"""
|
| 16 |
service = Mock(spec=AgentService)
|
| 17 |
service.active_tasks = {}
|
| 18 |
+
# update_trace_step is now async, so use AsyncMock
|
| 19 |
+
service.update_trace_step = AsyncMock(return_value=None)
|
| 20 |
+
service.update_trace_evaluation = AsyncMock(return_value=None)
|
| 21 |
return service
|
| 22 |
|
| 23 |
|
|
|
|
| 114 |
step_id = "1"
|
| 115 |
request_data = {"step_evaluation": "like"}
|
| 116 |
|
| 117 |
+
# Mock the service method to succeed (already set up as AsyncMock in fixture)
|
| 118 |
+
pass
|
| 119 |
|
| 120 |
response = client.patch(
|
| 121 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
|
|
|
| 138 |
step_id = "2"
|
| 139 |
request_data = {"step_evaluation": "dislike"}
|
| 140 |
|
|
|
|
|
|
|
| 141 |
response = client.patch(
|
| 142 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 143 |
)
|
|
|
|
| 154 |
step_id = "3"
|
| 155 |
request_data = {"step_evaluation": "neutral"}
|
| 156 |
|
|
|
|
|
|
|
| 157 |
response = client.patch(
|
| 158 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 159 |
)
|
|
|
|
| 184 |
request_data = {"step_evaluation": "like"}
|
| 185 |
|
| 186 |
# Mock the service to raise ValueError
|
| 187 |
+
mock_agent_service.update_trace_step = AsyncMock(
|
| 188 |
+
side_effect=ValueError("Invalid step_id format")
|
| 189 |
)
|
| 190 |
|
| 191 |
response = client.patch(
|
|
|
|
| 202 |
request_data = {"step_evaluation": "like"}
|
| 203 |
|
| 204 |
# Mock the service to raise FileNotFoundError
|
| 205 |
+
mock_agent_service.update_trace_step = AsyncMock(
|
| 206 |
+
side_effect=FileNotFoundError("Trace not found")
|
| 207 |
)
|
| 208 |
|
| 209 |
response = client.patch(
|
|
|
|
| 220 |
request_data = {"step_evaluation": "like"}
|
| 221 |
|
| 222 |
# Mock the service to raise ValueError for step not found
|
| 223 |
+
mock_agent_service.update_trace_step = AsyncMock(
|
| 224 |
+
side_effect=ValueError("Step 999 not found in trace")
|
| 225 |
)
|
| 226 |
|
| 227 |
response = client.patch(
|
|
|
|
| 249 |
step_id = "1"
|
| 250 |
request_data = {"step_evaluation": "like"}
|
| 251 |
|
|
|
|
|
|
|
| 252 |
response = client.patch(
|
| 253 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 254 |
)
|
|
|
|
| 265 |
step_id = "1"
|
| 266 |
request_data = {"step_evaluation": "like"}
|
| 267 |
|
|
|
|
|
|
|
| 268 |
response = client.patch(
|
| 269 |
f"/traces/{trace_id}/steps/{step_id}", json=request_data
|
| 270 |
)
|
|
|
|
| 288 |
|
| 289 |
def test_update_step_endpoint_available(self, client, mock_agent_service):
|
| 290 |
"""Test that update step endpoint is available"""
|
| 291 |
+
# Mock is already set up as AsyncMock in fixture
|
|
|
|
| 292 |
response = client.patch(
|
| 293 |
"/traces/test/steps/1", json={"step_evaluation": "like"}
|
| 294 |
)
|