Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Dispatch CodeMigrationAction to the corresponding plain tool function.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Optional | |
| from . import tools as T | |
| from .docker_sandbox import DockerSandbox | |
| class ToolResult: | |
| """Result returned by ToolExecutor.execute().""" | |
| output: str = "" | |
| patch: str | None = None | |
| full_log: str | None = None | |
| exit_code: int | None = None | |
| # Map tool_name → callable | |
| _TOOL_FUNCTIONS = { | |
| "list_dir": T.list_dir, | |
| "search_dir": T.search_dir, | |
| "search_file": T.search_file, | |
| "view_file": T.view_file, | |
| "edit_file": T.edit_file, | |
| "replace_all_in_file": T.replace_all_in_file, | |
| "revert_last": T.revert_last, | |
| "execute_tests": T.execute_tests, | |
| "search_last_log": T.search_last_log, | |
| "view_last_log": T.view_last_log, | |
| } | |
| class ToolExecutor: | |
| """Dispatch a tool call, injecting environment-managed arguments.""" | |
| def execute( | |
| self, | |
| tool_name: str, | |
| tool_args: dict, | |
| *, | |
| host_repo_dir: str, | |
| repo_name: str, | |
| test_files: list[str], | |
| image_name: str, | |
| last_log_path: str, | |
| last_patch: tuple[str, str] | None, | |
| sandbox: DockerSandbox, | |
| ) -> ToolResult: | |
| """Dispatch to the appropriate tool function with injected args.""" | |
| try: | |
| fn = _TOOL_FUNCTIONS.get(tool_name) | |
| if fn is None: | |
| return ToolResult( | |
| output=f"Unknown tool: {tool_name}. Valid tools: {list(_TOOL_FUNCTIONS.keys())}" | |
| ) | |
| # Build the full kwargs: start with user-provided args, then inject | |
| kwargs = dict(tool_args) | |
| # Inject per the action→tool argument mapping table | |
| if tool_name == "list_dir": | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| kwargs["repo_name"] = repo_name | |
| elif tool_name in ("search_dir",): | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| elif tool_name in ("search_file", "view_file"): | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| elif tool_name in ("edit_file", "replace_all_in_file"): | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| kwargs["test_files"] = test_files | |
| elif tool_name == "revert_last": | |
| kwargs["last_patch"] = last_patch | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| elif tool_name == "execute_tests": | |
| kwargs["host_repo_dir"] = host_repo_dir | |
| kwargs["image_name"] = image_name | |
| kwargs["sec_timeout"] = sandbox.timeout | |
| kwargs["mem_limit"] = sandbox.memory_limit | |
| elif tool_name in ("search_last_log", "view_last_log"): | |
| kwargs["last_log_path"] = last_log_path | |
| observation = fn(**kwargs) | |
| # Handle dict-returning tools (edit_file, replace_all_in_file, execute_tests) | |
| if isinstance(observation, dict): | |
| if tool_name in ("edit_file", "replace_all_in_file"): | |
| return ToolResult( | |
| output=observation.get("message", ""), | |
| patch=observation.get("patch"), | |
| ) | |
| elif tool_name == "execute_tests": | |
| return ToolResult( | |
| output=observation.get("test_result", ""), | |
| full_log=observation.get("full_log"), | |
| exit_code=observation.get("container_status"), | |
| ) | |
| # String-returning tools | |
| return ToolResult(output=str(observation)) | |
| except Exception as e: | |
| return ToolResult(output=f"Tool call failed with error: {str(e)}") | |