autonomous-code-refactoring-env / openenv_interface.py
PRANAV05092003's picture
Updated structure and fixed module import issue
8d66fec
from __future__ import annotations
import os
import sys
from typing import Any, Dict, Optional, Tuple
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
try:
from openenv.env import Env as OpenEnvBase
except Exception: # pragma: no cover
class OpenEnvBase:
def __init__(self, *args: Any, **kwargs: Any) -> None:
return None
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
from acre.env.refactor_env import RefactorEnv
from acre.tasks.task_registry import TaskRegistry
from models import ActionModel, ObservationModel, RewardModel, StateResponse
class OpenEnvRefactorEnv(OpenEnvBase):
"""
Canonical OpenEnv interface for ACRE.
This wrapper keeps the strict hackathon contract:
- reset() -> ObservationModel
- step(action) -> (ObservationModel, RewardModel, done, info)
- state() -> StateResponse
"""
def __init__(
self,
*,
env: Optional[RefactorEnv] = None,
registry: Optional[TaskRegistry] = None,
) -> None:
super().__init__(
name="ACRE",
state_space="ObservationModel",
action_space="ActionModel",
episode_max_length=RefactorEnv.MAX_STEPS,
)
self._env = env or RefactorEnv()
self._registry = registry or TaskRegistry()
self._task_id: Optional[str] = None
self._last_reset_info: Dict[str, Any] = {}
@property
def action_meanings(self) -> Dict[int, str]:
return self._env.ACTION_MEANINGS
@property
def last_reset_info(self) -> Dict[str, Any]:
return dict(self._last_reset_info)
def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
initial_code = code
if initial_code is None and task_id:
task = self._registry.get_task(task_id)
if task is None:
raise ValueError(f"Task '{task_id}' not found")
# Load a multi-sample dataset for this task. Sample selection is
# deterministic given the `seed` passed to `reset()`.
samples = list(getattr(task, "samples", []) or [])
if not samples:
initial_code = task.initial_code
else:
self._env.dataset = CodeSampleDataset(
[
CodeSample(
id=f"{task_id}:{i}",
language="python",
code=str(src),
)
for i, src in enumerate(samples)
]
)
return None
if initial_code is None:
return None
self._env.dataset = CodeSampleDataset(
[
CodeSample(
id=task_id or "custom",
language="python",
code=initial_code,
)
]
)
return None
def reset(
self,
*,
seed: Optional[int] = None,
task_id: Optional[str] = None,
code: Optional[str] = None,
) -> ObservationModel:
self._task_id = task_id
self._load_episode_source(task_id=task_id, code=code)
observation, info = self._env.reset(seed=seed)
self._last_reset_info = dict(info)
return ObservationModel.from_vector(observation.tolist())
def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
action_value = action.action if isinstance(action, ActionModel) else int(action)
observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
reward = RewardModel(
raw=float(raw_reward),
normalized=float(info.get("normalized_reward", 0.0)),
components=dict(info.get("reward_components", {})),
)
done = bool(terminated or truncated)
return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
def state(self) -> StateResponse:
raw_state = self._env.state()
observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
observation = ObservationModel.from_vector(observation_vector)
return StateResponse(
current_code=str(raw_state.get("current_code", "")),
episode_steps=int(raw_state.get("episode_steps", 0)),
max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
complexity=float(raw_state.get("complexity", 0.0)),
last_runtime=float(raw_state.get("last_runtime", 0.0)),
last_error=bool(raw_state.get("last_error", False)),
sample_id=raw_state.get("sample_id"),
language=raw_state.get("language"),
task_id=self._task_id,
observation=observation,
observation_vector=observation.to_vector(),
action_meanings=dict(raw_state.get("action_meanings", {})),
)