Spaces:
Running
Running
File size: 6,155 Bytes
c68e036 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """Unified policy interface for the OpenRA-Bench eval stack.
Every actor that can drive a side of a scenario โ an LLM agent, a human
labeler, a scripted reference policy โ implements the same contract:
controller.act(observation, Command) -> list[Command]
This is the keystone of the human-labeling machine and the 1v1
adversarial harness: one harness, interchangeable policy backends.
`run_level` / `run_episode` drive a single Controller; a 1v1 match
drives two, one per side, each fed its own side-specific observation.
Back-compat is non-negotiable: the historical policy shape was a bare
callable ``agent_fn(render_state, Command) -> [Command]`` and ~190 test
files still pass one. `as_controller()` adapts any such callable (or a
`ModelAgent` bound method) into a Controller, so every existing scripted
policy and test keeps working unchanged โ the eval loop simply coerces
its policy argument through `as_controller()` before stepping.
Design notes
------------
* `act` keeps `Command` as an explicit parameter rather than binding it
at construction. `Command` is the pyo3 `openra_train.Command` factory
handle, only available once an env exists; threading it per-call keeps
Controllers constructible without an engine (cheap to unit-test) and
is byte-identical to the legacy `agent_fn` signature.
* `reset(ctx)` is the per-episode lifecycle hook. Scripted policies
ignore it; the model agent re-arms history; a human controller would
reset its click queue. The 1v1 harness calls it once per side with a
`side`-stamped `EpisodeContext`.
* `history` / `stats` are the optional introspection surface the
playback writer reads. `BaseController` provides empty defaults so a
caller can read them unconditionally.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Callable, Protocol, runtime_checkable
# A bare legacy policy: (render_state, Command) -> [Command].
PolicyFn = Callable[[dict, Any], list]
@dataclass
class EpisodeContext:
"""What a Controller is told once, at episode start (`reset`).
A scenario eval populates `pack_id` / `level` / `seed` / `objective`;
a 1v1 match additionally stamps `side` so the two Controllers know
which colour they are driving."""
pack_id: str = ""
level: str = ""
seed: int = 0
side: str = "agent" # "agent" | "enemy" โ which side this drives
objective: str = ""
max_turns: int = 0
extra: dict = field(default_factory=dict)
@runtime_checkable
class Controller(Protocol):
"""A policy that observes the world and emits engine Commands.
Structural โ anything exposing `name`, `reset`, and `act` satisfies
it; `ModelAgent` does so without importing this module."""
name: str
def reset(self, ctx: "EpisodeContext") -> None: ...
def act(self, observation: dict, Command: Any) -> list: ...
def is_controller(obj: Any) -> bool:
"""True if `obj` already satisfies the Controller contract.
Deliberately structural and stricter than `isinstance(obj,
Controller)`: a bare function is callable but is NOT a Controller,
so it must carry callable `act` AND `reset` attributes โ which a
plain function never does."""
return callable(getattr(obj, "act", None)) and callable(
getattr(obj, "reset", None)
)
class BaseController:
"""Convenience base: a no-op `reset`, a `name`, empty introspection.
Subclass and implement `act`. Concrete eval policies (the human
bridge, scripted reference wrappers) derive from this so they share
one introspection surface (`history`, `stats`)."""
name: str = "controller"
def __init__(self, name: str | None = None) -> None:
if name:
self.name = name
self.history: list[dict] = []
self.stats: dict[str, Any] = {}
def reset(self, ctx: EpisodeContext) -> None: # noqa: D401
"""Per-episode lifecycle hook. Default: no-op."""
def act(self, observation: dict, Command: Any) -> list:
raise NotImplementedError(
f"{type(self).__name__} must implement act()"
)
class FunctionController(BaseController):
"""Adapt a bare ``agent_fn(render_state, Command) -> [Command]``
callable into a Controller โ the back-compat bridge for every
scripted reference policy and the legacy `scripted_explore_agent`.
When the callable is a bound method (e.g. ``ModelAgent.agent_fn``),
its ``__self__`` is captured as `source` so the eval loop can still
reach the underlying object's `history` / `stats` for playback."""
def __init__(
self, fn: PolicyFn, name: str | None = None
) -> None:
super().__init__(
name or getattr(fn, "__name__", None) or "fn"
)
self._fn = fn
self.source: Any = getattr(fn, "__self__", None)
def act(self, observation: dict, Command: Any) -> list:
return self._fn(observation, Command)
def as_controller(policy: Any, name: str | None = None) -> Controller:
"""Coerce anything policy-shaped into a Controller.
Accepts, in priority order:
* an object already satisfying the Controller contract โ returned
as-is (idempotent);
* any callable โ a bare `agent_fn` or a bound method โ wrapped in
a `FunctionController` (a bound method's `__self__` is kept
reachable via `.source`).
Raises `TypeError` for anything else."""
if is_controller(policy):
return policy
if callable(policy):
return FunctionController(policy, name)
raise TypeError(
f"cannot coerce {type(policy).__name__} into a Controller: "
"expected a Controller, a ModelAgent, or an "
"agent_fn(render_state, Command) -> [Command] callable"
)
def introspection_source(controller: Controller) -> Any:
"""The object carrying `history` / `stats` for playback.
For a `FunctionController` wrapping a bound method this is the bound
instance (`.source`); otherwise it is the Controller itself."""
src = getattr(controller, "source", None)
return src if src is not None else controller
|