meta-rl-dsa-solver / env /adapt_env.py
s-shah4
Add V1 env
4433dc8
from __future__ import annotations
import ast
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from env.executor import run_code
from env.test_cases import load_problem, split_test_cases
from models import AdaptAction, AdaptObservation, AdaptState
FORBIDDEN_IMPORTS = {"os", "pathlib", "shutil", "socket", "subprocess"}
class AdaptEnvironment(Environment[AdaptAction, AdaptObservation, AdaptState]):
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
super().__init__()
self._state = AdaptState(episode_id=str(uuid4()), step_count=0)
self.problem: dict[str, Any] = {}
self.test_cases: list[dict[str, str]] = []
self.visible_tests: list[dict[str, str]] = []
self.hidden_tests: list[dict[str, str]] = []
self.last_results: list[dict[str, Any]] = []
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
problem_id: str | None = None,
difficulty: str | None = None,
**_: Any,
) -> AdaptObservation:
del seed
self.problem = load_problem(problem_id=problem_id, difficulty=difficulty)
self.test_cases = [dict(test_case) for test_case in self.problem["test_cases"]]
self.visible_tests, self.hidden_tests = split_test_cases(self.test_cases)
self.last_results = []
self._state = AdaptState(
episode_id=episode_id or str(uuid4()),
step_count=0,
problem_id=self.problem["problem_id"],
difficulty=self.problem["difficulty"],
)
return self._build_observation(
reward=0.0,
done=False,
feedback="Submit Python code that reads stdin and prints the required answer.",
)
def step(
self,
action: AdaptAction,
timeout_s: float | None = None,
**_: Any,
) -> AdaptObservation:
del timeout_s
if not self.problem:
self.reset()
self._state.step_count += 1
syntax_ok, syntax_error = self._check_syntax(action.code)
if not syntax_ok:
observation = self._build_observation(
reward=0.0,
done=True,
feedback=f"Syntax error: {syntax_error}",
syntax_valid=False,
execution_status="syntax_error",
)
self._record_metrics(observation)
return observation
safety_ok, safety_error = self._check_safety(action.code)
if not safety_ok:
observation = self._build_observation(
reward=0.0,
done=True,
feedback=safety_error,
syntax_valid=True,
execution_status="safety_violation",
)
self._record_metrics(observation)
return observation
run_results = self._run_all_tests(action.code)
self.last_results = run_results
metrics = self._score_results(run_results)
verifier_reward, verifier_metadata = self._try_verify(action.code)
if verifier_reward is not None:
metrics["reward"] = max(metrics["reward"], verifier_reward)
if verifier_metadata.get("feedback"):
metrics["feedback"] = str(verifier_metadata["feedback"])
observation = self._build_observation(
reward=metrics["reward"],
done=True,
feedback=metrics["feedback"],
pass_rate=metrics["pass_rate"],
visible_pass_rate=metrics["visible_pass_rate"],
hidden_pass_rate=metrics["hidden_pass_rate"],
syntax_valid=True,
execution_status=metrics["execution_status"],
timeout_count=metrics["timeout_count"],
runtime_error_count=metrics["runtime_error_count"],
format_compliance=metrics["format_compliance"],
reward_components=metrics["reward_components"],
)
self._record_metrics(observation)
return observation
@property
def state(self) -> AdaptState:
return self._state
def _build_observation(
self,
reward: float,
done: bool,
feedback: str,
pass_rate: float = 0.0,
visible_pass_rate: float = 0.0,
hidden_pass_rate: float = 0.0,
syntax_valid: bool = True,
execution_status: str = "not_run",
timeout_count: int = 0,
runtime_error_count: int = 0,
format_compliance: float = 0.0,
reward_components: dict[str, float] | None = None,
) -> AdaptObservation:
return AdaptObservation(
problem_id=self.problem["problem_id"],
difficulty=self.problem["difficulty"],
problem=self.problem["problem"],
input_format=self.problem["input_format"],
constraints=self.problem["constraints"],
examples=self.problem["examples"],
visible_tests=self.visible_tests,
feedback=feedback,
pass_rate=pass_rate,
visible_pass_rate=visible_pass_rate,
hidden_pass_rate=hidden_pass_rate,
syntax_valid=syntax_valid,
execution_status=execution_status,
timeout_count=timeout_count,
runtime_error_count=runtime_error_count,
format_compliance=format_compliance,
reward_components=reward_components or {},
reward=round(max(0.0, min(1.0, reward)), 4),
done=done,
)
def _run_all_tests(self, code: str) -> list[dict[str, Any]]:
results = []
visible_count = len(self.visible_tests)
for index, test_case in enumerate(self.test_cases):
execution = run_code(code, test_case["input"])
actual = str(execution["stdout"]).strip()
expected = test_case["output"].strip()
results.append(
{
"index": index,
"split": "visible" if index < visible_count else "hidden",
"input": test_case["input"] if index < visible_count else None,
"expected": expected if index < visible_count else None,
"actual": actual if index < visible_count else None,
"stderr": str(execution["stderr"]).strip(),
"exit_code": int(execution["exit_code"]),
"timed_out": bool(execution.get("timed_out", False)),
"passed": execution["exit_code"] == 0 and actual == expected,
"format_ok": execution["exit_code"] == 0 and actual != "",
}
)
return results
def _score_results(self, run_results: list[dict[str, Any]]) -> dict[str, Any]:
total = len(run_results)
visible = [result for result in run_results if result["split"] == "visible"]
hidden = [result for result in run_results if result["split"] == "hidden"]
pass_rate = self._pass_rate(run_results)
visible_pass_rate = self._pass_rate(visible)
hidden_pass_rate = self._pass_rate(hidden)
timeout_count = sum(1 for result in run_results if result["timed_out"])
runtime_error_count = sum(
1
for result in run_results
if result["exit_code"] != 0 and not result["timed_out"]
)
format_compliance = (
sum(1 for result in run_results if result["format_ok"]) / total
if total
else 0.0
)
timeout_rate = timeout_count / total if total else 0.0
runtime_error_rate = runtime_error_count / total if total else 0.0
reward_components = {
"correctness": 0.8 * pass_rate,
"syntax": 0.05,
"execution": 0.05 if runtime_error_count == 0 and timeout_count == 0 else 0.0,
"format": 0.1 * format_compliance,
"timeout_penalty": -0.2 * timeout_rate,
"runtime_penalty": -0.1 * runtime_error_rate,
}
reward = max(0.0, min(1.0, sum(reward_components.values())))
if timeout_count:
status = "timeout"
elif runtime_error_count:
status = "runtime_error"
else:
status = "completed"
return {
"reward": round(reward, 4),
"feedback": self._build_feedback(run_results, pass_rate),
"pass_rate": round(pass_rate, 4),
"visible_pass_rate": round(visible_pass_rate, 4),
"hidden_pass_rate": round(hidden_pass_rate, 4),
"timeout_count": timeout_count,
"runtime_error_count": runtime_error_count,
"format_compliance": round(format_compliance, 4),
"execution_status": status,
"reward_components": {
key: round(value, 4) for key, value in reward_components.items()
},
}
def _build_feedback(self, run_results: list[dict[str, Any]], pass_rate: float) -> str:
for result in run_results:
if result["timed_out"]:
label = self._safe_test_label(result)
return f"Timed out on {label}."
if result["exit_code"] != 0:
label = self._safe_test_label(result)
error = result["stderr"] or "runtime error"
return f"Runtime error on {label}: {error}"
if not result["passed"] and result["split"] == "visible":
return (
f"Failed on visible input {str(result['input']).strip()}: "
f"expected {result['expected']}, got {result['actual']}"
)
if not result["passed"]:
return f"Failed on hidden test {result['index'] + 1}."
return f"All tests passed. Pass rate: {pass_rate:.2f}"
def _record_metrics(self, observation: AdaptObservation) -> None:
self._state.last_reward = float(observation.reward or 0.0)
self._state.last_pass_rate = observation.pass_rate
self._state.last_feedback = observation.feedback
self._state.recent_metrics = {
"visible_pass_rate": observation.visible_pass_rate,
"hidden_pass_rate": observation.hidden_pass_rate,
"execution_status": observation.execution_status,
"timeout_count": observation.timeout_count,
"runtime_error_count": observation.runtime_error_count,
"format_compliance": observation.format_compliance,
"reward_components": dict(observation.reward_components),
}
def _try_verify(self, code: str) -> tuple[float | None, dict[str, Any]]:
try:
from verifier.verifier import verify
except ImportError:
return None, {}
try:
reward, metadata = verify(code, self.test_cases)
except Exception as exc:
return None, {"feedback": f"Verifier unavailable: {exc}"}
return float(reward), metadata or {}
def _check_syntax(self, code: str) -> tuple[bool, str]:
try:
ast.parse(code)
except SyntaxError as exc:
return False, str(exc)
return True, ""
def _check_safety(self, code: str) -> tuple[bool, str]:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
root_name = alias.name.split(".", 1)[0]
if root_name in FORBIDDEN_IMPORTS:
return False, f"Forbidden import: {root_name}"
if isinstance(node, ast.ImportFrom):
root_name = (node.module or "").split(".", 1)[0]
if root_name in FORBIDDEN_IMPORTS:
return False, f"Forbidden import: {root_name}"
return True, ""
def _pass_rate(self, results: list[dict[str, Any]]) -> float:
if not results:
return 0.0
return sum(1 for result in results if result["passed"]) / len(results)
def _safe_test_label(self, result: dict[str, Any]) -> str:
if result["split"] == "visible":
return f"visible input {str(result['input']).strip()}"
return f"hidden test {result['index'] + 1}"