# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ Code Migration Environment — full OpenEnv-compatible RL environment. """ from __future__ import annotations import atexit import math import os import re import tempfile from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..dataset_loader import DatasetLoader, Task from ..docker_sandbox import DockerSandbox from ..models import CodeMigrationAction, CodeMigrationObservation from ..prompt import MIGRATION_SYSTEM_PROMPT from ..repo_manager import RepoManager from ..tool_executor import ToolExecutor except ImportError: from dataset_loader import DatasetLoader, Task from docker_sandbox import DockerSandbox from models import CodeMigrationAction, CodeMigrationObservation from prompt import MIGRATION_SYSTEM_PROMPT from repo_manager import RepoManager from tool_executor import ToolExecutor # ────────────────────────────────────────────────────────────── # Test log parsing — extract pass/fail counts for continuous reward # ────────────────────────────────────────────────────────────── # pytest: "5 passed, 2 failed in 0.12s" _RE_PYTEST_PASSED = re.compile(r"(\d+)\s+passed") _RE_PYTEST_FAILED = re.compile(r"(\d+)\s+failed") _RE_PYTEST_ERROR = re.compile(r"(\d+)\s+error") # unittest: "Ran 7 tests" + "OK" or "FAILED (failures=2, errors=1)" _RE_UNITTEST_RAN = re.compile(r"Ran\s+(\d+)\s+tests?") _RE_UNITTEST_FAIL = re.compile(r"failures=(\d+)") _RE_UNITTEST_ERR = re.compile(r"errors=(\d+)") def _parse_test_pass_rate(log_text: str) -> float | None: """Extract the fraction of tests passing from a test log. Returns a float in [0.0, 1.0] or None if we can't parse. """ if not log_text: return None # Try pytest format first passed_m = _RE_PYTEST_PASSED.search(log_text) failed_m = _RE_PYTEST_FAILED.search(log_text) error_m = _RE_PYTEST_ERROR.search(log_text) if passed_m: passed = int(passed_m.group(1)) failed = int(failed_m.group(1)) if failed_m else 0 errors = int(error_m.group(1)) if error_m else 0 total = passed + failed + errors if total > 0: return passed / total # Try unittest format ran_m = _RE_UNITTEST_RAN.search(log_text) if ran_m: total = int(ran_m.group(1)) if total == 0: return None fail_m = _RE_UNITTEST_FAIL.search(log_text) err_m = _RE_UNITTEST_ERR.search(log_text) failures = int(fail_m.group(1)) if fail_m else 0 errors = int(err_m.group(1)) if err_m else 0 # If "OK" appears and no failures/errors, all passed if "OK" in log_text and failures == 0 and errors == 0: return 1.0 passed = total - failures - errors return max(0.0, passed / total) return None # ────────────────────────────────────────────────────────────── # Reward function # ────────────────────────────────────────────────────────────── def compute_reward( *, tool_name: str, result_output: str, result_patch: str | None, test_exit_code: int | None, test_log: str | None, prev_pass_rate: float | None, curr_pass_rate: float | None, step_count: int, max_steps: int, is_limit_hit: bool, ) -> float: """Compute reward for a single step. Design: - Intermediate steps: ALWAYS positive (encourage exploration) - Success (tests pass): large positive, with efficiency bonus - Terminal failure (hit limit without solving): large negative Intermediate reward scale (per step): 0.10 successful edit 0.08 test run that improved pass rate 0.05 test run (no improvement but informative) 0.04 found useful search matches 0.03 viewed file / gathered info 0.02 any other valid action 0.01 minimum (even failed actions get a tiny positive) Terminal rewards: +1.0 to +2.0 tests pass (higher = fewer steps used) -1.0 hit step/test limit without passing """ # ── Terminal: hit limits without solving ── if is_limit_hit: return -3.0 # ── Terminal: tests pass ── if tool_name == "execute_tests" and test_exit_code == 0: # Big positive: 5.0 at step 1, down to 3.0 at max_steps efficiency = 5.0 - 2.0 * (step_count / max(max_steps, 1)) return max(3.0, efficiency) # ── Intermediate: execute_tests (didn't pass) ── if tool_name == "execute_tests": if curr_pass_rate is not None and prev_pass_rate is not None: delta = curr_pass_rate - prev_pass_rate if delta > 0: # Improved pass rate — good signal return 0.05 + delta * 0.3 # 0.05 to ~0.35 else: # No improvement or regression — still positive but small return 0.02 return 0.03 # ran tests, can't parse rate — still informative # ── Intermediate: successful edit ── if tool_name in ("edit_file", "replace_all_in_file"): if result_patch: return 0.10 # applied a real change return 0.01 # edit refused but still a valid action # ── Intermediate: search found results ── if tool_name in ("search_file", "search_dir"): if "match" in result_output.lower() and "no match" not in result_output.lower(): return 0.04 # found something useful return 0.01 # searched, found nothing — still exploring # ── Intermediate: information gathering ── if tool_name in ("view_file", "view_last_log", "search_last_log"): return 0.03 if tool_name == "list_dir": return 0.02 if tool_name == "revert_last": return 0.02 # ── Fallback: any valid action ── return 0.01 class CodeMigrationEnvironment(Environment): """OpenEnv environment for Python code-migration tasks.""" SUPPORTS_CONCURRENT_SESSIONS: bool = False def __init__( self, dataset_path: str | None = None, max_steps: int = 200, max_test_executions: int = 10, container_timeout: int = 600, container_memory_limit: str = "16g", difficulty_filter: str | None = None, ) -> None: self._loader = DatasetLoader(dataset_path) if difficulty_filter: tasks = self._loader.filter_by_difficulty(difficulty_filter) self._loader._tasks = tasks self._max_steps = max_steps self._max_test_executions = max_test_executions self._repo_manager = RepoManager() self._sandbox = DockerSandbox( timeout=container_timeout, memory_limit=container_memory_limit ) self._tool_executor = ToolExecutor() # Episode state self._current_task: Task | None = None self._workspace_dir: str | None = None self._image_name: str | None = None self._step_count: int = 0 self._patch_history: list[tuple[str, str]] = [] self._last_log_path: str | None = None self._last_test_exit_code: int | None = None self._last_pass_rate: float | None = None # continuous test pass rate self._num_test_executions: int = 0 self._done: bool = False self._task_index: int = 0 self._state = State(episode_id=str(uuid4()), step_count=0) atexit.register(self._atexit_cleanup) # ------------------------------------------------------------------ # reset # ------------------------------------------------------------------ def reset( self, *, task_index: int | None = None, repo_name: str | None = None, ) -> CodeMigrationObservation: """Prepare a fresh workspace for the next (or specified) task.""" self._cleanup_episode() # Select task try: if repo_name is not None: task = self._loader.get_by_repo_name(repo_name) if task is None: return CodeMigrationObservation( tool_output=f"No task found for repo_name: {repo_name}", done=True, ) elif task_index is not None: task = self._loader[task_index] else: task = self._loader[self._task_index] self._task_index = (self._task_index + 1) % len(self._loader) self._current_task = task except Exception as e: return CodeMigrationObservation( tool_output=f"Failed to select task: {e}", done=True, ) # Setup workspace try: self._workspace_dir = self._repo_manager.setup_workspace(task) except Exception as e: return CodeMigrationObservation( tool_output=f"Failed to setup workspace: {e}", done=True, ) # Derive image name escaped_name = task.repo_name.replace("/", "__").lower() self._image_name = escaped_name + "_new" # Create temp log file tmp = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".log") self._last_log_path = tmp.name tmp.close() # Reset episode state self._step_count = 0 self._patch_history = [] self._last_test_exit_code = None self._last_pass_rate = None self._num_test_executions = 0 self._done = False self._state = State(episode_id=str(uuid4()), step_count=0) # Initial test run try: test_result = self._sandbox.run_tests(self._image_name, self._workspace_dir) self._num_test_executions += 1 if test_result.full_log is not None: with open(self._last_log_path, "w", newline="") as f: f.write(test_result.full_log) self._last_test_exit_code = test_result.exit_code self._last_pass_rate = _parse_test_pass_rate(test_result.full_log or "") initial_test_output = ( "Test execution completed. Here is the test log.\n\n" f"\n{test_result.truncated_log}\n" ) except Exception as e: initial_test_output = f"Initial test execution failed: {e}" # Build system prompt system_prompt = MIGRATION_SYSTEM_PROMPT.strip().format( python_version=task.migration_target_version, dependency_versions=task.dependency_versions, ) combined_output = system_prompt + "\n\n" + initial_test_output return CodeMigrationObservation( tool_output=combined_output, reward=0.0, done=False, metadata=self._build_metadata("reset"), ) # ------------------------------------------------------------------ # step # ------------------------------------------------------------------ def step(self, action: CodeMigrationAction) -> CodeMigrationObservation: # type: ignore[override] """Execute one tool call and return the observation.""" if self._done: return CodeMigrationObservation( tool_output="Episode is already done. Call reset() to start a new episode.", reward=0.0, done=True, metadata=self._build_metadata(action.tool_name), ) self._step_count += 1 self._state.step_count = self._step_count # Check step limit if self._step_count > self._max_steps: self._done = True reward = compute_reward( tool_name=action.tool_name, result_output="", result_patch=None, test_exit_code=None, test_log=None, prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate, step_count=self._step_count, max_steps=self._max_steps, is_limit_hit=True, ) return CodeMigrationObservation( tool_output=f"Step limit ({self._max_steps}) reached.", reward=reward, done=True, metadata=self._build_metadata(action.tool_name), ) # Check test execution limit if action.tool_name == "execute_tests": if self._num_test_executions >= self._max_test_executions: self._done = True reward = compute_reward( tool_name=action.tool_name, result_output="", result_patch=None, test_exit_code=None, test_log=None, prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate, step_count=self._step_count, max_steps=self._max_steps, is_limit_hit=True, ) return CodeMigrationObservation( tool_output=f"Test execution limit ({self._max_test_executions}) reached.", reward=reward, done=True, metadata=self._build_metadata(action.tool_name), ) # Determine last_patch for revert last_patch = self._patch_history[-1] if self._patch_history else None # Dispatch tool test_files = [ tf.strip() for tf in (self._current_task.test_files or "").split(",") if tf.strip() ] result = self._tool_executor.execute( tool_name=action.tool_name, tool_args=action.tool_args, host_repo_dir=self._workspace_dir, repo_name=self._current_task.repo_name, test_files=test_files, image_name=self._image_name, last_log_path=self._last_log_path, last_patch=last_patch, sandbox=self._sandbox, ) # Track patches if action.tool_name in ("edit_file", "replace_all_in_file") and result.patch: file_path = action.tool_args.get("file_path", "") self._patch_history.append((file_path, result.patch)) self._patch_history = self._patch_history[-5:] # Handle revert if action.tool_name == "revert_last" and last_patch is not None: if "succeeded" in result.output.lower(): if self._patch_history: self._patch_history.pop() # Handle execute_tests — update state and parse pass rate curr_pass_rate = self._last_pass_rate if action.tool_name == "execute_tests": self._num_test_executions += 1 if result.full_log is not None and self._last_log_path: with open(self._last_log_path, "w", newline="") as f: f.write(result.full_log) self._last_test_exit_code = result.exit_code curr_pass_rate = _parse_test_pass_rate(result.full_log or "") # Compute continuous reward reward = compute_reward( tool_name=action.tool_name, result_output=result.output, result_patch=result.patch, test_exit_code=result.exit_code if action.tool_name == "execute_tests" else None, test_log=result.full_log if action.tool_name == "execute_tests" else None, prev_pass_rate=self._last_pass_rate, curr_pass_rate=curr_pass_rate, step_count=self._step_count, max_steps=self._max_steps, is_limit_hit=False, ) # Update pass rate after reward computation if action.tool_name == "execute_tests" and curr_pass_rate is not None: self._last_pass_rate = curr_pass_rate # Check if tests passed if action.tool_name == "execute_tests" and result.exit_code == 0: self._done = True # Check step limit if self._step_count >= self._max_steps: self._done = True metadata = self._build_metadata(action.tool_name) metadata["pass_rate"] = curr_pass_rate metadata["prev_pass_rate"] = self._last_pass_rate return CodeMigrationObservation( tool_output=result.output, reward=round(reward, 4), done=self._done, metadata=metadata, ) # ------------------------------------------------------------------ # state # ------------------------------------------------------------------ @property def state(self) -> State: meta: dict = {} if self._current_task: meta.update({ "repo_name": self._current_task.repo_name, "difficulty": self._current_task.difficulty, "test_type": self._current_task.test_type, "test_count": self._current_task.test_count, "num_test_executions": self._num_test_executions, "last_test_exit_code": self._last_test_exit_code, "last_pass_rate": self._last_pass_rate, "migration_target_version": self._current_task.migration_target_version, "reproduction_target_version": self._current_task.reproduction_target_version, }) self._state.metadata = meta return self._state # ------------------------------------------------------------------ # helpers # ------------------------------------------------------------------ def _build_metadata(self, tool_name: str) -> dict: meta: dict = { "step_count": self._step_count, "tool_name": tool_name, "last_test_exit_code": self._last_test_exit_code, "num_test_executions": self._num_test_executions, } if self._current_task: meta["repo_name"] = self._current_task.repo_name meta["difficulty"] = self._current_task.difficulty return meta def _cleanup_episode(self) -> None: if self._workspace_dir: self._repo_manager.cleanup(self._workspace_dir) self._workspace_dir = None if self._last_log_path and os.path.exists(self._last_log_path): try: os.remove(self._last_log_path) except Exception: pass self._last_log_path = None def _atexit_cleanup(self) -> None: try: self._cleanup_episode() except Exception: pass