Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Code Migration Environment — full OpenEnv-compatible RL environment. | |
| """ | |
| from __future__ import annotations | |
| import atexit | |
| import math | |
| import os | |
| import re | |
| import tempfile | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..dataset_loader import DatasetLoader, Task | |
| from ..docker_sandbox import DockerSandbox | |
| from ..models import CodeMigrationAction, CodeMigrationObservation | |
| from ..prompt import MIGRATION_SYSTEM_PROMPT | |
| from ..repo_manager import RepoManager | |
| from ..tool_executor import ToolExecutor | |
| except ImportError: | |
| from dataset_loader import DatasetLoader, Task | |
| from docker_sandbox import DockerSandbox | |
| from models import CodeMigrationAction, CodeMigrationObservation | |
| from prompt import MIGRATION_SYSTEM_PROMPT | |
| from repo_manager import RepoManager | |
| from tool_executor import ToolExecutor | |
| # ────────────────────────────────────────────────────────────── | |
| # Test log parsing — extract pass/fail counts for continuous reward | |
| # ────────────────────────────────────────────────────────────── | |
| # pytest: "5 passed, 2 failed in 0.12s" | |
| _RE_PYTEST_PASSED = re.compile(r"(\d+)\s+passed") | |
| _RE_PYTEST_FAILED = re.compile(r"(\d+)\s+failed") | |
| _RE_PYTEST_ERROR = re.compile(r"(\d+)\s+error") | |
| # unittest: "Ran 7 tests" + "OK" or "FAILED (failures=2, errors=1)" | |
| _RE_UNITTEST_RAN = re.compile(r"Ran\s+(\d+)\s+tests?") | |
| _RE_UNITTEST_FAIL = re.compile(r"failures=(\d+)") | |
| _RE_UNITTEST_ERR = re.compile(r"errors=(\d+)") | |
| def _parse_test_pass_rate(log_text: str) -> float | None: | |
| """Extract the fraction of tests passing from a test log. | |
| Returns a float in [0.0, 1.0] or None if we can't parse. | |
| """ | |
| if not log_text: | |
| return None | |
| # Try pytest format first | |
| passed_m = _RE_PYTEST_PASSED.search(log_text) | |
| failed_m = _RE_PYTEST_FAILED.search(log_text) | |
| error_m = _RE_PYTEST_ERROR.search(log_text) | |
| if passed_m: | |
| passed = int(passed_m.group(1)) | |
| failed = int(failed_m.group(1)) if failed_m else 0 | |
| errors = int(error_m.group(1)) if error_m else 0 | |
| total = passed + failed + errors | |
| if total > 0: | |
| return passed / total | |
| # Try unittest format | |
| ran_m = _RE_UNITTEST_RAN.search(log_text) | |
| if ran_m: | |
| total = int(ran_m.group(1)) | |
| if total == 0: | |
| return None | |
| fail_m = _RE_UNITTEST_FAIL.search(log_text) | |
| err_m = _RE_UNITTEST_ERR.search(log_text) | |
| failures = int(fail_m.group(1)) if fail_m else 0 | |
| errors = int(err_m.group(1)) if err_m else 0 | |
| # If "OK" appears and no failures/errors, all passed | |
| if "OK" in log_text and failures == 0 and errors == 0: | |
| return 1.0 | |
| passed = total - failures - errors | |
| return max(0.0, passed / total) | |
| return None | |
| # ────────────────────────────────────────────────────────────── | |
| # Reward function | |
| # ────────────────────────────────────────────────────────────── | |
| def compute_reward( | |
| *, | |
| tool_name: str, | |
| result_output: str, | |
| result_patch: str | None, | |
| test_exit_code: int | None, | |
| test_log: str | None, | |
| prev_pass_rate: float | None, | |
| curr_pass_rate: float | None, | |
| step_count: int, | |
| max_steps: int, | |
| is_limit_hit: bool, | |
| ) -> float: | |
| """Compute reward for a single step. | |
| Design: | |
| - Intermediate steps: ALWAYS positive (encourage exploration) | |
| - Success (tests pass): large positive, with efficiency bonus | |
| - Terminal failure (hit limit without solving): large negative | |
| Intermediate reward scale (per step): | |
| 0.10 successful edit | |
| 0.08 test run that improved pass rate | |
| 0.05 test run (no improvement but informative) | |
| 0.04 found useful search matches | |
| 0.03 viewed file / gathered info | |
| 0.02 any other valid action | |
| 0.01 minimum (even failed actions get a tiny positive) | |
| Terminal rewards: | |
| +1.0 to +2.0 tests pass (higher = fewer steps used) | |
| -1.0 hit step/test limit without passing | |
| """ | |
| # ── Terminal: hit limits without solving ── | |
| if is_limit_hit: | |
| return -3.0 | |
| # ── Terminal: tests pass ── | |
| if tool_name == "execute_tests" and test_exit_code == 0: | |
| # Big positive: 5.0 at step 1, down to 3.0 at max_steps | |
| efficiency = 5.0 - 2.0 * (step_count / max(max_steps, 1)) | |
| return max(3.0, efficiency) | |
| # ── Intermediate: execute_tests (didn't pass) ── | |
| if tool_name == "execute_tests": | |
| if curr_pass_rate is not None and prev_pass_rate is not None: | |
| delta = curr_pass_rate - prev_pass_rate | |
| if delta > 0: | |
| # Improved pass rate — good signal | |
| return 0.05 + delta * 0.3 # 0.05 to ~0.35 | |
| else: | |
| # No improvement or regression — still positive but small | |
| return 0.02 | |
| return 0.03 # ran tests, can't parse rate — still informative | |
| # ── Intermediate: successful edit ── | |
| if tool_name in ("edit_file", "replace_all_in_file"): | |
| if result_patch: | |
| return 0.10 # applied a real change | |
| return 0.01 # edit refused but still a valid action | |
| # ── Intermediate: search found results ── | |
| if tool_name in ("search_file", "search_dir"): | |
| if "match" in result_output.lower() and "no match" not in result_output.lower(): | |
| return 0.04 # found something useful | |
| return 0.01 # searched, found nothing — still exploring | |
| # ── Intermediate: information gathering ── | |
| if tool_name in ("view_file", "view_last_log", "search_last_log"): | |
| return 0.03 | |
| if tool_name == "list_dir": | |
| return 0.02 | |
| if tool_name == "revert_last": | |
| return 0.02 | |
| # ── Fallback: any valid action ── | |
| return 0.01 | |
| class CodeMigrationEnvironment(Environment): | |
| """OpenEnv environment for Python code-migration tasks.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = False | |
| def __init__( | |
| self, | |
| dataset_path: str | None = None, | |
| max_steps: int = 200, | |
| max_test_executions: int = 10, | |
| container_timeout: int = 600, | |
| container_memory_limit: str = "16g", | |
| difficulty_filter: str | None = None, | |
| ) -> None: | |
| self._loader = DatasetLoader(dataset_path) | |
| if difficulty_filter: | |
| tasks = self._loader.filter_by_difficulty(difficulty_filter) | |
| self._loader._tasks = tasks | |
| self._max_steps = max_steps | |
| self._max_test_executions = max_test_executions | |
| self._repo_manager = RepoManager() | |
| self._sandbox = DockerSandbox( | |
| timeout=container_timeout, memory_limit=container_memory_limit | |
| ) | |
| self._tool_executor = ToolExecutor() | |
| # Episode state | |
| self._current_task: Task | None = None | |
| self._workspace_dir: str | None = None | |
| self._image_name: str | None = None | |
| self._step_count: int = 0 | |
| self._patch_history: list[tuple[str, str]] = [] | |
| self._last_log_path: str | None = None | |
| self._last_test_exit_code: int | None = None | |
| self._last_pass_rate: float | None = None # continuous test pass rate | |
| self._num_test_executions: int = 0 | |
| self._done: bool = False | |
| self._task_index: int = 0 | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| atexit.register(self._atexit_cleanup) | |
| # ------------------------------------------------------------------ | |
| # reset | |
| # ------------------------------------------------------------------ | |
| def reset( | |
| self, | |
| *, | |
| task_index: int | None = None, | |
| repo_name: str | None = None, | |
| ) -> CodeMigrationObservation: | |
| """Prepare a fresh workspace for the next (or specified) task.""" | |
| self._cleanup_episode() | |
| # Select task | |
| try: | |
| if repo_name is not None: | |
| task = self._loader.get_by_repo_name(repo_name) | |
| if task is None: | |
| return CodeMigrationObservation( | |
| tool_output=f"No task found for repo_name: {repo_name}", | |
| done=True, | |
| ) | |
| elif task_index is not None: | |
| task = self._loader[task_index] | |
| else: | |
| task = self._loader[self._task_index] | |
| self._task_index = (self._task_index + 1) % len(self._loader) | |
| self._current_task = task | |
| except Exception as e: | |
| return CodeMigrationObservation( | |
| tool_output=f"Failed to select task: {e}", done=True, | |
| ) | |
| # Setup workspace | |
| try: | |
| self._workspace_dir = self._repo_manager.setup_workspace(task) | |
| except Exception as e: | |
| return CodeMigrationObservation( | |
| tool_output=f"Failed to setup workspace: {e}", done=True, | |
| ) | |
| # Derive image name | |
| escaped_name = task.repo_name.replace("/", "__").lower() | |
| self._image_name = escaped_name + "_new" | |
| # Create temp log file | |
| tmp = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".log") | |
| self._last_log_path = tmp.name | |
| tmp.close() | |
| # Reset episode state | |
| self._step_count = 0 | |
| self._patch_history = [] | |
| self._last_test_exit_code = None | |
| self._last_pass_rate = None | |
| self._num_test_executions = 0 | |
| self._done = False | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Initial test run | |
| try: | |
| test_result = self._sandbox.run_tests(self._image_name, self._workspace_dir) | |
| self._num_test_executions += 1 | |
| if test_result.full_log is not None: | |
| with open(self._last_log_path, "w", newline="") as f: | |
| f.write(test_result.full_log) | |
| self._last_test_exit_code = test_result.exit_code | |
| self._last_pass_rate = _parse_test_pass_rate(test_result.full_log or "") | |
| initial_test_output = ( | |
| "Test execution completed. Here is the test log.\n\n" | |
| f"<test_log>\n{test_result.truncated_log}\n</test_log>" | |
| ) | |
| except Exception as e: | |
| initial_test_output = f"Initial test execution failed: {e}" | |
| # Build system prompt | |
| system_prompt = MIGRATION_SYSTEM_PROMPT.strip().format( | |
| python_version=task.migration_target_version, | |
| dependency_versions=task.dependency_versions, | |
| ) | |
| combined_output = system_prompt + "\n\n" + initial_test_output | |
| return CodeMigrationObservation( | |
| tool_output=combined_output, | |
| reward=0.0, | |
| done=False, | |
| metadata=self._build_metadata("reset"), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # step | |
| # ------------------------------------------------------------------ | |
| def step(self, action: CodeMigrationAction) -> CodeMigrationObservation: # type: ignore[override] | |
| """Execute one tool call and return the observation.""" | |
| if self._done: | |
| return CodeMigrationObservation( | |
| tool_output="Episode is already done. Call reset() to start a new episode.", | |
| reward=0.0, done=True, | |
| metadata=self._build_metadata(action.tool_name), | |
| ) | |
| self._step_count += 1 | |
| self._state.step_count = self._step_count | |
| # Check step limit | |
| if self._step_count > self._max_steps: | |
| self._done = True | |
| reward = compute_reward( | |
| tool_name=action.tool_name, result_output="", result_patch=None, | |
| test_exit_code=None, test_log=None, | |
| prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate, | |
| step_count=self._step_count, max_steps=self._max_steps, | |
| is_limit_hit=True, | |
| ) | |
| return CodeMigrationObservation( | |
| tool_output=f"Step limit ({self._max_steps}) reached.", | |
| reward=reward, done=True, | |
| metadata=self._build_metadata(action.tool_name), | |
| ) | |
| # Check test execution limit | |
| if action.tool_name == "execute_tests": | |
| if self._num_test_executions >= self._max_test_executions: | |
| self._done = True | |
| reward = compute_reward( | |
| tool_name=action.tool_name, result_output="", result_patch=None, | |
| test_exit_code=None, test_log=None, | |
| prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate, | |
| step_count=self._step_count, max_steps=self._max_steps, | |
| is_limit_hit=True, | |
| ) | |
| return CodeMigrationObservation( | |
| tool_output=f"Test execution limit ({self._max_test_executions}) reached.", | |
| reward=reward, done=True, | |
| metadata=self._build_metadata(action.tool_name), | |
| ) | |
| # Determine last_patch for revert | |
| last_patch = self._patch_history[-1] if self._patch_history else None | |
| # Dispatch tool | |
| test_files = [ | |
| tf.strip() | |
| for tf in (self._current_task.test_files or "").split(",") | |
| if tf.strip() | |
| ] | |
| result = self._tool_executor.execute( | |
| tool_name=action.tool_name, | |
| tool_args=action.tool_args, | |
| host_repo_dir=self._workspace_dir, | |
| repo_name=self._current_task.repo_name, | |
| test_files=test_files, | |
| image_name=self._image_name, | |
| last_log_path=self._last_log_path, | |
| last_patch=last_patch, | |
| sandbox=self._sandbox, | |
| ) | |
| # Track patches | |
| if action.tool_name in ("edit_file", "replace_all_in_file") and result.patch: | |
| file_path = action.tool_args.get("file_path", "") | |
| self._patch_history.append((file_path, result.patch)) | |
| self._patch_history = self._patch_history[-5:] | |
| # Handle revert | |
| if action.tool_name == "revert_last" and last_patch is not None: | |
| if "succeeded" in result.output.lower(): | |
| if self._patch_history: | |
| self._patch_history.pop() | |
| # Handle execute_tests — update state and parse pass rate | |
| curr_pass_rate = self._last_pass_rate | |
| if action.tool_name == "execute_tests": | |
| self._num_test_executions += 1 | |
| if result.full_log is not None and self._last_log_path: | |
| with open(self._last_log_path, "w", newline="") as f: | |
| f.write(result.full_log) | |
| self._last_test_exit_code = result.exit_code | |
| curr_pass_rate = _parse_test_pass_rate(result.full_log or "") | |
| # Compute continuous reward | |
| reward = compute_reward( | |
| tool_name=action.tool_name, | |
| result_output=result.output, | |
| result_patch=result.patch, | |
| test_exit_code=result.exit_code if action.tool_name == "execute_tests" else None, | |
| test_log=result.full_log if action.tool_name == "execute_tests" else None, | |
| prev_pass_rate=self._last_pass_rate, | |
| curr_pass_rate=curr_pass_rate, | |
| step_count=self._step_count, | |
| max_steps=self._max_steps, | |
| is_limit_hit=False, | |
| ) | |
| # Update pass rate after reward computation | |
| if action.tool_name == "execute_tests" and curr_pass_rate is not None: | |
| self._last_pass_rate = curr_pass_rate | |
| # Check if tests passed | |
| if action.tool_name == "execute_tests" and result.exit_code == 0: | |
| self._done = True | |
| # Check step limit | |
| if self._step_count >= self._max_steps: | |
| self._done = True | |
| metadata = self._build_metadata(action.tool_name) | |
| metadata["pass_rate"] = curr_pass_rate | |
| metadata["prev_pass_rate"] = self._last_pass_rate | |
| return CodeMigrationObservation( | |
| tool_output=result.output, | |
| reward=round(reward, 4), | |
| done=self._done, | |
| metadata=metadata, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # state | |
| # ------------------------------------------------------------------ | |
| def state(self) -> State: | |
| meta: dict = {} | |
| if self._current_task: | |
| meta.update({ | |
| "repo_name": self._current_task.repo_name, | |
| "difficulty": self._current_task.difficulty, | |
| "test_type": self._current_task.test_type, | |
| "test_count": self._current_task.test_count, | |
| "num_test_executions": self._num_test_executions, | |
| "last_test_exit_code": self._last_test_exit_code, | |
| "last_pass_rate": self._last_pass_rate, | |
| "migration_target_version": self._current_task.migration_target_version, | |
| "reproduction_target_version": self._current_task.reproduction_target_version, | |
| }) | |
| self._state.metadata = meta | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # helpers | |
| # ------------------------------------------------------------------ | |
| def _build_metadata(self, tool_name: str) -> dict: | |
| meta: dict = { | |
| "step_count": self._step_count, | |
| "tool_name": tool_name, | |
| "last_test_exit_code": self._last_test_exit_code, | |
| "num_test_executions": self._num_test_executions, | |
| } | |
| if self._current_task: | |
| meta["repo_name"] = self._current_task.repo_name | |
| meta["difficulty"] = self._current_task.difficulty | |
| return meta | |
| def _cleanup_episode(self) -> None: | |
| if self._workspace_dir: | |
| self._repo_manager.cleanup(self._workspace_dir) | |
| self._workspace_dir = None | |
| if self._last_log_path and os.path.exists(self._last_log_path): | |
| try: | |
| os.remove(self._last_log_path) | |
| except Exception: | |
| pass | |
| self._last_log_path = None | |
| def _atexit_cleanup(self) -> None: | |
| try: | |
| self._cleanup_episode() | |
| except Exception: | |
| pass | |