akseljoonas HF Staff Claude Opus 4.6 commited on
Commit
960792d
·
1 Parent(s): 7edb225

refactor: simplify tool call validation and make interrupts cancel tool execution

Browse files

- Handle finish_reason=length by dropping truncated tool calls before
they enter context
- Single json.loads per tool call instead of 4 redundant parses
- Parsed args flow through to approval/execution without re-parsing
- Replace recover_malformed_tool_calls (90 lines) with inline validation
in agent loop — bad calls get error results, good calls execute
- Make tool execution cancellable: asyncio.wait races gather against
cancel event so Ctrl+C/frontend interrupt stops immediately instead
of waiting for all tools to finish
- Keep _patch_dangling_tool_calls as the only safety net in
get_messages() for any remaining edge cases

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

agent/context_manager/manager.py CHANGED
@@ -133,13 +133,10 @@ class ContextManager:
133
  def get_messages(self) -> list[Message]:
134
  """Get all messages for sending to LLM.
135
 
136
- Automatically recovers malformed tool_call arguments and patches
137
- any dangling tool_calls (assistant messages with tool_calls that
138
- have no matching tool-result message). Both can happen after
139
- errors or cancellations and would cause the LLM API to reject the
140
- request.
141
  """
142
- self.recover_malformed_tool_calls()
143
  self._patch_dangling_tool_calls()
144
  return self.items
145
 
@@ -163,99 +160,6 @@ class ContextManager:
163
  tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
164
  ]
165
 
166
- def recover_malformed_tool_calls(self) -> set[str]:
167
- """Sanitize malformed tool_call arguments and inject error results.
168
-
169
- Handles two classes of corruption:
170
- - **Empty/missing IDs**: Stripped from the assistant message entirely
171
- (common when streaming is interrupted mid-tool-call).
172
- - **Malformed JSON arguments**: Replaced with ``"{}"`` and an error
173
- tool-result is injected asking the agent to retry.
174
-
175
- This method is idempotent — safe to call from both the agent loop
176
- (before tool execution) and from :meth:`get_messages` (safety net).
177
-
178
- Returns:
179
- Set of tool_call IDs that had malformed arguments.
180
- """
181
- import json
182
-
183
- malformed_ids: set[str] = set()
184
-
185
- for msg in self.items:
186
- if getattr(msg, "role", None) != "assistant":
187
- continue
188
- tool_calls = getattr(msg, "tool_calls", None)
189
- if not tool_calls:
190
- continue
191
- self._normalize_tool_calls(msg)
192
-
193
- # 1. Strip tool_calls with empty/missing IDs (cannot be repaired)
194
- valid_tcs = []
195
- for tc in msg.tool_calls:
196
- if not getattr(tc, "id", None):
197
- logger.warning(
198
- "Stripping tool_call with empty ID (name=%s) — likely interrupted stream",
199
- getattr(tc.function, "name", "?"),
200
- )
201
- continue
202
- valid_tcs.append(tc)
203
- if len(valid_tcs) != len(msg.tool_calls):
204
- msg.tool_calls = valid_tcs or None
205
-
206
- if not msg.tool_calls:
207
- continue
208
-
209
- # 2. Fix malformed JSON arguments
210
- for tc in msg.tool_calls:
211
- try:
212
- json.loads(tc.function.arguments)
213
- except (json.JSONDecodeError, TypeError, ValueError) as e:
214
- logger.warning(
215
- "Malformed arguments for tool_call %s (%s): %s",
216
- tc.id,
217
- tc.function.name,
218
- e,
219
- )
220
- tc.function.arguments = "{}"
221
- malformed_ids.add(tc.id)
222
-
223
- if not malformed_ids:
224
- return malformed_ids
225
-
226
- # 3. Inject error results for malformed calls that don't have one yet
227
- answered_ids = {
228
- getattr(m, "tool_call_id", None)
229
- for m in self.items
230
- if getattr(m, "role", None) == "tool"
231
- }
232
- for msg in self.items:
233
- if getattr(msg, "role", None) != "assistant":
234
- continue
235
- tool_calls = getattr(msg, "tool_calls", None)
236
- if not tool_calls:
237
- continue
238
- for tc in msg.tool_calls:
239
- if tc.id in malformed_ids and tc.id not in answered_ids:
240
- self.items.append(
241
- Message(
242
- role="tool",
243
- content=(
244
- f"ERROR: Your tool call to '{tc.function.name}' had malformed "
245
- f"JSON arguments and was NOT executed. This usually happens "
246
- f"when the content is too large and gets truncated. "
247
- f"Please retry with smaller content — for 'write', split the "
248
- f"file into multiple smaller writes using 'edit' to build up "
249
- f"the file incrementally."
250
- ),
251
- tool_call_id=tc.id,
252
- name=tc.function.name,
253
- )
254
- )
255
- answered_ids.add(tc.id)
256
-
257
- return malformed_ids
258
-
259
  def _patch_dangling_tool_calls(self) -> None:
