akseljoonas HF Staff commited on
Commit
085cd02
·
1 Parent(s): 7ead77c

fix: correct context_length init and emit tool_call events for malformed calls

Browse files

- Initialize context_length to max_context (not token-estimate of system prompt)
and reserve 10k token buffer to prevent overflows
- Emit synthetic tool_call events before tool_output errors for malformed calls
so the frontend renders matching dynamic-tool parts

agent/context_manager/manager.py CHANGED
@@ -85,9 +85,9 @@ class ContextManager:
85
  prompt_file_suffix="system_prompt_v3.yaml",
86
  hf_token=hf_token,
87
  )
88
- self.max_context = max_context
89
  self.compact_size = int(max_context * compact_size)
90
- self.context_length = len(self.system_prompt) // 4
91
  self.untouched_messages = untouched_messages
92
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
93
 
@@ -160,8 +160,7 @@ class ContextManager:
160
  if not needs_fix:
161
  return
162
  msg.tool_calls = [
163
- tc if not isinstance(tc, dict) else ToolCall(**tc)
164
- for tc in tool_calls
165
  ]
166
 
167
  def recover_malformed_tool_calls(self) -> set[str]:
@@ -214,7 +213,9 @@ class ContextManager:
214
  except (json.JSONDecodeError, TypeError, ValueError) as e:
215
  logger.warning(
216
  "Malformed arguments for tool_call %s (%s): %s",
217
- tc.id, tc.function.name, e,
 
 
218
  )
219
  tc.function.arguments = "{}"
220
  malformed_ids.add(tc.id)
@@ -268,7 +269,9 @@ class ContextManager:
268
  assistant_msg = None
269
  for i in range(len(self.items) - 1, -1, -1):
270
  msg = self.items[i]
271
- if getattr(msg, "role", None) == "assistant" and getattr(msg, "tool_calls", None):
 
 
272
  assistant_msg = msg
273
  break
274
  # Stop scanning once we hit a user message — anything before
 
85
  prompt_file_suffix="system_prompt_v3.yaml",
86
  hf_token=hf_token,
87
  )
88
+ self.max_context = max_context - 10000
89
  self.compact_size = int(max_context * compact_size)
90
+ self.context_length = max_context
91
  self.untouched_messages = untouched_messages
92
  self.items: list[Message] = [Message(role="system", content=self.system_prompt)]
93
 
 
160
  if not needs_fix:
161
  return
162
  msg.tool_calls = [
163
+ tc if not isinstance(tc, dict) else ToolCall(**tc) for tc in tool_calls
 
164
  ]
165
 
166
  def recover_malformed_tool_calls(self) -> set[str]:
 
213
  except (json.JSONDecodeError, TypeError, ValueError) as e:
214
  logger.warning(
215
  "Malformed arguments for tool_call %s (%s): %s",
216
+ tc.id,
217
+ tc.function.name,
218
+ e,
219
  )
220
  tc.function.arguments = "{}"
221
  malformed_ids.add(tc.id)
 
269
  assistant_msg = None
270
  for i in range(len(self.items) - 1, -1, -1):
271
  msg = self.items[i]
272
+ if getattr(msg, "role", None) == "assistant" and getattr(
273
+ msg, "tool_calls", None
274
+ ):
275
  assistant_msg = msg
276
  break
277
  # Stop scanning once we hit a user message — anything before
agent/core/agent_loop.py CHANGED
@@ -37,7 +37,9 @@ def _resolve_hf_router_params(model_name: str) -> dict:
37
  if not model_name.startswith("huggingface/"):
38
  return {"model": model_name}
39
 
40
- parts = model_name.split("/", 2) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5']
 
 
41
  if len(parts) < 3:
42
  return {"model": model_name}
43
 
@@ -162,8 +164,6 @@ async def _compact_and_notify(session: Session) -> None:
162
  )
163
 
164
 
165
-
166
-
167
  class Handlers:
168
  """Handler functions for each operation type"""
169
 
@@ -178,7 +178,9 @@ class Handlers:
178
  tool_calls = session.pending_approval.get("tool_calls", [])
179
  for tc in tool_calls:
180
  tool_name = tc.function.name
181
- abandon_msg = "Task abandoned — user continued the conversation without approving."
 
 
182
 
183
  # Keep LLM context valid: every tool_call needs a tool result
