black-yt commited on
Commit
75ff73e
·
1 Parent(s): f7e066b

Sync ResearchHarness runtime update

Browse files
Files changed (3) hide show
  1. VERSION +1 -1
  2. agent_base/react_agent.py +62 -5
  3. agent_base/tools/custom.py +188 -0
VERSION CHANGED
@@ -1 +1 @@
1
- v0.0.38
 
1
+ v0.0.39
agent_base/react_agent.py CHANGED
@@ -20,6 +20,7 @@ from agent_base.provider_compat import apply_sampling_params
20
  from agent_base.prompt import composed_system_prompt
21
  from agent_base.session_state import AgentSessionState, CompactionRecord, persist_session_state, resolve_session_state_path
22
  from agent_base.trace_utils import FlatTraceWriter
 
23
  from agent_base.tools.tooling import normalize_workspace_root
24
  from agent_base.tools.tool_extra import StrReplaceEditor
25
  from agent_base.tools.tool_file import Edit, Glob, Grep, Read, ReadImage, ReadPDF, Write
@@ -645,6 +646,16 @@ def tool_execution_batches(tool_names: Sequence[str]) -> list[list[int]]:
645
  return batches
646
 
647
 
 
 
 
 
 
 
 
 
 
 
648
  class MultiTurnReactAgent(BaseAgent):
649
  def __init__(
650
  self,
@@ -652,16 +663,34 @@ class MultiTurnReactAgent(BaseAgent):
652
  llm: Optional[Dict] = None,
653
  trace_dir: Optional[str] = None,
654
  role_prompt: Optional[str] = None,
 
 
655
  max_llm_calls: Optional[int] = None,
656
  max_rounds: Optional[int] = None,
657
  max_runtime_seconds: Optional[int] = None,
658
  ):
659
  if not isinstance(llm, dict):
660
  raise ValueError("llm must be a dict configuration.")
 
 
 
 
 
661
  requested_tools = self.resolve_function_list(function_list)
662
  if requested_tools is None:
663
  requested_tools = list(AVAILABLE_TOOL_MAP.keys())
664
- unknown_tools = [tool for tool in requested_tools if tool not in ALL_TOOL_MAP]
 
 
 
 
 
 
 
 
 
 
 
665
  if unknown_tools:
666
  raise ValueError(f"Unknown tools requested: {unknown_tools}")
667
  if "model" not in llm or not str(llm["model"]).strip():
@@ -669,7 +698,7 @@ class MultiTurnReactAgent(BaseAgent):
669
  if "generate_cfg" not in llm or not isinstance(llm["generate_cfg"], dict):
670
  raise ValueError('llm["generate_cfg"] must be a dict.')
671
 
672
- self.tool_map = {tool_name: ALL_TOOL_MAP[tool_name] for tool_name in requested_tools}
673
  self.tool_names = list(self.tool_map.keys())
674
  self.model = str(llm["model"])
675
  self.llm_generate_cfg = llm["generate_cfg"]
@@ -677,6 +706,7 @@ class MultiTurnReactAgent(BaseAgent):
677
  self.trace_path: Optional[Path] = None
678
  self.session_state_path: Optional[Path] = None
679
  self.role_prompt = self.resolve_role_prompt(role_prompt)
 
680
  self.max_llm_calls = int(max_llm_calls) if max_llm_calls is not None else max_llm_calls_per_run()
681
  self.max_rounds = int(max_rounds) if max_rounds is not None else max_agent_rounds()
682
  self.max_runtime_seconds = (
@@ -873,9 +903,34 @@ class MultiTurnReactAgent(BaseAgent):
873
  )
874
  return token_count
875
 
876
- def run(self, prompt: str, workspace_root: Optional[str] = None) -> str:
 
 
 
 
 
877
  """Run the agent on one prompt and return only the final result text."""