260
  """Add stub tool results for any tool_calls that lack a matching result.
261
 
 
133
  def get_messages(self) -> list[Message]:
134
  """Get all messages for sending to LLM.
135
 
136
+ Patches any dangling tool_calls (assistant messages with tool_calls
137
+ that have no matching tool-result message) so the LLM API doesn't
138
+ reject the request.
 
 
139
  """
 
140
  self._patch_dangling_tool_calls()
141
  return self.items
142
 
 
160
  tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
161
  ]
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def _patch_dangling_tool_calls(self) -> None:
164
  """Add stub tool results for any tool_calls that lack a matching result.
165
 
agent/core/agent_loop.py CHANGED
@@ -261,6 +261,7 @@ class Handlers:
261
  full_content = ""
262
  tool_calls_acc: dict[int, dict] = {}
263
  token_count = 0
 
264
 
265
  async for chunk in response:
266
  # ── Check cancellation during streaming ──
@@ -276,6 +277,8 @@ class Handlers:
276
  continue
277
 
278
  delta = choice.delta
 
 
279
 
280
  # Stream text deltas to the frontend
281
  if delta.content:
@@ -316,17 +319,15 @@ class Handlers:
316
  # ── Stream finished — reconstruct full message ───────
317
  content = full_content or None
318
 
319
- # Build tool_calls list from accumulated deltas,
320
- # dropping any with empty IDs (from interrupted streams)
 
 
 
 
321
  tool_calls: list[ToolCall] = []
322
  for idx in sorted(tool_calls_acc.keys()):
323
  tc_data = tool_calls_acc[idx]
