# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Dispatch CodeMigrationAction to the corresponding plain tool function.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Optional from . import tools as T from .docker_sandbox import DockerSandbox @dataclass class ToolResult: """Result returned by ToolExecutor.execute().""" output: str = "" patch: str | None = None full_log: str | None = None exit_code: int | None = None # Map tool_name → callable _TOOL_FUNCTIONS = { "list_dir": T.list_dir, "search_dir": T.search_dir, "search_file": T.search_file, "view_file": T.view_file, "edit_file": T.edit_file, "replace_all_in_file": T.replace_all_in_file, "revert_last": T.revert_last, "execute_tests": T.execute_tests, "search_last_log": T.search_last_log, "view_last_log": T.view_last_log, } class ToolExecutor: """Dispatch a tool call, injecting environment-managed arguments.""" def execute( self, tool_name: str, tool_args: dict, *, host_repo_dir: str, repo_name: str, test_files: list[str], image_name: str, last_log_path: str, last_patch: tuple[str, str] | None, sandbox: DockerSandbox, ) -> ToolResult: """Dispatch to the appropriate tool function with injected args.""" try: fn = _TOOL_FUNCTIONS.get(tool_name) if fn is None: return ToolResult( output=f"Unknown tool: {tool_name}. Valid tools: {list(_TOOL_FUNCTIONS.keys())}" ) # Build the full kwargs: start with user-provided args, then inject kwargs = dict(tool_args) # Inject per the action→tool argument mapping table if tool_name == "list_dir": kwargs["host_repo_dir"] = host_repo_dir kwargs["repo_name"] = repo_name elif tool_name in ("search_dir",): kwargs["host_repo_dir"] = host_repo_dir elif tool_name in ("search_file", "view_file"): kwargs["host_repo_dir"] = host_repo_dir elif tool_name in ("edit_file", "replace_all_in_file"): kwargs["host_repo_dir"] = host_repo_dir kwargs["test_files"] = test_files elif tool_name == "revert_last": kwargs["last_patch"] = last_patch kwargs["host_repo_dir"] = host_repo_dir elif tool_name == "execute_tests": kwargs["host_repo_dir"] = host_repo_dir kwargs["image_name"] = image_name kwargs["sec_timeout"] = sandbox.timeout kwargs["mem_limit"] = sandbox.memory_limit elif tool_name in ("search_last_log", "view_last_log"): kwargs["last_log_path"] = last_log_path observation = fn(**kwargs) # Handle dict-returning tools (edit_file, replace_all_in_file, execute_tests) if isinstance(observation, dict): if tool_name in ("edit_file", "replace_all_in_file"): return ToolResult( output=observation.get("message", ""), patch=observation.get("patch"), ) elif tool_name == "execute_tests": return ToolResult( output=observation.get("test_result", ""), full_log=observation.get("full_log"), exit_code=observation.get("container_status"), ) # String-returning tools return ToolResult(output=str(observation)) except Exception as e: return ToolResult(output=f"Tool call failed with error: {str(e)}")