yxc20098 commited on
Commit
c68e036
ยท
1 Parent(s): 248d766

Phase 1: unified Controller interface for the eval stack

Browse files

Introduce openra_bench/controller.py โ€” the keystone for the
human-labeling machine and the 1v1 adversarial harness. Every policy
backend (LLM agent, human labeler, scripted reference) now implements
one contract:

controller.act(observation, Command) -> [Command]

- Controller protocol + BaseController + FunctionController.
- as_controller() coerces a bare agent_fn, a ModelAgent bound method,
or an existing Controller โ€” so all ~190 legacy test files that pass
a bare function keep working unchanged.
- EpisodeContext carries per-episode info (pack/level/seed/side) to
reset(); the 'side' field makes the interface 1v1-ready.
- run_level / run_episode drive any Controller; introspection_source()
recovers the underlying object's history/stats for playback.
- ModelAgent now satisfies the contract directly (name/reset/act).

tests/test_controller.py: 13 tests โ€” coercion, idempotency, bound-method
source recovery, abstract act(), ModelAgent conformance, and an
end-to-end run_level smoke proving a bare fn and its Controller wrapper
produce byte-identical EpisodeResults.

openra_bench/agent.py CHANGED
@@ -486,6 +486,11 @@ class ModelAgent:
486
  )
487
  self.history: list[dict] = [{"role": "system", "content": sys_content}]
488
  self.stats = {"turns": 0, "tool_calls": 0, "empty_replies": 0}
 
 
 
 
 
489
 
490
  def _user_message(self, render_state: dict) -> dict:
491
  # Briefing = vendored training briefing_v2 (one unit/line,
@@ -614,3 +619,16 @@ class ModelAgent:
614
  {"role": "tool", "tool_call_id": f"c{i}", "content": "ok"}
615
  )
616
  return cmds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  )
487
  self.history: list[dict] = [{"role": "system", "content": sys_content}]
488
  self.stats = {"turns": 0, "tool_calls": 0, "empty_replies": 0}
489
+ # Controller contract (openra_bench/controller.py): a ModelAgent
490
+ # IS a Controller โ€” it exposes `name`, `reset`, `act` so the
491
+ # eval loop, the 1v1 harness, and the human-labeling harness can
492
+ # all drive it interchangeably with any other policy backend.
493
+ self.name = getattr(cfg, "model", None) or "model"
494
 
495
  def _user_message(self, render_state: dict) -> dict:
496
  # Briefing = vendored training briefing_v2 (one unit/line,
 
619
  {"role": "tool", "tool_call_id": f"c{i}", "content": "ok"}
620
  )
621
  return cmds
