# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Data models for the Code Migration Environment. """ from typing import Literal from openenv.core.env_server.types import Action, Observation from pydantic import Field, model_validator VALID_TOOL_NAMES = Literal[ "list_dir", "search_dir", "search_file", "view_file", "edit_file", "replace_all_in_file", "revert_last", "execute_tests", "search_last_log", "view_last_log", ] # Required arguments per tool (tools not listed here have no required args) _TOOL_REQUIRED_ARGS: dict[str, list[str]] = { "list_dir": [], "search_dir": ["regex_pattern"], "search_file": ["regex_pattern", "file_path"], "view_file": ["file_path", "line_no"], "edit_file": ["file_path", "start_line", "end_line", "replacement_text"], "replace_all_in_file": ["file_path", "regex_pattern", "replacement_string"], "revert_last": [], "execute_tests": [], "search_last_log": ["regex_pattern"], "view_last_log": ["line_no"], } class CodeMigrationAction(Action): """Action for the Code Migration environment — one of 10 tool calls.""" tool_name: VALID_TOOL_NAMES = Field( ..., description="Name of the tool to invoke" ) tool_args: dict = Field( default_factory=dict, description="Tool-specific arguments as a dictionary", ) @model_validator(mode="after") def validate_tool_args(self) -> "CodeMigrationAction": """Validate that required arguments are present for the given tool_name.""" required = _TOOL_REQUIRED_ARGS.get(self.tool_name, []) missing = [arg for arg in required if arg not in self.tool_args] if missing: raise ValueError( f"Tool '{self.tool_name}' requires arguments: {missing}" ) return self class CodeMigrationObservation(Observation): """Observation from the Code Migration environment.""" tool_output: str = Field(default="", description="String output from the tool") reward: float = Field(default=0.0) done: bool = Field(default=False) metadata: dict = Field(default_factory=dict)