File size: 6,155 Bytes
c68e036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Unified policy interface for the OpenRA-Bench eval stack.

Every actor that can drive a side of a scenario โ€” an LLM agent, a human
labeler, a scripted reference policy โ€” implements the same contract:

    controller.act(observation, Command) -> list[Command]

This is the keystone of the human-labeling machine and the 1v1
adversarial harness: one harness, interchangeable policy backends.
`run_level` / `run_episode` drive a single Controller; a 1v1 match
drives two, one per side, each fed its own side-specific observation.

Back-compat is non-negotiable: the historical policy shape was a bare
callable ``agent_fn(render_state, Command) -> [Command]`` and ~190 test
files still pass one. `as_controller()` adapts any such callable (or a
`ModelAgent` bound method) into a Controller, so every existing scripted
policy and test keeps working unchanged โ€” the eval loop simply coerces
its policy argument through `as_controller()` before stepping.

Design notes
------------
* `act` keeps `Command` as an explicit parameter rather than binding it
  at construction. `Command` is the pyo3 `openra_train.Command` factory
  handle, only available once an env exists; threading it per-call keeps
  Controllers constructible without an engine (cheap to unit-test) and
  is byte-identical to the legacy `agent_fn` signature.
* `reset(ctx)` is the per-episode lifecycle hook. Scripted policies
  ignore it; the model agent re-arms history; a human controller would
  reset its click queue. The 1v1 harness calls it once per side with a
  `side`-stamped `EpisodeContext`.
* `history` / `stats` are the optional introspection surface the
  playback writer reads. `BaseController` provides empty defaults so a
  caller can read them unconditionally.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Protocol, runtime_checkable

# A bare legacy policy: (render_state, Command) -> [Command].
PolicyFn = Callable[[dict, Any], list]


@dataclass
class EpisodeContext:
    """What a Controller is told once, at episode start (`reset`).

    A scenario eval populates `pack_id` / `level` / `seed` / `objective`;
    a 1v1 match additionally stamps `side` so the two Controllers know
    which colour they are driving."""

    pack_id: str = ""
    level: str = ""
    seed: int = 0
    side: str = "agent"  # "agent" | "enemy" โ€” which side this drives
    objective: str = ""
    max_turns: int = 0
    extra: dict = field(default_factory=dict)


@runtime_checkable
class Controller(Protocol):
    """A policy that observes the world and emits engine Commands.

    Structural โ€” anything exposing `name`, `reset`, and `act` satisfies
    it; `ModelAgent` does so without importing this module."""

    name: str

    def reset(self, ctx: "EpisodeContext") -> None: ...

    def act(self, observation: dict, Command: Any) -> list: ...


def is_controller(obj: Any) -> bool:
    """True if `obj` already satisfies the Controller contract.

    Deliberately structural and stricter than `isinstance(obj,
    Controller)`: a bare function is callable but is NOT a Controller,
    so it must carry callable `act` AND `reset` attributes โ€” which a
    plain function never does."""
    return callable(getattr(obj, "act", None)) and callable(
        getattr(obj, "reset", None)
    )


class BaseController:
    """Convenience base: a no-op `reset`, a `name`, empty introspection.

    Subclass and implement `act`. Concrete eval policies (the human
    bridge, scripted reference wrappers) derive from this so they share
    one introspection surface (`history`, `stats`)."""

    name: str = "controller"

    def __init__(self, name: str | None = None) -> None:
        if name:
            self.name = name
        self.history: list[dict] = []
        self.stats: dict[str, Any] = {}

    def reset(self, ctx: EpisodeContext) -> None:  # noqa: D401
        """Per-episode lifecycle hook. Default: no-op."""

    def act(self, observation: dict, Command: Any) -> list:
        raise NotImplementedError(
            f"{type(self).__name__} must implement act()"
        )


class FunctionController(BaseController):
    """Adapt a bare ``agent_fn(render_state, Command) -> [Command]``
    callable into a Controller โ€” the back-compat bridge for every
    scripted reference policy and the legacy `scripted_explore_agent`.

    When the callable is a bound method (e.g. ``ModelAgent.agent_fn``),
    its ``__self__`` is captured as `source` so the eval loop can still
    reach the underlying object's `history` / `stats` for playback."""

    def __init__(
        self, fn: PolicyFn, name: str | None = None
    ) -> None:
        super().__init__(
            name or getattr(fn, "__name__", None) or "fn"
        )
        self._fn = fn
        self.source: Any = getattr(fn, "__self__", None)

    def act(self, observation: dict, Command: Any) -> list:
        return self._fn(observation, Command)


def as_controller(policy: Any, name: str | None = None) -> Controller:
    """Coerce anything policy-shaped into a Controller.

    Accepts, in priority order:
      * an object already satisfying the Controller contract โ€” returned
        as-is (idempotent);
      * any callable โ€” a bare `agent_fn` or a bound method โ€” wrapped in
        a `FunctionController` (a bound method's `__self__` is kept
        reachable via `.source`).

    Raises `TypeError` for anything else."""
    if is_controller(policy):
        return policy
    if callable(policy):
        return FunctionController(policy, name)
    raise TypeError(
        f"cannot coerce {type(policy).__name__} into a Controller: "
        "expected a Controller, a ModelAgent, or an "
        "agent_fn(render_state, Command) -> [Command] callable"
    )


def introspection_source(controller: Controller) -> Any:
    """The object carrying `history` / `stats` for playback.

    For a `FunctionController` wrapping a bound method this is the bound
    instance (`.source`); otherwise it is the Controller itself."""
    src = getattr(controller, "source", None)
    return src if src is not None else controller