"""gRPC bridge client for communicating with the OpenRA ExternalBotBridge. This client connects to the gRPC server running inside the OpenRA process and handles bidirectional streaming of observations and actions. Protocol: - Bidirectional streaming RPC (GameSession): game sends observations, agent sends actions - Unary RPC (GetState): query current game state on demand - Real-time: game runs at normal speed, observations stream continuously, actions are sent whenever the agent is ready """ import asyncio import base64 import logging from typing import AsyncIterator, Optional import grpc from openra_env.generated import rl_bridge_pb2, rl_bridge_pb2_grpc logger = logging.getLogger(__name__) class BridgeClient: """Async gRPC client for the OpenRA RL Bridge. Uses bidirectional streaming: the game sends observations continuously at its natural tick rate, and the agent sends actions when ready. A background reader task keeps the latest observation cached. """ def __init__(self, host: str = "localhost", port: int = 9999, timeout_s: float = 30.0): self.host = host self.port = port self.timeout_s = timeout_s self._channel: Optional[grpc.aio.Channel] = None self._stub: Optional[rl_bridge_pb2_grpc.RLBridgeStub] = None self._session_call = None self._action_queue: asyncio.Queue[rl_bridge_pb2.AgentAction] = asyncio.Queue() self._connected = False # Background observation reader state self._latest_obs: Optional[rl_bridge_pb2.GameObservation] = None self._obs_event: asyncio.Event = asyncio.Event() self._obs_tick: int = 0 self._obs_reader_task: Optional[asyncio.Task] = None async def connect(self) -> None: """Establish gRPC channel.""" target = f"{self.host}:{self.port}" self._channel = grpc.aio.insecure_channel( target, options=[ ("grpc.max_receive_message_length", 64 * 1024 * 1024), ("grpc.max_send_message_length", 16 * 1024 * 1024), ("grpc.keepalive_time_ms", 10000), ("grpc.keepalive_timeout_ms", 5000), ], ) self._stub = rl_bridge_pb2_grpc.RLBridgeStub(self._channel) self._connected = True logger.info(f"Connected to OpenRA bridge at {target}") async def wait_for_ready(self, max_retries: int = 30, retry_interval: float = 1.0) -> bool: """Wait for the gRPC server to become available.""" for attempt in range(max_retries): try: await self.connect() state = await self.get_state() logger.info(f"Bridge ready after {attempt + 1} attempts, phase={state.phase}") return True except grpc.aio.AioRpcError as e: if attempt < max_retries - 1: logger.debug(f"Bridge not ready (attempt {attempt + 1}): {e.code()}") await asyncio.sleep(retry_interval) else: logger.error(f"Bridge failed to become ready after {max_retries} attempts") return False except Exception as e: if attempt < max_retries - 1: logger.debug(f"Connection attempt {attempt + 1} failed: {e}") await asyncio.sleep(retry_interval) else: return False return False @property def session_started(self) -> bool: """Whether the streaming session has been started.""" return self._session_call is not None async def start_session(self) -> rl_bridge_pb2.GameObservation: """Start a bidirectional streaming session and return the first observation. The game sends observations continuously; a background reader task keeps the latest observation cached. Actions are sent via step(). Idempotent: if the session is already started, returns the latest observation. """ if self._session_call is not None: # Already started — return latest cached observation return self._latest_obs if not self._connected: await self.connect() self._action_queue = asyncio.Queue() self._session_call = self._stub.GameSession(self._action_request_iterator()) first_obs = await self._session_call.read() if first_obs is None: raise ConnectionError("Bridge stream closed before sending initial observation") # Initialize observation state and start background reader self._latest_obs = first_obs self._obs_tick = first_obs.tick self._obs_event = asyncio.Event() self._obs_event.set() self._obs_reader_task = asyncio.create_task(self._bg_obs_reader()) logger.info(f"Session started, initial tick={first_obs.tick}") return first_obs async def _action_request_iterator(self) -> AsyncIterator[rl_bridge_pb2.AgentAction]: """Yield actions from the queue as the gRPC stream requests them.""" while True: action = await self._action_queue.get() yield action async def _bg_obs_reader(self): """Background task: continuously read observations from the gRPC stream. Updates _latest_obs and signals _obs_event each time a new observation arrives. The game sends observations at its natural tick rate regardless of agent actions. """ try: while True: obs = await self._session_call.read() if obs is None: logger.info("gRPC observation stream ended") break self._latest_obs = obs self._obs_tick = obs.tick self._obs_event.set() if obs.done: logger.info(f"Game over at tick {obs.tick}: {obs.result}") break except grpc.aio.AioRpcError as e: logger.error(f"Background observation reader error: {e.code()}") except asyncio.CancelledError: logger.debug("Background observation reader cancelled") def _check_reader_alive(self): """Raise if the background observation reader has exited (game died).""" if self._obs_reader_task is not None and self._obs_reader_task.done(): exc = self._obs_reader_task.exception() if exc: raise ConnectionError(f"Game connection lost: {exc}") from exc raise ConnectionError("Game connection lost (observation stream ended)") async def step(self, action: rl_bridge_pb2.AgentAction) -> rl_bridge_pb2.GameObservation: """Send an action and wait for the next observation. The action is queued immediately. Then we wait for an observation with a tick newer than the current one (confirming the game has processed at least one more tick since the action was sent). """ if self._session_call is None: raise RuntimeError("Session not started. Call start_session() first.") current_tick = self._obs_tick await self._action_queue.put(action) # Wait for an observation newer than when we sent the action while self._obs_tick <= current_tick: self._check_reader_alive() self._obs_event.clear() await asyncio.wait_for(self._obs_event.wait(), timeout=self.timeout_s) return self._latest_obs async def wait_ticks(self, n: int) -> rl_bridge_pb2.GameObservation: """Wait for approximately N game ticks to pass. The game runs at its natural speed (~25 ticks/sec at default). Returns the observation at or after the target tick. """ target_tick = self._obs_tick + n while self._obs_tick < target_tick: self._check_reader_alive() self._obs_event.clear() await asyncio.wait_for(self._obs_event.wait(), timeout=self.timeout_s) if self._latest_obs and self._latest_obs.done: break return self._latest_obs async def observe(self) -> Optional[rl_bridge_pb2.GameObservation]: """Return the latest cached observation without sending any action.""" return self._latest_obs async def get_state(self) -> rl_bridge_pb2.GameState: """Query current game state via unary RPC.""" if not self._connected or self._stub is None: raise RuntimeError("Not connected. Call connect() first.") request = rl_bridge_pb2.StateRequest() return await self._stub.GetState(request, timeout=self.timeout_s) async def close(self) -> None: """Close the gRPC channel and clean up.""" # Cancel background observation reader if self._obs_reader_task is not None: self._obs_reader_task.cancel() try: await self._obs_reader_task except asyncio.CancelledError: pass self._obs_reader_task = None if self._session_call is not None: self._session_call.cancel() self._session_call = None if self._channel is not None: await self._channel.close() self._channel = None self._stub = None self._connected = False self._latest_obs = None logger.info("Bridge connection closed") @property def is_connected(self) -> bool: return self._connected def observation_to_dict(obs: rl_bridge_pb2.GameObservation) -> dict: """Convert a protobuf GameObservation to a plain dict for the OpenEnv layer.""" return { "tick": obs.tick, "economy": { "cash": obs.economy.cash, "ore": obs.economy.ore, "power_provided": obs.economy.power_provided, "power_drained": obs.economy.power_drained, "resource_capacity": obs.economy.resource_capacity, "harvester_count": obs.economy.harvester_count, }, "military": { "units_killed": obs.military.units_killed, "units_lost": obs.military.units_lost, "buildings_killed": obs.military.buildings_killed, "buildings_lost": obs.military.buildings_lost, "army_value": obs.military.army_value, "active_unit_count": obs.military.active_unit_count, "kills_cost": obs.military.kills_cost, "deaths_cost": obs.military.deaths_cost, "assets_value": obs.military.assets_value, "experience": obs.military.experience, "order_count": obs.military.order_count, }, "units": [ { "actor_id": u.actor_id, "type": u.type, "pos_x": u.pos_x, "pos_y": u.pos_y, "cell_x": u.cell_x, "cell_y": u.cell_y, "hp_percent": u.hp_percent, "is_idle": u.is_idle, "current_activity": u.current_activity, "owner": u.owner, "can_attack": u.can_attack, "facing": u.facing, "experience_level": u.experience_level, "stance": u.stance, "speed": u.speed, "attack_range": u.attack_range, "passenger_count": u.passenger_count, "is_building": u.is_building, } for u in obs.units ], "buildings": [ { "actor_id": b.actor_id, "type": b.type, "pos_x": b.pos_x, "pos_y": b.pos_y, "hp_percent": b.hp_percent, "owner": b.owner, "is_producing": b.is_producing, "production_progress": b.production_progress, "producing_item": b.producing_item, "is_powered": b.is_powered, "is_repairing": b.is_repairing, "sell_value": b.sell_value, "rally_x": b.rally_x, "rally_y": b.rally_y, "power_amount": b.power_amount, "can_produce": list(b.can_produce), "cell_x": b.cell_x, "cell_y": b.cell_y, } for b in obs.buildings ], "production": [ { "queue_type": p.queue_type, "item": p.item, "progress": p.progress, "remaining_ticks": p.remaining_ticks, "remaining_cost": p.remaining_cost, "paused": p.paused, } for p in obs.production ], "visible_enemies": [ { "actor_id": u.actor_id, "type": u.type, "pos_x": u.pos_x, "pos_y": u.pos_y, "cell_x": u.cell_x, "cell_y": u.cell_y, "hp_percent": u.hp_percent, "is_idle": u.is_idle, "current_activity": u.current_activity, "owner": u.owner, "can_attack": u.can_attack, "facing": u.facing, "experience_level": u.experience_level, "stance": u.stance, "speed": u.speed, "attack_range": u.attack_range, "passenger_count": u.passenger_count, "is_building": u.is_building, } for u in obs.visible_enemies ], "visible_enemy_buildings": [ { "actor_id": b.actor_id, "type": b.type, "pos_x": b.pos_x, "pos_y": b.pos_y, "hp_percent": b.hp_percent, "owner": b.owner, "is_producing": b.is_producing, "production_progress": b.production_progress, "producing_item": b.producing_item, "is_powered": b.is_powered, "is_repairing": b.is_repairing, "sell_value": b.sell_value, "rally_x": b.rally_x, "rally_y": b.rally_y, "power_amount": b.power_amount, "can_produce": list(b.can_produce), "cell_x": b.cell_x, "cell_y": b.cell_y, } for b in obs.visible_enemy_buildings ], "map_info": { "width": obs.map_info.width, "height": obs.map_info.height, "map_name": obs.map_info.map_name, }, "available_production": list(obs.available_production), "done": obs.done, "reward": obs.reward, "result": obs.result, "spatial_map": base64.b64encode(bytes(obs.spatial_map)).decode("ascii"), "spatial_channels": obs.spatial_channels, } def commands_to_proto(commands: list[dict]) -> rl_bridge_pb2.AgentAction: """Convert a list of command dicts to a protobuf AgentAction.""" action_type_map = { "no_op": rl_bridge_pb2.NO_OP, "move": rl_bridge_pb2.MOVE, "attack_move": rl_bridge_pb2.ATTACK_MOVE, "attack": rl_bridge_pb2.ATTACK, "stop": rl_bridge_pb2.STOP, "harvest": rl_bridge_pb2.HARVEST, "build": rl_bridge_pb2.BUILD, "train": rl_bridge_pb2.TRAIN, "deploy": rl_bridge_pb2.DEPLOY, "sell": rl_bridge_pb2.SELL, "repair": rl_bridge_pb2.REPAIR, "place_building": rl_bridge_pb2.PLACE_BUILDING, "cancel_production": rl_bridge_pb2.CANCEL_PRODUCTION, "set_rally_point": rl_bridge_pb2.SET_RALLY_POINT, "guard": rl_bridge_pb2.GUARD, "set_stance": rl_bridge_pb2.SET_STANCE, "enter_transport": rl_bridge_pb2.ENTER_TRANSPORT, "unload": rl_bridge_pb2.UNLOAD, "power_down": rl_bridge_pb2.POWER_DOWN, "set_primary": rl_bridge_pb2.SET_PRIMARY, "surrender": rl_bridge_pb2.SURRENDER, } proto_commands = [] for cmd in commands: action_str = cmd.get("action", "no_op") proto_cmd = rl_bridge_pb2.Command( action=action_type_map.get(action_str, rl_bridge_pb2.NO_OP), actor_id=cmd.get("actor_id", 0), target_actor_id=cmd.get("target_actor_id", 0), target_x=cmd.get("target_x", 0), target_y=cmd.get("target_y", 0), item_type=cmd.get("item_type", ""), queued=cmd.get("queued", False), ) proto_commands.append(proto_cmd) return rl_bridge_pb2.AgentAction(commands=proto_commands)