rl_code_fix_env / client.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Code Fixer Environment Client."""
import asyncio
import inspect
import logging
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation
log = logging.getLogger(__name__)
class CodeFixerEnv(
EnvClient[CodeFixerAction, CodeFixerObservation, State]
):
"""
Client for the Code Fixer Environment.
This client maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step interactions with lower latency.
Each client instance has its own dedicated environment session on the server.
Example:
>>> # Connect to a running server
>>> with CodeFixerEnv(base_url="http://localhost:8000") as client:
... result = client.reset()
... print(result.observation.code)
...
... result = client.step(CodeFixerAction(type="run_tests"))
... print(result.observation.test_passed)
Example with Docker:
>>> # Automatically start container and connect
>>> client = CodeFixerEnv.from_docker_image("code_fixer-env:latest")
>>> try:
... result = client.reset()
... result = client.step(CodeFixerAction(type="run_tests"))
... finally:
... client.close()
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._loop = asyncio.new_event_loop()
# Store init args for reconnection
self._init_args = args
self._init_kwargs = kwargs
def _run_sync(self, result):
"""Run coroutine results on this client's dedicated event loop."""
if inspect.iscoroutine(result):
return self._loop.run_until_complete(result)
return result
def _reconnect(self) -> None:
"""
Tear down the dead event loop and WebSocket connection, then
re-initialise so the next call works cleanly.
Called automatically by reset() and step() when a 1011 / timeout
error is detected after an idle period.
"""
log.warning("[CodeFixerEnv] WebSocket timed out reconnecting...")
# Close the old loop gracefully
try:
self._run_sync(super().close())
except Exception:
pass
if not self._loop.is_closed():
self._loop.close()
# Re-initialise: fresh loop + fresh base-class state
self._loop = asyncio.new_event_loop()
super().__init__(*self._init_args, **self._init_kwargs)
log.warning("[CodeFixerEnv] Reconnected successfully.")
@staticmethod
def _is_reconnectable_ws_error(exc: Exception) -> bool:
err = str(exc).lower()
reconnect_markers = (
"1011",
"1006",
"keepalive",
"timed out",
"closed",
"close frame",
"connection closed",
"connectionclosed",
"websocket",
)
return any(marker in err for marker in reconnect_markers)
def reset(self):
"""Reset the environment auto-reconnects if the WebSocket died."""
try:
return self._run_sync(super().reset())
except Exception as exc:
if self._is_reconnectable_ws_error(exc):
self._reconnect()
return self._run_sync(super().reset()) # one retry
raise
def step(self, action: CodeFixerAction):
"""Execute a step auto-reconnects if the WebSocket died."""
try:
return self._run_sync(super().step(action))
except Exception as exc:
if self._is_reconnectable_ws_error(exc):
self._reconnect()
return self._run_sync(super().step(action)) # one retry
raise
def close(self):
"""Close client resources and the dedicated event loop safely."""
try:
self._run_sync(super().close())
finally:
if not self._loop.is_closed():
self._loop.close()
def _step_payload(self, action: CodeFixerAction) -> Dict:
"""
Convert CodeFixerAction to JSON payload for step message.
Args:
action: CodeFixerAction instance
Returns:
Dictionary representation suitable for JSON encoding
"""
return {
"type": action.type,
"payload": action.payload,
}
def _parse_result(self, payload: Dict) -> StepResult[CodeFixerObservation]:
"""
Parse server response into StepResult[CodeFixerObservation].
Args:
payload: JSON response data from server
Returns:
StepResult with CodeFixerObservation
"""
obs_data = payload.get("observation", {})
observation = CodeFixerObservation(
code=obs_data.get("code", ""),
logs=obs_data.get("logs"),
test_score=float(obs_data.get("test_score", 0.0)),
total_tests=obs_data.get("total_tests", 1),
steps=obs_data.get("steps", 0),
done=obs_data.get("done", payload.get("done", False)),
reward=obs_data.get("reward", payload.get("reward")),
)
return StepResult(
observation=observation,
reward=payload.get("reward"),
done=payload.get("done", False),
)
def _parse_state(self, payload: Dict) -> State:
"""
Parse server response into State object.
Args:
payload: JSON response from state request
Returns:
State object with episode_id and step_count
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)