324
- if not tc_data["id"]:
325
- logger.warning(
326
- "Dropping tool_call with empty ID (name=%s) — likely interrupted stream",
327
- tc_data["function"]["name"],
328
- )
329
- continue
330
  tool_calls.append(
331
  ToolCall(
332
  id=tc_data["id"],
@@ -351,7 +352,23 @@ class Handlers:
351
  final_response = content
352
  break
353
 
354
- # Add assistant message with tool calls to history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
  assistant_msg = Message(
356
  role="assistant",
357
  content=content,
@@ -359,79 +376,49 @@ class Handlers:
359
  )
360
  session.context_manager.add_message(assistant_msg, token_count)
361
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  # ── Cancellation check: before tool execution ──
363
  if session.is_cancelled:
364
  break
365
 
366
- # Recover any malformed tool calls (sanitize JSON + inject
367
- # error results). Returns IDs to skip during execution.
368
- malformed_ids = session.context_manager.recover_malformed_tool_calls()
369
- if malformed_ids:
370
- # For each malformed tool_call, emit a synthetic tool_call +
371
- # tool_output-error pair so the frontend has a matching
372
- # dynamic-tool part instead of an orphan error.
373
- for tc in tool_calls:
374
- if tc.id not in malformed_ids:
375
- continue
376
- tool_name = tc.function.name
377
- try:
378
- tool_args = json.loads(tc.function.arguments)
379
- except (json.JSONDecodeError, TypeError, ValueError):
380
- tool_args = {}
381
-
382
- await session.send_event(
383
- Event(
384
- event_type="tool_call",
385
- data={
386
- "tool": tool_name,
387
- "arguments": tool_args,
388
- "tool_call_id": tc.id,
389
- },
390
- )
391
- )
392
- await session.send_event(
393
- Event(
394
- event_type="tool_output",
395
- data={
396
- "tool": tool_name,
397
- "tool_call_id": tc.id,
398
- "output": "Malformed tool call — see error in context.",
399
- "success": False,
400
- },
401
- )
402
- )
403
-
404
- # Separate tools into those requiring approval and those that don't
405
- approval_required_tools = []
406
- non_approval_tools = []
407
-
408
- for tc in tool_calls:
409
- if tc.id in malformed_ids:
410
- continue
411
- tool_name = tc.function.name
412
- try:
413
- tool_args = json.loads(tc.function.arguments)
414
- except (json.JSONDecodeError, TypeError) as e:
415
- logger.warning(f"Malformed tool arguments for {tool_name}: {e}")
416
- tool_args = {}
417
-
418
  if _needs_approval(tool_name, tool_args, session.config):
419
- approval_required_tools.append(tc)
420
  else:
421
- non_approval_tools.append(tc)
 
422
  # Execute non-approval tools (in parallel when possible)
423
  if non_approval_tools:
424
- # 1. Parse args and validate upfront
425
  parsed_tools: list[
426
- tuple[ChatCompletionMessageToolCall, str, dict, bool, str]
427
  ] = []
428
- for tc in non_approval_tools:
429
- tool_name = tc.function.name
430
- try:
431
- tool_args = json.loads(tc.function.arguments)
432
- except (json.JSONDecodeError, TypeError):
433
- tool_args = {}
434
-
435
  args_valid, error_msg = _validate_tool_args(tool_args)
436
  parsed_tools.append(
437
  (tc, tool_name, tool_args, args_valid, error_msg)
@@ -451,14 +438,14 @@ class Handlers:
451
  )
452
  )
453
 
454
- # 3. Execute all valid tools in parallel
455
  async def _exec_tool(
456
- tc: ChatCompletionMessageToolCall,
457
  name: str,
458
  args: dict,
459
  valid: bool,
460
  err: str,
461
- ) -> tuple[ChatCompletionMessageToolCall, str, dict, str, bool]:
462
  if not valid:
463
  return (tc, name, args, err, False)
464
  out, ok = await session.tool_router.call_tool(
@@ -466,13 +453,30 @@ class Handlers:
466
  )
467
  return (tc, name, args, out, ok)
468
 
469
- results = await asyncio.gather(
470
  *[
471
  _exec_tool(tc, name, args, valid, err)
472
  for tc, name, args, valid, err in parsed_tools
473
  ]
 
 
 
 
 
 
474
  )
475
 
 
 
 
 
 
 
 
 
 
 
 
476
  # 4. Record results and send outputs (order preserved)
477
  for tc, tool_name, tool_args, output, success in results:
478
  tool_msg = Message(
@@ -495,56 +499,34 @@ class Handlers:
495
  )
496
  )
497
 
498
- # ── Cancellation check: after tool execution ──
499
- if session.is_cancelled:
500
- break
501
-
502
  # If there are tools requiring approval, ask for batch approval
503
  if approval_required_tools:
504
  # Prepare batch approval data
505
  tools_data = []
506
- for tc in approval_required_tools:
507
- tool_name = tc.function.name
508
- try:
509
- tool_args = json.loads(tc.function.arguments)
510
- except (json.JSONDecodeError, TypeError):
511
- tool_args = {}
512
-
513
  # Resolve sandbox file paths for hf_jobs scripts so the
514
  # frontend can display & edit the actual file content.
515
- if tool_name == "hf_jobs" and isinstance(
516
- tool_args.get("script"), str
517
- ):
518
  from agent.tools.sandbox_tool import resolve_sandbox_script
519
-
520
  sandbox = getattr(session, "sandbox", None)
