Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pytest | |
| from proteus.game.agents import VanillaAgent | |
| from proteus.game.engine.difficulty import Difficulty | |
| from proteus.game.scenarios.base import get_scenario | |
| from proteus.providers import FakeProvider | |
| from proteus.game.runtime.session import SessionRunner | |
| from proteus.game.viz import FrameStep, TraceReconstructionError, reconstruct | |
| def _make_trace(seed=42, turns=5, action="ACTION: up"): | |
| agent = VanillaAgent(FakeProvider([action])) | |
| return SessionRunner( | |
| "template", agent, seed=seed, play_turns=turns, use_probe=False, | |
| ).run() | |
| def test_reconstruct_frame_count_matches_cut_plus_play(): | |
| trace = _make_trace() | |
| steps = reconstruct(trace) | |
| cut_len = get_scenario("template")().cut_length(Difficulty.EASY) | |
| # initial Cut frame + cut_len Cut steps + one frame per played turn. | |
| assert len(steps) == (cut_len + 1) + len(trace.turns) | |
| assert all(isinstance(s, FrameStep) for s in steps) | |
| assert isinstance(steps[0].frame, np.ndarray) | |
| def test_reconstruct_phases_and_terminal_flag(): | |
| trace = _make_trace() | |
| steps = reconstruct(trace) | |
| assert steps[0].meta.phase == "cut" | |
| assert steps[-1].meta.phase == "play" | |
| # The final frame carries the session outcome iff the game actually ended. | |
| if trace.outcome in ("eliminated", "survived") and ( | |
| len(trace.turns) < 5 or trace.outcome == "survived" | |
| ): | |
| assert steps[-1].meta.terminal == trace.outcome | |
| def test_reconstruct_carries_turn_metadata(): | |
| trace = _make_trace() | |
| steps = reconstruct(trace) | |
| play = [s for s in steps if s.meta.phase == "play"] | |
| assert len(play) == len(trace.turns) | |
| first = play[0].meta | |
| assert first.turn_idx == trace.turns[0].turn_idx | |
| assert first.action == trace.turns[0].action | |
| assert first.motive_action == trace.turns[0].motive_action | |
| assert first.habit_action == trace.turns[0].habit_action | |
| def test_reconstruct_raises_on_corrupt_positions(): | |
| trace = _make_trace() | |
| bad_turn = trace.turns[0].model_copy(update={"focal_pos": (99, 99)}) | |
| corrupt = trace.model_copy(update={"turns": [bad_turn] + list(trace.turns[1:])}) | |
| with pytest.raises(TraceReconstructionError): | |
| reconstruct(corrupt) | |
| def test_reconstruct_raises_on_corrupt_cut_frame(): | |
| trace = _make_trace() | |
| corrupt = trace.model_copy( | |
| update={"cut_frames": ["GARBAGE"] + list(trace.cut_frames[1:])} | |
| ) | |
| with pytest.raises(TraceReconstructionError): | |
| reconstruct(corrupt) | |