Spaces:
Sleeping
Sleeping
File size: 4,015 Bytes
1b35d41 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Dispatch CodeMigrationAction to the corresponding plain tool function."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional
from . import tools as T
from .docker_sandbox import DockerSandbox
@dataclass
class ToolResult:
"""Result returned by ToolExecutor.execute()."""
output: str = ""
patch: str | None = None
full_log: str | None = None
exit_code: int | None = None
# Map tool_name → callable
_TOOL_FUNCTIONS = {
"list_dir": T.list_dir,
"search_dir": T.search_dir,
"search_file": T.search_file,
"view_file": T.view_file,
"edit_file": T.edit_file,
"replace_all_in_file": T.replace_all_in_file,
"revert_last": T.revert_last,
"execute_tests": T.execute_tests,
"search_last_log": T.search_last_log,
"view_last_log": T.view_last_log,
}
class ToolExecutor:
"""Dispatch a tool call, injecting environment-managed arguments."""
def execute(
self,
tool_name: str,
tool_args: dict,
*,
host_repo_dir: str,
repo_name: str,
test_files: list[str],
image_name: str,
last_log_path: str,
last_patch: tuple[str, str] | None,
sandbox: DockerSandbox,
) -> ToolResult:
"""Dispatch to the appropriate tool function with injected args."""
try:
fn = _TOOL_FUNCTIONS.get(tool_name)
if fn is None:
return ToolResult(
output=f"Unknown tool: {tool_name}. Valid tools: {list(_TOOL_FUNCTIONS.keys())}"
)
# Build the full kwargs: start with user-provided args, then inject
kwargs = dict(tool_args)
# Inject per the action→tool argument mapping table
if tool_name == "list_dir":
kwargs["host_repo_dir"] = host_repo_dir
kwargs["repo_name"] = repo_name
elif tool_name in ("search_dir",):
kwargs["host_repo_dir"] = host_repo_dir
elif tool_name in ("search_file", "view_file"):
kwargs["host_repo_dir"] = host_repo_dir
elif tool_name in ("edit_file", "replace_all_in_file"):
kwargs["host_repo_dir"] = host_repo_dir
kwargs["test_files"] = test_files
elif tool_name == "revert_last":
kwargs["last_patch"] = last_patch
kwargs["host_repo_dir"] = host_repo_dir
elif tool_name == "execute_tests":
kwargs["host_repo_dir"] = host_repo_dir
kwargs["image_name"] = image_name
kwargs["sec_timeout"] = sandbox.timeout
kwargs["mem_limit"] = sandbox.memory_limit
elif tool_name in ("search_last_log", "view_last_log"):
kwargs["last_log_path"] = last_log_path
observation = fn(**kwargs)
# Handle dict-returning tools (edit_file, replace_all_in_file, execute_tests)
if isinstance(observation, dict):
if tool_name in ("edit_file", "replace_all_in_file"):
return ToolResult(
output=observation.get("message", ""),
patch=observation.get("patch"),
)
elif tool_name == "execute_tests":
return ToolResult(
output=observation.get("test_result", ""),
full_log=observation.get("full_log"),
exit_code=observation.get("container_status"),
)
# String-returning tools
return ToolResult(output=str(observation))
except Exception as e:
return ToolResult(output=f"Tool call failed with error: {str(e)}")
|