Spaces:
Running
Phase 1: unified Controller interface for the eval stack
Browse filesIntroduce openra_bench/controller.py โ the keystone for the
human-labeling machine and the 1v1 adversarial harness. Every policy
backend (LLM agent, human labeler, scripted reference) now implements
one contract:
controller.act(observation, Command) -> [Command]
- Controller protocol + BaseController + FunctionController.
- as_controller() coerces a bare agent_fn, a ModelAgent bound method,
or an existing Controller โ so all ~190 legacy test files that pass
a bare function keep working unchanged.
- EpisodeContext carries per-episode info (pack/level/seed/side) to
reset(); the 'side' field makes the interface 1v1-ready.
- run_level / run_episode drive any Controller; introspection_source()
recovers the underlying object's history/stats for playback.
- ModelAgent now satisfies the contract directly (name/reset/act).
tests/test_controller.py: 13 tests โ coercion, idempotency, bound-method
source recovery, abstract act(), ModelAgent conformance, and an
end-to-end run_level smoke proving a bare fn and its Controller wrapper
produce byte-identical EpisodeResults.
- openra_bench/agent.py +18 -0
- openra_bench/controller.py +162 -0
- openra_bench/eval_core.py +40 -6
- tests/test_controller.py +256 -0
|
@@ -486,6 +486,11 @@ class ModelAgent:
|
|
| 486 |
)
|
| 487 |
self.history: list[dict] = [{"role": "system", "content": sys_content}]
|
| 488 |
self.stats = {"turns": 0, "tool_calls": 0, "empty_replies": 0}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
def _user_message(self, render_state: dict) -> dict:
|
| 491 |
# Briefing = vendored training briefing_v2 (one unit/line,
|
|
@@ -614,3 +619,16 @@ class ModelAgent:
|
|
| 614 |
{"role": "tool", "tool_call_id": f"c{i}", "content": "ok"}
|
| 615 |
)
|
| 616 |
return cmds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 486 |
)
|
| 487 |
self.history: list[dict] = [{"role": "system", "content": sys_content}]
|
| 488 |
self.stats = {"turns": 0, "tool_calls": 0, "empty_replies": 0}
|
| 489 |
+
# Controller contract (openra_bench/controller.py): a ModelAgent
|
| 490 |
+
# IS a Controller โ it exposes `name`, `reset`, `act` so the
|
| 491 |
+
# eval loop, the 1v1 harness, and the human-labeling harness can
|
| 492 |
+
# all drive it interchangeably with any other policy backend.
|
| 493 |
+
self.name = getattr(cfg, "model", None) or "model"
|
| 494 |
|
| 495 |
def _user_message(self, render_state: dict) -> dict:
|
| 496 |
# Briefing = vendored training briefing_v2 (one unit/line,
|
|
|
|
| 619 |
{"role": "tool", "tool_call_id": f"c{i}", "content": "ok"}
|
| 620 |
)
|
| 621 |
return cmds
|
| 622 |
+
|
| 623 |
+
# โโ Controller contract โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 624 |
+
def act(self, observation: dict, Command: Any) -> list:
|
| 625 |
+
"""Controller contract โ alias of `agent_fn`. Lets a ModelAgent
|
| 626 |
+
be passed straight to `run_level` / the 1v1 harness in place of
|
| 627 |
+
a bare `agent_fn` callable."""
|
| 628 |
+
return self.agent_fn(observation, Command)
|
| 629 |
+
|
| 630 |
+
def reset(self, ctx: Any = None) -> None:
|
| 631 |
+
"""Controller contract per-episode hook. A ModelAgent is
|
| 632 |
+
constructed once per episode โ its bounded chat history starts
|
| 633 |
+
fresh in `__init__` โ so reset is a no-op; it exists so the
|
| 634 |
+
agent structurally satisfies the Controller protocol."""
|
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unified policy interface for the OpenRA-Bench eval stack.
|
| 2 |
+
|
| 3 |
+
Every actor that can drive a side of a scenario โ an LLM agent, a human
|
| 4 |
+
labeler, a scripted reference policy โ implements the same contract:
|
| 5 |
+
|
| 6 |
+
controller.act(observation, Command) -> list[Command]
|
| 7 |
+
|
| 8 |
+
This is the keystone of the human-labeling machine and the 1v1
|
| 9 |
+
adversarial harness: one harness, interchangeable policy backends.
|
| 10 |
+
`run_level` / `run_episode` drive a single Controller; a 1v1 match
|
| 11 |
+
drives two, one per side, each fed its own side-specific observation.
|
| 12 |
+
|
| 13 |
+
Back-compat is non-negotiable: the historical policy shape was a bare
|
| 14 |
+
callable ``agent_fn(render_state, Command) -> [Command]`` and ~190 test
|
| 15 |
+
files still pass one. `as_controller()` adapts any such callable (or a
|
| 16 |
+
`ModelAgent` bound method) into a Controller, so every existing scripted
|
| 17 |
+
policy and test keeps working unchanged โ the eval loop simply coerces
|
| 18 |
+
its policy argument through `as_controller()` before stepping.
|
| 19 |
+
|
| 20 |
+
Design notes
|
| 21 |
+
------------
|
| 22 |
+
* `act` keeps `Command` as an explicit parameter rather than binding it
|
| 23 |
+
at construction. `Command` is the pyo3 `openra_train.Command` factory
|
| 24 |
+
handle, only available once an env exists; threading it per-call keeps
|
| 25 |
+
Controllers constructible without an engine (cheap to unit-test) and
|
| 26 |
+
is byte-identical to the legacy `agent_fn` signature.
|
| 27 |
+
* `reset(ctx)` is the per-episode lifecycle hook. Scripted policies
|
| 28 |
+
ignore it; the model agent re-arms history; a human controller would
|
| 29 |
+
reset its click queue. The 1v1 harness calls it once per side with a
|
| 30 |
+
`side`-stamped `EpisodeContext`.
|
| 31 |
+
* `history` / `stats` are the optional introspection surface the
|
| 32 |
+
playback writer reads. `BaseController` provides empty defaults so a
|
| 33 |
+
caller can read them unconditionally.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
from __future__ import annotations
|
| 37 |
+
|
| 38 |
+
from dataclasses import dataclass, field
|
| 39 |
+
from typing import Any, Callable, Protocol, runtime_checkable
|
| 40 |
+
|
| 41 |
+
# A bare legacy policy: (render_state, Command) -> [Command].
|
| 42 |
+
PolicyFn = Callable[[dict, Any], list]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class EpisodeContext:
|
| 47 |
+
"""What a Controller is told once, at episode start (`reset`).
|
| 48 |
+
|
| 49 |
+
A scenario eval populates `pack_id` / `level` / `seed` / `objective`;
|
| 50 |
+
a 1v1 match additionally stamps `side` so the two Controllers know
|
| 51 |
+
which colour they are driving."""
|
| 52 |
+
|
| 53 |
+
pack_id: str = ""
|
| 54 |
+
level: str = ""
|
| 55 |
+
seed: int = 0
|
| 56 |
+
side: str = "agent" # "agent" | "enemy" โ which side this drives
|
| 57 |
+
objective: str = ""
|
| 58 |
+
max_turns: int = 0
|
| 59 |
+
extra: dict = field(default_factory=dict)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@runtime_checkable
|
| 63 |
+
class Controller(Protocol):
|
| 64 |
+
"""A policy that observes the world and emits engine Commands.
|
| 65 |
+
|
| 66 |
+
Structural โ anything exposing `name`, `reset`, and `act` satisfies
|
| 67 |
+
it; `ModelAgent` does so without importing this module."""
|
| 68 |
+
|
| 69 |
+
name: str
|
| 70 |
+
|
| 71 |
+
def reset(self, ctx: "EpisodeContext") -> None: ...
|
| 72 |
+
|
| 73 |
+
def act(self, observation: dict, Command: Any) -> list: ...
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def is_controller(obj: Any) -> bool:
|
| 77 |
+
"""True if `obj` already satisfies the Controller contract.
|
| 78 |
+
|
| 79 |
+
Deliberately structural and stricter than `isinstance(obj,
|
| 80 |
+
Controller)`: a bare function is callable but is NOT a Controller,
|
| 81 |
+
so it must carry callable `act` AND `reset` attributes โ which a
|
| 82 |
+
plain function never does."""
|
| 83 |
+
return callable(getattr(obj, "act", None)) and callable(
|
| 84 |
+
getattr(obj, "reset", None)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class BaseController:
|
| 89 |
+
"""Convenience base: a no-op `reset`, a `name`, empty introspection.
|
| 90 |
+
|
| 91 |
+
Subclass and implement `act`. Concrete eval policies (the human
|
| 92 |
+
bridge, scripted reference wrappers) derive from this so they share
|
| 93 |
+
one introspection surface (`history`, `stats`)."""
|
| 94 |
+
|
| 95 |
+
name: str = "controller"
|
| 96 |
+
|
| 97 |
+
def __init__(self, name: str | None = None) -> None:
|
| 98 |
+
if name:
|
| 99 |
+
self.name = name
|
| 100 |
+
self.history: list[dict] = []
|
| 101 |
+
self.stats: dict[str, Any] = {}
|
| 102 |
+
|
| 103 |
+
def reset(self, ctx: EpisodeContext) -> None: # noqa: D401
|
| 104 |
+
"""Per-episode lifecycle hook. Default: no-op."""
|
| 105 |
+
|
| 106 |
+
def act(self, observation: dict, Command: Any) -> list:
|
| 107 |
+
raise NotImplementedError(
|
| 108 |
+
f"{type(self).__name__} must implement act()"
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class FunctionController(BaseController):
|
| 113 |
+
"""Adapt a bare ``agent_fn(render_state, Command) -> [Command]``
|
| 114 |
+
callable into a Controller โ the back-compat bridge for every
|
| 115 |
+
scripted reference policy and the legacy `scripted_explore_agent`.
|
| 116 |
+
|
| 117 |
+
When the callable is a bound method (e.g. ``ModelAgent.agent_fn``),
|
| 118 |
+
its ``__self__`` is captured as `source` so the eval loop can still
|
| 119 |
+
reach the underlying object's `history` / `stats` for playback."""
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self, fn: PolicyFn, name: str | None = None
|
| 123 |
+
) -> None:
|
| 124 |
+
super().__init__(
|
| 125 |
+
name or getattr(fn, "__name__", None) or "fn"
|
| 126 |
+
)
|
| 127 |
+
self._fn = fn
|
| 128 |
+
self.source: Any = getattr(fn, "__self__", None)
|
| 129 |
+
|
| 130 |
+
def act(self, observation: dict, Command: Any) -> list:
|
| 131 |
+
return self._fn(observation, Command)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def as_controller(policy: Any, name: str | None = None) -> Controller:
|
| 135 |
+
"""Coerce anything policy-shaped into a Controller.
|
| 136 |
+
|
| 137 |
+
Accepts, in priority order:
|
| 138 |
+
* an object already satisfying the Controller contract โ returned
|
| 139 |
+
as-is (idempotent);
|
| 140 |
+
* any callable โ a bare `agent_fn` or a bound method โ wrapped in
|
| 141 |
+
a `FunctionController` (a bound method's `__self__` is kept
|
| 142 |
+
reachable via `.source`).
|
| 143 |
+
|
| 144 |
+
Raises `TypeError` for anything else."""
|
| 145 |
+
if is_controller(policy):
|
| 146 |
+
return policy
|
| 147 |
+
if callable(policy):
|
| 148 |
+
return FunctionController(policy, name)
|
| 149 |
+
raise TypeError(
|
| 150 |
+
f"cannot coerce {type(policy).__name__} into a Controller: "
|
| 151 |
+
"expected a Controller, a ModelAgent, or an "
|
| 152 |
+
"agent_fn(render_state, Command) -> [Command] callable"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def introspection_source(controller: Controller) -> Any:
|
| 157 |
+
"""The object carrying `history` / `stats` for playback.
|
| 158 |
+
|
| 159 |
+
For a `FunctionController` wrapping a bound method this is the bound
|
| 160 |
+
instance (`.source`); otherwise it is the Controller itself."""
|
| 161 |
+
src = getattr(controller, "source", None)
|
| 162 |
+
return src if src is not None else controller
|
|
@@ -20,11 +20,20 @@ from typing import Any, Callable
|
|
| 20 |
import yaml
|
| 21 |
from openra_rl_training.training.rust_env_pool import RustEnvPool
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
from .rust_adapter import EpisodeSignals, RustObsAdapter
|
| 24 |
from .scenarios.schema import CompiledLevel
|
| 25 |
from .scenarios.win_conditions import WinContext, evaluate
|
| 26 |
|
|
|
|
|
|
|
| 27 |
AgentFn = Callable[[dict, Any], list]
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str:
|
|
@@ -136,11 +145,14 @@ def scripted_explore_agent(render_state: dict, Command: Any) -> list:
|
|
| 136 |
|
| 137 |
def run_episode(
|
| 138 |
scenario_path: str,
|
| 139 |
-
agent_fn: AgentFn = scripted_explore_agent,
|
| 140 |
max_turns: int = 40,
|
| 141 |
seed: int = 0,
|
| 142 |
pool: RustEnvPool | None = None,
|
| 143 |
) -> EpisodeResult:
|
|
|
|
|
|
|
|
|
|
| 144 |
owns_pool = pool is None
|
| 145 |
if pool is None:
|
| 146 |
pool = RustEnvPool(size=1, scenario_path=scenario_path)
|
|
@@ -149,12 +161,16 @@ def run_episode(
|
|
| 149 |
adapter = RustObsAdapter()
|
| 150 |
obs = env.reset(seed=seed)
|
| 151 |
adapter.observe(obs)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
trace: list[dict] = []
|
| 153 |
turns = 0
|
| 154 |
issued = warned = 0
|
| 155 |
for turns in range(1, max_turns + 1):
|
| 156 |
rs = adapter.render_state()
|
| 157 |
-
cmds =
|
| 158 |
obs, _reward, done, info = env.step(cmds)
|
| 159 |
adapter.observe(obs, done=done)
|
| 160 |
issued += len(cmds)
|
|
@@ -188,13 +204,17 @@ def run_episode(
|
|
| 188 |
|
| 189 |
def run_level(
|
| 190 |
compiled: CompiledLevel,
|
| 191 |
-
agent_fn: AgentFn = scripted_explore_agent,
|
| 192 |
seed: int = 0,
|
| 193 |
playback=None,
|
| 194 |
) -> EpisodeResult:
|
| 195 |
"""Run one scenario-pack level, scoring against its declarative
|
| 196 |
win/fail conditions (checked every turn). Outcome maps to the
|
| 197 |
`reward_outcome` convention: win=1.0, draw=0.5, loss=0.0.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
"""
|
| 199 |
if not compiled.map_supported:
|
| 200 |
raise RuntimeError(
|
|
@@ -207,6 +227,19 @@ def run_level(
|
|
| 207 |
try:
|
| 208 |
adapter = RustObsAdapter()
|
| 209 |
adapter.observe(env.reset(seed=seed))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
trace: list[dict] = []
|
| 211 |
outcome = "draw"
|
| 212 |
turns = 0
|
|
@@ -242,7 +275,7 @@ def run_level(
|
|
| 242 |
forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])}
|
| 243 |
for turns in range(1, compiled.max_turns + 1):
|
| 244 |
rs = adapter.render_state()
|
| 245 |
-
cmds =
|
| 246 |
for _cmd in cmds:
|
| 247 |
_tn = _cmd_tool_name(_cmd)
|
| 248 |
if _tn:
|
|
@@ -364,8 +397,9 @@ def run_level(
|
|
| 364 |
)
|
| 365 |
if playback is not None:
|
| 366 |
# Dump the full modelโenv transcript when the agent is a
|
| 367 |
-
# ModelAgent
|
| 368 |
-
|
|
|
|
| 369 |
hist = getattr(agent_obj, "history", None)
|
| 370 |
if isinstance(hist, list):
|
| 371 |
playback.write_messages(hist)
|
|
|
|
| 20 |
import yaml
|
| 21 |
from openra_rl_training.training.rust_env_pool import RustEnvPool
|
| 22 |
|
| 23 |
+
from .controller import (
|
| 24 |
+
Controller,
|
| 25 |
+
EpisodeContext,
|
| 26 |
+
as_controller,
|
| 27 |
+
introspection_source,
|
| 28 |
+
)
|
| 29 |
from .rust_adapter import EpisodeSignals, RustObsAdapter
|
| 30 |
from .scenarios.schema import CompiledLevel
|
| 31 |
from .scenarios.win_conditions import WinContext, evaluate
|
| 32 |
|
| 33 |
+
# A policy is either a bare `agent_fn(render_state, Command) -> [Command]`
|
| 34 |
+
# callable (the legacy shape, still accepted everywhere) or a Controller.
|
| 35 |
AgentFn = Callable[[dict, Any], list]
|
| 36 |
+
Policy = "AgentFn | Controller"
|
| 37 |
|
| 38 |
|
| 39 |
def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str:
|
|
|
|
| 145 |
|
| 146 |
def run_episode(
|
| 147 |
scenario_path: str,
|
| 148 |
+
agent_fn: "AgentFn | Controller" = scripted_explore_agent,
|
| 149 |
max_turns: int = 40,
|
| 150 |
seed: int = 0,
|
| 151 |
pool: RustEnvPool | None = None,
|
| 152 |
) -> EpisodeResult:
|
| 153 |
+
"""Run a scenario for a fixed number of turns. `agent_fn` may be a
|
| 154 |
+
bare `agent_fn(render_state, Command) -> [Command]` callable or any
|
| 155 |
+
`Controller`; it is coerced through `as_controller()`."""
|
| 156 |
owns_pool = pool is None
|
| 157 |
if pool is None:
|
| 158 |
pool = RustEnvPool(size=1, scenario_path=scenario_path)
|
|
|
|
| 161 |
adapter = RustObsAdapter()
|
| 162 |
obs = env.reset(seed=seed)
|
| 163 |
adapter.observe(obs)
|
| 164 |
+
controller = as_controller(agent_fn)
|
| 165 |
+
controller.reset(
|
| 166 |
+
EpisodeContext(seed=seed, max_turns=max_turns)
|
| 167 |
+
)
|
| 168 |
trace: list[dict] = []
|
| 169 |
turns = 0
|
| 170 |
issued = warned = 0
|
| 171 |
for turns in range(1, max_turns + 1):
|
| 172 |
rs = adapter.render_state()
|
| 173 |
+
cmds = controller.act(rs, env.Command) or [env.Command.observe()]
|
| 174 |
obs, _reward, done, info = env.step(cmds)
|
| 175 |
adapter.observe(obs, done=done)
|
| 176 |
issued += len(cmds)
|
|
|
|
| 204 |
|
| 205 |
def run_level(
|
| 206 |
compiled: CompiledLevel,
|
| 207 |
+
agent_fn: "AgentFn | Controller" = scripted_explore_agent,
|
| 208 |
seed: int = 0,
|
| 209 |
playback=None,
|
| 210 |
) -> EpisodeResult:
|
| 211 |
"""Run one scenario-pack level, scoring against its declarative
|
| 212 |
win/fail conditions (checked every turn). Outcome maps to the
|
| 213 |
`reward_outcome` convention: win=1.0, draw=0.5, loss=0.0.
|
| 214 |
+
|
| 215 |
+
`agent_fn` may be a bare `agent_fn(render_state, Command) ->
|
| 216 |
+
[Command]` callable, a `ModelAgent` bound method, or any
|
| 217 |
+
`Controller`; it is coerced through `as_controller()`.
|
| 218 |
"""
|
| 219 |
if not compiled.map_supported:
|
| 220 |
raise RuntimeError(
|
|
|
|
| 227 |
try:
|
| 228 |
adapter = RustObsAdapter()
|
| 229 |
adapter.observe(env.reset(seed=seed))
|
| 230 |
+
# Coerce the policy through the unified Controller contract:
|
| 231 |
+
# a bare agent_fn, a ModelAgent bound method, or a Controller
|
| 232 |
+
# all resolve to a Controller the loop drives identically.
|
| 233 |
+
controller = as_controller(agent_fn)
|
| 234 |
+
controller.reset(
|
| 235 |
+
EpisodeContext(
|
| 236 |
+
pack_id=compiled.pack_id,
|
| 237 |
+
level=compiled.level,
|
| 238 |
+
seed=seed,
|
| 239 |
+
objective=compiled.scenario.description or "",
|
| 240 |
+
max_turns=compiled.max_turns,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
trace: list[dict] = []
|
| 244 |
outcome = "draw"
|
| 245 |
turns = 0
|
|
|
|
| 275 |
forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])}
|
| 276 |
for turns in range(1, compiled.max_turns + 1):
|
| 277 |
rs = adapter.render_state()
|
| 278 |
+
cmds = controller.act(rs, env.Command) or [env.Command.observe()]
|
| 279 |
for _cmd in cmds:
|
| 280 |
_tn = _cmd_tool_name(_cmd)
|
| 281 |
if _tn:
|
|
|
|
| 397 |
)
|
| 398 |
if playback is not None:
|
| 399 |
# Dump the full modelโenv transcript when the agent is a
|
| 400 |
+
# ModelAgent โ the Controller layer surfaces the underlying
|
| 401 |
+
# instance (bound-method __self__ or the Controller itself).
|
| 402 |
+
agent_obj = introspection_source(controller)
|
| 403 |
hist = getattr(agent_obj, "history", None)
|
| 404 |
if isinstance(hist, list):
|
| 405 |
playback.write_messages(hist)
|
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Phase 1 โ the unified Controller contract.
|
| 2 |
+
|
| 3 |
+
`openra_bench/controller.py` is the keystone of the human-labeling
|
| 4 |
+
machine and the 1v1 adversarial harness: LLM agents, human labelers and
|
| 5 |
+
scripted reference policies all implement one interface,
|
| 6 |
+
|
| 7 |
+
controller.act(observation, Command) -> [Command]
|
| 8 |
+
|
| 9 |
+
and `run_level` / `run_episode` drive any of them. This file pins:
|
| 10 |
+
|
| 11 |
+
* the coercion layer (`as_controller`) โ a bare `agent_fn`, a bound
|
| 12 |
+
method, and an existing Controller all resolve correctly, so the ~190
|
| 13 |
+
legacy test files that pass a bare function keep working;
|
| 14 |
+
* the introspection surface (`history` / `stats`) the playback writer
|
| 15 |
+
reads survives the coercion;
|
| 16 |
+
* `ModelAgent` structurally satisfies the contract;
|
| 17 |
+
* an end-to-end `run_level` smoke: the SAME scripted policy produces a
|
| 18 |
+
byte-identical outcome whether passed as a bare function or wrapped
|
| 19 |
+
in a Controller.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import pytest
|
| 25 |
+
|
| 26 |
+
from openra_bench.controller import (
|
| 27 |
+
BaseController,
|
| 28 |
+
Controller,
|
| 29 |
+
EpisodeContext,
|
| 30 |
+
FunctionController,
|
| 31 |
+
as_controller,
|
| 32 |
+
introspection_source,
|
| 33 |
+
is_controller,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# โโ Coercion (no engine needed) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _bare_policy(render_state, Command):
|
| 41 |
+
"""A legacy-shape agent_fn: ignore the world, just observe."""
|
| 42 |
+
return [("OBSERVE", id(render_state))]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_as_controller_wraps_a_bare_function():
|
| 46 |
+
c = as_controller(_bare_policy)
|
| 47 |
+
assert is_controller(c)
|
| 48 |
+
assert isinstance(c, FunctionController)
|
| 49 |
+
# The wrapper delegates verbatim.
|
| 50 |
+
out = c.act({"k": 1}, Command=None)
|
| 51 |
+
assert out[0][0] == "OBSERVE"
|
| 52 |
+
# Name defaults to the function's __name__.
|
| 53 |
+
assert c.name == "_bare_policy"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_as_controller_is_idempotent_on_a_controller():
|
| 57 |
+
c1 = as_controller(_bare_policy)
|
| 58 |
+
c2 = as_controller(c1)
|
| 59 |
+
assert c2 is c1, "coercing a Controller must return it unchanged"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_as_controller_rejects_non_callable():
|
| 63 |
+
with pytest.raises(TypeError):
|
| 64 |
+
as_controller(42)
|
| 65 |
+
with pytest.raises(TypeError):
|
| 66 |
+
as_controller("not a policy")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def test_as_controller_named_override():
|
| 70 |
+
c = as_controller(_bare_policy, name="custom")
|
| 71 |
+
assert c.name == "custom"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_is_controller_discriminates():
|
| 75 |
+
# A bare function is callable but NOT a Controller.
|
| 76 |
+
assert not is_controller(_bare_policy)
|
| 77 |
+
assert not is_controller(lambda rs, C: [])
|
| 78 |
+
# A FunctionController is.
|
| 79 |
+
assert is_controller(as_controller(_bare_policy))
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_base_controller_act_is_abstract():
|
| 83 |
+
b = BaseController(name="x")
|
| 84 |
+
assert b.name == "x"
|
| 85 |
+
assert b.history == [] and b.stats == {}
|
| 86 |
+
with pytest.raises(NotImplementedError):
|
| 87 |
+
b.act({}, Command=None)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_episode_context_defaults():
|
| 91 |
+
ctx = EpisodeContext()
|
| 92 |
+
assert ctx.side == "agent"
|
| 93 |
+
assert ctx.seed == 0 and ctx.max_turns == 0
|
| 94 |
+
assert ctx.extra == {}
|
| 95 |
+
ctx2 = EpisodeContext(pack_id="p", level="hard", side="enemy", seed=3)
|
| 96 |
+
assert (ctx2.pack_id, ctx2.level, ctx2.side, ctx2.seed) == (
|
| 97 |
+
"p", "hard", "enemy", 3
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# โโ Bound-method source recovery (the playback path) โโโโโโโโโโโโโโโโ
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class _FakeAgent:
|
| 105 |
+
"""Stand-in for ModelAgent: a bound `agent_fn` plus history/stats."""
|
| 106 |
+
|
| 107 |
+
def __init__(self):
|
| 108 |
+
self.history = [{"role": "system", "content": "hi"}]
|
| 109 |
+
self.stats = {"turns": 0}
|
| 110 |
+
|
| 111 |
+
def agent_fn(self, render_state, Command):
|
| 112 |
+
self.stats["turns"] += 1
|
| 113 |
+
return [("ACT", self.stats["turns"])]
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def test_bound_method_source_is_recovered_for_playback():
|
| 117 |
+
agent = _FakeAgent()
|
| 118 |
+
c = as_controller(agent.agent_fn)
|
| 119 |
+
assert is_controller(c)
|
| 120 |
+
# The bound instance is reachable so playback can dump history/stats.
|
| 121 |
+
assert c.source is agent
|
| 122 |
+
assert introspection_source(c) is agent
|
| 123 |
+
c.act({}, None)
|
| 124 |
+
assert agent.stats["turns"] == 1
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_introspection_source_falls_back_to_controller():
|
| 128 |
+
c = as_controller(_bare_policy) # plain function, no __self__
|
| 129 |
+
assert c.source is None
|
| 130 |
+
assert introspection_source(c) is c
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# โโ Subclassing BaseController โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class _CountingController(BaseController):
|
| 137 |
+
def __init__(self):
|
| 138 |
+
super().__init__(name="counter")
|
| 139 |
+
self.turns = 0
|
| 140 |
+
self.reset_calls = 0
|
| 141 |
+
|
| 142 |
+
def reset(self, ctx: EpisodeContext) -> None:
|
| 143 |
+
self.reset_calls += 1
|
| 144 |
+
self.last_ctx = ctx
|
| 145 |
+
|
| 146 |
+
def act(self, observation, Command):
|
| 147 |
+
self.turns += 1
|
| 148 |
+
return [("TURN", self.turns)]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_subclassed_controller_satisfies_contract():
|
| 152 |
+
c = _CountingController()
|
| 153 |
+
assert is_controller(c)
|
| 154 |
+
assert isinstance(c, Controller) # runtime_checkable structural
|
| 155 |
+
c.reset(EpisodeContext(pack_id="p", side="enemy"))
|
| 156 |
+
assert c.reset_calls == 1 and c.last_ctx.side == "enemy"
|
| 157 |
+
assert c.act({}, None) == [("TURN", 1)]
|
| 158 |
+
assert as_controller(c) is c
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# โโ ModelAgent structurally conforms โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def test_model_agent_class_exposes_controller_contract():
|
| 165 |
+
# Structural check on the class โ constructing a ModelAgent needs a
|
| 166 |
+
# provider, but the contract is method presence.
|
| 167 |
+
from openra_bench.agent import ModelAgent
|
| 168 |
+
|
| 169 |
+
for member in ("act", "reset", "agent_fn"):
|
| 170 |
+
assert callable(getattr(ModelAgent, member, None)), (
|
| 171 |
+
f"ModelAgent must expose {member}() for the Controller contract"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# โโ End-to-end: bare fn vs Controller produce identical runs โโโโโโโโ
|
| 176 |
+
|
| 177 |
+
pytest.importorskip("openra_train", reason="Rust env wheel not installed")
|
| 178 |
+
pytest.importorskip(
|
| 179 |
+
"openra_rl_training", reason="Rust env wheel not installed"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _stall(render_state, Command):
|
| 184 |
+
return [Command.observe()]
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _smallest_easy_pack():
|
| 188 |
+
"""The active pack with the fewest easy-level turns โ keeps the
|
| 189 |
+
end-to-end smoke fast and deterministic."""
|
| 190 |
+
from openra_bench.scenarios import load_pack
|
| 191 |
+
from openra_bench.scenarios.loader import PACKS_DIR, compile_level
|
| 192 |
+
|
| 193 |
+
best = None
|
| 194 |
+
for f in sorted(PACKS_DIR.glob("*.yaml")):
|
| 195 |
+
if f.name.startswith(("_", "TEMPLATE")):
|
| 196 |
+
continue
|
| 197 |
+
try:
|
| 198 |
+
pack = load_pack(f)
|
| 199 |
+
if pack.meta.status != "active" or "easy" not in pack.levels:
|
| 200 |
+
continue
|
| 201 |
+
c = compile_level(pack, "easy")
|
| 202 |
+
except Exception: # noqa: BLE001
|
| 203 |
+
continue
|
| 204 |
+
if not c.map_supported:
|
| 205 |
+
continue
|
| 206 |
+
if best is None or c.max_turns < best.max_turns:
|
| 207 |
+
best = c
|
| 208 |
+
return best
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def test_run_level_identical_for_bare_fn_and_controller():
|
| 212 |
+
"""The SAME scripted policy must yield a byte-identical EpisodeResult
|
| 213 |
+
whether passed as a bare agent_fn or wrapped in a Controller โ proof
|
| 214 |
+
the coercion layer is transparent."""
|
| 215 |
+
from openra_bench.eval_core import run_level
|
| 216 |
+
|
| 217 |
+
compiled = _smallest_easy_pack()
|
| 218 |
+
assert compiled is not None, "no runnable active pack found"
|
| 219 |
+
|
| 220 |
+
r_fn = run_level(compiled, _stall, seed=1)
|
| 221 |
+
r_ctrl = run_level(compiled, as_controller(_stall, name="stall"), seed=1)
|
| 222 |
+
|
| 223 |
+
assert r_fn.outcome == r_ctrl.outcome
|
| 224 |
+
assert r_fn.turns == r_ctrl.turns
|
| 225 |
+
assert r_fn.actions_issued == r_ctrl.actions_issued
|
| 226 |
+
assert r_fn.signals.game_tick == r_ctrl.signals.game_tick
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def test_run_level_drives_a_subclassed_controller():
|
| 230 |
+
"""A BaseController subclass โ the shape HumanController and the
|
| 231 |
+
scripted-bot wrapper will take โ runs end-to-end and its per-episode
|
| 232 |
+
`reset` hook fires with a populated EpisodeContext."""
|
| 233 |
+
from openra_bench.eval_core import run_level
|
| 234 |
+
|
| 235 |
+
compiled = _smallest_easy_pack()
|
| 236 |
+
assert compiled is not None
|
| 237 |
+
|
| 238 |
+
class _StallController(BaseController):
|
| 239 |
+
def __init__(self):
|
| 240 |
+
super().__init__(name="stall-ctrl")
|
| 241 |
+
self.acts = 0
|
| 242 |
+
self.ctx = None
|
| 243 |
+
|
| 244 |
+
def reset(self, ctx: EpisodeContext) -> None:
|
| 245 |
+
self.ctx = ctx
|
| 246 |
+
|
| 247 |
+
def act(self, observation, Command):
|
| 248 |
+
self.acts += 1
|
| 249 |
+
return [Command.observe()]
|
| 250 |
+
|
| 251 |
+
ctrl = _StallController()
|
| 252 |
+
res = run_level(compiled, ctrl, seed=1)
|
| 253 |
+
assert res.outcome in ("win", "loss", "draw")
|
| 254 |
+
assert ctrl.acts >= 1, "act() must have been called"
|
| 255 |
+
assert ctrl.ctx is not None and ctrl.ctx.pack_id == compiled.pack_id
|
| 256 |
+
assert ctrl.ctx.level == "easy" and ctrl.ctx.seed == 1
|