Spaces:
Sleeping
Sleeping
feat(discovery): source available actions from scenario.action_set (interact reaches the agent)
11cd1de | """Shared session-core helpers. | |
| The per-session logic (build + Cut replay, one-turn TurnTrace construction, | |
| finalize -> scored SessionTrace) lives here so both SessionRunner (agent-driven, | |
| batch) and InteractiveSession (HTTP-driven, stepwise) produce *identical* traces | |
| from the same actions. This is a pure extraction of what used to be inline in | |
| SessionRunner -- no behavior change (pinned by the existing suite). The one | |
| intentional deviation: ``finalize`` raises ``SessionNotFinishedError`` where the | |
| original used a bare ``assert`` (which vanishes under ``python -O``); the | |
| terminal condition checked is identical. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from dataclasses import dataclass | |
| import numpy as np | |
| from proteus.game.engine.ascii_view import frame_to_ascii, legend_text | |
| from proteus.game.engine.difficulty import Difficulty | |
| from proteus.game.engine.grid import MotiveGridGame | |
| from proteus.game.scenarios.base import Scenario, get_scenario | |
| from proteus.game.runtime.memory import MemoryCheckpoint, render_memory_block | |
| from proteus.game.metrics.metrics import compute_metrics | |
| from proteus.game.metrics.persona import ( | |
| PersonaWeights, | |
| pressure as persona_pressure, | |
| reference_actions, | |
| reward_rw, | |
| ) | |
| from proteus.game.metrics.rollout import optimal_rollout | |
| from proteus.game.runtime.trace import SessionTrace, TurnTrace | |
| from proteus.game.agents.parsing import extract_self_belief | |
| _ACTIONS = ["up", "down", "left", "right", "stay"] | |
| _DIRECTIONS = {"up", "down", "left", "right"} | |
| # Predator-evasion-scoped for the current slice; spec defers scenario generalization. | |
| # A future multi-scenario runner should source this from the Scenario. | |
| _PROBE_QUESTION = ( | |
| "Where is the predator, and which direction should you move to stay safe?" | |
| ) | |
| _HANDOVER_FRAMING = ( | |
| "\n\nYou have been watching the game so far. From now on YOU control the " | |
| "focal agent. Continue playing to survive." | |
| ) | |
| class SessionFinishedError(RuntimeError): | |
| """Raised when stepping a session that has already reached a terminal state.""" | |
| class SessionNotFinishedError(RuntimeError): | |
| """Raised when finalizing a session that is not yet terminal / budget-done.""" | |
| class BuiltSession: | |
| """The product of build_session: the live scenario+game and the captured Cut. | |
| cut_frames are ASCII (for the trace + the agent observation); cut_grids are | |
| the same frames as integer palette grids (for the web color animation). | |
| """ | |
| scenario: Scenario | |
| game: MotiveGridGame | |
| cut_frames: list[str] | |
| cut_grids: list[list[list[int]]] | |
| default_memory: MemoryCheckpoint | None = None | |
| def grid_to_list(grid: np.ndarray) -> list[list[int]]: | |
| """A (h, w) palette array -> JSON-serializable list[list[int]].""" | |
| return [[int(v) for v in row] for row in grid] | |
| def render_ascii(scenario: Scenario, game: MotiveGridGame) -> str: | |
| """Render the live grid for the trace/observation via the scenario's hook.""" | |
| return scenario.render_frame(game) | |
| def build_session( | |
| scenario_name: str, | |
| seed: int | None, | |
| difficulty: Difficulty, | |
| play_turns: int, | |
| ) -> BuiltSession: | |
| """Build the scenario+game and replay the scripted Cut pre-roll. | |
| Behaviour-identical to SessionRunner._build_and_replay_cut, plus it also | |
| captures the per-frame integer palette grids for the web color animation. | |
| """ | |
| scenario = get_scenario(scenario_name)() | |
| rng = random.Random(seed) | |
| cut_length = scenario.cut_length(difficulty) | |
| game = MotiveGridGame( | |
| scenario, rng, difficulty, max_steps=cut_length + play_turns, | |
| ) | |
| cut_frames = [render_ascii(scenario, game)] | |
| cut_grids = [grid_to_list(game.current_grid())] | |
| for _ in range(cut_length): | |
| action = scenario.cut_focal_policy(game) | |
| game.apply_motive_action(action) | |
| scenario.record_focal_move(action) | |
| cut_frames.append(render_ascii(scenario, game)) | |
| cut_grids.append(grid_to_list(game.current_grid())) | |
| # The Cut pre-roll must not end the game; if it does, the scenario's | |
| # cut_focal_policy is buggy and any resulting trace would be corrupt. | |
| if game.eliminated or game.survived: | |
| raise RuntimeError( | |
| f"Game terminated during Cut replay of '{scenario_name}'. " | |
| "cut_focal_policy must not trigger elimination or survival." | |
| ) | |
| default_memory = scenario.default_memory(seed, difficulty) | |
| return BuiltSession(scenario, game, cut_frames, cut_grids, default_memory) | |
| def build_observation( | |
| scenario: Scenario, | |
| game: MotiveGridGame, | |
| cut_frames: list[str], | |
| turn_idx: int, | |
| memory: MemoryCheckpoint | None = None, | |
| prior_actions: list[str] | None = None, | |
| ) -> str: | |
| """The self-contained, auto-regressive observation the agent sees this turn. | |
| Each turn the agent is called statelessly, so the observation must carry the | |
| full context the model needs to continue its OWN trajectory: | |
| * the handover ``memory`` (the prior episode / persona demonstration), shown | |
| EVERY turn so the model never loses it after turn 1; | |
| * the scripted ``cut_frames`` pre-roll (the lead-up before it took control); | |
| * ``prior_actions`` — the moves the model has already committed THIS run, so | |
| it plays auto-regressively (it can see and maintain its own line of play); | |
| * the current grid (``"Now:"``). | |
| ``turn_idx`` is retained for the call signature; at turn 1 there are no | |
| prior_actions and the current grid is the handover state, so the observation | |
| matches the historical turn-1 layout. | |
| """ | |
| legend = scenario.legend() | |
| parts: list[str] = [] | |
| if memory is not None: | |
| parts.append(render_memory_block(memory)) | |
| parts.append("NOW — this run so far:") | |
| if cut_frames: | |
| for i, frame in enumerate(cut_frames[:-1], start=1): | |
| parts.append(f"Cut {i}:") | |
| parts.append(frame) | |
| if prior_actions: | |
| parts.append( | |
| "Your moves so far this run (most recent last): " | |
| + ", ".join(prior_actions) | |
| ) | |
| parts.append("Now:") | |
| parts.append(render_ascii(scenario, game)) | |
| parts.append(legend_text(legend)) | |
| parts.append(f"Available actions: [{', '.join(scenario.action_set)}]") | |
| return "\n".join(parts) | |
| def apply_action(scenario: Scenario, game: MotiveGridGame, action: str) -> bool: | |
| """Apply the action; return True if a directional move was blocked. | |
| Behaviour-identical to SessionRunner._apply. | |
| """ | |
| focal = game.focal_sprite | |
| pre = (focal.x, focal.y) if focal else None | |
| game.apply_motive_action(action) | |
| scenario.record_focal_move(action) | |
| moved = game.focal_sprite | |
| post = (moved.x, moved.y) if moved else None | |
| return action in _DIRECTIONS and post == pre | |
| def make_turn_trace( | |
| scenario: Scenario, | |
| game: MotiveGridGame, | |
| *, | |
| turn_idx: int, | |
| observation: str, | |
| action: str, | |
| reasoning: str = "", | |
| raw_text: str = "", | |
| input_tokens: int = 0, | |
| output_tokens: int = 0, | |
| thinking_tokens: int = 0, | |
| probe_q: str = "", | |
| probe_a: str = "", | |
| probe_reasoning: str = "", | |
| probe_raw_text: str = "", | |
| probe_input_tokens: int = 0, | |
| probe_output_tokens: int = 0, | |
| probe_thinking_tokens: int = 0, | |
| persona: PersonaWeights | None = None, | |
| ) -> TurnTrace: | |
| """Compute pre-move answer keys + positions, apply, score, build a TurnTrace. | |
| Behaviour-identical to the SessionRunner play-loop body: pre-move answer | |
| keys/positions are read BEFORE applying the action, the move is applied, | |
| then step_reward is scored against the pre-move positions. When a CP8 | |
| persona is supplied, the reference action set / reference reward / pressure | |
| are read pre-move and the model's own R_w + regret post-move (the model | |
| never sees the weights — only the public reference set + scalars are stored). | |
| """ | |
| # Pre-move answer keys + positions. | |
| optimal = scenario.optimal_action(game) | |
| habit = scenario.habit_action(game) | |
| focal = game.focal_sprite | |
| predator = game.predator_sprite | |
| focal_pos = (focal.x, focal.y) if focal else (-1, -1) | |
| predator_pos = (predator.x, predator.y) if predator else (-1, -1) | |
| pre_bfs = scenario.safety_distance(game) | |
| # CP8 persona: reference action set + reference reward + pressure are read | |
| # from the PRE-move state (the model never sees the weights). | |
| ref_acts = ref_reward = turn_pressure = None | |
| if persona is not None: | |
| ref_acts = reference_actions(persona, scenario, game) | |
| ref_reward = reward_rw( | |
| persona, scenario, game, focal_pos, predator_pos, ref_acts[0], | |
| ) | |
| turn_pressure = persona_pressure(scenario, game) | |
| blocked = apply_action(scenario, game, action) | |
| reward = scenario.step_reward( | |
| game, action, blocked, | |
| focal_before=focal_pos, predator_before=predator_pos, | |
| ) | |
| # CP8: post-move positions + pre/post BFS + chase-corrected delta. | |
| post_focal = game.focal_sprite | |
| post_predator = game.predator_sprite | |
| post_focal_pos = (post_focal.x, post_focal.y) if post_focal else None | |
| post_predator_pos = ( | |
| (post_predator.x, post_predator.y) if post_predator else None | |
| ) | |
| post_bfs = scenario.safety_distance(game) | |
| agent_distance_delta = scenario.agent_distance_delta( | |
| game, focal_pos, predator_pos | |
| ) | |
| # CP8 persona: the model's own R_w (BFS geometry is static, so the pre-move | |
| # positions + actual blocked status fully determine it) and its regret. | |
| model_reward = reward_regret = None | |
| if persona is not None: | |
| model_reward = reward_rw( | |
| persona, scenario, game, focal_pos, predator_pos, action, | |
| blocked=blocked, | |
| ) | |
| reward_regret = ref_reward - model_reward | |
| # Find-your-body discovery: parse the optional SELF: report from raw_text | |
| # and score it against the scenario's (hidden) true body index. No-op for | |
| # non-discovery scenarios (discovery_candidates() == 0). | |
| n_candidates = scenario.discovery_candidates() | |
| self_belief = extract_self_belief(raw_text, n_candidates) if n_candidates else None | |
| true_index = scenario.discovery_true_index() | |
| self_correct = ( | |
| (self_belief == true_index) | |
| if (self_belief is not None and true_index is not None) | |
| else None | |
| ) | |
| return TurnTrace( | |
| turn_idx=turn_idx, | |
| observation=observation, | |
| probe_q=probe_q, | |
| probe_a=probe_a, | |
| probe_reasoning=probe_reasoning, | |
| probe_raw_text=probe_raw_text, | |
| probe_input_tokens=probe_input_tokens, | |
| probe_output_tokens=probe_output_tokens, | |
| probe_thinking_tokens=probe_thinking_tokens, | |
| reasoning=reasoning, | |
| raw_text=raw_text, | |
| action=action, | |
| motive_action=optimal, | |
| habit_action=habit, | |
| is_diagnostic=(optimal != habit), | |
| was_congruent=(action == optimal), | |
| reward=reward, | |
| focal_pos=focal_pos, | |
| predator_pos=predator_pos, | |
| input_tokens=input_tokens, | |
| output_tokens=output_tokens, | |
| thinking_tokens=thinking_tokens, | |
| post_focal_pos=post_focal_pos, | |
| post_predator_pos=post_predator_pos, | |
| pre_bfs_distance=pre_bfs, | |
| post_bfs_distance=post_bfs, | |
| agent_distance_delta=agent_distance_delta, | |
| reference_actions=ref_acts, | |
| reference_reward=ref_reward, | |
| model_reward=model_reward, | |
| reward_regret=reward_regret, | |
| pressure=turn_pressure, | |
| self_belief=self_belief, | |
| self_correct=self_correct, | |
| ) | |
| def finalize( | |
| scenario_name: str, | |
| scenario: Scenario, | |
| game: MotiveGridGame, | |
| *, | |
| seed: int | None, | |
| difficulty: Difficulty, | |
| play_turns: int, | |
| turns: list[TurnTrace], | |
| cut_frames: list[str], | |
| motive_category: str, | |
| model: str, | |
| memory_ref: str | None = None, | |
| persona: PersonaWeights | None = None, | |
| ) -> SessionTrace: | |
| """Score the played turns and assemble the SessionTrace. | |
| Behaviour-identical to the SessionRunner finalize block: requires a terminal | |
| state or budget exhaustion, then scores against the optimal rollout. The CP7 | |
| ``memory_ref`` and CP8 ``persona`` (its public id) are recorded on the trace. | |
| """ | |
| if not (game.eliminated or game.survived or len(turns) == play_turns): | |
| raise SessionNotFinishedError( | |
| "finalize called before a terminal state or budget exhaustion." | |
| ) | |
| outcome = "eliminated" if game.eliminated else "survived" | |
| rollout = optimal_rollout(scenario_name, seed, difficulty, len(turns)) | |
| realized_final_safety = scenario.safety_distance(game) | |
| metrics = compute_metrics( | |
| turns, | |
| played_turns=len(turns), | |
| play_turns=play_turns, | |
| outcome=outcome, | |
| optimal_focal_positions=rollout.focal_positions, | |
| realized_final_safety=realized_final_safety, | |
| optimal_final_safety=rollout.final_safety_distance, | |
| max_bfs_distance=scenario.max_bfs_distance(game), | |
| ) | |
| return SessionTrace( | |
| scenario=scenario_name, | |
| motive_category=motive_category, | |
| seed=seed, | |
| difficulty=difficulty.value, | |
| model=model, | |
| cut_frames=list(cut_frames), | |
| turns=turns, | |
| outcome=outcome, | |
| metrics=metrics, | |
| memory_ref=memory_ref, | |
| horizon=play_turns, | |
| persona_weight_id=( | |
| persona.persona_weight_id if persona else None | |
| ), | |
| ) | |