622
+
623
+ # โ”€โ”€ Controller contract โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
624
+ def act(self, observation: dict, Command: Any) -> list:
625
+ """Controller contract โ€” alias of `agent_fn`. Lets a ModelAgent
626
+ be passed straight to `run_level` / the 1v1 harness in place of
627
+ a bare `agent_fn` callable."""
628
+ return self.agent_fn(observation, Command)
629
+
630
+ def reset(self, ctx: Any = None) -> None:
631
+ """Controller contract per-episode hook. A ModelAgent is
632
+ constructed once per episode โ€” its bounded chat history starts
633
+ fresh in `__init__` โ€” so reset is a no-op; it exists so the
634
+ agent structurally satisfies the Controller protocol."""
openra_bench/controller.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified policy interface for the OpenRA-Bench eval stack.
2
+
3
+ Every actor that can drive a side of a scenario โ€” an LLM agent, a human
4
+ labeler, a scripted reference policy โ€” implements the same contract:
5
+
6
+ controller.act(observation, Command) -> list[Command]
7
+
8
+ This is the keystone of the human-labeling machine and the 1v1
9
+ adversarial harness: one harness, interchangeable policy backends.
10
+ `run_level` / `run_episode` drive a single Controller; a 1v1 match
11
+ drives two, one per side, each fed its own side-specific observation.
12
+
13
+ Back-compat is non-negotiable: the historical policy shape was a bare
14
+ callable ``agent_fn(render_state, Command) -> [Command]`` and ~190 test
15
+ files still pass one. `as_controller()` adapts any such callable (or a
16
+ `ModelAgent` bound method) into a Controller, so every existing scripted
17
+ policy and test keeps working unchanged โ€” the eval loop simply coerces
18
+ its policy argument through `as_controller()` before stepping.
19
+
20
+ Design notes
21
+ ------------
22
+ * `act` keeps `Command` as an explicit parameter rather than binding it
23
+ at construction. `Command` is the pyo3 `openra_train.Command` factory
24
+ handle, only available once an env exists; threading it per-call keeps
25
+ Controllers constructible without an engine (cheap to unit-test) and
26
+ is byte-identical to the legacy `agent_fn` signature.
27
+ * `reset(ctx)` is the per-episode lifecycle hook. Scripted policies
28
+ ignore it; the model agent re-arms history; a human controller would
29
+ reset its click queue. The 1v1 harness calls it once per side with a
30
+ `side`-stamped `EpisodeContext`.
31
+ * `history` / `stats` are the optional introspection surface the
32
+ playback writer reads. `BaseController` provides empty defaults so a
33
+ caller can read them unconditionally.
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ from dataclasses import dataclass, field
39
+ from typing import Any, Callable, Protocol, runtime_checkable
40
+
41
+ # A bare legacy policy: (render_state, Command) -> [Command].
42
+ PolicyFn = Callable[[dict, Any], list]
43
+
44
+
45
+ @dataclass
46
+ class EpisodeContext:
47
+ """What a Controller is told once, at episode start (`reset`).
48
+
49
+ A scenario eval populates `pack_id` / `level` / `seed` / `objective`;
50
+ a 1v1 match additionally stamps `side` so the two Controllers know
51
+ which colour they are driving."""
52
+
53
+ pack_id: str = ""
54
+ level: str = ""
55
+ seed: int = 0
56
+ side: str = "agent" # "agent" | "enemy" โ€” which side this drives
57
+ objective: str = ""
58
+ max_turns: int = 0
59
+ extra: dict = field(default_factory=dict)
60
+
61
+
62
+ @runtime_checkable
63
+ class Controller(Protocol):
64
+ """A policy that observes the world and emits engine Commands.
65
+
66
+ Structural โ€” anything exposing `name`, `reset`, and `act` satisfies
67
+ it; `ModelAgent` does so without importing this module."""
68
+
69
+ name: str
70
+
71
+ def reset(self, ctx: "EpisodeContext") -> None: ...
72
+
73
+ def act(self, observation: dict, Command: Any) -> list: ...
74
+
75
+
76
+ def is_controller(obj: Any) -> bool:
77
+ """True if `obj` already satisfies the Controller contract.
78
+
79
+ Deliberately structural and stricter than `isinstance(obj,
80
+ Controller)`: a bare function is callable but is NOT a Controller,
81
+ so it must carry callable `act` AND `reset` attributes โ€” which a
82
+ plain function never does."""
83
+ return callable(getattr(obj, "act", None)) and callable(
84
+ getattr(obj, "reset", None)
85
+ )
86
+
87
+
88
+ class BaseController:
89
+ """Convenience base: a no-op `reset`, a `name`, empty introspection.
90
+
91
+ Subclass and implement `act`. Concrete eval policies (the human
92
+ bridge, scripted reference wrappers) derive from this so they share
93
+ one introspection surface (`history`, `stats`)."""
94
+
95
+ name: str = "controller"
96
+
97
+ def __init__(self, name: str | None = None) -> None:
98
+ if name:
99
+ self.name = name
100
+ self.history: list[dict] = []
101
+ self.stats: dict[str, Any] = {}
102
+
103
+ def reset(self, ctx: EpisodeContext) -> None: # noqa: D401
104
+ """Per-episode lifecycle hook. Default: no-op."""
105
+
106
+ def act(self, observation: dict, Command: Any) -> list:
107
+ raise NotImplementedError(
108
+ f"{type(self).__name__} must implement act()"
109
+ )
110
+
111
+
112
+ class FunctionController(BaseController):
113
+ """Adapt a bare ``agent_fn(render_state, Command) -> [Command]``
114
+ callable into a Controller โ€” the back-compat bridge for every
115
+ scripted reference policy and the legacy `scripted_explore_agent`.
116
+
117
+ When the callable is a bound method (e.g. ``ModelAgent.agent_fn``),
118
+ its ``__self__`` is captured as `source` so the eval loop can still
119
+ reach the underlying object's `history` / `stats` for playback."""
120
+
121
+ def __init__(
122
+ self, fn: PolicyFn, name: str | None = None
123
+ ) -> None:
124
+ super().__init__(
125
+ name or getattr(fn, "__name__", None) or "fn"
126
+ )
127
+ self._fn = fn
128
+ self.source: Any = getattr(fn, "__self__", None)
129
+
130
+ def act(self, observation: dict, Command: Any) -> list:
131
+ return self._fn(observation, Command)
132
+
133
+
134
+ def as_controller(policy: Any, name: str | None = None) -> Controller:
135
+ """Coerce anything policy-shaped into a Controller.
136
+
137
+ Accepts, in priority order:
138
+ * an object already satisfying the Controller contract โ€” returned
139
+ as-is (idempotent);
140
+ * any callable โ€” a bare `agent_fn` or a bound method โ€” wrapped in
141
+ a `FunctionController` (a bound method's `__self__` is kept
142
+ reachable via `.source`).
143
+
144
+ Raises `TypeError` for anything else."""
145
+ if is_controller(policy):
146
+ return policy
147
+ if callable(policy):
148
+ return FunctionController(policy, name)
149
+ raise TypeError(
150
+ f"cannot coerce {type(policy).__name__} into a Controller: "
151
+ "expected a Controller, a ModelAgent, or an "
152
+ "agent_fn(render_state, Command) -> [Command] callable"
153
+ )
154
+
155
+
156
+ def introspection_source(controller: Controller) -> Any:
157
+ """The object carrying `history` / `stats` for playback.
158
+
159
+ For a `FunctionController` wrapping a bound method this is the bound
160
+ instance (`.source`); otherwise it is the Controller itself."""
161
+ src = getattr(controller, "source", None)
162
+ return src if src is not None else controller
openra_bench/eval_core.py CHANGED
@@ -20,11 +20,20 @@ from typing import Any, Callable
20
  import yaml
