File size: 4,015 Bytes
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Dispatch CodeMigrationAction to the corresponding plain tool function."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Optional

from . import tools as T
from .docker_sandbox import DockerSandbox


@dataclass
class ToolResult:
    """Result returned by ToolExecutor.execute()."""

    output: str = ""
    patch: str | None = None
    full_log: str | None = None
    exit_code: int | None = None


# Map tool_name → callable
_TOOL_FUNCTIONS = {
    "list_dir": T.list_dir,
    "search_dir": T.search_dir,
    "search_file": T.search_file,
    "view_file": T.view_file,
    "edit_file": T.edit_file,
    "replace_all_in_file": T.replace_all_in_file,
    "revert_last": T.revert_last,
    "execute_tests": T.execute_tests,
    "search_last_log": T.search_last_log,
    "view_last_log": T.view_last_log,
}


class ToolExecutor:
    """Dispatch a tool call, injecting environment-managed arguments."""

    def execute(
        self,
        tool_name: str,
        tool_args: dict,
        *,
        host_repo_dir: str,
        repo_name: str,
        test_files: list[str],
        image_name: str,
        last_log_path: str,
        last_patch: tuple[str, str] | None,
        sandbox: DockerSandbox,
    ) -> ToolResult:
        """Dispatch to the appropriate tool function with injected args."""
        try:
            fn = _TOOL_FUNCTIONS.get(tool_name)
            if fn is None:
                return ToolResult(
                    output=f"Unknown tool: {tool_name}. Valid tools: {list(_TOOL_FUNCTIONS.keys())}"
                )

            # Build the full kwargs: start with user-provided args, then inject
            kwargs = dict(tool_args)

            # Inject per the action→tool argument mapping table
            if tool_name == "list_dir":
                kwargs["host_repo_dir"] = host_repo_dir
                kwargs["repo_name"] = repo_name
            elif tool_name in ("search_dir",):
                kwargs["host_repo_dir"] = host_repo_dir
            elif tool_name in ("search_file", "view_file"):
                kwargs["host_repo_dir"] = host_repo_dir
            elif tool_name in ("edit_file", "replace_all_in_file"):
                kwargs["host_repo_dir"] = host_repo_dir
                kwargs["test_files"] = test_files
            elif tool_name == "revert_last":
                kwargs["last_patch"] = last_patch
                kwargs["host_repo_dir"] = host_repo_dir
            elif tool_name == "execute_tests":
                kwargs["host_repo_dir"] = host_repo_dir
                kwargs["image_name"] = image_name
                kwargs["sec_timeout"] = sandbox.timeout
                kwargs["mem_limit"] = sandbox.memory_limit
            elif tool_name in ("search_last_log", "view_last_log"):
                kwargs["last_log_path"] = last_log_path

            observation = fn(**kwargs)

            # Handle dict-returning tools (edit_file, replace_all_in_file, execute_tests)
            if isinstance(observation, dict):
                if tool_name in ("edit_file", "replace_all_in_file"):
                    return ToolResult(
                        output=observation.get("message", ""),
                        patch=observation.get("patch"),
                    )
                elif tool_name == "execute_tests":
                    return ToolResult(
                        output=observation.get("test_result", ""),
                        full_log=observation.get("full_log"),
                        exit_code=observation.get("container_status"),
                    )

            # String-returning tools
            return ToolResult(output=str(observation))

        except Exception as e:
            return ToolResult(output=f"Tool call failed with error: {str(e)}")