OpenRA-Bench / openra_bench /eval_core.py
yxc20098's picture
Defensive fix: strip agent.cash=0 when starting_cash>0 in tmp YAML
b77e43d
Raw
History Blame Contribute Delete
24.9 kB
"""Episode spine: Rust env + adapter + pluggable agent.
This is the Bench-side replacement for Training's `play_episodes_async`
(which is hardwired to the C# gRPC server). It reuses Training *components*
via the adapter; provider-agnostic agents plug in here (Phase 0 follow-up:
openra_bench/agent.py with vLLM/OpenRouter/Bedrock).
An `agent_fn` has signature:
agent_fn(render_state: dict, Command) -> list[Command]
where `Command` is `openra_train.Command` (move_units/attack_unit/observe).
"""
from __future__ import annotations
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
import yaml
from openra_rl_training.training.rust_env_pool import RustEnvPool
from .controller import (
Controller,
EpisodeContext,
as_controller,
introspection_source,
)
from .rust_adapter import EpisodeSignals, RustObsAdapter
from .scenarios.schema import CompiledLevel
from .scenarios.win_conditions import WinContext, evaluate
# A policy is either a bare `agent_fn(render_state, Command) -> [Command]`
# callable (the legacy shape, still accepted everywhere) or a Controller.
AgentFn = Callable[[dict, Any], list]
Policy = "AgentFn | Controller"
def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str:
"""Serialize a compiled level's ScenarioDefinition to a temp YAML the
Rust env can load (it reads actors from the given scenario path; the
map geometry is the Rust-supported base map)."""
data = compiled.scenario.model_dump(mode="json", exclude_none=True)
# The Rust loader resolves base_map relative to the scenario file's
# dir; this temp file lives in /tmp, so a relative ref would silently
# fall back to rush-hour terrain. Pin it to the resolved absolute
# .oramap so the *declared* map's real terrain loads.
from .scenarios.loader import resolve_map_path
_mp = resolve_map_path(str(data.get("base_map", "")))
if _mp is not None:
data["base_map"] = str(_mp)
# Training's ScenarioDefinition has no economy field; inject the
# pack's designed `starting_cash` constraint as a top-level key the
# Rust scenario parser reads (default 5000 when unset).
if compiled.starting_cash is not None:
data["starting_cash"] = compiled.starting_cash
# Per-player cash plumbing footgun: `PlayerSetup.cash` defaults to
# `int = 0` (not Optional), so an unset `agent: {faction: ...}` in
# the pack serializes as `{faction: ..., cash: 0}`. With the
# engine's per-player-cash fix landed, that 0 silently OVERRIDES
# the top-level `starting_cash:` — production stalls because the
# agent has no money to consume. Defensive bench-side fix: if
# `agent.cash` / `enemy.cash` is 0 AND `starting_cash` > 0, STRIP
# the per-player cash field so the engine falls back to the
# top-level. (A pack that genuinely wants cash=0 should also set
# `starting_cash: 0`, in which case there's nothing to fall back
# to and the behavior is unchanged.)
_top_cash = int(data.get("starting_cash") or 0)
for _side in ("agent", "enemy"):
_block = data.get(_side)
if isinstance(_block, dict) and _block.get("cash", None) == 0 and _top_cash > 0:
_block.pop("cash", None)
data[_side] = _block
# The Rust engine defaults spawn_mcvs:true → it auto-seeds MCVs at
# the map's built-in spawn points (e.g. (124,36)), which reveal fog
# and pollute unit counts for scenarios that never asked for them.
# Curated packs declare their own actors, so default it OFF; a pack
# may still opt in via base.spawn_mcvs / a scenario-declared `mcv`.
data.setdefault("spawn_mcvs", bool(data.get("spawn_mcvs", False)))
# Wave-9 `scheduled_events:` — preserved on CompiledLevel because
# the training ScenarioDefinition silently drops the field. Re-emit
# here so the Rust engine's scenario parser sees it.
sched = getattr(compiled, "scheduled_events", None) or []
if sched:
data["scheduled_events"] = sched
# Resource-wave `ore_patches:` — same lifting pattern as
# `scheduled_events`. Each entry is `{x, y, amount, radius}` and
# the engine (`oramap.rs::read_ore_patches`) materialises it into
# a disk of ore cells on the terrain at world-build time.
ore = getattr(compiled, "ore_patches", None) or []
if ore:
data["ore_patches"] = ore
# No-fog perception cells (fog_mode ends with "-clear") flip the
# engine's `reveal_map` flag: the agent observes the whole map with
# no fog of war — the clear half of the perception ablation grid.
if getattr(compiled, "reveal_map", False):
data["reveal_map"] = True
# Naval-MVP overlay: forward declared `water_cells:` and
# `water_rect:` so the Rust engine marks the corresponding
# terrain cells as ship-passable / ground-impassable. Without
# this lift the engine's terrain stays all-grass and ships have
# nowhere to move.
wc = getattr(compiled, "water_cells", None) or []
if wc:
data["water_cells"] = [list(c) for c in wc]
wr = getattr(compiled, "water_rect", None)
if wr:
data["water_rect"] = list(wr)
fd = tempfile.NamedTemporaryFile(
"w", suffix=f"_{compiled.pack_id}_{compiled.level}.yaml", delete=False
)
yaml.safe_dump(data, fd, sort_keys=False)
fd.close()
return fd.name
@dataclass
class EpisodeResult:
scenario: str
seed: int
turns: int
signals: EpisodeSignals
outcome: str = "draw" # "win" | "loss" | "draw"
actions_issued: int = 0
actions_warned: int = 0 # commands the engine rejected/warned on
trace: list[dict] = field(default_factory=list)
# Final goal-tracker snapshot (always computed, playback or not).
# objective_progress is continuous partial credit toward the
# scenario win condition; reward_vector is the normalized
# cumulative, scenario-agnostic vector (see goal_tracker).
objective_progress: float = 0.0
reward_vector: dict = field(default_factory=dict)
_CMD_NAME_RE = __import__("re").compile(r"Command::([A-Z][A-Za-z0-9]*)")
def _cmd_tool_name(cmd: Any) -> str | None:
"""Decode the snake_case tool name from a Command repr.
The Rust pyo3 Command enum stringifies as ``Command::VariantName { … }``
or ``Command::VariantName``. We extract the variant and convert to
snake_case to match the bench's tool-name vocabulary (and the
`tools:` / `forbidden_tools:` allowlist keys in YAML). Returns
``None`` for anything that doesn't match — defensive; never raises.
"""
m = _CMD_NAME_RE.search(repr(cmd))
if not m:
return None
variant = m.group(1)
# CamelCase → snake_case (MoveUnits → move_units, AttackUnit → attack_unit)
out: list[str] = []
for i, ch in enumerate(variant):
if i and ch.isupper():
out.append("_")
out.append(ch.lower())
return "".join(out)
def scripted_explore_agent(render_state: dict, Command: Any) -> list:
"""Baseline reference agent: walk every unit toward the nearest
unexplored frontier cell. Exercises the move path; a useful
lower-bound control for the perception/exploration scenarios.
"""
grid = render_state["minimap"].splitlines()
h = len(grid)
w = len(grid[0]) if grid else 0
frontier = [
(x, y)
for y in range(h)
for x in range(min(w, len(grid[y])))
if grid[y][x] == "#"
]
units = render_state.get("units_summary", [])
if not units or not frontier:
return [Command.observe()]
cmds = []
for u in units:
ux, uy = u["cell_x"], u["cell_y"]
tx, ty = min(frontier, key=lambda c: (c[0] - ux) ** 2 + (c[1] - uy) ** 2)
cmds.append(Command.move_units([str(u["id"])], target_x=tx, target_y=ty))
return cmds
def run_episode(
scenario_path: str,
agent_fn: "AgentFn | Controller" = scripted_explore_agent,
max_turns: int = 40,
seed: int = 0,
pool: RustEnvPool | None = None,
) -> EpisodeResult:
"""Run a scenario for a fixed number of turns. `agent_fn` may be a
bare `agent_fn(render_state, Command) -> [Command]` callable or any
`Controller`; it is coerced through `as_controller()`."""
owns_pool = pool is None
if pool is None:
pool = RustEnvPool(size=1, scenario_path=scenario_path)
env = pool.acquire()
try:
adapter = RustObsAdapter()
obs = env.reset(seed=seed)
adapter.observe(obs)
controller = as_controller(agent_fn)
controller.reset(
EpisodeContext(seed=seed, max_turns=max_turns)
)
trace: list[dict] = []
turns = 0
issued = warned = 0
for turns in range(1, max_turns + 1):
rs = adapter.render_state()
cmds = controller.act(rs, env.Command) or [env.Command.observe()]
obs, _reward, done, info = env.step(cmds)
adapter.observe(obs, done=done)
issued += len(cmds)
warned += len(info.get("warnings", []) if isinstance(info, dict) else [])
trace.append(
{
"turn": turns,
"tick": adapter.signals.game_tick,
"explored": round(adapter.signals.explored_percent, 2),
"kills": adapter.signals.units_killed,
"enemies_seen": len(adapter.signals.enemies_seen_ids),
"n_cmds": len(cmds),
}
)
if done:
break
return EpisodeResult(
scenario=scenario_path,
seed=seed,
turns=turns,
signals=adapter.signals,
actions_issued=issued,
actions_warned=warned,
trace=trace,
)
finally:
pool.release(env)
if owns_pool:
pool.shutdown()
def run_level(
compiled: CompiledLevel,
agent_fn: "AgentFn | Controller" = scripted_explore_agent,
seed: int = 0,
playback=None,
full_playback=None,
) -> EpisodeResult:
"""Run one scenario-pack level, scoring against its declarative
win/fail conditions (checked every turn). Outcome maps to the
`reward_outcome` convention: win=1.0, draw=0.5, loss=0.0.
`agent_fn` may be a bare `agent_fn(render_state, Command) ->
[Command]` callable, a `ModelAgent` bound method, or any
`Controller`; it is coerced through `as_controller()`.
"""
if not compiled.map_supported:
raise RuntimeError(
f"{compiled.pack_id}: base map not Rust-loadable yet (Phase 3). "
f"Validate-only; cannot execute."
)
tmp_path = _scenario_to_tmp_yaml(compiled)
pool = RustEnvPool(size=1, scenario_path=tmp_path)
env = pool.acquire()
try:
adapter = RustObsAdapter()
adapter.observe(env.reset(seed=seed))
# Coerce the policy through the unified Controller contract:
# a bare agent_fn, a ModelAgent bound method, or a Controller
# all resolve to a Controller the loop drives identically.
controller = as_controller(agent_fn)
controller.reset(
EpisodeContext(
pack_id=compiled.pack_id,
level=compiled.level,
seed=seed,
objective=compiled.scenario.description or "",
max_turns=compiled.max_turns,
)
)
trace: list[dict] = []
outcome = "draw"
turns = 0
issued = warned = 0
conceded = False
# Persistent fog history so the SAVED minimap == the image the
# model actually saw (same vendored _minimap_v2, accumulating).
_pb_explored: set = set()
_pb_terrain = None
# Audit-capture wiring: when a FullPlayback is attached, surface
# the underlying ModelAgent (if any) and flip on `audit_capture`
# so per-turn briefing / wire request+response are stashed for
# the audit JSONL.
_audit_agent = (
introspection_source(controller) if full_playback is not None else None
)
if _audit_agent is not None and hasattr(_audit_agent, "audit_capture"):
_audit_agent.audit_capture = True
# Interrupt-driven mode (step 4): if the scenario enabled any
# interrupt signals, advance with step_until_event so the agent
# is re-prompted (debriefed) the moment an event fires
# (enemy spotted, unit lost, production complete, …) instead of
# only on fixed tick boundaries. Falls back to fixed step()
# when no signals are enabled or the env lacks the API.
_KNOWN_SIGNALS = {
"enemy_unit_spotted", "enemy_building_spotted", "engage_start",
"own_unit_destroyed", "production_complete",
}
enabled_sig = sorted(
s for s, on in (compiled.scenario.interrupts or {}).items()
if on and s in _KNOWN_SIGNALS
)
raw_env = getattr(env, "_env", None)
interrupt_mode = bool(enabled_sig) and raw_env is not None and hasattr(
raw_env, "step_until_event"
)
# Strict-toolban / procedural-compliance accounting: any cmd whose
# tool name is in compiled.forbidden_tools increments
# signals.tool_violations (read by the tool_violations_gte
# predicate). Tracked here so scripted and live-model policies
# are graded by the exact same rule.
forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])}
for turns in range(1, compiled.max_turns + 1):
rs = adapter.render_state()
cmds = controller.act(rs, env.Command) or [env.Command.observe()]
for _cmd in cmds:
_tn = _cmd_tool_name(_cmd)
if _tn:
adapter.signals.tools_called[_tn] = (
adapter.signals.tools_called.get(_tn, 0) + 1
)
if _tn in forbidden:
adapter.signals.tool_violations += 1
if not conceded:
conceded = any("Surrender" in repr(c) for c in cmds)
interrupt = None
if interrupt_mode:
obs, _r, done, info, was_int, reason, _tk = (
raw_env.step_until_event(cmds, None, 5, enabled_sig)
)
if was_int:
interrupt = reason
else:
obs, _r, done, info = env.step(cmds)
adapter.observe(obs, done=done)
issued += len(cmds)
warned += len(info.get("warnings", []) if isinstance(info, dict) else [])
ctx = WinContext(signals=adapter.signals, render_state=adapter.render_state())
if evaluate(compiled.win_condition, ctx):
outcome = "win"
elif evaluate(compiled.fail_condition, ctx):
outcome = "loss"
if playback is not None:
_png = None
try:
from .minimap import terrain_png_for
if _pb_terrain is None:
_pb_terrain = terrain_png_for(
compiled.scenario.base_map
)
# Same vendored _minimap_v2 the model is sent, with
# accumulating fog → saved image == model's image.
from .prompt_v2 import minimap_b64 as _v2_mm
_png = _v2_mm(
rs, _pb_terrain, _pb_explored,
constant_colors=compiled.level in ("easy", "medium"),
)
if _png is None:
from .agent import _render_minimap_b64
_png = _render_minimap_b64(rs, _pb_terrain)
except Exception: # noqa: BLE001 — playback never breaks a run
pass
from .goal_tracker import turn_goal
playback.record_turn(
turns, rs, cmds, adapter.signals, _png,
interrupt=interrupt,
goal=turn_goal(compiled.win_condition, ctx),
)
if full_playback is not None:
# Mirror the same PNG (when the legacy playback rendered
# one). Otherwise render on-demand for the audit format.
_fp_png = locals().get("_png") if playback is not None else None
if _fp_png is None:
try:
from .minimap import terrain_png_for
if _pb_terrain is None:
_pb_terrain = terrain_png_for(
compiled.scenario.base_map
)
from .prompt_v2 import minimap_b64 as _v2_mm
_fp_png = _v2_mm(
rs, _pb_terrain, _pb_explored,
constant_colors=compiled.level in ("easy", "medium"),
)
if _fp_png is None:
from .agent import _render_minimap_b64
_fp_png = _render_minimap_b64(rs, _pb_terrain)
except Exception: # noqa: BLE001 — audit never breaks a run
_fp_png = None
try:
full_playback.record_turn(
turn=turns,
tick=adapter.signals.game_tick,
obs=rs,
briefing=getattr(_audit_agent, "last_briefing", "")
if _audit_agent is not None
else "",
system_prompt=getattr(_audit_agent, "system_prompt", "")
if _audit_agent is not None
else "",
model_request=getattr(_audit_agent, "last_request", None)
if _audit_agent is not None
else None,
model_response=getattr(_audit_agent, "last_response", None)
if _audit_agent is not None
else None,
commands_issued=cmds,
engine_warnings=(
info.get("warnings", [])
if isinstance(info, dict)
else []
),
signals=adapter.signals,
minimap_png_b64=_fp_png,
done=bool(done),
interrupt=interrupt,
)
except Exception: # noqa: BLE001
pass
trace.append(
{
"turn": turns,
"tick": adapter.signals.game_tick,
"explored": round(adapter.signals.explored_percent, 2),
"kills": adapter.signals.units_killed,
"enemies_seen": len(adapter.signals.enemies_seen_ids),
"interrupt": interrupt,
}
)
if outcome != "draw" or done:
break
if conceded:
outcome = "loss" # the agent chose to concede
adapter.signals.outcome = {"win": 1.0, "draw": 0.5, "loss": 0.0}[outcome]
from .goal_tracker import turn_goal
final_rs = adapter.render_state()
final_goal = turn_goal(
compiled.win_condition,
WinContext(signals=adapter.signals, render_state=final_rs),
)
# Terminal frame: the RESOLVED post-action board the moment the
# episode ends (win/loss). The per-turn record uses the
# pre-step state, so without this the viewer never shows the
# winning/losing position. No model action on this frame.
if playback is not None:
_fpng = None
try:
from .minimap import terrain_png_for
if _pb_terrain is None:
_pb_terrain = terrain_png_for(compiled.scenario.base_map)
from .prompt_v2 import minimap_b64 as _v2_mm
_fpng = _v2_mm(
final_rs, _pb_terrain, _pb_explored,
constant_colors=compiled.level in ("easy", "medium"),
)
if _fpng is None:
from .agent import _render_minimap_b64
_fpng = _render_minimap_b64(final_rs, _pb_terrain)
except Exception: # noqa: BLE001 — playback never breaks a run
pass
playback.record_turn(
turns + 1, final_rs,
[f"(episode end: {('loss' if conceded else outcome)})"],
adapter.signals, _fpng, interrupt=None, goal=final_goal,
)
result = EpisodeResult(
scenario=f"{compiled.pack_id}:{compiled.level}",
seed=seed,
turns=turns,
signals=adapter.signals,
outcome=outcome,
actions_issued=issued,
actions_warned=warned,
trace=trace,
objective_progress=final_goal["objective_progress"],
reward_vector=final_goal["reward_vector"],
)
if playback is not None:
# Dump the full model⇄env transcript when the agent is a
# ModelAgent — the Controller layer surfaces the underlying
# instance (bound-method __self__ or the Controller itself).
agent_obj = introspection_source(controller)
hist = getattr(agent_obj, "history", None)
if isinstance(hist, list):
playback.write_messages(hist)
playback.finalize(
{
"scenario": result.scenario,
"pack_id": compiled.pack_id,
"level": compiled.level,
"capability": compiled.meta.capability,
"run_id": getattr(playback, "run_id", None),
"model": getattr(playback, "model", None),
"seed": seed,
"outcome": outcome,
"turns": turns,
"max_turns": compiled.max_turns,
"actions_issued": issued,
"actions_warned": warned,
"agent_stats": getattr(agent_obj, "stats", None),
"objective_progress": result.objective_progress,
"reward_vector": result.reward_vector,
"signals": {
"economy_value": adapter.signals.cash
+ adapter.signals.resources,
"explored_percent": round(
adapter.signals.explored_percent, 2
),
"units_killed": adapter.signals.units_killed,
"units_lost": adapter.signals.units_lost,
},
}
)
if full_playback is not None:
try:
full_playback.finalize(
outcome=outcome,
final_obs=final_rs,
manifest_extra={
"scenario": result.scenario,
"pack_id": compiled.pack_id,
"level": compiled.level,
"capability": compiled.meta.capability,
"seed": seed,
"outcome": outcome,
"turns": turns,
"max_turns": compiled.max_turns,
"actions_issued": issued,
"actions_warned": warned,
"agent_stats": getattr(
introspection_source(controller), "stats", None
),
"objective_progress": result.objective_progress,
"reward_vector": result.reward_vector,
},
)
except Exception: # noqa: BLE001 — never break a run on I/O
pass
return result
finally:
# Abort the audit recorder if the loop crashed before finalize —
# leaves a `.partial` on disk for forensics; the resume scanner
# correctly sees no `.jsonl` and retries the cell.
try:
if full_playback is not None and Path(
full_playback.jsonl_path
).exists() is False:
full_playback.abort()
except Exception: # noqa: BLE001
pass
pool.release(env)
pool.shutdown()
Path(tmp_path).unlink(missing_ok=True)