irregular6612's picture
feat(discovery): source available actions from scenario.action_set (interact reaches the agent)
11cd1de
Raw
History Blame Contribute Delete
5.38 kB
"""SessionRunner — the PROTEUS arena orchestrator.
One session: build the scenario+game from (scenario, seed, difficulty), replay
the scripted Cut pre-roll, hand control to the agent for up to ``play_turns``
turns (observation -> optional probe -> act -> apply -> record), and emit a
scored SessionTrace. The per-session mechanics are shared with InteractiveSession
via ``proteus.game.runtime._session_core`` (single source of truth) — including the
CP7 memory pre-roll (prepended to the turn-1 observation) and the CP8 persona
reference/regret bookkeeping, both threaded through the shared helpers.
"""
from __future__ import annotations
from proteus.game.agents.base import Agent
from proteus.game.engine.difficulty import Difficulty
from proteus.game.runtime import _session_core as core
from proteus.game.runtime.memory import MemoryCheckpoint
from proteus.game.metrics.persona import PersonaWeights
from proteus.game.runtime.trace import SessionTrace, TurnTrace
class SessionRunner:
"""Run one motive_grid session end-to-end and return a SessionTrace.
Args:
scenario_name: Registered scenario (e.g. ``"template"``).
agent: The :class:`Agent` that plays after the handover.
difficulty: Difficulty band (controls Cut length).
seed: Seed for the deterministic world/Cut.
play_turns: Survival budget in played (post-Cut) turns.
use_probe: Whether to ask the side-channel probe each turn.
motive_category: Category label stored on the trace.
memory: CP7 memory checkpoint shown at the handover (turn-1 observation).
memory_ref: Path/ref of that checkpoint, recorded on the trace.
persona: CP8 hidden-weight persona scoring the run (public id only on
the trace; the model never sees the weights).
"""
def __init__(
self,
scenario_name: str,
agent: Agent,
*,
difficulty: Difficulty = Difficulty.EASY,
seed: int | None = None,
play_turns: int = 15,
use_probe: bool = True,
motive_category: str = "survival",
memory: MemoryCheckpoint | None = None,
memory_ref: str | None = None,
persona: PersonaWeights | None = None,
) -> None:
self._scenario_name = scenario_name
self._agent = agent
self._difficulty = difficulty
self._seed = seed
self._play_turns = play_turns
self._use_probe = use_probe
self._motive_category = motive_category
self._memory = memory
self._memory_ref = memory_ref
self._persona = persona
def run(self) -> SessionTrace:
"""Build, replay the Cut, play N turns, and return the trace."""
built = core.build_session(
self._scenario_name, self._seed, self._difficulty, self._play_turns,
)
scenario, game = built.scenario, built.game
system_prompt = scenario.rules_text + core._HANDOVER_FRAMING
turns: list[TurnTrace] = []
memory = self._memory if self._memory is not None else built.default_memory
for turn_idx in range(1, self._play_turns + 1):
observation = core.build_observation(
scenario, game, built.cut_frames, turn_idx, memory=memory,
prior_actions=[t.action for t in turns],
)
probe_fields: dict[str, object] = {}
if self._use_probe:
probe = self._agent.probe(
observation, core._PROBE_QUESTION, system_prompt,
)
probe_fields = dict(
probe_q=core._PROBE_QUESTION,
probe_a=probe.answer,
probe_reasoning=probe.reasoning,
probe_raw_text=probe.raw_text,
probe_input_tokens=probe.input_tokens,
probe_output_tokens=probe.output_tokens,
probe_thinking_tokens=probe.thinking_tokens,
)
result = self._agent.act(observation, list(scenario.action_set), system_prompt)
turns.append(core.make_turn_trace(
scenario, game,
turn_idx=turn_idx,
observation=observation,
action=result.action,
reasoning=result.reasoning,
raw_text=result.raw_text,
input_tokens=result.input_tokens,
output_tokens=result.output_tokens,
thinking_tokens=result.thinking_tokens,
persona=self._persona,
**probe_fields,
))
if game.eliminated or game.survived:
break
return core.finalize(
self._scenario_name, scenario, game,
seed=self._seed, difficulty=self._difficulty,
play_turns=self._play_turns, turns=turns,
cut_frames=built.cut_frames, motive_category=self._motive_category,
model=self._provider_model_name(),
memory_ref=self._memory_ref, persona=self._persona,
)
def _provider_model_name(self) -> str:
# Relies on the agent exposing its provider as `_provider` (VanillaAgent's
# convention); unknown agent types fall back to agent.name.
provider = getattr(self._agent, "_provider", None)
return getattr(provider, "model_name", self._agent.name)