rl_code_fix_env / server /rl_code_fix_env_environment.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
RlCodeFix Environment OpenEnv-compliant server wrapper.
Wraps the core CodeEnv engine and exposes it through the OpenEnv
HTTP/WebSocket interface via create_app().
Episode lifecycle:
reset() loads a randomly selected task (easy | medium | hard)
step() dispatches apply_patch / run_tests / get_logs
state returns current episode_id + step_count
"""
from uuid import uuid4
import os
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation
from rl_code_fix_env.src.environment.environment import CodeEnv
_DIFFICULTY_CYCLE = ["easy", "medium", "hard"]
class RlCodeFixEnvironment(Environment):
"""
OpenEnv-compliant wrapper around CodeEnv.
Exposes reset / step / state to the HTTP server produced by create_app().
Task difficulty is chosen randomly on each reset so the agent sees a
variety of problems across episodes.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._engine = CodeEnv() # core engine
self._difficulty_idx = 0
def _select_difficulty(self) -> str:
forced = (os.getenv("TRACERL_TASK") or "").strip().lower()
if forced in _DIFFICULTY_CYCLE:
return forced
difficulty = _DIFFICULTY_CYCLE[self._difficulty_idx % len(_DIFFICULTY_CYCLE)]
self._difficulty_idx += 1
return difficulty
def reset(self) -> CodeFixerObservation:
"""Load a randomly selected task and return initial observation."""
difficulty = self._select_difficulty()
self._state = State(episode_id=str(uuid4()), step_count=0)
obs_dict = self._engine.reset(difficulty=difficulty)
return CodeFixerObservation(
code=obs_dict["code"],
logs=obs_dict["logs"],
test_score=float(obs_dict["test_score"]),
total_tests=obs_dict["total_tests"],
steps=obs_dict["steps"],
done=False,
reward=0.0,
)
def step(self, action: CodeFixerAction) -> CodeFixerObservation: # type: ignore[override]
"""Dispatch action to the core engine and return observation."""
self._state.step_count += 1
obs_dict, reward, done, _ = self._engine.step(
{"type": action.type, "payload": action.payload}
)
return CodeFixerObservation(
code=obs_dict["code"],
logs=obs_dict["logs"],
test_score=float(obs_dict["test_score"]),
total_tests=obs_dict["total_tests"],
steps=obs_dict["steps"],
done=done,
reward=float(reward),
)
@property
def state(self) -> State:
"""Current episode_id and step_count."""
return self._state