# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """Code Fixer Environment Client.""" import asyncio import inspect import logging from typing import Dict from openenv.core import EnvClient from openenv.core.client_types import StepResult from openenv.core.env_server.types import State from rl_code_fix_env.models import CodeFixerAction, CodeFixerObservation log = logging.getLogger(__name__) class CodeFixerEnv( EnvClient[CodeFixerAction, CodeFixerObservation, State] ): """ Client for the Code Fixer Environment. This client maintains a persistent WebSocket connection to the environment server, enabling efficient multi-step interactions with lower latency. Each client instance has its own dedicated environment session on the server. Example: >>> # Connect to a running server >>> with CodeFixerEnv(base_url="http://localhost:8000") as client: ... result = client.reset() ... print(result.observation.code) ... ... result = client.step(CodeFixerAction(type="run_tests")) ... print(result.observation.test_passed) Example with Docker: >>> # Automatically start container and connect >>> client = CodeFixerEnv.from_docker_image("code_fixer-env:latest") >>> try: ... result = client.reset() ... result = client.step(CodeFixerAction(type="run_tests")) ... finally: ... client.close() """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._loop = asyncio.new_event_loop() # Store init args for reconnection self._init_args = args self._init_kwargs = kwargs def _run_sync(self, result): """Run coroutine results on this client's dedicated event loop.""" if inspect.iscoroutine(result): return self._loop.run_until_complete(result) return result def _reconnect(self) -> None: """ Tear down the dead event loop and WebSocket connection, then re-initialise so the next call works cleanly. Called automatically by reset() and step() when a 1011 / timeout error is detected after an idle period. """ log.warning("[CodeFixerEnv] WebSocket timed out reconnecting...") # Close the old loop gracefully try: self._run_sync(super().close()) except Exception: pass if not self._loop.is_closed(): self._loop.close() # Re-initialise: fresh loop + fresh base-class state self._loop = asyncio.new_event_loop() super().__init__(*self._init_args, **self._init_kwargs) log.warning("[CodeFixerEnv] Reconnected successfully.") @staticmethod def _is_reconnectable_ws_error(exc: Exception) -> bool: err = str(exc).lower() reconnect_markers = ( "1011", "1006", "keepalive", "timed out", "closed", "close frame", "connection closed", "connectionclosed", "websocket", ) return any(marker in err for marker in reconnect_markers) def reset(self): """Reset the environment auto-reconnects if the WebSocket died.""" try: return self._run_sync(super().reset()) except Exception as exc: if self._is_reconnectable_ws_error(exc): self._reconnect() return self._run_sync(super().reset()) # one retry raise def step(self, action: CodeFixerAction): """Execute a step auto-reconnects if the WebSocket died.""" try: return self._run_sync(super().step(action)) except Exception as exc: if self._is_reconnectable_ws_error(exc): self._reconnect() return self._run_sync(super().step(action)) # one retry raise def close(self): """Close client resources and the dedicated event loop safely.""" try: self._run_sync(super().close()) finally: if not self._loop.is_closed(): self._loop.close() def _step_payload(self, action: CodeFixerAction) -> Dict: """ Convert CodeFixerAction to JSON payload for step message. Args: action: CodeFixerAction instance Returns: Dictionary representation suitable for JSON encoding """ return { "type": action.type, "payload": action.payload, } def _parse_result(self, payload: Dict) -> StepResult[CodeFixerObservation]: """ Parse server response into StepResult[CodeFixerObservation]. Args: payload: JSON response data from server Returns: StepResult with CodeFixerObservation """ obs_data = payload.get("observation", {}) observation = CodeFixerObservation( code=obs_data.get("code", ""), logs=obs_data.get("logs"), test_score=float(obs_data.get("test_score", 0.0)), total_tests=obs_data.get("total_tests", 1), steps=obs_data.get("steps", 0), done=obs_data.get("done", payload.get("done", False)), reward=obs_data.get("reward", payload.get("reward")), ) return StepResult( observation=observation, reward=payload.get("reward"), done=payload.get("done", False), ) def _parse_state(self, payload: Dict) -> State: """ Parse server response into State object. Args: payload: JSON response from state request Returns: State object with episode_id and step_count """ return State( episode_id=payload.get("episode_id"), step_count=payload.get("step_count", 0), )