521
- content, _ = await resolve_sandbox_script(
522
- sandbox, tool_args["script"]
523
- )
524
- if content:
525
- tool_args = {**tool_args, "script": content}
526
-
527
- tools_data.append(
528
- {
529
- "tool": tool_name,
530
- "arguments": tool_args,
531
- "tool_call_id": tc.id,
532
- }
533
- )
534
 
535
- await session.send_event(
536
- Event(
537
- event_type="approval_required",
538
- data={
539
- "tools": tools_data, # Batch of tools
540
- "count": len(tools_data),
541
- },
542
- )
543
- )
 
544
 
545
- # Store all approval-requiring tools
546
  session.pending_approval = {
547
- "tool_calls": approval_required_tools,
548
  }
549
 
550
  # Return early - wait for EXEC_APPROVAL operation
 
261
  full_content = ""
262
  tool_calls_acc: dict[int, dict] = {}
263
  token_count = 0
264
+ finish_reason = None
265
 
266
  async for chunk in response:
267
  # ── Check cancellation during streaming ──
 
277
  continue
278
 
279
  delta = choice.delta
280
+ if choice.finish_reason:
281
+ finish_reason = choice.finish_reason
282
 
283
  # Stream text deltas to the frontend
284
  if delta.content:
 
319
  # ── Stream finished — reconstruct full message ───────
320
  content = full_content or None
321
 
322
+ # If output was truncated, all tool call args are garbage
323
+ if finish_reason == "length" and tool_calls_acc:
324
+ logger.warning("Output truncated (finish_reason=length) — dropping tool calls")
325
+ tool_calls_acc.clear()
326
+
327
+ # Build tool_calls list from accumulated deltas
328
  tool_calls: list[ToolCall] = []
329
  for idx in sorted(tool_calls_acc.keys()):
330
  tc_data = tool_calls_acc[idx]
 
 
 
 
 
 
