amrithanandini's picture
integrated backend and frontend
1b35d41
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Data models for the Code Migration Environment.
"""
from typing import Literal
from openenv.core.env_server.types import Action, Observation
from pydantic import Field, model_validator
VALID_TOOL_NAMES = Literal[
"list_dir",
"search_dir",
"search_file",
"view_file",
"edit_file",
"replace_all_in_file",
"revert_last",
"execute_tests",
"search_last_log",
"view_last_log",
]
# Required arguments per tool (tools not listed here have no required args)
_TOOL_REQUIRED_ARGS: dict[str, list[str]] = {
"list_dir": [],
"search_dir": ["regex_pattern"],
"search_file": ["regex_pattern", "file_path"],
"view_file": ["file_path", "line_no"],
"edit_file": ["file_path", "start_line", "end_line", "replacement_text"],
"replace_all_in_file": ["file_path", "regex_pattern", "replacement_string"],
"revert_last": [],
"execute_tests": [],
"search_last_log": ["regex_pattern"],
"view_last_log": ["line_no"],
}
class CodeMigrationAction(Action):
"""Action for the Code Migration environment — one of 10 tool calls."""
tool_name: VALID_TOOL_NAMES = Field(
..., description="Name of the tool to invoke"
)
tool_args: dict = Field(
default_factory=dict,
description="Tool-specific arguments as a dictionary",
)
@model_validator(mode="after")
def validate_tool_args(self) -> "CodeMigrationAction":
"""Validate that required arguments are present for the given tool_name."""
required = _TOOL_REQUIRED_ARGS.get(self.tool_name, [])
missing = [arg for arg in required if arg not in self.tool_args]
if missing:
raise ValueError(
f"Tool '{self.tool_name}' requires arguments: {missing}"
)
return self
class CodeMigrationObservation(Observation):
"""Observation from the Code Migration environment."""
tool_output: str = Field(default="", description="String output from the tool")
reward: float = Field(default=0.0)
done: bool = Field(default=False)
metadata: dict = Field(default_factory=dict)