"""OpenRA-RL environment client. Provides the EnvClient subclass for connecting to the OpenRA-RL environment server over WebSocket. """ import os from typing import Any, Dict from openenv.core.client_types import StepResult from openenv.core.env_client import EnvClient from websockets.asyncio.client import connect as ws_connect from openra_env.models import ( BuildingInfoModel, EconomyInfo, MapInfoModel, MilitaryInfo, OpenRAAction, OpenRAObservation, OpenRAState, ProductionInfoModel, UnitInfoModel, ) class OpenRAEnv(EnvClient[OpenRAAction, OpenRAObservation, OpenRAState]): """WebSocket client for the OpenRA-RL environment. Usage: async with OpenRAEnv(base_url="http://localhost:8000") as env: result = await env.reset() while not result.done: action = OpenRAAction(commands=[...]) result = await env.step(action) """ async def connect(self) -> "OpenRAEnv": """Connect with ping keepalive disabled. OpenRA operations (especially reset) can take 60-120+ seconds with software rendering. The default websockets ping_interval=20s would kill the connection before the server responds. """ if self._ws is not None: return self ws_url_lower = self._ws_url.lower() is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower old_no_proxy = os.environ.get("NO_PROXY") if is_localhost: current_no_proxy = old_no_proxy or "" if "localhost" not in current_no_proxy.lower(): os.environ["NO_PROXY"] = ( f"{current_no_proxy},localhost,127.0.0.1" if current_no_proxy else "localhost,127.0.0.1" ) try: self._ws = await ws_connect( self._ws_url, open_timeout=self._connect_timeout, max_size=self._max_message_size, ping_interval=None, ) except Exception as e: raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e finally: if is_localhost: if old_no_proxy is None: os.environ.pop("NO_PROXY", None) else: os.environ["NO_PROXY"] = old_no_proxy return self def _step_payload(self, action: OpenRAAction) -> Dict[str, Any]: """Convert action to JSON for WebSocket transport.""" return action.model_dump() def _parse_result(self, data: Dict[str, Any]) -> StepResult[OpenRAObservation]: """Parse server response into StepResult.""" obs_data = data.get("observation", data) observation = OpenRAObservation( tick=obs_data.get("tick", 0), economy=EconomyInfo(**obs_data.get("economy", {})), military=MilitaryInfo(**obs_data.get("military", {})), units=[UnitInfoModel(**u) for u in obs_data.get("units", [])], buildings=[BuildingInfoModel(**b) for b in obs_data.get("buildings", [])], production=[ProductionInfoModel(**p) for p in obs_data.get("production", [])], visible_enemies=[UnitInfoModel(**u) for u in obs_data.get("visible_enemies", [])], visible_enemy_buildings=[BuildingInfoModel(**b) for b in obs_data.get("visible_enemy_buildings", [])], map_info=MapInfoModel(**obs_data.get("map_info", {})), available_production=obs_data.get("available_production", []), done=obs_data.get("done", False), reward=obs_data.get("reward"), result=obs_data.get("result", ""), spatial_map=obs_data.get("spatial_map", ""), spatial_channels=obs_data.get("spatial_channels", 0), ) return StepResult( observation=observation, reward=data.get("reward", obs_data.get("reward")), done=data.get("done", obs_data.get("done", False)), ) def _parse_state(self, data: Dict[str, Any]) -> OpenRAState: """Parse state response into OpenRAState.""" return OpenRAState(**data)