Spaces:
Sleeping
Sleeping
feat(discovery): source available actions from scenario.action_set (interact reaches the agent)
11cd1de | """SessionRunner — the PROTEUS arena orchestrator. | |
| One session: build the scenario+game from (scenario, seed, difficulty), replay | |
| the scripted Cut pre-roll, hand control to the agent for up to ``play_turns`` | |
| turns (observation -> optional probe -> act -> apply -> record), and emit a | |
| scored SessionTrace. The per-session mechanics are shared with InteractiveSession | |
| via ``proteus.game.runtime._session_core`` (single source of truth) — including the | |
| CP7 memory pre-roll (prepended to the turn-1 observation) and the CP8 persona | |
| reference/regret bookkeeping, both threaded through the shared helpers. | |
| """ | |
| from __future__ import annotations | |
| from proteus.game.agents.base import Agent | |
| from proteus.game.engine.difficulty import Difficulty | |
| from proteus.game.runtime import _session_core as core | |
| from proteus.game.runtime.memory import MemoryCheckpoint | |
| from proteus.game.metrics.persona import PersonaWeights | |
| from proteus.game.runtime.trace import SessionTrace, TurnTrace | |
| class SessionRunner: | |
| """Run one motive_grid session end-to-end and return a SessionTrace. | |
| Args: | |
| scenario_name: Registered scenario (e.g. ``"template"``). | |
| agent: The :class:`Agent` that plays after the handover. | |
| difficulty: Difficulty band (controls Cut length). | |
| seed: Seed for the deterministic world/Cut. | |
| play_turns: Survival budget in played (post-Cut) turns. | |
| use_probe: Whether to ask the side-channel probe each turn. | |
| motive_category: Category label stored on the trace. | |
| memory: CP7 memory checkpoint shown at the handover (turn-1 observation). | |
| memory_ref: Path/ref of that checkpoint, recorded on the trace. | |
| persona: CP8 hidden-weight persona scoring the run (public id only on | |
| the trace; the model never sees the weights). | |
| """ | |
| def __init__( | |
| self, | |
| scenario_name: str, | |
| agent: Agent, | |
| *, | |
| difficulty: Difficulty = Difficulty.EASY, | |
| seed: int | None = None, | |
| play_turns: int = 15, | |
| use_probe: bool = True, | |
| motive_category: str = "survival", | |
| memory: MemoryCheckpoint | None = None, | |
| memory_ref: str | None = None, | |
| persona: PersonaWeights | None = None, | |
| ) -> None: | |
| self._scenario_name = scenario_name | |
| self._agent = agent | |
| self._difficulty = difficulty | |
| self._seed = seed | |
| self._play_turns = play_turns | |
| self._use_probe = use_probe | |
| self._motive_category = motive_category | |
| self._memory = memory | |
| self._memory_ref = memory_ref | |
| self._persona = persona | |
| def run(self) -> SessionTrace: | |
| """Build, replay the Cut, play N turns, and return the trace.""" | |
| built = core.build_session( | |
| self._scenario_name, self._seed, self._difficulty, self._play_turns, | |
| ) | |
| scenario, game = built.scenario, built.game | |
| system_prompt = scenario.rules_text + core._HANDOVER_FRAMING | |
| turns: list[TurnTrace] = [] | |
| memory = self._memory if self._memory is not None else built.default_memory | |
| for turn_idx in range(1, self._play_turns + 1): | |
| observation = core.build_observation( | |
| scenario, game, built.cut_frames, turn_idx, memory=memory, | |
| prior_actions=[t.action for t in turns], | |
| ) | |
| probe_fields: dict[str, object] = {} | |
| if self._use_probe: | |
| probe = self._agent.probe( | |
| observation, core._PROBE_QUESTION, system_prompt, | |
| ) | |
| probe_fields = dict( | |
| probe_q=core._PROBE_QUESTION, | |
| probe_a=probe.answer, | |
| probe_reasoning=probe.reasoning, | |
| probe_raw_text=probe.raw_text, | |
| probe_input_tokens=probe.input_tokens, | |
| probe_output_tokens=probe.output_tokens, | |
| probe_thinking_tokens=probe.thinking_tokens, | |
| ) | |
| result = self._agent.act(observation, list(scenario.action_set), system_prompt) | |
| turns.append(core.make_turn_trace( | |
| scenario, game, | |
| turn_idx=turn_idx, | |
| observation=observation, | |
| action=result.action, | |
| reasoning=result.reasoning, | |
| raw_text=result.raw_text, | |
| input_tokens=result.input_tokens, | |
| output_tokens=result.output_tokens, | |
| thinking_tokens=result.thinking_tokens, | |
| persona=self._persona, | |
| **probe_fields, | |
| )) | |
| if game.eliminated or game.survived: | |
| break | |
| return core.finalize( | |
| self._scenario_name, scenario, game, | |
| seed=self._seed, difficulty=self._difficulty, | |
| play_turns=self._play_turns, turns=turns, | |
| cut_frames=built.cut_frames, motive_category=self._motive_category, | |
| model=self._provider_model_name(), | |
| memory_ref=self._memory_ref, persona=self._persona, | |
| ) | |
| def _provider_model_name(self) -> str: | |
| # Relies on the agent exposing its provider as `_provider` (VanillaAgent's | |
| # convention); unknown agent types fall back to agent.name. | |
| provider = getattr(self._agent, "_provider", None) | |
| return getattr(provider, "model_name", self._agent.name) | |