openra-rl / openra_env /client.py
github-actions[bot]
Sync from GitHub ac82c3e
02f4a63
"""OpenRA-RL environment client.
Provides the EnvClient subclass for connecting to the OpenRA-RL
environment server over WebSocket.
"""
import os
from typing import Any, Dict
from openenv.core.client_types import StepResult
from openenv.core.env_client import EnvClient
from websockets.asyncio.client import connect as ws_connect
from openra_env.models import (
BuildingInfoModel,
EconomyInfo,
MapInfoModel,
MilitaryInfo,
OpenRAAction,
OpenRAObservation,
OpenRAState,
ProductionInfoModel,
UnitInfoModel,
)
class OpenRAEnv(EnvClient[OpenRAAction, OpenRAObservation, OpenRAState]):
"""WebSocket client for the OpenRA-RL environment.
Usage:
async with OpenRAEnv(base_url="http://localhost:8000") as env:
result = await env.reset()
while not result.done:
action = OpenRAAction(commands=[...])
result = await env.step(action)
"""
async def connect(self) -> "OpenRAEnv":
"""Connect with ping keepalive disabled.
OpenRA operations (especially reset) can take 60-120+ seconds
with software rendering. The default websockets ping_interval=20s
would kill the connection before the server responds.
"""
if self._ws is not None:
return self
ws_url_lower = self._ws_url.lower()
is_localhost = "localhost" in ws_url_lower or "127.0.0.1" in ws_url_lower
old_no_proxy = os.environ.get("NO_PROXY")
if is_localhost:
current_no_proxy = old_no_proxy or ""
if "localhost" not in current_no_proxy.lower():
os.environ["NO_PROXY"] = (
f"{current_no_proxy},localhost,127.0.0.1"
if current_no_proxy
else "localhost,127.0.0.1"
)
try:
self._ws = await ws_connect(
self._ws_url,
open_timeout=self._connect_timeout,
max_size=self._max_message_size,
ping_interval=None,
)
except Exception as e:
raise ConnectionError(f"Failed to connect to {self._ws_url}: {e}") from e
finally:
if is_localhost:
if old_no_proxy is None:
os.environ.pop("NO_PROXY", None)
else:
os.environ["NO_PROXY"] = old_no_proxy
return self
def _step_payload(self, action: OpenRAAction) -> Dict[str, Any]:
"""Convert action to JSON for WebSocket transport."""
return action.model_dump()
def _parse_result(self, data: Dict[str, Any]) -> StepResult[OpenRAObservation]:
"""Parse server response into StepResult."""
obs_data = data.get("observation", data)
observation = OpenRAObservation(
tick=obs_data.get("tick", 0),
economy=EconomyInfo(**obs_data.get("economy", {})),
military=MilitaryInfo(**obs_data.get("military", {})),
units=[UnitInfoModel(**u) for u in obs_data.get("units", [])],
buildings=[BuildingInfoModel(**b) for b in obs_data.get("buildings", [])],
production=[ProductionInfoModel(**p) for p in obs_data.get("production", [])],
visible_enemies=[UnitInfoModel(**u) for u in obs_data.get("visible_enemies", [])],
visible_enemy_buildings=[BuildingInfoModel(**b) for b in obs_data.get("visible_enemy_buildings", [])],
map_info=MapInfoModel(**obs_data.get("map_info", {})),
available_production=obs_data.get("available_production", []),
done=obs_data.get("done", False),
reward=obs_data.get("reward"),
result=obs_data.get("result", ""),
spatial_map=obs_data.get("spatial_map", ""),
spatial_channels=obs_data.get("spatial_channels", 0),
)
return StepResult(
observation=observation,
reward=data.get("reward", obs_data.get("reward")),
done=data.get("done", obs_data.get("done", False)),
)
def _parse_state(self, data: Dict[str, Any]) -> OpenRAState:
"""Parse state response into OpenRAState."""
return OpenRAState(**data)