black-yt commited on
Commit
b4eb71d
·
1 Parent(s): a47195c

Sync tool execution semantics

Browse files
VERSION CHANGED
@@ -1 +1 @@
1
- v0.0.35
 
1
+ v0.0.36
agent_base/prompts/system_base.md CHANGED
@@ -92,6 +92,7 @@ You are a capable all-purpose AI assistant. You do far more than simple question
92
  - ask the human user for essential missing information -> `AskUser`
93
  - persistent interactive shell state -> `Terminal*`
94
  - Search results and scholar results are discovery aids. They are not page-verification evidence by themselves.
 
95
  - Prefer `Bash` over `Terminal*` unless persistent interactive shell state is genuinely required.
96
 
97
  ## Human Clarification Workflow
 
92
  - ask the human user for essential missing information -> `AskUser`
93
  - persistent interactive shell state -> `Terminal*`
94
  - Search results and scholar results are discovery aids. They are not page-verification evidence by themselves.
95
+ - Each tool call should express one clear request. For independent read-only work, such as multiple searches, multiple page fetches, or multiple file reads, issue multiple tool calls in the same assistant turn rather than packing several requests into one tool argument.
96
  - Prefer `Bash` over `Terminal*` unless persistent interactive shell state is genuinely required.
97
 
98
  ## Human Clarification Workflow
agent_base/react_agent.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  from contextlib import contextmanager
3
  import json
4
  import os
@@ -79,6 +80,18 @@ DEFAULT_TEMPERATURE = 0.6
79
  DEFAULT_TOP_P = 0.95
80
  DEFAULT_PRESENCE_PENALTY = 1.1
81
  DEFAULT_LLM_TIMEOUT_SECONDS = 600.0
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
 
84
  def default_model_name() -> str:
@@ -613,6 +626,26 @@ def execute_tool_by_name(tool_map: dict[str, Any], tool_name: str, tool_args: An
613
  return tool.call(tool_args, **kwargs)
614
 
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  class MultiTurnReactAgent(BaseAgent):
617
  def __init__(
618
  self,
@@ -1216,38 +1249,64 @@ class MultiTurnReactAgent(BaseAgent):
1216
  tool_turn_message_start = len(messages)
1217
  messages.append(assistant_message)
1218
  deferred_image_contexts: list[tuple[str, str, Any, Any, dict[str, Any]]] = []
 
1219
  for tool_call, tool_arguments in zip(assistant_tool_calls, assistant_tool_arguments):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
  if remaining_runtime_seconds(runtime_deadline) is not None and remaining_runtime_seconds(runtime_deadline) <= 0:
1221
  result_text = "No result found before the maximum agent runtime limit."
1222
  termination = f"agent runtime limit reached: {agent_runtime_limit}s"
1223
  return finalize(result_text, termination, error=termination)
1224
- tool_call_id = str(tool_call.get("id", ""))
1225
- function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {}
1226
- tool_name = str(function_block.get("name", ""))
1227
  try:
1228
- result = self.custom_call_tool(
1229
- tool_name,
1230
- tool_arguments,
1231
- workspace_root=resolved_workspace_root,
1232
- runtime_deadline=runtime_deadline,
1233
- model_name=self.model,
1234
  )
 
 
 
 
 
 
1235
  except KeyboardInterrupt:
1236
  messages = messages[:tool_turn_message_start]
1237
  return finalize_interrupted()
1238
- tool_result_text = tool_result_message_content(result)
1239
- messages.append(api_tool_message(tool_call_id, result))
1240
- trace_writer.append(
1241
- role="tool",
1242
- text=tool_result_text,
1243
- turn_index=round_index,
1244
- tool_call_ids=[tool_call_id],
1245
- tool_names=[tool_name],
1246
- tool_arguments=[tool_arguments],
1247
- )
1248
- extra_image_context = image_context_message(result, self.model)
1249
- if extra_image_context is not None:
1250
- deferred_image_contexts.append((tool_call_id, tool_name, tool_arguments, result, extra_image_context))
 
 
 
 
 
1251
  for tool_call_id, tool_name, tool_arguments, result, extra_image_context in deferred_image_contexts:
1252
  messages.append(extra_image_context)
1253
  trace_writer.append(
 
1
  import argparse
2
+ from concurrent.futures import ThreadPoolExecutor
3
  from contextlib import contextmanager
4
  import json
5
  import os
 
80
  DEFAULT_TOP_P = 0.95
81
  DEFAULT_PRESENCE_PENALTY = 1.1
82
  DEFAULT_LLM_TIMEOUT_SECONDS = 600.0
83
+ MAX_PARALLEL_READ_TOOL_CALLS = 3
84
+ PARALLEL_READ_TOOL_NAMES = frozenset(
85
+ {
86
+ "Glob",
87
+ "Grep",
88
+ "Read",
89
+ "ReadImage",
90
+ "WebSearch",
91
+ "ScholarSearch",
92
+ "WebFetch",
93
+ }
94
+ )
95
 
96
 
97
  def default_model_name() -> str:
 
626
  return tool.call(tool_args, **kwargs)
627
 
628
 
629
+ def can_parallelize_tool_name(tool_name: str) -> bool:
630
+ return tool_name in PARALLEL_READ_TOOL_NAMES
631
+
632
+
633
+ def tool_execution_batches(tool_names: Sequence[str]) -> list[list[int]]:
634
+ batches: list[list[int]] = []
635
+ read_batch: list[int] = []
636
+ for index, tool_name in enumerate(tool_names):
637
+ if can_parallelize_tool_name(tool_name):
638
+ read_batch.append(index)
639
+ continue
640
+ if read_batch:
641
+ batches.append(read_batch)
642
+ read_batch = []
643
+ batches.append([index])
644
+ if read_batch:
645
+ batches.append(read_batch)
646
+ return batches
647
+
648
+
649
  class MultiTurnReactAgent(BaseAgent):
650
  def __init__(
651
  self,
 
1249
  tool_turn_message_start = len(messages)
1250
  messages.append(assistant_message)
1251
  deferred_image_contexts: list[tuple[str, str, Any, Any, dict[str, Any]]] = []
1252
+ tool_call_items: list[dict[str, Any]] = []
1253
  for tool_call, tool_arguments in zip(assistant_tool_calls, assistant_tool_arguments):
1254
+ function_block = tool_call.get("function", {}) if isinstance(tool_call, dict) else {}
1255
+ tool_call_items.append(
1256
+ {
1257
+ "tool_call_id": str(tool_call.get("id", "")),
1258
+ "tool_name": str(function_block.get("name", "")),
1259
+ "tool_arguments": tool_arguments,
1260
+ }
1261
+ )
1262
+
1263
+ def execute_tool_item(item: dict[str, Any]) -> tuple[dict[str, Any], Any]:
1264
+ result = self.custom_call_tool(
1265
+ str(item["tool_name"]),
1266
+ item["tool_arguments"],
1267
+ workspace_root=resolved_workspace_root,
1268
+ runtime_deadline=runtime_deadline,
1269
+ model_name=self.model,
1270
+ )
1271
+ return item, result
1272
+
1273
+ for batch_indexes in tool_execution_batches([str(item["tool_name"]) for item in tool_call_items]):
1274
  if remaining_runtime_seconds(runtime_deadline) is not None and remaining_runtime_seconds(runtime_deadline) <= 0:
1275
  result_text = "No result found before the maximum agent runtime limit."
1276
  termination = f"agent runtime limit reached: {agent_runtime_limit}s"
1277
  return finalize(result_text, termination, error=termination)
1278
+ batch_items = [tool_call_items[index] for index in batch_indexes]
 
 
1279
  try:
1280
+ should_run_parallel = len(batch_items) > 1 and all(
1281
+ can_parallelize_tool_name(str(item["tool_name"])) for item in batch_items
 
 
 
 
1282
  )
1283
+ if should_run_parallel:
1284
+ max_workers = min(MAX_PARALLEL_READ_TOOL_CALLS, len(batch_items))
1285
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
1286
+ batch_results = list(executor.map(execute_tool_item, batch_items))
1287
+ else:
1288
+ batch_results = [execute_tool_item(item) for item in batch_items]
1289
  except KeyboardInterrupt:
1290
  messages = messages[:tool_turn_message_start]
1291
  return finalize_interrupted()
1292
+
1293
+ for item, result in batch_results:
1294
+ tool_call_id = str(item["tool_call_id"])
1295
+ tool_name = str(item["tool_name"])
1296
+ tool_arguments = item["tool_arguments"]
1297
+ tool_result_text = tool_result_message_content(result)
1298
+ messages.append(api_tool_message(tool_call_id, result))
1299
+ trace_writer.append(
1300
+ role="tool",
1301
+ text=tool_result_text,
1302
+ turn_index=round_index,
1303
+ tool_call_ids=[tool_call_id],
1304
+ tool_names=[tool_name],
1305
+ tool_arguments=[tool_arguments],
1306
+ )
1307
+ extra_image_context = image_context_message(result, self.model)
1308
+ if extra_image_context is not None:
1309
+ deferred_image_contexts.append((tool_call_id, tool_name, tool_arguments, result, extra_image_context))
1310
  for tool_call_id, tool_name, tool_arguments, result, extra_image_context in deferred_image_contexts:
1311
  messages.append(extra_image_context)
1312
  trace_writer.append(
agent_base/tools/tool_web.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import re
5
  import sys
6
  import time
7
- from concurrent.futures import ThreadPoolExecutor
8
  from typing import Optional, Union
9
 
10
  import requests
@@ -61,17 +60,13 @@ def _clean_webpage_text(text: str) -> str:
61
 
62
  class WebSearch(ToolBase):
63
  name = "WebSearch"
64
- description = "Perform Google web searches and return the top results. Accepts multiple complementary queries."
65
  parameters = {
66
  "type": "object",
67
  "properties": {
68
  "query": {
69
- "type": "array",
70
- "items": {
71
- "type": "string",
72
- },
73
- "minItems": 1,
74
- "description": "Array of query strings. Include multiple complementary search queries in a single call.",
75
  },
76
  },
77
  "required": ["query"],
@@ -166,27 +161,21 @@ class WebSearch(ToolBase):
166
  except ValueError as exc:
167
  return f"[WebSearch] {exc}"
168
 
169
- if isinstance(query, list):
170
- with ThreadPoolExecutor(max_workers=3) as executor:
171
- responses = list(executor.map(self.search_with_serp, query))
172
- response = "\n=======\n".join(responses)
173
- else:
174
- return "[WebSearch] 'query' must be a list of strings."
175
 
176
- return response
177
 
178
 
179
  class ScholarSearch(ToolBase):
180
  name = "ScholarSearch"
181
- description = "Search academic sources through Google Scholar and return relevant publication results."
182
  parameters = {
183
  "type": "object",
184
  "properties": {
185
  "query": {
186
- "type": "array",
187
- "items": {"type": "string", "description": "The search query."},
188
- "minItems": 1,
189
- "description": "The list of search queries for Google Scholar.",
190
  },
191
  },
192
  "required": ["query"],
@@ -264,13 +253,9 @@ class ScholarSearch(ToolBase):
264
  except ValueError as exc:
265
  return f"[ScholarSearch] {exc}"
266
 
267
- if isinstance(query, list):
268
- with ThreadPoolExecutor(max_workers=3) as executor:
269
- response = list(executor.map(self.google_scholar_with_serp, query))
270
- response = "\n=======\n".join(response)
271
- else:
272
- return "[ScholarSearch] 'query' must be a list of strings."
273
- return response
274
 
275
 
276
  class WebFetch(ToolBase):
@@ -475,9 +460,9 @@ def main(argv: Optional[list[str]] = None) -> int:
475
  load_dotenv(PROJECT_ROOT / ".env")
476
 
477
  if args.tool == "search":
478
- result = WebSearch().call({"query": [" ".join(args.query)]})
479
  elif args.tool == "scholar":
480
- result = ScholarSearch().call({"query": [" ".join(args.query)]})
481
  else:
482
  result = WebFetch().call(
483
  {
 
4
  import re
5
  import sys
6
  import time
 
7
  from typing import Optional, Union
8
 
9
  import requests
 
60
 
61
  class WebSearch(ToolBase):
62
  name = "WebSearch"
63
+ description = "Perform one Google web search and return the top results. Call WebSearch multiple times for multiple queries."
64
  parameters = {
65
  "type": "object",
66
  "properties": {
67
  "query": {
68
+ "type": "string",
69
+ "description": "The search query.",
 
 
 
 
70
  },
71
  },
72
  "required": ["query"],
 
161
  except ValueError as exc:
162
  return f"[WebSearch] {exc}"
163
 
164
+ if not isinstance(query, str) or not query.strip():
165
+ return "[WebSearch] 'query' must be a non-empty string."
 
 
 
 
166
 
167
+ return self.search_with_serp(query.strip())
168
 
169
 
170
  class ScholarSearch(ToolBase):
171
  name = "ScholarSearch"
172
+ description = "Run one academic search through Google Scholar and return relevant publication results. Call ScholarSearch multiple times for multiple queries."
173
  parameters = {
174
  "type": "object",
175
  "properties": {
176
  "query": {
177
+ "type": "string",
178
+ "description": "The search query for Google Scholar.",
 
 
179
  },
180
  },
181
  "required": ["query"],
 
253
  except ValueError as exc:
254
  return f"[ScholarSearch] {exc}"
255
 
256
+ if not isinstance(query, str) or not query.strip():
257
+ return "[ScholarSearch] 'query' must be a non-empty string."
258
+ return self.google_scholar_with_serp(query.strip())
 
 
 
 
259
 
260
 
261
  class WebFetch(ToolBase):
 
460
  load_dotenv(PROJECT_ROOT / ".env")
461
 
462
  if args.tool == "search":
463
+ result = WebSearch().call({"query": " ".join(args.query)})
464
  elif args.tool == "scholar":
465
+ result = ScholarSearch().call({"query": " ".join(args.query)})
466
  else:
467
  result = WebFetch().call(
468
  {