21
  from openra_rl_training.training.rust_env_pool import RustEnvPool
22
 
 
 
 
 
 
 
23
  from .rust_adapter import EpisodeSignals, RustObsAdapter
24
  from .scenarios.schema import CompiledLevel
25
  from .scenarios.win_conditions import WinContext, evaluate
26
 
 
 
27
  AgentFn = Callable[[dict, Any], list]
 
28
 
29
 
30
  def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str:
@@ -136,11 +145,14 @@ def scripted_explore_agent(render_state: dict, Command: Any) -> list:
136
 
137
  def run_episode(
138
  scenario_path: str,
139
- agent_fn: AgentFn = scripted_explore_agent,
140
  max_turns: int = 40,
141
  seed: int = 0,
142
  pool: RustEnvPool | None = None,
143
  ) -> EpisodeResult:
 
 
 
144
  owns_pool = pool is None
145
  if pool is None:
146
  pool = RustEnvPool(size=1, scenario_path=scenario_path)
@@ -149,12 +161,16 @@ def run_episode(
149
  adapter = RustObsAdapter()
150
  obs = env.reset(seed=seed)
151
  adapter.observe(obs)
 
 
 
 
152
  trace: list[dict] = []
153
  turns = 0
154
  issued = warned = 0
155
  for turns in range(1, max_turns + 1):
156
  rs = adapter.render_state()
157
- cmds = agent_fn(rs, env.Command) or [env.Command.observe()]
158
  obs, _reward, done, info = env.step(cmds)
159
  adapter.observe(obs, done=done)
160
  issued += len(cmds)
@@ -188,13 +204,17 @@ def run_episode(
188
 
189
  def run_level(
190
  compiled: CompiledLevel,
191
- agent_fn: AgentFn = scripted_explore_agent,
192
  seed: int = 0,
193
  playback=None,
194
  ) -> EpisodeResult:
195
  """Run one scenario-pack level, scoring against its declarative
196
  win/fail conditions (checked every turn). Outcome maps to the
197
  `reward_outcome` convention: win=1.0, draw=0.5, loss=0.0.
 
 
 
 
198
  """
199
  if not compiled.map_supported:
200
  raise RuntimeError(
@@ -207,6 +227,19 @@ def run_level(
207
  try:
208
  adapter = RustObsAdapter()
209
  adapter.observe(env.reset(seed=seed))
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  trace: list[dict] = []
211
  outcome = "draw"
212
  turns = 0
@@ -242,7 +275,7 @@ def run_level(
242
  forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])}
243
  for turns in range(1, compiled.max_turns + 1):
244
  rs = adapter.render_state()
245
- cmds = agent_fn(rs, env.Command) or [env.Command.observe()]
246
  for _cmd in cmds:
247
  _tn = _cmd_tool_name(_cmd)
248
  if _tn:
@@ -364,8 +397,9 @@ def run_level(
364
  )
365
  if playback is not None:
366
  # Dump the full modelโ‡„env transcript when the agent is a
367
- # ModelAgent (bound-method closure exposes the instance).
368
- agent_obj = getattr(agent_fn, "__self__", None)
 
369
  hist = getattr(agent_obj, "history", None)
370
  if isinstance(hist, list):
371
  playback.write_messages(hist)
 
20
  import yaml
21
  from openra_rl_training.training.rust_env_pool import RustEnvPool
22
 
23
+ from .controller import (
24
+ Controller,
25
+ EpisodeContext,
26
+ as_controller,
27
+ introspection_source,
28
+ )
29
  from .rust_adapter import EpisodeSignals, RustObsAdapter
30
  from .scenarios.schema import CompiledLevel
31
  from .scenarios.win_conditions import WinContext, evaluate
32
 
33
+ # A policy is either a bare `agent_fn(render_state, Command) -> [Command]`
34
+ # callable (the legacy shape, still accepted everywhere) or a Controller.
35
  AgentFn = Callable[[dict, Any], list]
36
+ Policy = "AgentFn | Controller"
37
 
38
 
39
  def _scenario_to_tmp_yaml(compiled: CompiledLevel) -> str:
 
145
 
146
  def run_episode(
147
  scenario_path: str,
148
+ agent_fn: "AgentFn | Controller" = scripted_explore_agent,
149
  max_turns: int = 40,
150
  seed: int = 0,
151
  pool: RustEnvPool | None = None,
152
  ) -> EpisodeResult:
153
+ """Run a scenario for a fixed number of turns. `agent_fn` may be a
154
+ bare `agent_fn(render_state, Command) -> [Command]` callable or any
155
+ `Controller`; it is coerced through `as_controller()`."""
156
  owns_pool = pool is None
157
  if pool is None:
158
  pool = RustEnvPool(size=1, scenario_path=scenario_path)
 
161
  adapter = RustObsAdapter()
162
  obs = env.reset(seed=seed)
163
  adapter.observe(obs)
164
+ controller = as_controller(agent_fn)
165
+ controller.reset(
166
+ EpisodeContext(seed=seed, max_turns=max_turns)
167
+ )
168
  trace: list[dict] = []
169
  turns = 0
170
  issued = warned = 0
171
  for turns in range(1, max_turns + 1):
172
  rs = adapter.render_state()
173
+ cmds = controller.act(rs, env.Command) or [env.Command.observe()]
174
  obs, _reward, done, info = env.step(cmds)
175
  adapter.observe(obs, done=done)
176
  issued += len(cmds)
 
204
 
205
  def run_level(
206
  compiled: CompiledLevel,
207
+ agent_fn: "AgentFn | Controller" = scripted_explore_agent,
208
  seed: int = 0,
209
  playback=None,
210
  ) -> EpisodeResult:
211
  """Run one scenario-pack level, scoring against its declarative
212
  win/fail conditions (checked every turn). Outcome maps to the
213
  `reward_outcome` convention: win=1.0, draw=0.5, loss=0.0.
214
+
215
+ `agent_fn` may be a bare `agent_fn(render_state, Command) ->
216
+ [Command]` callable, a `ModelAgent` bound method, or any
217
+ `Controller`; it is coerced through `as_controller()`.
218
  """
219
  if not compiled.map_supported:
220
  raise RuntimeError(
 
227
  try:
228
  adapter = RustObsAdapter()
229
  adapter.observe(env.reset(seed=seed))
230
+ # Coerce the policy through the unified Controller contract:
231
+ # a bare agent_fn, a ModelAgent bound method, or a Controller
232
+ # all resolve to a Controller the loop drives identically.
233
+ controller = as_controller(agent_fn)
234
+ controller.reset(
235
+ EpisodeContext(
236
+ pack_id=compiled.pack_id,
237
+ level=compiled.level,
238
+ seed=seed,
239
+ objective=compiled.scenario.description or "",
240
+ max_turns=compiled.max_turns,
241
+ )
242
+ )
243
  trace: list[dict] = []
244
  outcome = "draw"
245
  turns = 0
 
275
  forbidden = {str(t).lower() for t in (compiled.forbidden_tools or [])}
276
  for turns in range(1, compiled.max_turns + 1):
277
  rs = adapter.render_state()
278
+ cmds = controller.act(rs, env.Command) or [env.Command.observe()]
279
  for _cmd in cmds:
280
  _tn = _cmd_tool_name(_cmd)
281
  if _tn:
 
397
  )
398
  if playback is not None:
399
  # Dump the full modelโ‡„env transcript when the agent is a
400
+ # ModelAgent โ€” the Controller layer surfaces the underlying
401
+ # instance (bound-method __self__ or the Controller itself).
402
+ agent_obj = introspection_source(controller)
403
  hist = getattr(agent_obj, "history", None)
404
  if isinstance(hist, list):
405
  playback.write_messages(hist)
tests/test_controller.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Phase 1 โ€” the unified Controller contract.
2
+
3
+ `openra_bench/controller.py` is the keystone of the human-labeling
4
+ machine and the 1v1 adversarial harness: LLM agents, human labelers and
5
+ scripted reference policies all implement one interface,
6
+
7
+ controller.act(observation, Command) -> [Command]
8
+
9
+ and `run_level` / `run_episode` drive any of them. This file pins:
10
+
11
+ * the coercion layer (`as_controller`) โ€” a bare `agent_fn`, a bound
12
+ method, and an existing Controller all resolve correctly, so the ~190
13
+ legacy test files that pass a bare function keep working;
14
+ * the introspection surface (`history` / `stats`) the playback writer
15
+ reads survives the coercion;
16
+ * `ModelAgent` structurally satisfies the contract;
17
+ * an end-to-end `run_level` smoke: the SAME scripted policy produces a
18
+ byte-identical outcome whether passed as a bare function or wrapped
19
+ in a Controller.
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import pytest
25
+
26
+ from openra_bench.controller import (
27
+ BaseController,
28
+ Controller,
29
+ EpisodeContext,
30
+ FunctionController,
31
+ as_controller,
32
+ introspection_source,
33
+ is_controller,
34
+ )
35
+
36
+
37
+ # โ”€โ”€ Coercion (no engine needed) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
38
+
39
+
40
+ def _bare_policy(render_state, Command):
41
+ """A legacy-shape agent_fn: ignore the world, just observe."""
42
+ return [("OBSERVE", id(render_state))]
43
+
44
+
45
+ def test_as_controller_wraps_a_bare_function():
46
+ c = as_controller(_bare_policy)
47
+ assert is_controller(c)
48
+ assert isinstance(c, FunctionController)
49
+ # The wrapper delegates verbatim.
50
+ out = c.act({"k": 1}, Command=None)
51
+ assert out[0][0] == "OBSERVE"
52
+ # Name defaults to the function's __name__.
53
+ assert c.name == "_bare_policy"
54
+
55
+
56
+ def test_as_controller_is_idempotent_on_a_controller():
57
+ c1 = as_controller(_bare_policy)
58
+ c2 = as_controller(c1)
59
+ assert c2 is c1, "coercing a Controller must return it unchanged"
60
+
61
+
62
+ def test_as_controller_rejects_non_callable():
63
+ with pytest.raises(TypeError):
64
+ as_controller(42)
65
+ with pytest.raises(TypeError):
66
+ as_controller("not a policy")
67
+
68
+
69
+ def test_as_controller_named_override():
70
+ c = as_controller(_bare_policy, name="custom")
71
+ assert c.name == "custom"
72
+
73
+
74
+ def test_is_controller_discriminates():
75
+ # A bare function is callable but NOT a Controller.
76
+ assert not is_controller(_bare_policy)
77
+ assert not is_controller(lambda rs, C: [])
78
+ # A FunctionController is.
79
+ assert is_controller(as_controller(_bare_policy))
80
+
81
+
82
+ def test_base_controller_act_is_abstract():
83
+ b = BaseController(name="x")
84
+ assert b.name == "x"
85
+ assert b.history == [] and b.stats == {}
86
+ with pytest.raises(NotImplementedError):
87
+ b.act({}, Command=None)
88
+
89
+
90
+ def test_episode_context_defaults():
91
+ ctx = EpisodeContext()
92
+ assert ctx.side == "agent"
93
+ assert ctx.seed == 0 and ctx.max_turns == 0
94
+ assert ctx.extra == {}
95
+ ctx2 = EpisodeContext(pack_id="p", level="hard", side="enemy", seed=3)
96
+ assert (ctx2.pack_id, ctx2.level, ctx2.side, ctx2.seed) == (
97
+ "p", "hard", "enemy", 3
98
+ )
99
+
100
+
101
+ # โ”€โ”€ Bound-method source recovery (the playback path) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
102
+
103
+
104
+ class _FakeAgent:
105
+ """Stand-in for ModelAgent: a bound `agent_fn` plus history/stats."""
106
+
107
+ def __init__(self):
108
+ self.history = [{"role": "system", "content": "hi"}]
109
+ self.stats = {"turns": 0}
110
+
111
+ def agent_fn(self, render_state, Command):
112
+ self.stats["turns"] += 1
113
+ return [("ACT", self.stats["turns"])]
114
+
115
+
116
+ def test_bound_method_source_is_recovered_for_playback():
117
+ agent = _FakeAgent()
118
+ c = as_controller(agent.agent_fn)
119
+ assert is_controller(c)
120
+ # The bound instance is reachable so playback can dump history/stats.
121
+ assert c.source is agent
122
+ assert introspection_source(c) is agent
123
+ c.act({}, None)
124
+ assert agent.stats["turns"] == 1
125
+
126
+
127
+ def test_introspection_source_falls_back_to_controller():
128
+ c = as_controller(_bare_policy) # plain function, no __self__
129
+ assert c.source is None
130
+ assert introspection_source(c) is c
131
+
132
+
133
+ # โ”€โ”€ Subclassing BaseController โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
134
+
135
+
136
+ class _CountingController(BaseController):
137
+ def __init__(self):
138
+ super().__init__(name="counter")
139
+ self.turns = 0
140
+ self.reset_calls = 0
141
+
142
+ def reset(self, ctx: EpisodeContext) -> None:
143
+ self.reset_calls += 1
144
+ self.last_ctx = ctx
145
+
146
+ def act(self, observation, Command):
147
+ self.turns += 1
148
+ return [("TURN", self.turns)]
149
+
150
+
151
+ def test_subclassed_controller_satisfies_contract():
152
+ c = _CountingController()
153
+ assert is_controller(c)
154
+ assert isinstance(c, Controller) # runtime_checkable structural
155
+ c.reset(EpisodeContext(pack_id="p", side="enemy"))
156
+ assert c.reset_calls == 1 and c.last_ctx.side == "enemy"
157
+ assert c.act({}, None) == [("TURN", 1)]
158
+ assert as_controller(c) is c
159
+
160
+
161
+ # โ”€โ”€ ModelAgent structurally conforms โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
162
+
163
+
164
+ def test_model_agent_class_exposes_controller_contract():
165
+ # Structural check on the class โ€” constructing a ModelAgent needs a
166
+ # provider, but the contract is method presence.
167
+ from openra_bench.agent import ModelAgent
168
+
169
+ for member in ("act", "reset", "agent_fn"):
170
+ assert callable(getattr(ModelAgent, member, None)), (
171
+ f"ModelAgent must expose {member}() for the Controller contract"
172
+ )
173
+
174
+
175
+ # โ”€โ”€ End-to-end: bare fn vs Controller produce identical runs โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
176
+
177
+ pytest.importorskip("openra_train", reason="Rust env wheel not installed")
178
+ pytest.importorskip(
179
+ "openra_rl_training", reason="Rust env wheel not installed"
180
+ )
181
+
182
+
183
+ def _stall(render_state, Command):
184
+ return [Command.observe()]
185
+
186
+
187
+ def _smallest_easy_pack():
188
+ """The active pack with the fewest easy-level turns โ€” keeps the
189
+ end-to-end smoke fast and deterministic."""
190
+ from openra_bench.scenarios import load_pack
191
+ from openra_bench.scenarios.loader import PACKS_DIR, compile_level
192
+
193
+ best = None
194
+ for f in sorted(PACKS_DIR.glob("*.yaml")):
195
+ if f.name.startswith(("_", "TEMPLATE")):
196
+ continue
197
+ try:
198
+ pack = load_pack(f)
199
+ if pack.meta.status != "active" or "easy" not in pack.levels:
200
+ continue
201
+ c = compile_level(pack, "easy")
202
+ except Exception: # noqa: BLE001
203
+ continue
204
+ if not c.map_supported:
205
+ continue
206
+ if best is None or c.max_turns < best.max_turns:
207
+ best = c
208
+ return best
209
+
210
+
211
+ def test_run_level_identical_for_bare_fn_and_controller():
212
+ """The SAME scripted policy must yield a byte-identical EpisodeResult
213
+ whether passed as a bare agent_fn or wrapped in a Controller โ€” proof
214
+ the coercion layer is transparent."""
215
+ from openra_bench.eval_core import run_level
216
+
217
+ compiled = _smallest_easy_pack()
218
+ assert compiled is not None, "no runnable active pack found"
219
+
220
+ r_fn = run_level(compiled, _stall, seed=1)
221
+ r_ctrl = run_level(compiled, as_controller(_stall, name="stall"), seed=1)
222
+
223
+ assert r_fn.outcome == r_ctrl.outcome
224
+ assert r_fn.turns == r_ctrl.turns
225
+ assert r_fn.actions_issued == r_ctrl.actions_issued
226
+ assert r_fn.signals.game_tick == r_ctrl.signals.game_tick
227
+
228
+
229
+ def test_run_level_drives_a_subclassed_controller():
230
+ """A BaseController subclass โ€” the shape HumanController and the
231
+ scripted-bot wrapper will take โ€” runs end-to-end and its per-episode
232
+ `reset` hook fires with a populated EpisodeContext."""
233
+ from openra_bench.eval_core import run_level
234
+
235
+ compiled = _smallest_easy_pack()
236
+ assert compiled is not None
237
+
238
+ class _StallController(BaseController):
239
+ def __init__(self):
240
+ super().__init__(name="stall-ctrl")
241
+ self.acts = 0
242
+ self.ctx = None
243
+
244
+ def reset(self, ctx: EpisodeContext) -> None:
245
+ self.ctx = ctx
246
+
247
+ def act(self, observation, Command):
248
+ self.acts += 1
249
+ return [Command.observe()]
250
+
251
+ ctrl = _StallController()
252
+ res = run_level(compiled, ctrl, seed=1)
253
+ assert res.outcome in ("win", "loss", "draw")
254
+ assert ctrl.acts >= 1, "act() must have been called"
255
+ assert ctrl.ctx is not None and ctrl.ctx.pack_id == compiled.pack_id
256
+ assert ctrl.ctx.level == "easy" and ctrl.ctx.seed == 1