"""SessionRunner — the PROTEUS arena orchestrator. One session: build the scenario+game from (scenario, seed, difficulty), replay the scripted Cut pre-roll, hand control to the agent for up to ``play_turns`` turns (observation -> optional probe -> act -> apply -> record), and emit a scored SessionTrace. The per-session mechanics are shared with InteractiveSession via ``proteus.game.runtime._session_core`` (single source of truth) — including the CP7 memory pre-roll (prepended to the turn-1 observation) and the CP8 persona reference/regret bookkeeping, both threaded through the shared helpers. """ from __future__ import annotations from proteus.game.agents.base import Agent from proteus.game.engine.difficulty import Difficulty from proteus.game.runtime import _session_core as core from proteus.game.runtime.memory import MemoryCheckpoint from proteus.game.metrics.persona import PersonaWeights from proteus.game.runtime.trace import SessionTrace, TurnTrace class SessionRunner: """Run one motive_grid session end-to-end and return a SessionTrace. Args: scenario_name: Registered scenario (e.g. ``"template"``). agent: The :class:`Agent` that plays after the handover. difficulty: Difficulty band (controls Cut length). seed: Seed for the deterministic world/Cut. play_turns: Survival budget in played (post-Cut) turns. use_probe: Whether to ask the side-channel probe each turn. motive_category: Category label stored on the trace. memory: CP7 memory checkpoint shown at the handover (turn-1 observation). memory_ref: Path/ref of that checkpoint, recorded on the trace. persona: CP8 hidden-weight persona scoring the run (public id only on the trace; the model never sees the weights). """ def __init__( self, scenario_name: str, agent: Agent, *, difficulty: Difficulty = Difficulty.EASY, seed: int | None = None, play_turns: int = 15, use_probe: bool = True, motive_category: str = "survival", memory: MemoryCheckpoint | None = None, memory_ref: str | None = None, persona: PersonaWeights | None = None, ) -> None: self._scenario_name = scenario_name self._agent = agent self._difficulty = difficulty self._seed = seed self._play_turns = play_turns self._use_probe = use_probe self._motive_category = motive_category self._memory = memory self._memory_ref = memory_ref self._persona = persona def run(self) -> SessionTrace: """Build, replay the Cut, play N turns, and return the trace.""" built = core.build_session( self._scenario_name, self._seed, self._difficulty, self._play_turns, ) scenario, game = built.scenario, built.game system_prompt = scenario.rules_text + core._HANDOVER_FRAMING turns: list[TurnTrace] = [] memory = self._memory if self._memory is not None else built.default_memory for turn_idx in range(1, self._play_turns + 1): observation = core.build_observation( scenario, game, built.cut_frames, turn_idx, memory=memory, prior_actions=[t.action for t in turns], ) probe_fields: dict[str, object] = {} if self._use_probe: probe = self._agent.probe( observation, core._PROBE_QUESTION, system_prompt, ) probe_fields = dict( probe_q=core._PROBE_QUESTION, probe_a=probe.answer, probe_reasoning=probe.reasoning, probe_raw_text=probe.raw_text, probe_input_tokens=probe.input_tokens, probe_output_tokens=probe.output_tokens, probe_thinking_tokens=probe.thinking_tokens, ) result = self._agent.act(observation, list(scenario.action_set), system_prompt) turns.append(core.make_turn_trace( scenario, game, turn_idx=turn_idx, observation=observation, action=result.action, reasoning=result.reasoning, raw_text=result.raw_text, input_tokens=result.input_tokens, output_tokens=result.output_tokens, thinking_tokens=result.thinking_tokens, persona=self._persona, **probe_fields, )) if game.eliminated or game.survived: break return core.finalize( self._scenario_name, scenario, game, seed=self._seed, difficulty=self._difficulty, play_turns=self._play_turns, turns=turns, cut_frames=built.cut_frames, motive_category=self._motive_category, model=self._provider_model_name(), memory_ref=self._memory_ref, persona=self._persona, ) def _provider_model_name(self) -> str: # Relies on the agent exposing its provider as `_provider` (VanillaAgent's # convention); unknown agent types fall back to agent.name. provider = getattr(self._agent, "_provider", None) return getattr(provider, "model_name", self._agent.name)