Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| RlCodeFix Environment OpenEnv-compliant server wrapper. | |
| Wraps the core CodeEnv engine and exposes it through the OpenEnv | |
| HTTP/WebSocket interface via create_app(). | |
| Episode lifecycle: | |
| reset() loads a randomly selected task (easy | medium | hard) | |
| step() dispatches apply_patch / run_tests / get_logs | |
| state returns current episode_id + step_count | |
| """ | |
| from uuid import uuid4 | |
| import os | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation | |
| from rl_code_fix_env.src.environment.environment import CodeEnv | |
| _DIFFICULTY_CYCLE = ["easy", "medium", "hard"] | |
| class RlCodeFixEnvironment(Environment): | |
| """ | |
| OpenEnv-compliant wrapper around CodeEnv. | |
| Exposes reset / step / state to the HTTP server produced by create_app(). | |
| Task difficulty is chosen randomly on each reset so the agent sees a | |
| variety of problems across episodes. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._engine = CodeEnv() # core engine | |
| self._difficulty_idx = 0 | |
| def _select_difficulty(self) -> str: | |
| forced = (os.getenv("TRACERL_TASK") or "").strip().lower() | |
| if forced in _DIFFICULTY_CYCLE: | |
| return forced | |
| difficulty = _DIFFICULTY_CYCLE[self._difficulty_idx % len(_DIFFICULTY_CYCLE)] | |
| self._difficulty_idx += 1 | |
| return difficulty | |
| def reset(self) -> CodeFixerObservation: | |
| """Load a randomly selected task and return initial observation.""" | |
| difficulty = self._select_difficulty() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| obs_dict = self._engine.reset(difficulty=difficulty) | |
| return CodeFixerObservation( | |
| code=obs_dict["code"], | |
| logs=obs_dict["logs"], | |
| test_score=float(obs_dict["test_score"]), | |
| total_tests=obs_dict["total_tests"], | |
| steps=obs_dict["steps"], | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: CodeFixerAction) -> CodeFixerObservation: # type: ignore[override] | |
| """Dispatch action to the core engine and return observation.""" | |
| self._state.step_count += 1 | |
| obs_dict, reward, done, _ = self._engine.step( | |
| {"type": action.type, "payload": action.payload} | |
| ) | |
| return CodeFixerObservation( | |
| code=obs_dict["code"], | |
| logs=obs_dict["logs"], | |
| test_score=float(obs_dict["test_score"]), | |
| total_tests=obs_dict["total_tests"], | |
| steps=obs_dict["steps"], | |
| done=done, | |
| reward=float(reward), | |
| ) | |
| def state(self) -> State: | |
| """Current episode_id and step_count.""" | |
| return self._state | |