import numpy as np import pytest from proteus.game.agents import VanillaAgent from proteus.game.engine.difficulty import Difficulty from proteus.game.scenarios.base import get_scenario from proteus.providers import FakeProvider from proteus.game.runtime.session import SessionRunner from proteus.game.viz import FrameStep, TraceReconstructionError, reconstruct def _make_trace(seed=42, turns=5, action="ACTION: up"): agent = VanillaAgent(FakeProvider([action])) return SessionRunner( "template", agent, seed=seed, play_turns=turns, use_probe=False, ).run() def test_reconstruct_frame_count_matches_cut_plus_play(): trace = _make_trace() steps = reconstruct(trace) cut_len = get_scenario("template")().cut_length(Difficulty.EASY) # initial Cut frame + cut_len Cut steps + one frame per played turn. assert len(steps) == (cut_len + 1) + len(trace.turns) assert all(isinstance(s, FrameStep) for s in steps) assert isinstance(steps[0].frame, np.ndarray) def test_reconstruct_phases_and_terminal_flag(): trace = _make_trace() steps = reconstruct(trace) assert steps[0].meta.phase == "cut" assert steps[-1].meta.phase == "play" # The final frame carries the session outcome iff the game actually ended. if trace.outcome in ("eliminated", "survived") and ( len(trace.turns) < 5 or trace.outcome == "survived" ): assert steps[-1].meta.terminal == trace.outcome def test_reconstruct_carries_turn_metadata(): trace = _make_trace() steps = reconstruct(trace) play = [s for s in steps if s.meta.phase == "play"] assert len(play) == len(trace.turns) first = play[0].meta assert first.turn_idx == trace.turns[0].turn_idx assert first.action == trace.turns[0].action assert first.motive_action == trace.turns[0].motive_action assert first.habit_action == trace.turns[0].habit_action def test_reconstruct_raises_on_corrupt_positions(): trace = _make_trace() bad_turn = trace.turns[0].model_copy(update={"focal_pos": (99, 99)}) corrupt = trace.model_copy(update={"turns": [bad_turn] + list(trace.turns[1:])}) with pytest.raises(TraceReconstructionError): reconstruct(corrupt) def test_reconstruct_raises_on_corrupt_cut_frame(): trace = _make_trace() corrupt = trace.model_copy( update={"cut_frames": ["GARBAGE"] + list(trace.cut_frames[1:])} ) with pytest.raises(TraceReconstructionError): reconstruct(corrupt)