migratron / code_migration /server /code_migration_environment.py
amrithanandini's picture
integrated backend and frontend
1b35d41
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Code Migration Environment — full OpenEnv-compatible RL environment.
"""
from __future__ import annotations
import atexit
import math
import os
import re
import tempfile
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..dataset_loader import DatasetLoader, Task
from ..docker_sandbox import DockerSandbox
from ..models import CodeMigrationAction, CodeMigrationObservation
from ..prompt import MIGRATION_SYSTEM_PROMPT
from ..repo_manager import RepoManager
from ..tool_executor import ToolExecutor
except ImportError:
from dataset_loader import DatasetLoader, Task
from docker_sandbox import DockerSandbox
from models import CodeMigrationAction, CodeMigrationObservation
from prompt import MIGRATION_SYSTEM_PROMPT
from repo_manager import RepoManager
from tool_executor import ToolExecutor
# ──────────────────────────────────────────────────────────────
# Test log parsing — extract pass/fail counts for continuous reward
# ──────────────────────────────────────────────────────────────
# pytest: "5 passed, 2 failed in 0.12s"
_RE_PYTEST_PASSED = re.compile(r"(\d+)\s+passed")
_RE_PYTEST_FAILED = re.compile(r"(\d+)\s+failed")
_RE_PYTEST_ERROR = re.compile(r"(\d+)\s+error")
# unittest: "Ran 7 tests" + "OK" or "FAILED (failures=2, errors=1)"
_RE_UNITTEST_RAN = re.compile(r"Ran\s+(\d+)\s+tests?")
_RE_UNITTEST_FAIL = re.compile(r"failures=(\d+)")
_RE_UNITTEST_ERR = re.compile(r"errors=(\d+)")
def _parse_test_pass_rate(log_text: str) -> float | None:
"""Extract the fraction of tests passing from a test log.
Returns a float in [0.0, 1.0] or None if we can't parse.
"""
if not log_text:
return None
# Try pytest format first
passed_m = _RE_PYTEST_PASSED.search(log_text)
failed_m = _RE_PYTEST_FAILED.search(log_text)
error_m = _RE_PYTEST_ERROR.search(log_text)
if passed_m:
passed = int(passed_m.group(1))
failed = int(failed_m.group(1)) if failed_m else 0
errors = int(error_m.group(1)) if error_m else 0
total = passed + failed + errors
if total > 0:
return passed / total
# Try unittest format
ran_m = _RE_UNITTEST_RAN.search(log_text)
if ran_m:
total = int(ran_m.group(1))
if total == 0:
return None
fail_m = _RE_UNITTEST_FAIL.search(log_text)
err_m = _RE_UNITTEST_ERR.search(log_text)
failures = int(fail_m.group(1)) if fail_m else 0
errors = int(err_m.group(1)) if err_m else 0
# If "OK" appears and no failures/errors, all passed
if "OK" in log_text and failures == 0 and errors == 0:
return 1.0
passed = total - failures - errors
return max(0.0, passed / total)
return None
# ──────────────────────────────────────────────────────────────
# Reward function
# ──────────────────────────────────────────────────────────────
def compute_reward(
*,
tool_name: str,
result_output: str,
result_patch: str | None,
test_exit_code: int | None,
test_log: str | None,
prev_pass_rate: float | None,
curr_pass_rate: float | None,
step_count: int,
max_steps: int,
is_limit_hit: bool,
) -> float:
"""Compute reward for a single step.
Design:
- Intermediate steps: ALWAYS positive (encourage exploration)
- Success (tests pass): large positive, with efficiency bonus
- Terminal failure (hit limit without solving): large negative
Intermediate reward scale (per step):
0.10 successful edit
0.08 test run that improved pass rate
0.05 test run (no improvement but informative)
0.04 found useful search matches
0.03 viewed file / gathered info
0.02 any other valid action
0.01 minimum (even failed actions get a tiny positive)
Terminal rewards:
+1.0 to +2.0 tests pass (higher = fewer steps used)
-1.0 hit step/test limit without passing
"""
# ── Terminal: hit limits without solving ──
if is_limit_hit:
return -3.0
# ── Terminal: tests pass ──
if tool_name == "execute_tests" and test_exit_code == 0:
# Big positive: 5.0 at step 1, down to 3.0 at max_steps
efficiency = 5.0 - 2.0 * (step_count / max(max_steps, 1))
return max(3.0, efficiency)
# ── Intermediate: execute_tests (didn't pass) ──
if tool_name == "execute_tests":
if curr_pass_rate is not None and prev_pass_rate is not None:
delta = curr_pass_rate - prev_pass_rate
if delta > 0:
# Improved pass rate — good signal
return 0.05 + delta * 0.3 # 0.05 to ~0.35
else:
# No improvement or regression — still positive but small
return 0.02
return 0.03 # ran tests, can't parse rate — still informative
# ── Intermediate: successful edit ──
if tool_name in ("edit_file", "replace_all_in_file"):
if result_patch:
return 0.10 # applied a real change
return 0.01 # edit refused but still a valid action
# ── Intermediate: search found results ──
if tool_name in ("search_file", "search_dir"):
if "match" in result_output.lower() and "no match" not in result_output.lower():
return 0.04 # found something useful
return 0.01 # searched, found nothing — still exploring
# ── Intermediate: information gathering ──
if tool_name in ("view_file", "view_last_log", "search_last_log"):
return 0.03
if tool_name == "list_dir":
return 0.02
if tool_name == "revert_last":
return 0.02
# ── Fallback: any valid action ──
return 0.01
class CodeMigrationEnvironment(Environment):
"""OpenEnv environment for Python code-migration tasks."""
SUPPORTS_CONCURRENT_SESSIONS: bool = False
def __init__(
self,
dataset_path: str | None = None,
max_steps: int = 200,
max_test_executions: int = 10,
container_timeout: int = 600,
container_memory_limit: str = "16g",
difficulty_filter: str | None = None,
) -> None:
self._loader = DatasetLoader(dataset_path)
if difficulty_filter:
tasks = self._loader.filter_by_difficulty(difficulty_filter)
self._loader._tasks = tasks
self._max_steps = max_steps
self._max_test_executions = max_test_executions
self._repo_manager = RepoManager()
self._sandbox = DockerSandbox(
timeout=container_timeout, memory_limit=container_memory_limit
)
self._tool_executor = ToolExecutor()
# Episode state
self._current_task: Task | None = None
self._workspace_dir: str | None = None
self._image_name: str | None = None
self._step_count: int = 0
self._patch_history: list[tuple[str, str]] = []
self._last_log_path: str | None = None
self._last_test_exit_code: int | None = None
self._last_pass_rate: float | None = None # continuous test pass rate
self._num_test_executions: int = 0
self._done: bool = False
self._task_index: int = 0
self._state = State(episode_id=str(uuid4()), step_count=0)
atexit.register(self._atexit_cleanup)
# ------------------------------------------------------------------
# reset
# ------------------------------------------------------------------
def reset(
self,
*,
task_index: int | None = None,
repo_name: str | None = None,
) -> CodeMigrationObservation:
"""Prepare a fresh workspace for the next (or specified) task."""
self._cleanup_episode()
# Select task
try:
if repo_name is not None:
task = self._loader.get_by_repo_name(repo_name)
if task is None:
return CodeMigrationObservation(
tool_output=f"No task found for repo_name: {repo_name}",
done=True,
)
elif task_index is not None:
task = self._loader[task_index]
else:
task = self._loader[self._task_index]
self._task_index = (self._task_index + 1) % len(self._loader)
self._current_task = task
except Exception as e:
return CodeMigrationObservation(
tool_output=f"Failed to select task: {e}", done=True,
)
# Setup workspace
try:
self._workspace_dir = self._repo_manager.setup_workspace(task)
except Exception as e:
return CodeMigrationObservation(
tool_output=f"Failed to setup workspace: {e}", done=True,
)
# Derive image name
escaped_name = task.repo_name.replace("/", "__").lower()
self._image_name = escaped_name + "_new"
# Create temp log file
tmp = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".log")
self._last_log_path = tmp.name
tmp.close()
# Reset episode state
self._step_count = 0
self._patch_history = []
self._last_test_exit_code = None
self._last_pass_rate = None
self._num_test_executions = 0
self._done = False
self._state = State(episode_id=str(uuid4()), step_count=0)
# Initial test run
try:
test_result = self._sandbox.run_tests(self._image_name, self._workspace_dir)
self._num_test_executions += 1
if test_result.full_log is not None:
with open(self._last_log_path, "w", newline="") as f:
f.write(test_result.full_log)
self._last_test_exit_code = test_result.exit_code
self._last_pass_rate = _parse_test_pass_rate(test_result.full_log or "")
initial_test_output = (
"Test execution completed. Here is the test log.\n\n"
f"<test_log>\n{test_result.truncated_log}\n</test_log>"
)
except Exception as e:
initial_test_output = f"Initial test execution failed: {e}"
# Build system prompt
system_prompt = MIGRATION_SYSTEM_PROMPT.strip().format(
python_version=task.migration_target_version,
dependency_versions=task.dependency_versions,
)
combined_output = system_prompt + "\n\n" + initial_test_output
return CodeMigrationObservation(
tool_output=combined_output,
reward=0.0,
done=False,
metadata=self._build_metadata("reset"),
)
# ------------------------------------------------------------------
# step
# ------------------------------------------------------------------
def step(self, action: CodeMigrationAction) -> CodeMigrationObservation: # type: ignore[override]
"""Execute one tool call and return the observation."""
if self._done:
return CodeMigrationObservation(
tool_output="Episode is already done. Call reset() to start a new episode.",
reward=0.0, done=True,
metadata=self._build_metadata(action.tool_name),
)
self._step_count += 1
self._state.step_count = self._step_count
# Check step limit
if self._step_count > self._max_steps:
self._done = True
reward = compute_reward(
tool_name=action.tool_name, result_output="", result_patch=None,
test_exit_code=None, test_log=None,
prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate,
step_count=self._step_count, max_steps=self._max_steps,
is_limit_hit=True,
)
return CodeMigrationObservation(
tool_output=f"Step limit ({self._max_steps}) reached.",
reward=reward, done=True,
metadata=self._build_metadata(action.tool_name),
)
# Check test execution limit
if action.tool_name == "execute_tests":
if self._num_test_executions >= self._max_test_executions:
self._done = True
reward = compute_reward(
tool_name=action.tool_name, result_output="", result_patch=None,
test_exit_code=None, test_log=None,
prev_pass_rate=self._last_pass_rate, curr_pass_rate=self._last_pass_rate,
step_count=self._step_count, max_steps=self._max_steps,
is_limit_hit=True,
)
return CodeMigrationObservation(
tool_output=f"Test execution limit ({self._max_test_executions}) reached.",
reward=reward, done=True,
metadata=self._build_metadata(action.tool_name),
)
# Determine last_patch for revert
last_patch = self._patch_history[-1] if self._patch_history else None
# Dispatch tool
test_files = [
tf.strip()
for tf in (self._current_task.test_files or "").split(",")
if tf.strip()
]
result = self._tool_executor.execute(
tool_name=action.tool_name,
tool_args=action.tool_args,
host_repo_dir=self._workspace_dir,
repo_name=self._current_task.repo_name,
test_files=test_files,
image_name=self._image_name,
last_log_path=self._last_log_path,
last_patch=last_patch,
sandbox=self._sandbox,
)
# Track patches
if action.tool_name in ("edit_file", "replace_all_in_file") and result.patch:
file_path = action.tool_args.get("file_path", "")
self._patch_history.append((file_path, result.patch))
self._patch_history = self._patch_history[-5:]
# Handle revert
if action.tool_name == "revert_last" and last_patch is not None:
if "succeeded" in result.output.lower():
if self._patch_history:
self._patch_history.pop()
# Handle execute_tests — update state and parse pass rate
curr_pass_rate = self._last_pass_rate
if action.tool_name == "execute_tests":
self._num_test_executions += 1
if result.full_log is not None and self._last_log_path:
with open(self._last_log_path, "w", newline="") as f:
f.write(result.full_log)
self._last_test_exit_code = result.exit_code
curr_pass_rate = _parse_test_pass_rate(result.full_log or "")
# Compute continuous reward
reward = compute_reward(
tool_name=action.tool_name,
result_output=result.output,
result_patch=result.patch,
test_exit_code=result.exit_code if action.tool_name == "execute_tests" else None,
test_log=result.full_log if action.tool_name == "execute_tests" else None,
prev_pass_rate=self._last_pass_rate,
curr_pass_rate=curr_pass_rate,
step_count=self._step_count,
max_steps=self._max_steps,
is_limit_hit=False,
)
# Update pass rate after reward computation
if action.tool_name == "execute_tests" and curr_pass_rate is not None:
self._last_pass_rate = curr_pass_rate
# Check if tests passed
if action.tool_name == "execute_tests" and result.exit_code == 0:
self._done = True
# Check step limit
if self._step_count >= self._max_steps:
self._done = True
metadata = self._build_metadata(action.tool_name)
metadata["pass_rate"] = curr_pass_rate
metadata["prev_pass_rate"] = self._last_pass_rate
return CodeMigrationObservation(
tool_output=result.output,
reward=round(reward, 4),
done=self._done,
metadata=metadata,
)
# ------------------------------------------------------------------
# state
# ------------------------------------------------------------------
@property
def state(self) -> State:
meta: dict = {}
if self._current_task:
meta.update({
"repo_name": self._current_task.repo_name,
"difficulty": self._current_task.difficulty,
"test_type": self._current_task.test_type,
"test_count": self._current_task.test_count,
"num_test_executions": self._num_test_executions,
"last_test_exit_code": self._last_test_exit_code,
"last_pass_rate": self._last_pass_rate,
"migration_target_version": self._current_task.migration_target_version,
"reproduction_target_version": self._current_task.reproduction_target_version,
})
self._state.metadata = meta
return self._state
# ------------------------------------------------------------------
# helpers
# ------------------------------------------------------------------
def _build_metadata(self, tool_name: str) -> dict:
meta: dict = {
"step_count": self._step_count,
"tool_name": tool_name,
"last_test_exit_code": self._last_test_exit_code,
"num_test_executions": self._num_test_executions,
}
if self._current_task:
meta["repo_name"] = self._current_task.repo_name
meta["difficulty"] = self._current_task.difficulty
return meta
def _cleanup_episode(self) -> None:
if self._workspace_dir:
self._repo_manager.cleanup(self._workspace_dir)
self._workspace_dir = None
if self._last_log_path and os.path.exists(self._last_log_path):
try:
os.remove(self._last_log_path)
except Exception:
pass
self._last_log_path = None
def _atexit_cleanup(self) -> None:
try:
self._cleanup_episode()
except Exception:
pass