"""Episode spine: Rust env + adapter + pluggable agent. This is the Bench-side replacement for Training's `play_episodes_async` (which is hardwired to the C# gRPC server). It reuses Training *components* via the adapter; provider-agnostic agents plug in here (Phase 0 follow-up: openra_bench/agent.py with vLLM/OpenRouter/Bedrock). An `agent_fn` has signature: agent_fn(render_state: dict, Command) -> list[Command] where `Command` is `openra_train.Command` (move_units/attack_unit/observe). """ from __future__ import annotations import tempfile from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable import yaml from openra_rl_training.training.rust_env_pool import RustEnvPool from .controller import ( Controller, EpisodeContext, as_controller, introspection_source, ) from .rust_adapter import EpisodeSignals, RustObsAdapter from .scenarios.schema import CompiledLevel from .scenarios.win_conditions import WinContext, evaluate # A policy is either a bare `agent_fn(render_state, Command) -> [Command]` # callable (the legacy shape, still accepted everywhere) or a Controller. AgentFn = Callable[[dict, Any], list] Policy = "AgentFn | Controller" def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str: """Serialize a compiled level's ScenarioDefinition to a temp YAML the Rust env can load (it reads actors from the given scenario path; the map geometry is the Rust-supported base map).""" data = compiled.scenario.model_dump(mode="json", exclude_none=True) # The Rust loader resolves base_map relative to the scenario file's # dir; this temp file lives in /tmp, so a relative ref would silently # fall back to rush-hour terrain. Pin it to the resolved absolute # .oramap so the *declared* map's real terrain loads. from .scenarios.loader import resolve_map_path _mp = resolve_map_path(str(data.get("base_map", ""))) if _mp is not None: data["base_map"] = str(_mp) # Training's ScenarioDefinition has no economy field; inject the # pack's designed `starting_cash` constraint as a top-level key the # Rust scenario parser reads (default 5000 when unset). if compiled.starting_cash is not None: data["starting_cash"] = compiled.starting_cash # Per-player cash plumbing footgun: `PlayerSetup.cash` defaults to # `int = 0` (not Optional), so an unset `agent: {faction: ...}` in # the pack serializes as `{faction: ..., cash: 0}`. With the # engine's per-player-cash fix landed, that 0 silently OVERRIDES # the top-level `starting_cash:` — production stalls because the # agent has no money to consume. Defensive bench-side fix: if # `agent.cash` / `enemy.cash` is 0 AND `starting_cash` > 0, STRIP # the per-player cash field so the engine falls back to the # top-level. (A pack that genuinely wants cash=0 should also set # `starting_cash: 0`, in which case there's nothing to fall back # to and the behavior is unchanged.) _top_cash = int(data.get("starting_cash") or 0) for _side in ("agent", "enemy"): _block = data.get(_side) if isinstance(_block, dict) and _block.get("cash", None) == 0 and _top_cash > 0: _block.pop("cash", None) data[_side] = _block # The Rust engine defaults spawn_mcvs:true → it auto-seeds MCVs at # the map's built-in spawn points (e.g. (124,36)), which reveal fog # and pollute unit counts for scenarios that never asked for them. # Curated packs declare their own actors, so default it OFF; a pack # may still opt in via base.spawn_mcvs / a scenario-declared `mcv`. data.setdefault("spawn_mcvs", bool(data.get("spawn_mcvs", False))) # Wave-9 `scheduled_events:` — preserved on CompiledLevel because # the training ScenarioDefinition silently drops the field. Re-emit # here so the Rust engine's scenario parser sees it. sched = getattr(compiled, "scheduled_events", None) or [] if sched: data["scheduled_events"] = sched # Resource-wave `ore_patches:` — same lifting pattern as # `scheduled_events`. Each entry is `{x, y, amount, radius}` and # the engine (`oramap.rs::read_ore_patches`) materialises it into # a disk of ore cells on the terrain at world-build time. ore = getattr(compiled, "ore_patches", None) or [] if ore: data["ore_patches"] = ore # No-fog perception cells (fog_mode ends with "-clear") flip the # engine's `reveal_map` flag: the agent observes the whole map with # no fog of war — the clear half of the perception ablation grid. if getattr(compiled, "reveal_map", False): data["reveal_map"] = True # Naval-MVP overlay: forward declared `water_cells:` and # `water_rect:` so the Rust engine marks the corresponding # terrain cells as ship-passable / ground-impassable. Without # this lift the engine's terrain stays all-grass and ships have # nowhere to move. wc = getattr(compiled, "water_cells", None) or [] if wc: data["water_cells"] = [list(c) for c in wc] wr = getattr(compiled, "water_rect", None) if wr: data["water_rect"] = list(wr) fd = tempfile.NamedTemporaryFile( "w", suffix=f"_{compiled.pack_id}_{compiled.level}.yaml", delete=False ) yaml.safe_dump(data, fd, sort_keys=False) fd.close() return fd.name @dataclass class EpisodeResult: scenario: str seed: int turns: int signals: EpisodeSignals outcome: str = "draw" # "win" | "loss" | "draw" actions_issued: int = 0 actions_warned: int = 0 # commands the engine rejected/warned on trace: list[dict] = field(default_factory=list) # Final goal-tracker snapshot (always computed, playback or not). # objective_progress is continuous partial credit toward the # scenario win condition; reward_vector is the normalized # cumulative, scenario-agnostic vector (see goal_tracker). objective_progress: float = 0.0 reward_vector: dict = field(default_factory=dict) _CMD_NAME_RE = __import__("re").compile(r"Command::([A-Z][A-Za-z0-9]*)") def _cmd_tool_name(cmd: Any) -> str | None: """Decode the snake_case tool name from a Command repr. The Rust pyo3 Command enum stringifies as ``Command::VariantName { … }`` or ``Command::VariantName``. We extract the variant and convert to snake_case to match the bench's tool-name vocabulary (and the `tools:` / `forbidden_tools:` allowlist keys in YAML). Returns ``None`` for anything that doesn't match — defensive; never raises. """ m = _CMD_NAME_RE.search(repr(cmd)) if not m: return None variant = m.group(1) # CamelCase → snake_case (MoveUnits → move_units, AttackUnit → attack_unit) out: list[str] = [] for i, ch in enumerate(variant): if i and ch.isupper(): out.append("_") out.append(ch.lower()) return "".join(out) def scripted_explore_agent(render_state: dict, Command: Any) -> list: """Baseline reference agent: walk every unit toward the nearest unexplored frontier cell. Exercises the move path; a useful lower-bound control for the perception/exploration scenarios. """ grid = render_state["minimap"].splitlines() h = len(grid) w = len(grid[0]) if grid else 0 frontier = [ (x, y) for y in range(h) for x in range(min(w, len(grid[y]))) if grid[y][x] == "#" ] units = render_state.get("units_summary", []) if not units or not frontier: return [Command.observe()] cmds = [] for u in units: ux, uy = u["cell_x"], u["cell_y"] tx, ty = min(frontier, key=lambda c: (c[0] - ux) ** 2 + (c[1] - uy) ** 2) cmds.append(Command.move_units([str(u["id"])], target_x=tx, target_y=ty)) return cmds def run_episode( scenario_path: str, agent_fn: "AgentFn | Controller" = scripted_explore_agent, max_turns: int = 40, seed: int = 0, pool: RustEnvPool | None = None, ) -> EpisodeResult: """Run a scenario for a fixed number of turns. `agent_fn` may be a bare `agent_fn(render_state, Command) -> [Command]` callable or any `Controller`; it is coerced through `as_controller()`.""" owns_pool = pool is None if pool is None: pool = RustEnvPool(size=1, scenario_path=scenario_path) env = pool.acquire() try: adapter = RustObsAdapter() obs = env.reset(seed=seed) adapter.observe(obs) controller = as_controller(agent_fn) controller.reset( EpisodeContext(seed=seed, max_turns=max_turns) ) trace: list[dict] = [] turns = 0 issued = warned = 0 for turns in range(1, max_turns + 1): rs = adapter.render_state() cmds = controller.act(rs, env.Command) or [env.Command.observe()] obs, _reward, done, info = env.step(cmds) adapter.observe(obs, done=done) issued += len(cmds) warned += len(info.get("warnings", []) if isinstance(info, dict) else []) trace.append( { "turn": turns, "tick": adapter.signals.game_tick, "explored": round(adapter.signals.explored_percent, 2), "kills": adapter.signals.units_killed, "enemies_seen": len(adapter.signals.enemies_seen_ids), "n_cmds": len(cmds), } ) if done: break return EpisodeResult( scenario=scenario_path, seed=seed, turns=turns, signals=adapter.signals, actions_issued=issued, actions_warned=warned, trace=trace, ) finally: pool.release(env) if owns_pool: pool.shutdown() def run_level( compiled: CompiledLevel, agent_fn: "AgentFn | Controller" = scripted_explore_agent, seed: int = 0, playback=None, full_playback=None, ) -> EpisodeResult: """Run one scenario-pack level, scoring against its declarative win/fail conditions (checked every turn). Outcome maps to the `reward_outcome` convention: win=1.0, draw=0.5, loss=0.0. `agent_fn` may be a bare `agent_fn(render_state, Command) -> [Command]` callable, a `ModelAgent` bound method, or any `Controller`; it is coerced through `as_controller()`. """ if not compiled.map_supported: raise RuntimeError( f"{compiled.pack_id}: base map not Rust-loadable yet (Phase 3). " f"Validate-only; cannot execute." ) tmp_path = _scenario_to_tmp_yaml(compiled) pool = RustEnvPool(size=1, scenario_path=tmp_path) env = pool.acquire() try: adapter = RustObsAdapter() adapter.observe(env.reset(seed=seed)) # Coerce the policy through the unified Controller contract: # a bare agent_fn, a ModelAgent bound method, or a Controller # all resolve to a Controller the loop drives identically. controller = as_controller(agent_fn) controller.reset( EpisodeContext( pack_id=compiled.pack_id, level=compiled.level, seed=seed, objective=compiled.scenario.description or "", max_turns=compiled.max_turns, ) ) trace: list[dict] = [] outcome = "draw" turns = 0 issued = warned = 0 conceded = False # Persistent fog history so the SAVED minimap == the image the # model actually saw (same vendored _minimap_v2, accumulating). _pb_explored: set = set() _pb_terrain = None # Audit-capture wiring: when a FullPlayback is attached, surface # the underlying ModelAgent (if any) and flip on `audit_capture` # so per-turn briefing / wire request+response are stashed for # the audit JSONL. _audit_agent = ( introspection_source(controller) if full_playback is not None else None ) if _audit_agent is not None and hasattr(_audit_agent, "audit_capture"): _audit_agent.audit_capture = True # Interrupt-driven mode (step 4): if the scenario enabled any # interrupt signals, advance with step_until_event so the agent # is re-prompted (debriefed) the moment an event fires # (enemy spotted, unit lost, production complete, …) instead of # only on fixed tick boundaries. Falls back to fixed step() # when no signals are enabled or the env lacks the API. _KNOWN_SIGNALS = { "enemy_unit_spotted", "enemy_building_spotted", "engage_start", "own_unit_destroyed", "production_complete", } enabled_sig = sorted( s for s, on in (compiled.scenario.interrupts or {}).items() if on and s in _KNOWN_SIGNALS ) raw_env = getattr(env, "_env", None) interrupt_mode = bool(enabled_sig) and raw_env is not None and hasattr( raw_env, "step_until_event" ) # Strict-toolban / procedural-compliance accounting: any cmd whose # tool name is in compiled.forbidden_tools increments # signals.tool_violations (read by the tool_violations_gte # predicate). Tracked here so scripted and live-model policies # are graded by the exact same rule. forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])} for turns in range(1, compiled.max_turns + 1): rs = adapter.render_state() cmds = controller.act(rs, env.Command) or [env.Command.observe()] for _cmd in cmds: _tn = _cmd_tool_name(_cmd) if _tn: adapter.signals.tools_called[_tn] = ( adapter.signals.tools_called.get(_tn, 0) + 1 ) if _tn in forbidden: adapter.signals.tool_violations += 1 if not conceded: conceded = any("Surrender" in repr(c) for c in cmds) interrupt = None if interrupt_mode: obs, _r, done, info, was_int, reason, _tk = ( raw_env.step_until_event(cmds, None, 5, enabled_sig) ) if was_int: interrupt = reason else: obs, _r, done, info = env.step(cmds) adapter.observe(obs, done=done) issued += len(cmds) warned += len(info.get("warnings", []) if isinstance(info, dict) else []) ctx = WinContext(signals=adapter.signals, render_state=adapter.render_state()) if evaluate(compiled.win_condition, ctx): outcome = "win" elif evaluate(compiled.fail_condition, ctx): outcome = "loss" if playback is not None: _png = None try: from .minimap import terrain_png_for if _pb_terrain is None: _pb_terrain = terrain_png_for( compiled.scenario.base_map ) # Same vendored _minimap_v2 the model is sent, with # accumulating fog → saved image == model's image. from .prompt_v2 import minimap_b64 as _v2_mm _png = _v2_mm( rs, _pb_terrain, _pb_explored, constant_colors=compiled.level in ("easy", "medium"), ) if _png is None: from .agent import _render_minimap_b64 _png = _render_minimap_b64(rs, _pb_terrain) except Exception: # noqa: BLE001 — playback never breaks a run pass from .goal_tracker import turn_goal playback.record_turn( turns, rs, cmds, adapter.signals, _png, interrupt=interrupt, goal=turn_goal(compiled.win_condition, ctx), ) if full_playback is not None: # Mirror the same PNG (when the legacy playback rendered # one). Otherwise render on-demand for the audit format. _fp_png = locals().get("_png") if playback is not None else None if _fp_png is None: try: from .minimap import terrain_png_for if _pb_terrain is None: _pb_terrain = terrain_png_for( compiled.scenario.base_map ) from .prompt_v2 import minimap_b64 as _v2_mm _fp_png = _v2_mm( rs, _pb_terrain, _pb_explored, constant_colors=compiled.level in ("easy", "medium"), ) if _fp_png is None: from .agent import _render_minimap_b64 _fp_png = _render_minimap_b64(rs, _pb_terrain) except Exception: # noqa: BLE001 — audit never breaks a run _fp_png = None try: full_playback.record_turn( turn=turns, tick=adapter.signals.game_tick, obs=rs, briefing=getattr(_audit_agent, "last_briefing", "") if _audit_agent is not None else "", system_prompt=getattr(_audit_agent, "system_prompt", "") if _audit_agent is not None else "", model_request=getattr(_audit_agent, "last_request", None) if _audit_agent is not None else None, model_response=getattr(_audit_agent, "last_response", None) if _audit_agent is not None else None, commands_issued=cmds, engine_warnings=( info.get("warnings", []) if isinstance(info, dict) else [] ), signals=adapter.signals, minimap_png_b64=_fp_png, done=bool(done), interrupt=interrupt, ) except Exception: # noqa: BLE001 pass trace.append( { "turn": turns, "tick": adapter.signals.game_tick, "explored": round(adapter.signals.explored_percent, 2), "kills": adapter.signals.units_killed, "enemies_seen": len(adapter.signals.enemies_seen_ids), "interrupt": interrupt, } ) if outcome != "draw" or done: break if conceded: outcome = "loss" # the agent chose to concede adapter.signals.outcome = {"win": 1.0, "draw": 0.5, "loss": 0.0}[outcome] from .goal_tracker import turn_goal final_rs = adapter.render_state() final_goal = turn_goal( compiled.win_condition, WinContext(signals=adapter.signals, render_state=final_rs), ) # Terminal frame: the RESOLVED post-action board the moment the # episode ends (win/loss). The per-turn record uses the # pre-step state, so without this the viewer never shows the # winning/losing position. No model action on this frame. if playback is not None: _fpng = None try: from .minimap import terrain_png_for if _pb_terrain is None: _pb_terrain = terrain_png_for(compiled.scenario.base_map) from .prompt_v2 import minimap_b64 as _v2_mm _fpng = _v2_mm( final_rs, _pb_terrain, _pb_explored, constant_colors=compiled.level in ("easy", "medium"), ) if _fpng is None: from .agent import _render_minimap_b64 _fpng = _render_minimap_b64(final_rs, _pb_terrain) except Exception: # noqa: BLE001 — playback never breaks a run pass playback.record_turn( turns + 1, final_rs, [f"(episode end: {('loss' if conceded else outcome)})"], adapter.signals, _fpng, interrupt=None, goal=final_goal, ) result = EpisodeResult( scenario=f"{compiled.pack_id}:{compiled.level}", seed=seed, turns=turns, signals=adapter.signals, outcome=outcome, actions_issued=issued, actions_warned=warned, trace=trace, objective_progress=final_goal["objective_progress"], reward_vector=final_goal["reward_vector"], ) if playback is not None: # Dump the full model⇄env transcript when the agent is a # ModelAgent — the Controller layer surfaces the underlying # instance (bound-method __self__ or the Controller itself). agent_obj = introspection_source(controller) hist = getattr(agent_obj, "history", None) if isinstance(hist, list): playback.write_messages(hist) playback.finalize( { "scenario": result.scenario, "pack_id": compiled.pack_id, "level": compiled.level, "capability": compiled.meta.capability, "run_id": getattr(playback, "run_id", None), "model": getattr(playback, "model", None), "seed": seed, "outcome": outcome, "turns": turns, "max_turns": compiled.max_turns, "actions_issued": issued, "actions_warned": warned, "agent_stats": getattr(agent_obj, "stats", None), "objective_progress": result.objective_progress, "reward_vector": result.reward_vector, "signals": { "economy_value": adapter.signals.cash + adapter.signals.resources, "explored_percent": round( adapter.signals.explored_percent, 2 ), "units_killed": adapter.signals.units_killed, "units_lost": adapter.signals.units_lost, }, } ) if full_playback is not None: try: full_playback.finalize( outcome=outcome, final_obs=final_rs, manifest_extra={ "scenario": result.scenario, "pack_id": compiled.pack_id, "level": compiled.level, "capability": compiled.meta.capability, "seed": seed, "outcome": outcome, "turns": turns, "max_turns": compiled.max_turns, "actions_issued": issued, "actions_warned": warned, "agent_stats": getattr( introspection_source(controller), "stats", None ), "objective_progress": result.objective_progress, "reward_vector": result.reward_vector, }, ) except Exception: # noqa: BLE001 — never break a run on I/O pass return result finally: # Abort the audit recorder if the loop crashed before finalize — # leaves a `.partial` on disk for forensics; the resume scanner # correctly sees no `.jsonl` and retries the cell. try: if full_playback is not None and Path( full_playback.jsonl_path ).exists() is False: full_playback.abort() except Exception: # noqa: BLE001 pass pool.release(env) pool.shutdown() Path(tmp_path).unlink(missing_ok=True)