| from __future__ import annotations |
|
|
| import ast |
| from typing import Any |
| from uuid import uuid4 |
|
|
| from openenv.core.env_server.interfaces import Environment |
|
|
| from env.executor import run_code |
| from env.test_cases import load_problem, split_test_cases |
| from models import AdaptAction, AdaptObservation, AdaptState |
|
|
|
|
| FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"} |
|
|
|
|
| class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]): |
| SUPPORTS_CONCURRENT_SESSIONS = True |
|
|
| def __init__(self) -> None: |
| super().__init__() |
| self._state = AdaptState(episode_id=str(uuid4()), step_count=0) |
| self.problem: dict[str, Any] = {} |
| self.test_cases: list[dict[str, str]] = [] |
| self.visible_tests: list[dict[str, str]] = [] |
| self.hidden_tests: list[dict[str, str]] = [] |
| self.last_results: list[dict[str, Any]] = [] |
|
|
| def reset( |
| self, |
| seed: int | None = None, |
| episode_id: str | None = None, |
| problem_id: str | None = None, |
| difficulty: str | None = None, |
| **_: Any, |
| ) -> AdaptObservation: |
| del seed |
| self.problem = load_problem(problem_id=problem_id, difficulty=difficulty) |
| self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]] |
| self.visible_tests, self.hidden_tests = split_test_cases(self.test_cases) |
| self.last_results = [] |
| self._state = AdaptState( |
| episode_id=episode_id or str(uuid4()), |
| step_count=0, |
| problem_id=self.problem["problem_id"], |
| difficulty=self.problem["difficulty"], |
| ) |
| return self._build_observation( |
| reward=0.0, |
| done=False, |
| feedback="Submit Python code that reads stdin and prints the required answer.", |
| ) |
|
|
| def step( |
| self, |
| action: AdaptAction, |
| timeout_s: float | None = None, |
| **_: Any, |
| ) -> AdaptObservation: |
| del timeout_s |
| if not self.problem: |
| self.reset() |
|
|
| self._state.step_count += 1 |
| syntax_ok, syntax_error = self._check_syntax(action.code) |
| if not syntax_ok: |
| observation = self._build_observation( |
| reward=0.0, |
| done=True, |
| feedback=f"Syntax error: {syntax_error}", |
| syntax_valid=False, |
| execution_status="syntax_error", |
| ) |
| self._record_metrics(observation) |
| return observation |
|
|
| safety_ok, safety_error = self._check_safety(action.code) |
| if not safety_ok: |
| observation = self._build_observation( |
| reward=0.0, |
| done=True, |
| feedback=safety_error, |
| syntax_valid=True, |
| execution_status="safety_violation", |
| ) |
| self._record_metrics(observation) |
| return observation |
|
|
| run_results = self._run_all_tests(action.code) |
| self.last_results = run_results |
| metrics = self._score_results(run_results) |
| verifier_reward, verifier_metadata = self._try_verify(action.code) |
| if verifier_reward is not None: |
| metrics["reward"] = max(metrics["reward"], verifier_reward) |
| if verifier_metadata.get("feedback"): |
| metrics["feedback"] = str(verifier_metadata["feedback"]) |
|
|
| observation = self._build_observation( |
| reward=metrics["reward"], |
| done=True, |
| feedback=metrics["feedback"], |
| pass_rate=metrics["pass_rate"], |
| visible_pass_rate=metrics["visible_pass_rate"], |
| hidden_pass_rate=metrics["hidden_pass_rate"], |
| syntax_valid=True, |
| execution_status=metrics["execution_status"], |
| timeout_count=metrics["timeout_count"], |
| runtime_error_count=metrics["runtime_error_count"], |
| format_compliance=metrics["format_compliance"], |
| reward_components=metrics["reward_components"], |
| ) |
| self._record_metrics(observation) |
| return observation |
|
|
| @property |
| def state(self) -> AdaptState: |
| return self._state |
|
|
| def _build_observation( |
| self, |
| reward: float, |
| done: bool, |
| feedback: str, |
| pass_rate: float = 0.0, |
| visible_pass_rate: float = 0.0, |
| hidden_pass_rate: float = 0.0, |
| syntax_valid: bool = True, |
| execution_status: str = "not_run", |
| timeout_count: int = 0, |
| runtime_error_count: int = 0, |
| format_compliance: float = 0.0, |
| reward_components: dict[str, float] | None = None, |
| ) -> AdaptObservation: |
| return AdaptObservation( |
| problem_id=self.problem["problem_id"], |
| difficulty=self.problem["difficulty"], |
| problem=self.problem["problem"], |
| input_format=self.problem["input_format"], |
| constraints=self.problem["constraints"], |
| examples=self.problem["examples"], |
| visible_tests=self.visible_tests, |
| feedback=feedback, |
| pass_rate=pass_rate, |
| visible_pass_rate=visible_pass_rate, |
| hidden_pass_rate=hidden_pass_rate, |
| syntax_valid=syntax_valid, |
| execution_status=execution_status, |
| timeout_count=timeout_count, |
| runtime_error_count=runtime_error_count, |
| format_compliance=format_compliance, |
| reward_components=reward_components or {}, |
| reward=round(max(0.0, min(1.0, reward)), 4), |
| done=done, |
| ) |
|
|
| def _run_all_tests(self, code: str) -> list[dict[str, Any]]: |
| results = [] |
| visible_count = len(self.visible_tests) |
| for index, test_case in enumerate(self.test_cases): |
| execution = run_code(code, test_case["input"]) |
| actual = str(execution["stdout"]).strip() |
| expected = test_case["output"].strip() |
| results.append( |
| { |
| "index": index, |
| "split": "visible" if index < visible_count else "hidden", |
| "input": test_case["input"] if index < visible_count else None, |
| "expected": expected if index < visible_count else None, |
| "actual": actual if index < visible_count else None, |
| "stderr": str(execution["stderr"]).strip(), |
| "exit_code": int(execution["exit_code"]), |
| "timed_out": bool(execution.get("timed_out", False)), |
| "passed": execution["exit_code"] == 0 and actual == expected, |
| "format_ok": execution["exit_code"] == 0 and actual != "", |
| } |
| ) |
| return results |
|
|
| def _score_results(self, run_results: list[dict[str, Any]]) -> dict[str, Any]: |
| total = len(run_results) |
| visible = [result for result in run_results if result["split"] == "visible"] |
| hidden = [result for result in run_results if result["split"] == "hidden"] |
| pass_rate = self._pass_rate(run_results) |
| visible_pass_rate = self._pass_rate(visible) |
| hidden_pass_rate = self._pass_rate(hidden) |
| timeout_count = sum(1 for result in run_results if result["timed_out"]) |
| runtime_error_count = sum( |
| 1 |
| for result in run_results |
| if result["exit_code"] != 0 and not result["timed_out"] |
| ) |
| format_compliance = ( |
| sum(1 for result in run_results if result["format_ok"]) / total |
| if total |
| else 0.0 |
| ) |
| timeout_rate = timeout_count / total if total else 0.0 |
| runtime_error_rate = runtime_error_count / total if total else 0.0 |
| reward_components = { |
| "correctness": 0.8 * pass_rate, |
| "syntax": 0.05, |
| "execution": 0.05 if runtime_error_count == 0 and timeout_count == 0 else 0.0, |
| "format": 0.1 * format_compliance, |
| "timeout_penalty": -0.2 * timeout_rate, |
| "runtime_penalty": -0.1 * runtime_error_rate, |
| } |
| reward = max(0.0, min(1.0, sum(reward_components.values()))) |
|
|
| if timeout_count: |
| status = "timeout" |
| elif runtime_error_count: |
| status = "runtime_error" |
| else: |
| status = "completed" |
|
|
| return { |
| "reward": round(reward, 4), |
| "feedback": self._build_feedback(run_results, pass_rate), |
| "pass_rate": round(pass_rate, 4), |
| "visible_pass_rate": round(visible_pass_rate, 4), |
| "hidden_pass_rate": round(hidden_pass_rate, 4), |
| "timeout_count": timeout_count, |
| "runtime_error_count": runtime_error_count, |
| "format_compliance": round(format_compliance, 4), |
| "execution_status": status, |
| "reward_components": { |
| key: round(value, 4) for key, value in reward_components.items() |
| }, |
| } |
|
|
| def _build_feedback(self, run_results: list[dict[str, Any]], pass_rate: float) -> str: |
| for result in run_results: |
| if result["timed_out"]: |
| label = self._safe_test_label(result) |
| return f"Timed out on {label}." |
|
|
| if result["exit_code"] != 0: |
| label = self._safe_test_label(result) |
| error = result["stderr"] or "runtime error" |
| return f"Runtime error on {label}: {error}" |
|
|
| if not result["passed"] and result["split"] == "visible": |
| return ( |
| f"Failed on visible input {str(result['input']).strip()}: " |
| f"expected {result['expected']}, got {result['actual']}" |
| ) |
|
|
| if not result["passed"]: |
| return f"Failed on hidden test {result['index'] + 1}." |
|
|
| return f"All tests passed. Pass rate: {pass_rate:.2f}" |
|
|
| def _record_metrics(self, observation: AdaptObservation) -> None: |
| self._state.last_reward = float(observation.reward or 0.0) |
| self._state.last_pass_rate = observation.pass_rate |
| self._state.last_feedback = observation.feedback |
| self._state.recent_metrics = { |
| "visible_pass_rate": observation.visible_pass_rate, |
| "hidden_pass_rate": observation.hidden_pass_rate, |
| "execution_status": observation.execution_status, |
| "timeout_count": observation.timeout_count, |
| "runtime_error_count": observation.runtime_error_count, |
| "format_compliance": observation.format_compliance, |
| "reward_components": dict(observation.reward_components), |
| } |
|
|
| def _try_verify(self, code: str) -> tuple[float | None, dict[str, Any]]: |
| try: |
| from verifier.verifier import verify |
| except ImportError: |
| return None, {} |
|
|
| try: |
| reward, metadata = verify(code, self.test_cases) |
| except Exception as exc: |
| return None, {"feedback": f"Verifier unavailable: {exc}"} |
|
|
| return float(reward), metadata or {} |
|
|
| def _check_syntax(self, code: str) -> tuple[bool, str]: |
| try: |
| ast.parse(code) |
| except SyntaxError as exc: |
| return False, str(exc) |
| return True, "" |
|
|
| def _check_safety(self, code: str) -> tuple[bool, str]: |
| tree = ast.parse(code) |
| for node in ast.walk(tree): |
| if isinstance(node, ast.Import): |
| for alias in node.names: |
| root_name = alias.name.split(".", 1)[0] |
| if root_name in FORBIDDEN_IMPORTS: |
| return False, f"Forbidden import: {root_name}" |
|
|
| if isinstance(node, ast.ImportFrom): |
| root_name = (node.module or "").split(".", 1)[0] |
| if root_name in FORBIDDEN_IMPORTS: |
| return False, f"Forbidden import: {root_name}" |
|
|
| return True, "" |
|
|
| def _pass_rate(self, results: list[dict[str, Any]]) -> float: |
| if not results: |
| return 0.0 |
| return sum(1 for result in results if result["passed"]) / len(results) |
|
|
| def _safe_test_label(self, result: dict[str, Any]) -> str: |
| if result["split"] == "visible": |
| return f"visible input {str(result['input']).strip()}" |
| return f"hidden test {result['index'] + 1}" |
|
|