878
- return self._run_session(prompt, workspace_root=workspace_root)["result_text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
 
880
  def _run_session(
881
  self,
@@ -891,7 +946,9 @@ class MultiTurnReactAgent(BaseAgent):
891
  raise ValueError("prompt must be a non-empty string.")
892
 
893
  prompt_text = prompt.strip()
894
- resolved_workspace_root = normalize_workspace_root(workspace_root)
 
 
895
  start_time = time.time()
896
  trace_dir = self.trace_dir
897
  cur_date = today_date()
 
20
  from agent_base.prompt import composed_system_prompt
21
  from agent_base.session_state import AgentSessionState, CompactionRecord, persist_session_state, resolve_session_state_path
22
  from agent_base.trace_utils import FlatTraceWriter
23
+ from agent_base.tools.custom import build_custom_tool_map
24
  from agent_base.tools.tooling import normalize_workspace_root
25
  from agent_base.tools.tool_extra import StrReplaceEditor
26
  from agent_base.tools.tool_file import Edit, Glob, Grep, Read, ReadImage, ReadPDF, Write
 
646
  return batches
647
 
648
 
649
+ def normalized_image_inputs(images: Optional[str | Path | Sequence[str | Path]]) -> list[str | Path]:
650
+ if images is None:
651
+ return []
652
+ if isinstance(images, (str, Path)):
653
+ return [images]
654
+ if isinstance(images, Sequence) and not isinstance(images, (str, bytes)):
655
+ return list(images)
656
+ raise ValueError("images must be a path or a sequence of paths.")
657
+
658
+
659
  class MultiTurnReactAgent(BaseAgent):
660
  def __init__(
661
  self,
 
663
  llm: Optional[Dict] = None,
664
  trace_dir: Optional[str] = None,
665
  role_prompt: Optional[str] = None,
666
+ workspace_root: Optional[str] = None,
667
+ custom_tools: Optional[Sequence[Any]] = None,
668
  max_llm_calls: Optional[int] = None,
669
  max_rounds: Optional[int] = None,
670
  max_runtime_seconds: Optional[int] = None,
671
  ):
672
  if not isinstance(llm, dict):
673
  raise ValueError("llm must be a dict configuration.")
674
+ custom_tool_map = build_custom_tool_map(custom_tools)
675
+ conflicting_tools = [name for name in custom_tool_map if name in ALL_TOOL_MAP]
676
+ if conflicting_tools:
677
+ raise ValueError(f"Custom tool names conflict with built-in tools: {conflicting_tools}")
678
+ tool_registry = {**ALL_TOOL_MAP, **custom_tool_map}
679
  requested_tools = self.resolve_function_list(function_list)
680
  if requested_tools is None:
681
  requested_tools = list(AVAILABLE_TOOL_MAP.keys())
682
+ for tool_name in custom_tool_map:
683
+ if tool_name not in requested_tools:
684
+ requested_tools.append(tool_name)
685
+ duplicate_tools: list[str] = []
686
+ seen_tools: set[str] = set()
687
+ for tool_name in requested_tools:
688
+ if tool_name in seen_tools and tool_name not in duplicate_tools:
689
+ duplicate_tools.append(tool_name)
690
+ seen_tools.add(tool_name)
691
+ if duplicate_tools:
692
+ raise ValueError(f"Duplicate tools requested: {duplicate_tools}")
693
+ unknown_tools = [tool for tool in requested_tools if tool not in tool_registry]
694
  if unknown_tools:
695
  raise ValueError(f"Unknown tools requested: {unknown_tools}")
696
  if "model" not in llm or not str(llm["model"]).strip():
 
698
  if "generate_cfg" not in llm or not isinstance(llm["generate_cfg"], dict):
699
  raise ValueError('llm["generate_cfg"] must be a dict.')
700
 
701
+ self.tool_map = {tool_name: tool_registry[tool_name] for tool_name in requested_tools}
702
  self.tool_names = list(self.tool_map.keys())
703
  self.model = str(llm["model"])
704
  self.llm_generate_cfg = llm["generate_cfg"]
 
706
  self.trace_path: Optional[Path] = None
707
  self.session_state_path: Optional[Path] = None
708
  self.role_prompt = self.resolve_role_prompt(role_prompt)
709
+ self.workspace_root = normalize_workspace_root(workspace_root) if workspace_root else None
710
  self.max_llm_calls = int(max_llm_calls) if max_llm_calls is not None else max_llm_calls_per_run()
711
  self.max_rounds = int(max_rounds) if max_rounds is not None else max_agent_rounds()
712
  self.max_runtime_seconds = (
 
903
  )
904
  return token_count
905
 
906
+ def run(
907
+ self,
908
+ prompt: str,
909
+ workspace_root: Optional[str] = None,
910
+ images: Optional[str | Path | Sequence[str | Path]] = None,
911
+ ) -> str:
912
  """Run the agent on one prompt and return only the final result text."""
913
+ resolved_workspace_root = normalize_workspace_root(
914
+ workspace_root if workspace_root is not None else self.workspace_root
915
+ )
916
+ run_prompt = prompt
917
+ initial_content_parts: list[dict[str, Any]] = []
918
+ saved_image_paths: list[str] = []
919
+ for image_index, image_path in enumerate(normalized_image_inputs(images)):
920
+ saved_path, data_url = stage_image_file_for_input(
921
+ image_path,
922
+ workspace_root=resolved_workspace_root,
923
+ image_index=image_index,
924
+ )
925
+ saved_image_paths.append(saved_path)
926
+ initial_content_parts.extend(image_input_content_parts(data_url, saved_path))
927
+ if saved_image_paths:
928
+ run_prompt = append_saved_image_paths_to_prompt(prompt, saved_image_paths)
929
+ return self._run_session(
930
+ run_prompt,
931
+ workspace_root=str(resolved_workspace_root),
932
+ initial_content_parts=initial_content_parts or None,
933
+ )["result_text"]
934
 
935
  def _run_session(
936
  self,
 
946
  raise ValueError("prompt must be a non-empty string.")
947
 
948
  prompt_text = prompt.strip()
949
+ resolved_workspace_root = normalize_workspace_root(
950
+ workspace_root if workspace_root is not None else self.workspace_root
951
+ )
952
  start_time = time.time()
953
  trace_dir = self.trace_dir
954
  cur_date = today_date()
agent_base/tools/custom.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python function tools for the public ResearchHarness embedding API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import inspect
6
+ import re
7
+ from collections.abc import Callable, Sequence as AbcSequence
8
+ from types import UnionType
9
+ from typing import Any, Literal, Sequence, Union, get_args, get_origin, get_type_hints
10
+
11
+ from agent_base.tools.tooling import ToolBase
12
+
13
+
14
+ TOOL_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_-]{0,63}$")
15
+ CONTEXT_PARAMETER_NAMES = frozenset({"workspace_root", "runtime_deadline", "model_name"})
16
+
17
+
18
+ class FunctionTool(ToolBase):
19
+ """ToolBase adapter for a validated Python function."""
20
+
21
+ def __init__(self, func: Callable[..., Any], *, name: str | None = None, description: str | None = None):
22
+ self.func = func
23
+ self._context_parameters: set[str] = set()
24
+ self.name = _resolve_tool_name(func, name)
25
+ self.description = _resolve_tool_description(func, description)
26
+ self.parameters = _schema_from_signature(func, self._context_parameters)
27
+ super().__init__()
28
+
29
+ def call(self, params: str | dict[str, Any], **kwargs: Any) -> Any:
30
+ parsed = self.parse_json_args(params)
31
+ call_kwargs = dict(parsed)
32
+ for name in self._context_parameters:
33
+ if name in kwargs:
34
+ call_kwargs[name] = kwargs[name]
35
+ return self.func(**call_kwargs)
36
+
37
+
38
+ def tool(
39
+ func: Callable[..., Any] | None = None,
40
+ *,
41
+ name: str | None = None,
42
+ description: str | None = None,
43
+ ) -> Callable[..., Any]:
44
+ """Mark a Python function as a ResearchHarness custom tool.
45
+
46
+ The decorated function remains directly callable. ResearchHarness converts it
47
+ into a ToolBase instance when passed to create_agent(tools=[...]).
48
+ """
49
+
50
+ def decorate(inner: Callable[..., Any]) -> Callable[..., Any]:
51
+ if not callable(inner):
52
+ raise TypeError("@tool can only decorate a callable.")
53
+ setattr(inner, "__researchharness_tool__", {"name": name, "description": description})
54
+ return inner
55
+
56
+ if func is None:
57
+ return decorate
58
+ return decorate(func)
59
+
60
+
61
+ def build_custom_tool_map(custom_tools: Sequence[Any] | None) -> dict[str, ToolBase]:
62
+ """Validate and instantiate user-provided custom tools."""
63
+
64
+ resolved: dict[str, ToolBase] = {}
65
+ for item in custom_tools or []:
66
+ tool_obj = _coerce_custom_tool(item)
67
+ if tool_obj.name in resolved:
68
+ raise ValueError(f"Duplicate custom tool name: {tool_obj.name}")
69
+ resolved[tool_obj.name] = tool_obj
70
+ return resolved
71
+
72
+
73
+ def _coerce_custom_tool(item: Any) -> ToolBase:
74
+ if isinstance(item, ToolBase):
75
+ return item
76
+ if callable(item):
77
+ metadata = getattr(item, "__researchharness_tool__", None)
78
+ if not isinstance(metadata, dict):
79
+ raise ValueError(
80
+ f"Custom tool function {getattr(item, '__name__', item)!r} must be decorated with @researchharness.tool."
81
+ )
82
+ return FunctionTool(
83
+ item,
84
+ name=metadata.get("name"),
85
+ description=metadata.get("description"),
86
+ )
87
+ raise ValueError(f"Custom tool must be a decorated function or ToolBase instance, got {type(item).__name__}.")
88
+
89
+
90
+ def _resolve_tool_name(func: Callable[..., Any], override: str | None) -> str:
91
+ name = str(override or getattr(func, "__name__", "")).strip()
92
+ if not name:
93
+ raise ValueError("Custom tool name must be non-empty.")
94
+ if not TOOL_NAME_RE.fullmatch(name):
95
+ raise ValueError(
96
+ f"Invalid custom tool name {name!r}. Use 1-64 characters: letters, numbers, underscore, or hyphen; start with a letter or underscore."
97
+ )
98
+ return name
99
+
100
+
101
+ def _resolve_tool_description(func: Callable[..., Any], override: str | None) -> str:
102
+ description = str(override or inspect.getdoc(func) or "").strip()
103
+ if not description:
104
+ raise ValueError(f"Custom tool {getattr(func, '__name__', '<callable>')!r} must have a docstring or description.")
105
+ return description
106
+
107
+
108
+ def _schema_from_signature(func: Callable[..., Any], context_parameters: set[str]) -> dict[str, Any]:
109
+ signature = inspect.signature(func)
110
+ try:
111
+ hints = get_type_hints(func)
112
+ except Exception as exc:
113
+ raise ValueError(f"Could not resolve type hints for custom tool {func.__name__}: {exc}") from exc
114
+
115
+ properties: dict[str, Any] = {}
116
+ required: list[str] = []
117
+ for param in signature.parameters.values():
118
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
119
+ raise ValueError(f"Custom tool {func.__name__} may not use *args or **kwargs.")
120
+ if param.kind == inspect.Parameter.POSITIONAL_ONLY:
121
+ raise ValueError(f"Custom tool {func.__name__} may not use positional-only parameters.")
122
+ if param.name in CONTEXT_PARAMETER_NAMES:
123
+ if param.kind is not inspect.Parameter.KEYWORD_ONLY:
124
+ raise ValueError(f"Context parameter {param.name!r} in custom tool {func.__name__} must be keyword-only.")
125
+ context_parameters.add(param.name)
126
+ continue
127
+ if param.name not in hints:
128
+ raise ValueError(f"Custom tool {func.__name__} parameter {param.name!r} must have a type annotation.")
129
+ schema, nullable = _annotation_to_schema(hints[param.name], f"{func.__name__}.{param.name}")
130
+ if param.default is inspect.Parameter.empty and not nullable:
131
+ required.append(param.name)
132
+ elif param.default is not inspect.Parameter.empty:
133
+ schema["default"] = param.default
134
+ properties[param.name] = schema
135
+
136
+ return {
137
+ "type": "object",
138
+ "properties": properties,
139
+ "required": required,
140
+ "additionalProperties": False,
141
+ }
142
+
143
+
144
+ def _annotation_to_schema(annotation: Any, label: str) -> tuple[dict[str, Any], bool]:
145
+ origin = get_origin(annotation)
146
+ args = get_args(annotation)
147
+
148
+ if annotation is Any:
149
+ raise ValueError(f"Custom tool parameter {label} may not use Any; use a concrete JSON-compatible type.")
150
+ if origin in (UnionType, Union):
151
+ non_none = [arg for arg in args if arg is not type(None)]
152
+ if len(non_none) == 1 and len(non_none) != len(args):
153
+ schema, _ = _annotation_to_schema(non_none[0], label)
154
+ return schema, True
155
+ raise ValueError(f"Custom tool parameter {label} uses an unsupported union type.")
156
+ if origin is Literal:
157
+ values = list(args)
158
+ if not values:
159
+ raise ValueError(f"Custom tool parameter {label} uses an empty Literal.")
160
+ value_types = {type(value) for value in values}
161
+ if len(value_types) != 1 or next(iter(value_types)) not in {str, int, float, bool}:
162
+ raise ValueError(f"Custom tool parameter {label} uses unsupported Literal values.")
163
+ schema, _ = _annotation_to_schema(type(values[0]), label)
164
+ schema["enum"] = values
165
+ return schema, False
166
+ if annotation is str:
167
+ return {"type": "string"}, False
168
+ if annotation is int:
169
+ return {"type": "integer"}, False
170
+ if annotation is float:
171
+ return {"type": "number"}, False
172
+ if annotation is bool:
173
+ return {"type": "boolean"}, False
174
+ if annotation is dict:
175
+ return {"type": "object"}, False
176
+ if annotation in (list, tuple):
177
+ return {"type": "array"}, False
178
+ if origin in (list, tuple, Sequence, AbcSequence):
179
+ item_schema: dict[str, Any] = {}
180
+ if args and args[0] is not Ellipsis:
181
+ item_schema, _ = _annotation_to_schema(args[0], label)
182
+ return {"type": "array", "items": item_schema}, False
183
+ if origin is dict:
184
+ key_type = args[0] if args else str
185
+ if key_type is not str:
186
+ raise ValueError(f"Custom tool parameter {label} dict keys must be str.")
187
+ return {"type": "object"}, False
188
+ raise ValueError(f"Custom tool parameter {label} has unsupported type annotation: {annotation!r}")