Spaces:
Running
Running
Commit ·
960792d
1
Parent(s): 7edb225
refactor: simplify tool call validation and make interrupts cancel tool execution
Browse files- Handle finish_reason=length by dropping truncated tool calls before
they enter context
- Single json.loads per tool call instead of 4 redundant parses
- Parsed args flow through to approval/execution without re-parsing
- Replace recover_malformed_tool_calls (90 lines) with inline validation
in agent loop — bad calls get error results, good calls execute
- Make tool execution cancellable: asyncio.wait races gather against
cancel event so Ctrl+C/frontend interrupt stops immediately instead
of waiting for all tools to finish
- Keep _patch_dangling_tool_calls as the only safety net in
get_messages() for any remaining edge cases
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- agent/context_manager/manager.py +3 -99
- agent/core/agent_loop.py +97 -115
agent/context_manager/manager.py
CHANGED
|
@@ -133,13 +133,10 @@ class ContextManager:
|
|
| 133 |
def get_messages(self) -> list[Message]:
|
| 134 |
"""Get all messages for sending to LLM.
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
errors or cancellations and would cause the LLM API to reject the
|
| 140 |
-
request.
|
| 141 |
"""
|
| 142 |
-
self.recover_malformed_tool_calls()
|
| 143 |
self._patch_dangling_tool_calls()
|
| 144 |
return self.items
|
| 145 |
|
|
@@ -163,99 +160,6 @@ class ContextManager:
|
|
| 163 |
tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
|
| 164 |
]
|
| 165 |
|
| 166 |
-
def recover_malformed_tool_calls(self) -> set[str]:
|
| 167 |
-
"""Sanitize malformed tool_call arguments and inject error results.
|
| 168 |
-
|
| 169 |
-
Handles two classes of corruption:
|
| 170 |
-
- **Empty/missing IDs**: Stripped from the assistant message entirely
|
| 171 |
-
(common when streaming is interrupted mid-tool-call).
|
| 172 |
-
- **Malformed JSON arguments**: Replaced with ``"{}"`` and an error
|
| 173 |
-
tool-result is injected asking the agent to retry.
|
| 174 |
-
|
| 175 |
-
This method is idempotent — safe to call from both the agent loop
|
| 176 |
-
(before tool execution) and from :meth:`get_messages` (safety net).
|
| 177 |
-
|
| 178 |
-
Returns:
|
| 179 |
-
Set of tool_call IDs that had malformed arguments.
|
| 180 |
-
"""
|
| 181 |
-
import json
|
| 182 |
-
|
| 183 |
-
malformed_ids: set[str] = set()
|
| 184 |
-
|
| 185 |
-
for msg in self.items:
|
| 186 |
-
if getattr(msg, "role", None) != "assistant":
|
| 187 |
-
continue
|
| 188 |
-
tool_calls = getattr(msg, "tool_calls", None)
|
| 189 |
-
if not tool_calls:
|
| 190 |
-
continue
|
| 191 |
-
self._normalize_tool_calls(msg)
|
| 192 |
-
|
| 193 |
-
# 1. Strip tool_calls with empty/missing IDs (cannot be repaired)
|
| 194 |
-
valid_tcs = []
|
| 195 |
-
for tc in msg.tool_calls:
|
| 196 |
-
if not getattr(tc, "id", None):
|
| 197 |
-
logger.warning(
|
| 198 |
-
"Stripping tool_call with empty ID (name=%s) — likely interrupted stream",
|
| 199 |
-
getattr(tc.function, "name", "?"),
|
| 200 |
-
)
|
| 201 |
-
continue
|
| 202 |
-
valid_tcs.append(tc)
|
| 203 |
-
if len(valid_tcs) != len(msg.tool_calls):
|
| 204 |
-
msg.tool_calls = valid_tcs or None
|
| 205 |
-
|
| 206 |
-
if not msg.tool_calls:
|
| 207 |
-
continue
|
| 208 |
-
|
| 209 |
-
# 2. Fix malformed JSON arguments
|
| 210 |
-
for tc in msg.tool_calls:
|
| 211 |
-
try:
|
| 212 |
-
json.loads(tc.function.arguments)
|
| 213 |
-
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
| 214 |
-
logger.warning(
|
| 215 |
-
"Malformed arguments for tool_call %s (%s): %s",
|
| 216 |
-
tc.id,
|
| 217 |
-
tc.function.name,
|
| 218 |
-
e,
|
| 219 |
-
)
|
| 220 |
-
tc.function.arguments = "{}"
|
| 221 |
-
malformed_ids.add(tc.id)
|
| 222 |
-
|
| 223 |
-
if not malformed_ids:
|
| 224 |
-
return malformed_ids
|
| 225 |
-
|
| 226 |
-
# 3. Inject error results for malformed calls that don't have one yet
|
| 227 |
-
answered_ids = {
|
| 228 |
-
getattr(m, "tool_call_id", None)
|
| 229 |
-
for m in self.items
|
| 230 |
-
if getattr(m, "role", None) == "tool"
|
| 231 |
-
}
|
| 232 |
-
for msg in self.items:
|
| 233 |
-
if getattr(msg, "role", None) != "assistant":
|
| 234 |
-
continue
|
| 235 |
-
tool_calls = getattr(msg, "tool_calls", None)
|
| 236 |
-
if not tool_calls:
|
| 237 |
-
continue
|
| 238 |
-
for tc in msg.tool_calls:
|
| 239 |
-
if tc.id in malformed_ids and tc.id not in answered_ids:
|
| 240 |
-
self.items.append(
|
| 241 |
-
Message(
|
| 242 |
-
role="tool",
|
| 243 |
-
content=(
|
| 244 |
-
f"ERROR: Your tool call to '{tc.function.name}' had malformed "
|
| 245 |
-
f"JSON arguments and was NOT executed. This usually happens "
|
| 246 |
-
f"when the content is too large and gets truncated. "
|
| 247 |
-
f"Please retry with smaller content — for 'write', split the "
|
| 248 |
-
f"file into multiple smaller writes using 'edit' to build up "
|
| 249 |
-
f"the file incrementally."
|
| 250 |
-
),
|
| 251 |
-
tool_call_id=tc.id,
|
| 252 |
-
name=tc.function.name,
|
| 253 |
-
)
|
| 254 |
-
)
|
| 255 |
-
answered_ids.add(tc.id)
|
| 256 |
-
|
| 257 |
-
return malformed_ids
|
| 258 |
-
|
| 259 |
def _patch_dangling_tool_calls(self) -> None:
|
| 260 |
"""Add stub tool results for any tool_calls that lack a matching result.
|
| 261 |
|
|
|
|
| 133 |
def get_messages(self) -> list[Message]:
|
| 134 |
"""Get all messages for sending to LLM.
|
| 135 |
|
| 136 |
+
Patches any dangling tool_calls (assistant messages with tool_calls
|
| 137 |
+
that have no matching tool-result message) so the LLM API doesn't
|
| 138 |
+
reject the request.
|
|
|
|
|
|
|
| 139 |
"""
|
|
|
|
| 140 |
self._patch_dangling_tool_calls()
|
| 141 |
return self.items
|
| 142 |
|
|
|
|
| 160 |
tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
|
| 161 |
]
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def _patch_dangling_tool_calls(self) -> None:
|
| 164 |
"""Add stub tool results for any tool_calls that lack a matching result.
|
| 165 |
|
agent/core/agent_loop.py
CHANGED
|
@@ -261,6 +261,7 @@ class Handlers:
|
|
| 261 |
full_content = ""
|
| 262 |
tool_calls_acc: dict[int, dict] = {}
|
| 263 |
token_count = 0
|
|
|
|
| 264 |
|
| 265 |
async for chunk in response:
|
| 266 |
# ── Check cancellation during streaming ──
|
|
@@ -276,6 +277,8 @@ class Handlers:
|
|
| 276 |
continue
|
| 277 |
|
| 278 |
delta = choice.delta
|
|
|
|
|
|
|
| 279 |
|
| 280 |
# Stream text deltas to the frontend
|
| 281 |
if delta.content:
|
|
@@ -316,17 +319,15 @@ class Handlers:
|
|
| 316 |
# ── Stream finished — reconstruct full message ───────
|
| 317 |
content = full_content or None
|
| 318 |
|
| 319 |
-
#
|
| 320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
tool_calls: list[ToolCall] = []
|
| 322 |
for idx in sorted(tool_calls_acc.keys()):
|
| 323 |
tc_data = tool_calls_acc[idx]
|
| 324 |
-
if not tc_data["id"]:
|
| 325 |
-
logger.warning(
|
| 326 |
-
"Dropping tool_call with empty ID (name=%s) — likely interrupted stream",
|
| 327 |
-
tc_data["function"]["name"],
|
| 328 |
-
)
|
| 329 |
-
continue
|
| 330 |
tool_calls.append(
|
| 331 |
ToolCall(
|
| 332 |
id=tc_data["id"],
|
|
@@ -351,7 +352,23 @@ class Handlers:
|
|
| 351 |
final_response = content
|
| 352 |
break
|
| 353 |
|
| 354 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
assistant_msg = Message(
|
| 356 |
role="assistant",
|
| 357 |
content=content,
|
|
@@ -359,79 +376,49 @@ class Handlers:
|
|
| 359 |
)
|
| 360 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 361 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
# ── Cancellation check: before tool execution ──
|
| 363 |
if session.is_cancelled:
|
| 364 |
break
|
| 365 |
|
| 366 |
-
#
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
# For each malformed tool_call, emit a synthetic tool_call +
|
| 371 |
-
# tool_output-error pair so the frontend has a matching
|
| 372 |
-
# dynamic-tool part instead of an orphan error.
|
| 373 |
-
for tc in tool_calls:
|
| 374 |
-
if tc.id not in malformed_ids:
|
| 375 |
-
continue
|
| 376 |
-
tool_name = tc.function.name
|
| 377 |
-
try:
|
| 378 |
-
tool_args = json.loads(tc.function.arguments)
|
| 379 |
-
except (json.JSONDecodeError, TypeError, ValueError):
|
| 380 |
-
tool_args = {}
|
| 381 |
-
|
| 382 |
-
await session.send_event(
|
| 383 |
-
Event(
|
| 384 |
-
event_type="tool_call",
|
| 385 |
-
data={
|
| 386 |
-
"tool": tool_name,
|
| 387 |
-
"arguments": tool_args,
|
| 388 |
-
"tool_call_id": tc.id,
|
| 389 |
-
},
|
| 390 |
-
)
|
| 391 |
-
)
|
| 392 |
-
await session.send_event(
|
| 393 |
-
Event(
|
| 394 |
-
event_type="tool_output",
|
| 395 |
-
data={
|
| 396 |
-
"tool": tool_name,
|
| 397 |
-
"tool_call_id": tc.id,
|
| 398 |
-
"output": "Malformed tool call — see error in context.",
|
| 399 |
-
"success": False,
|
| 400 |
-
},
|
| 401 |
-
)
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
# Separate tools into those requiring approval and those that don't
|
| 405 |
-
approval_required_tools = []
|
| 406 |
-
non_approval_tools = []
|
| 407 |
-
|
| 408 |
-
for tc in tool_calls:
|
| 409 |
-
if tc.id in malformed_ids:
|
| 410 |
-
continue
|
| 411 |
-
tool_name = tc.function.name
|
| 412 |
-
try:
|
| 413 |
-
tool_args = json.loads(tc.function.arguments)
|
| 414 |
-
except (json.JSONDecodeError, TypeError) as e:
|
| 415 |
-
logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
|
| 416 |
-
tool_args = {}
|
| 417 |
-
|
| 418 |
if _needs_approval(tool_name, tool_args, session.config):
|
| 419 |
-
approval_required_tools.append(tc)
|
| 420 |
else:
|
| 421 |
-
non_approval_tools.append(tc)
|
|
|
|
| 422 |
# Execute non-approval tools (in parallel when possible)
|
| 423 |
if non_approval_tools:
|
| 424 |
-
# 1.
|
| 425 |
parsed_tools: list[
|
| 426 |
-
tuple[
|
| 427 |
] = []
|
| 428 |
-
for tc in non_approval_tools:
|
| 429 |
-
tool_name = tc.function.name
|
| 430 |
-
try:
|
| 431 |
-
tool_args = json.loads(tc.function.arguments)
|
| 432 |
-
except (json.JSONDecodeError, TypeError):
|
| 433 |
-
tool_args = {}
|
| 434 |
-
|
| 435 |
args_valid, error_msg = _validate_tool_args(tool_args)
|
| 436 |
parsed_tools.append(
|
| 437 |
(tc, tool_name, tool_args, args_valid, error_msg)
|
|
@@ -451,14 +438,14 @@ class Handlers:
|
|
| 451 |
)
|
| 452 |
)
|
| 453 |
|
| 454 |
-
# 3. Execute all valid tools in parallel
|
| 455 |
async def _exec_tool(
|
| 456 |
-
tc:
|
| 457 |
name: str,
|
| 458 |
args: dict,
|
| 459 |
valid: bool,
|
| 460 |
err: str,
|
| 461 |
-
) -> tuple[
|
| 462 |
if not valid:
|
| 463 |
return (tc, name, args, err, False)
|
| 464 |
out, ok = await session.tool_router.call_tool(
|
|
@@ -466,13 +453,30 @@ class Handlers:
|
|
| 466 |
)
|
| 467 |
return (tc, name, args, out, ok)
|
| 468 |
|
| 469 |
-
|
| 470 |
*[
|
| 471 |
_exec_tool(tc, name, args, valid, err)
|
| 472 |
for tc, name, args, valid, err in parsed_tools
|
| 473 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
)
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
# 4. Record results and send outputs (order preserved)
|
| 477 |
for tc, tool_name, tool_args, output, success in results:
|
| 478 |
tool_msg = Message(
|
|
@@ -495,56 +499,34 @@ class Handlers:
|
|
| 495 |
)
|
| 496 |
)
|
| 497 |
|
| 498 |
-
# ── Cancellation check: after tool execution ──
|
| 499 |
-
if session.is_cancelled:
|
| 500 |
-
break
|
| 501 |
-
|
| 502 |
# If there are tools requiring approval, ask for batch approval
|
| 503 |
if approval_required_tools:
|
| 504 |
# Prepare batch approval data
|
| 505 |
tools_data = []
|
| 506 |
-
for tc in approval_required_tools:
|
| 507 |
-
tool_name = tc.function.name
|
| 508 |
-
try:
|
| 509 |
-
tool_args = json.loads(tc.function.arguments)
|
| 510 |
-
except (json.JSONDecodeError, TypeError):
|
| 511 |
-
tool_args = {}
|
| 512 |
-
|
| 513 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 514 |
# frontend can display & edit the actual file content.
|
| 515 |
-
if tool_name == "hf_jobs" and isinstance(
|
| 516 |
-
tool_args.get("script"), str
|
| 517 |
-
):
|
| 518 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
| 519 |
-
|
| 520 |
sandbox = getattr(session, "sandbox", None)
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
if content:
|
| 525 |
-
tool_args = {**tool_args, "script": content}
|
| 526 |
-
|
| 527 |
-
tools_data.append(
|
| 528 |
-
{
|
| 529 |
-
"tool": tool_name,
|
| 530 |
-
"arguments": tool_args,
|
| 531 |
-
"tool_call_id": tc.id,
|
| 532 |
-
}
|
| 533 |
-
)
|
| 534 |
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
|
|
|
| 544 |
|
| 545 |
-
# Store all approval-requiring tools
|
| 546 |
session.pending_approval = {
|
| 547 |
-
"tool_calls": approval_required_tools,
|
| 548 |
}
|
| 549 |
|
| 550 |
# Return early - wait for EXEC_APPROVAL operation
|
|
|
|
| 261 |
full_content = ""
|
| 262 |
tool_calls_acc: dict[int, dict] = {}
|
| 263 |
token_count = 0
|
| 264 |
+
finish_reason = None
|
| 265 |
|
| 266 |
async for chunk in response:
|
| 267 |
# ── Check cancellation during streaming ──
|
|
|
|
| 277 |
continue
|
| 278 |
|
| 279 |
delta = choice.delta
|
| 280 |
+
if choice.finish_reason:
|
| 281 |
+
finish_reason = choice.finish_reason
|
| 282 |
|
| 283 |
# Stream text deltas to the frontend
|
| 284 |
if delta.content:
|
|
|
|
| 319 |
# ── Stream finished — reconstruct full message ───────
|
| 320 |
content = full_content or None
|
| 321 |
|
| 322 |
+
# If output was truncated, all tool call args are garbage
|
| 323 |
+
if finish_reason == "length" and tool_calls_acc:
|
| 324 |
+
logger.warning("Output truncated (finish_reason=length) — dropping tool calls")
|
| 325 |
+
tool_calls_acc.clear()
|
| 326 |
+
|
| 327 |
+
# Build tool_calls list from accumulated deltas
|
| 328 |
tool_calls: list[ToolCall] = []
|
| 329 |
for idx in sorted(tool_calls_acc.keys()):
|
| 330 |
tc_data = tool_calls_acc[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
tool_calls.append(
|
| 332 |
ToolCall(
|
| 333 |
id=tc_data["id"],
|
|
|
|
| 352 |
final_response = content
|
| 353 |
break
|
| 354 |
|
| 355 |
+
# Validate tool call args (one json.loads per call, once)
|
| 356 |
+
# and split into good vs bad
|
| 357 |
+
good_tools: list[tuple[ToolCall, str, dict]] = []
|
| 358 |
+
bad_tools: list[ToolCall] = []
|
| 359 |
+
for tc in tool_calls:
|
| 360 |
+
try:
|
| 361 |
+
args = json.loads(tc.function.arguments)
|
| 362 |
+
good_tools.append((tc, tc.function.name, args))
|
| 363 |
+
except (json.JSONDecodeError, TypeError, ValueError):
|
| 364 |
+
logger.warning(
|
| 365 |
+
"Malformed arguments for tool_call %s (%s) — skipping",
|
| 366 |
+
tc.id, tc.function.name,
|
| 367 |
+
)
|
| 368 |
+
tc.function.arguments = "{}"
|
| 369 |
+
bad_tools.append(tc)
|
| 370 |
+
|
| 371 |
+
# Add assistant message with all tool calls to context
|
| 372 |
assistant_msg = Message(
|
| 373 |
role="assistant",
|
| 374 |
content=content,
|
|
|
|
| 376 |
)
|
| 377 |
session.context_manager.add_message(assistant_msg, token_count)
|
| 378 |
|
| 379 |
+
# Add error results for bad tool calls so the LLM
|
| 380 |
+
# knows what happened and can retry differently
|
| 381 |
+
for tc in bad_tools:
|
| 382 |
+
error_msg = (
|
| 383 |
+
f"ERROR: Tool call to '{tc.function.name}' had malformed JSON "
|
| 384 |
+
f"arguments and was NOT executed. Retry with smaller content — "
|
| 385 |
+
f"for 'write', split into multiple smaller writes using 'edit'."
|
| 386 |
+
)
|
| 387 |
+
session.context_manager.add_message(Message(
|
| 388 |
+
role="tool",
|
| 389 |
+
content=error_msg,
|
| 390 |
+
tool_call_id=tc.id,
|
| 391 |
+
name=tc.function.name,
|
| 392 |
+
))
|
| 393 |
+
await session.send_event(Event(
|
| 394 |
+
event_type="tool_call",
|
| 395 |
+
data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id},
|
| 396 |
+
))
|
| 397 |
+
await session.send_event(Event(
|
| 398 |
+
event_type="tool_output",
|
| 399 |
+
data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False},
|
| 400 |
+
))
|
| 401 |
+
|
| 402 |
# ── Cancellation check: before tool execution ──
|
| 403 |
if session.is_cancelled:
|
| 404 |
break
|
| 405 |
|
| 406 |
+
# Separate good tools into approval-required vs auto-execute
|
| 407 |
+
approval_required_tools: list[tuple[ToolCall, str, dict]] = []
|
| 408 |
+
non_approval_tools: list[tuple[ToolCall, str, dict]] = []
|
| 409 |
+
for tc, tool_name, tool_args in good_tools:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
if _needs_approval(tool_name, tool_args, session.config):
|
| 411 |
+
approval_required_tools.append((tc, tool_name, tool_args))
|
| 412 |
else:
|
| 413 |
+
non_approval_tools.append((tc, tool_name, tool_args))
|
| 414 |
+
|
| 415 |
# Execute non-approval tools (in parallel when possible)
|
| 416 |
if non_approval_tools:
|
| 417 |
+
# 1. Validate args upfront
|
| 418 |
parsed_tools: list[
|
| 419 |
+
tuple[ToolCall, str, dict, bool, str]
|
| 420 |
] = []
|
| 421 |
+
for tc, tool_name, tool_args in non_approval_tools:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
args_valid, error_msg = _validate_tool_args(tool_args)
|
| 423 |
parsed_tools.append(
|
| 424 |
(tc, tool_name, tool_args, args_valid, error_msg)
|
|
|
|
| 438 |
)
|
| 439 |
)
|
| 440 |
|
| 441 |
+
# 3. Execute all valid tools in parallel, cancellable
|
| 442 |
async def _exec_tool(
|
| 443 |
+
tc: ToolCall,
|
| 444 |
name: str,
|
| 445 |
args: dict,
|
| 446 |
valid: bool,
|
| 447 |
err: str,
|
| 448 |
+
) -> tuple[ToolCall, str, dict, str, bool]:
|
| 449 |
if not valid:
|
| 450 |
return (tc, name, args, err, False)
|
| 451 |
out, ok = await session.tool_router.call_tool(
|
|
|
|
| 453 |
)
|
| 454 |
return (tc, name, args, out, ok)
|
| 455 |
|
| 456 |
+
gather_task = asyncio.ensure_future(asyncio.gather(
|
| 457 |
*[
|
| 458 |
_exec_tool(tc, name, args, valid, err)
|
| 459 |
for tc, name, args, valid, err in parsed_tools
|
| 460 |
]
|
| 461 |
+
))
|
| 462 |
+
cancel_task = asyncio.ensure_future(session._cancelled.wait())
|
| 463 |
+
|
| 464 |
+
done, _ = await asyncio.wait(
|
| 465 |
+
[gather_task, cancel_task],
|
| 466 |
+
return_when=asyncio.FIRST_COMPLETED,
|
| 467 |
)
|
| 468 |
|
| 469 |
+
if cancel_task in done:
|
| 470 |
+
gather_task.cancel()
|
| 471 |
+
try:
|
| 472 |
+
await gather_task
|
| 473 |
+
except asyncio.CancelledError:
|
| 474 |
+
pass
|
| 475 |
+
break
|
| 476 |
+
|
| 477 |
+
cancel_task.cancel()
|
| 478 |
+
results = gather_task.result()
|
| 479 |
+
|
| 480 |
# 4. Record results and send outputs (order preserved)
|
| 481 |
for tc, tool_name, tool_args, output, success in results:
|
| 482 |
tool_msg = Message(
|
|
|
|
| 499 |
)
|
| 500 |
)
|
| 501 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
# If there are tools requiring approval, ask for batch approval
|
| 503 |
if approval_required_tools:
|
| 504 |
# Prepare batch approval data
|
| 505 |
tools_data = []
|
| 506 |
+
for tc, tool_name, tool_args in approval_required_tools:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
# Resolve sandbox file paths for hf_jobs scripts so the
|
| 508 |
# frontend can display & edit the actual file content.
|
| 509 |
+
if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
|
|
|
|
|
|
|
| 510 |
from agent.tools.sandbox_tool import resolve_sandbox_script
|
|
|
|
| 511 |
sandbox = getattr(session, "sandbox", None)
|
| 512 |
+
resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"])
|
| 513 |
+
if resolved:
|
| 514 |
+
tool_args = {**tool_args, "script": resolved}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
+
tools_data.append({
|
| 517 |
+
"tool": tool_name,
|
| 518 |
+
"arguments": tool_args,
|
| 519 |
+
"tool_call_id": tc.id,
|
| 520 |
+
})
|
| 521 |
+
|
| 522 |
+
await session.send_event(Event(
|
| 523 |
+
event_type="approval_required",
|
| 524 |
+
data={"tools": tools_data, "count": len(tools_data)},
|
| 525 |
+
))
|
| 526 |
|
| 527 |
+
# Store all approval-requiring tools (ToolCall objects for execution)
|
| 528 |
session.pending_approval = {
|
| 529 |
+
"tool_calls": [tc for tc, _, _ in approval_required_tools],
|
| 530 |
}
|
| 531 |
|
| 532 |
# Return early - wait for EXEC_APPROVAL operation
|