331
  tool_calls.append(
332
  ToolCall(
333
  id=tc_data["id"],
 
352
  final_response = content
353
  break
354
 
355
+ # Validate tool call args (one json.loads per call, once)
356
+ # and split into good vs bad
357
+ good_tools: list[tuple[ToolCall, str, dict]] = []
358
+ bad_tools: list[ToolCall] = []
359
+ for tc in tool_calls:
360
+ try:
361
+ args = json.loads(tc.function.arguments)
362
+ good_tools.append((tc, tc.function.name, args))
363
+ except (json.JSONDecodeError, TypeError, ValueError):
364
+ logger.warning(
365
+ "Malformed arguments for tool_call %s (%s) — skipping",
366
+ tc.id, tc.function.name,
367
+ )
368
+ tc.function.arguments = "{}"
369
+ bad_tools.append(tc)
370
+
371
+ # Add assistant message with all tool calls to context
372
  assistant_msg = Message(
373
  role="assistant",
374
  content=content,
 
376
  )
377
  session.context_manager.add_message(assistant_msg, token_count)
378
 
379
+ # Add error results for bad tool calls so the LLM
380
+ # knows what happened and can retry differently
381
+ for tc in bad_tools:
382
+ error_msg = (
383
+ f"ERROR: Tool call to '{tc.function.name}' had malformed JSON "
384
+ f"arguments and was NOT executed. Retry with smaller content — "
385
+ f"for 'write', split into multiple smaller writes using 'edit'."
386
+ )
387
+ session.context_manager.add_message(Message(
388
+ role="tool",
389
+ content=error_msg,
390
+ tool_call_id=tc.id,
391
+ name=tc.function.name,
392
+ ))
393
+ await session.send_event(Event(
394
+ event_type="tool_call",
395
+ data={"tool": tc.function.name, "arguments": {}, "tool_call_id": tc.id},
396
+ ))
397
+ await session.send_event(Event(
398
+ event_type="tool_output",
399
+ data={"tool": tc.function.name, "tool_call_id": tc.id, "output": error_msg, "success": False},
400
+ ))
401
+
402
  # ── Cancellation check: before tool execution ──
403
  if session.is_cancelled:
404
  break
405
 
406
+ # Separate good tools into approval-required vs auto-execute
407
+ approval_required_tools: list[tuple[ToolCall, str, dict]] = []
408
+ non_approval_tools: list[tuple[ToolCall, str, dict]] = []
409
+ for tc, tool_name, tool_args in good_tools:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  if _needs_approval(tool_name, tool_args, session.config):
411
+ approval_required_tools.append((tc, tool_name, tool_args))
412
  else:
413
+ non_approval_tools.append((tc, tool_name, tool_args))
414
+
415
  # Execute non-approval tools (in parallel when possible)
416
  if non_approval_tools:
417
+ # 1. Validate args upfront
418
  parsed_tools: list[
419
+ tuple[ToolCall, str, dict, bool, str]
420
  ] = []
421
+ for tc, tool_name, tool_args in non_approval_tools:
 
 
 
 
 
 
422
  args_valid, error_msg = _validate_tool_args(tool_args)
423
  parsed_tools.append(
424
  (tc, tool_name, tool_args, args_valid, error_msg)
 
438
  )
439
  )
440
 
441
+ # 3. Execute all valid tools in parallel, cancellable
442
  async def _exec_tool(
443
+ tc: ToolCall,
444
  name: str,
445
  args: dict,
446
  valid: bool,
447
  err: str,
448
+ ) -> tuple[ToolCall, str, dict, str, bool]:
449
  if not valid:
450
  return (tc, name, args, err, False)
451
  out, ok = await session.tool_router.call_tool(
 
453
  )
454
  return (tc, name, args, out, ok)
455
 
456
+ gather_task = asyncio.ensure_future(asyncio.gather(
457
  *[
458
  _exec_tool(tc, name, args, valid, err)
459
  for tc, name, args, valid, err in parsed_tools
460
  ]
461
+ ))
462
+ cancel_task = asyncio.ensure_future(session._cancelled.wait())
463
+
464
+ done, _ = await asyncio.wait(
465
+ [gather_task, cancel_task],
466
+ return_when=asyncio.FIRST_COMPLETED,
467
  )
468
 
469
+ if cancel_task in done:
470
+ gather_task.cancel()
471
+ try:
472
+ await gather_task
473
+ except asyncio.CancelledError:
474
+ pass
475
+ break
476
+
477
+ cancel_task.cancel()
478
+ results = gather_task.result()
479
+
480
  # 4. Record results and send outputs (order preserved)
481
  for tc, tool_name, tool_args, output, success in results:
482
  tool_msg = Message(
 
499
  )
500
  )
501
 
 
 
 
 
502
  # If there are tools requiring approval, ask for batch approval
503
  if approval_required_tools:
504
  # Prepare batch approval data
505
  tools_data = []
506
+ for tc, tool_name, tool_args in approval_required_tools:
 
 
 
 
 
 
507
  # Resolve sandbox file paths for hf_jobs scripts so the
508
  # frontend can display & edit the actual file content.
509
+ if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
 
 
510
  from agent.tools.sandbox_tool import resolve_sandbox_script
 
511
  sandbox = getattr(session, "sandbox", None)
512
+ resolved, _ = await resolve_sandbox_script(sandbox, tool_args["script"])
513
+ if resolved:
514
+ tool_args = {**tool_args, "script": resolved}
 
 
 
 
 
 
 
 
 
 
515
 
516
+ tools_data.append({
517
+ "tool": tool_name,
518
+ "arguments": tool_args,
519
+ "tool_call_id": tc.id,
520
+ })
521
+
522
+ await session.send_event(Event(
523
+ event_type="approval_required",
524
+ data={"tools": tools_data, "count": len(tools_data)},
525
+ ))
526
 
527
+ # Store all approval-requiring tools (ToolCall objects for execution)
528
  session.pending_approval = {
529
+ "tool_calls": [tc for tc, _, _ in approval_required_tools],
530
  }
531
 
532
  # Return early - wait for EXEC_APPROVAL operation