File size: 2,323 Bytes
1b35d41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Data models for the Code Migration Environment.
"""

from typing import Literal

from openenv.core.env_server.types import Action, Observation
from pydantic import Field, model_validator


VALID_TOOL_NAMES = Literal[
    "list_dir",
    "search_dir",
    "search_file",
    "view_file",
    "edit_file",
    "replace_all_in_file",
    "revert_last",
    "execute_tests",
    "search_last_log",
    "view_last_log",
]

# Required arguments per tool (tools not listed here have no required args)
_TOOL_REQUIRED_ARGS: dict[str, list[str]] = {
    "list_dir": [],
    "search_dir": ["regex_pattern"],
    "search_file": ["regex_pattern", "file_path"],
    "view_file": ["file_path", "line_no"],
    "edit_file": ["file_path", "start_line", "end_line", "replacement_text"],
    "replace_all_in_file": ["file_path", "regex_pattern", "replacement_string"],
    "revert_last": [],
    "execute_tests": [],
    "search_last_log": ["regex_pattern"],
    "view_last_log": ["line_no"],
}


class CodeMigrationAction(Action):
    """Action for the Code Migration environment — one of 10 tool calls."""

    tool_name: VALID_TOOL_NAMES = Field(
        ..., description="Name of the tool to invoke"
    )
    tool_args: dict = Field(
        default_factory=dict,
        description="Tool-specific arguments as a dictionary",
    )

    @model_validator(mode="after")
    def validate_tool_args(self) -> "CodeMigrationAction":
        """Validate that required arguments are present for the given tool_name."""
        required = _TOOL_REQUIRED_ARGS.get(self.tool_name, [])
        missing = [arg for arg in required if arg not in self.tool_args]
        if missing:
            raise ValueError(
                f"Tool '{self.tool_name}' requires arguments: {missing}"
            )
        return self


class CodeMigrationObservation(Observation):
    """Observation from the Code Migration environment."""

    tool_output: str = Field(default="", description="String output from the tool")
    reward: float = Field(default=0.0)
    done: bool = Field(default=False)
    metadata: dict = Field(default_factory=dict)