File size: 5,106 Bytes
bc5030f 900e1f4 bc5030f 900e1f4 bc5030f 8d66fec bc5030f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | from __future__ import annotations
import os
import sys
from typing import Any, Dict, Optional, Tuple
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
try:
from openenv.env import Env as OpenEnvBase
except Exception: # pragma: no cover
class OpenEnvBase:
def __init__(self, *args: Any, **kwargs: Any) -> None:
return None
from acre.datasets.code_samples import CodeSample, CodeSampleDataset
from acre.env.refactor_env import RefactorEnv
from acre.tasks.task_registry import TaskRegistry
from models import ActionModel, ObservationModel, RewardModel, StateResponse
class OpenEnvRefactorEnv(OpenEnvBase):
"""
Canonical OpenEnv interface for ACRE.
This wrapper keeps the strict hackathon contract:
- reset() -> ObservationModel
- step(action) -> (ObservationModel, RewardModel, done, info)
- state() -> StateResponse
"""
def __init__(
self,
*,
env: Optional[RefactorEnv] = None,
registry: Optional[TaskRegistry] = None,
) -> None:
super().__init__(
name="ACRE",
state_space="ObservationModel",
action_space="ActionModel",
episode_max_length=RefactorEnv.MAX_STEPS,
)
self._env = env or RefactorEnv()
self._registry = registry or TaskRegistry()
self._task_id: Optional[str] = None
self._last_reset_info: Dict[str, Any] = {}
@property
def action_meanings(self) -> Dict[int, str]:
return self._env.ACTION_MEANINGS
@property
def last_reset_info(self) -> Dict[str, Any]:
return dict(self._last_reset_info)
def _load_episode_source(self, *, task_id: Optional[str], code: Optional[str]) -> None:
initial_code = code
if initial_code is None and task_id:
task = self._registry.get_task(task_id)
if task is None:
raise ValueError(f"Task '{task_id}' not found")
# Load a multi-sample dataset for this task. Sample selection is
# deterministic given the `seed` passed to `reset()`.
samples = list(getattr(task, "samples", []) or [])
if not samples:
initial_code = task.initial_code
else:
self._env.dataset = CodeSampleDataset(
[
CodeSample(
id=f"{task_id}:{i}",
language="python",
code=str(src),
)
for i, src in enumerate(samples)
]
)
return None
if initial_code is None:
return None
self._env.dataset = CodeSampleDataset(
[
CodeSample(
id=task_id or "custom",
language="python",
code=initial_code,
)
]
)
return None
def reset(
self,
*,
seed: Optional[int] = None,
task_id: Optional[str] = None,
code: Optional[str] = None,
) -> ObservationModel:
self._task_id = task_id
self._load_episode_source(task_id=task_id, code=code)
observation, info = self._env.reset(seed=seed)
self._last_reset_info = dict(info)
return ObservationModel.from_vector(observation.tolist())
def step(self, action: int | ActionModel) -> Tuple[ObservationModel, RewardModel, bool, Dict[str, Any]]:
action_value = action.action if isinstance(action, ActionModel) else int(action)
observation, raw_reward, terminated, truncated, info = self._env.step(action_value)
reward = RewardModel(
raw=float(raw_reward),
normalized=float(info.get("normalized_reward", 0.0)),
components=dict(info.get("reward_components", {})),
)
done = bool(terminated or truncated)
return ObservationModel.from_vector(observation.tolist()), reward, done, dict(info)
def state(self) -> StateResponse:
raw_state = self._env.state()
observation_vector = list(raw_state.get("observation", [0.0, 0.0, 0.0, 0.0]))
observation = ObservationModel.from_vector(observation_vector)
return StateResponse(
current_code=str(raw_state.get("current_code", "")),
episode_steps=int(raw_state.get("episode_steps", 0)),
max_steps=int(raw_state.get("max_steps", RefactorEnv.MAX_STEPS)),
complexity=float(raw_state.get("complexity", 0.0)),
last_runtime=float(raw_state.get("last_runtime", 0.0)),
last_error=bool(raw_state.get("last_error", False)),
sample_id=raw_state.get("sample_id"),
language=raw_state.get("language"),
task_id=self._task_id,
observation=observation,
observation_vector=observation.to_vector(),
action_meanings=dict(raw_state.get("action_meanings", {})),
)
|