184
  tool_msg = Message(
@@ -364,21 +366,40 @@ class Handlers:
364
  # Recover any malformed tool calls (sanitize JSON + inject
365
  # error results). Returns IDs to skip during execution.
366
  malformed_ids = session.context_manager.recover_malformed_tool_calls()
367
- for mid in malformed_ids:
368
- await session.send_event(
369
- Event(
370
- event_type="tool_output",
371
- data={
372
- "tool": next(
373
- (tc.function.name for tc in tool_calls if tc.id == mid),
374
- "unknown",
375
- ),
376
- "tool_call_id": mid,
377
- "output": "Malformed tool call — see error in context.",
378
- "success": False,
379
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  )
381
- )
382
 
383
  # Separate tools into those requiring approval and those that don't
384
  approval_required_tools = []
@@ -491,10 +512,15 @@ class Handlers:
491
 
492
  # Resolve sandbox file paths for hf_jobs scripts so the
493
  # frontend can display & edit the actual file content.
494
- if tool_name == "hf_jobs" and isinstance(tool_args.get("script"), str):
 
 
495
  from agent.tools.sandbox_tool import resolve_sandbox_script
 
496
  sandbox = getattr(session, "sandbox", None)
497
- content, _ = await resolve_sandbox_script(sandbox, tool_args["script"])
 
 
498
  if content:
499
  tool_args = {**tool_args, "script": content}
500
 
@@ -596,7 +622,9 @@ class Handlers:
596
  approval_map = {a["tool_call_id"]: a for a in approvals}
597
  for a in approvals:
598
  if a.get("edited_script"):
599
- logger.info(f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)")
 
 
600
 
601
  # Separate approved and rejected tool calls
602
  approved_tasks = []
@@ -742,7 +770,9 @@ class Handlers:
742
  # Ensure feedback is a string and sanitize any problematic characters
743
  feedback_str = str(user_feedback).strip()
744
  # Remove any control characters that might break JSON parsing
745
- feedback_str = "".join(char for char in feedback_str if ord(char) >= 32 or char in "\n\t")
 
 
746
  rejection_msg += f". User feedback: {feedback_str}"
747
 
748
  # Ensure rejection_msg is a clean string
@@ -837,7 +867,9 @@ async def submission_loop(
837
  """
838
 
839
  # Create session with tool router
840
- session = Session(event_queue, config=config, tool_router=tool_router, hf_token=hf_token)
 
 
841
  if session_holder is not None:
842
  session_holder[0] = session
843
  logger.info("Agent loop started")
 
37
  if not model_name.startswith("huggingface/"):
38
  return {"model": model_name}
39
 
40
+ parts = model_name.split(
41
+ "/", 2
42
+ ) # ['huggingface', 'novita', 'moonshotai/kimi-k2.5']
43
  if len(parts) < 3:
44
  return {"model": model_name}
45
 
 
164
  )
165
 
166
 
 
 
167
  class Handlers:
168
  """Handler functions for each operation type"""
169
 
 
178
  tool_calls = session.pending_approval.get("tool_calls", [])
179
  for tc in tool_calls:
180
  tool_name = tc.function.name
181
+ abandon_msg = (
182
+ "Task abandoned — user continued the conversation without approving."
183
+ )
184
 
185
  # Keep LLM context valid: every tool_call needs a tool result
186
  tool_msg = Message(
 
366
  # Recover any malformed tool calls (sanitize JSON + inject
367
  # error results). Returns IDs to skip during execution.
368
  malformed_ids = session.context_manager.recover_malformed_tool_calls()
369
+ if malformed_ids:
370
+ # For each malformed tool_call, emit a synthetic tool_call +
371
+ # tool_output-error pair so the frontend has a matching
372
+ # dynamic-tool part instead of an orphan error.
373
+ for tc in tool_calls:
374
+ if tc.id not in malformed_ids:
375
+ continue
376
+ tool_name = tc.function.name
377
+ try:
378
+ tool_args = json.loads(tc.function.arguments)
379
+ except (json.JSONDecodeError, TypeError, ValueError):
380
+ tool_args = {}
381
+
382
+ await session.send_event(
383
+ Event(
384
+ event_type="tool_call",
385
+ data={
386
+ "tool": tool_name,
387
+ "arguments": tool_args,
388
+ "tool_call_id": tc.id,
389
+ },
390
+ )
391
+ )
392
+ await session.send_event(
393
+ Event(
394
+ event_type="tool_output",
395
+ data={
396
+ "tool": tool_name,
397
+ "tool_call_id": tc.id,
398
+ "output": "Malformed tool call — see error in context.",
399
+ "success": False,
400
+ },
401
+ )
402
  )
 
403
 
404
  # Separate tools into those requiring approval and those that don't
405
  approval_required_tools = []
 
512
 
513
  # Resolve sandbox file paths for hf_jobs scripts so the
514
  # frontend can display & edit the actual file content.
515
+ if tool_name == "hf_jobs" and isinstance(
516
+ tool_args.get("script"), str
517
+ ):
518
  from agent.tools.sandbox_tool import resolve_sandbox_script
519
+
520
  sandbox = getattr(session, "sandbox", None)
521
+ content, _ = await resolve_sandbox_script(
522
+ sandbox, tool_args["script"]
523
+ )
524
  if content:
525
  tool_args = {**tool_args, "script": content}
526
 
 
622
  approval_map = {a["tool_call_id"]: a for a in approvals}
623
  for a in approvals:
624
  if a.get("edited_script"):
625
+ logger.info(
626
+ f"Received edited script for tool_call {a['tool_call_id']} ({len(a['edited_script'])} chars)"
627
+ )
628
 
629
  # Separate approved and rejected tool calls
630
  approved_tasks = []
 
770
  # Ensure feedback is a string and sanitize any problematic characters
771
  feedback_str = str(user_feedback).strip()
772
  # Remove any control characters that might break JSON parsing
773
+ feedback_str = "".join(
774
+ char for char in feedback_str if ord(char) >= 32 or char in "\n\t"
775
+ )
776
  rejection_msg += f". User feedback: {feedback_str}"
777
 
778
  # Ensure rejection_msg is a clean string
 
867
  """
868
 
869
  # Create session with tool router
870
+ session = Session(
871
+ event_queue, config=config, tool_router=tool_router, hf_token=hf_token
872
+ )
873
  if session_holder is not None:
874
  session_holder[0] = session
875
  logger.info("Agent loop started")