| from __future__ import annotations |
|
|
| import os |
| import sys |
| from typing import Any, Dict, Optional, Tuple |
|
|
| PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__)) |
| if PROJECT_ROOT not in sys.path: |
| sys.path.insert(0, PROJECT_ROOT) |
|
|
| try: |
| from openenv.env import Env as OpenEnvBase |
| except Exception: |
| class OpenEnvBase: |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| return None |
|
|
| from acre.datasets.code_samples import CodeSample, CodeSampleDataset |
| from acre.env.refactor_env import RefactorEnv |
| from acre.tasks.task_registry import TaskRegistry |
| from models import ActionModel, ObservationModel, RewardModel, StateResponse |
|
|
|
|
| class OpenEnvRefactorEnv(OpenEnvBase): |
| """ |
| Canonical OpenEnv interface for ACRE. |
| |
| This wrapper keeps the strict hackathon contract: |
| - reset() -> ObservationModel |
| - step(action) -> (ObservationModel, RewardModel, done, info) |
| - state() -> StateResponse |
| """ |
|
|
| def __init__( |
| self, |
| *, |
| env: Optional[RefactorEnv] = None, |
| registry: Optional[TaskRegistry] = None, |
| ) -> None: |
| super().__init__( |
| name="ACRE", |
| state_space="ObservationModel", |
| action_space="ActionModel", |
| episode_max_length=RefactorEnv.MAX_STEPS, |
| ) |
| self._env = env or RefactorEnv() |
| self._registry = registry or TaskRegistry() |
| self._task_id: Optional[str] = None |
| self._last_reset_info: Dict[str, Any] = {} |
|
|
| @property |
| def action_meanings(self) -> Dict[int, str]: |
| return self._env.ACTION_MEANINGS |
|
|
| @property |
| def last_reset_info(self) -> Dict[str, Any]: |
| return dict(self._last_reset_info) |
|
|
| def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None: |
| initial_code = code |
| if initial_code is None and task_id: |
| task = self._registry.get_task(task_id) |
| if task is None: |
| raise ValueError(f"Task '{task_id}' not found") |
| |
| |
| samples = list(getattr(task, "samples", []) or []) |
| if not samples: |
| initial_code = task.initial_code |
| else: |
| self._env.dataset = CodeSampleDataset( |
| [ |
| CodeSample( |
| id=f"{task_id}:{i}", |
| language="python", |
| code=str(src), |
| ) |
| for i, src in enumerate(samples) |
| ] |
| ) |
| return None |
|
|
| if initial_code is None: |
| return None |
|
|
| self._env.dataset = CodeSampleDataset( |
| [ |
| CodeSample( |
| id=task_id or "custom", |
| language="python", |
| code=initial_code, |
| ) |
| ] |
| ) |
| return None |
|
|
| def reset( |
| self, |
| *, |
| seed: Optional[int] = None, |
| task_id: Optional[str] = None, |
| code: Optional[str] = None, |
| ) -> ObservationModel: |
| self._task_id = task_id |
| self._load_episode_source(task_id=task_id, code=code) |
| observation, info = self._env.reset(seed=seed) |
| self._last_reset_info = dict(info) |
| return ObservationModel.from_vector(observation.tolist()) |
|
|
| def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]: |
| action_value = action.action if isinstance(action, ActionModel) else int(action) |
| observation, raw_reward, terminated, truncated, info = self._env.step(action_value) |
| reward = RewardModel( |
| raw=float(raw_reward), |
| normalized=float(info.get("normalized_reward", 0.0)), |
| components=dict(info.get("reward_components", {})), |
| ) |
| done = bool(terminated or truncated) |
| return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info) |
|
|
| def state(self) -> StateResponse: |
| raw_state = self._env.state() |
| observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0])) |
| observation = ObservationModel.from_vector(observation_vector) |
| return StateResponse( |
| current_code=str(raw_state.get("current_code", "")), |
| episode_steps=int(raw_state.get("episode_steps", 0)), |
| max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)), |
| complexity=float(raw_state.get("complexity", 0.0)), |
| last_runtime=float(raw_state.get("last_runtime", 0.0)), |
| last_error=bool(raw_state.get("last_error", False)), |
| sample_id=raw_state.get("sample_id"), |
| language=raw_state.get("language"), |
| task_id=self._task_id, |
| observation=observation, |
| observation_vector=observation.to_vector(), |
| action_meanings=dict(raw_state.get("action_meanings", {})), |
| ) |
|
|