AgentnessBench / tests /viz /test_reconstruct.py
irregular6612's picture
refactor(scenario): delete predator_evade; template is the canonical scenario
93cd78f
Raw
History Blame Contribute Delete
2.49 kB
import numpy as np
import pytest
from proteus.game.agents import VanillaAgent
from proteus.game.engine.difficulty import Difficulty
from proteus.game.scenarios.base import get_scenario
from proteus.providers import FakeProvider
from proteus.game.runtime.session import SessionRunner
from proteus.game.viz import FrameStep, TraceReconstructionError, reconstruct
def _make_trace(seed=42, turns=5, action="ACTION: up"):
agent = VanillaAgent(FakeProvider([action]))
return SessionRunner(
"template", agent, seed=seed, play_turns=turns, use_probe=False,
).run()
def test_reconstruct_frame_count_matches_cut_plus_play():
trace = _make_trace()
steps = reconstruct(trace)
cut_len = get_scenario("template")().cut_length(Difficulty.EASY)
# initial Cut frame + cut_len Cut steps + one frame per played turn.
assert len(steps) == (cut_len + 1) + len(trace.turns)
assert all(isinstance(s, FrameStep) for s in steps)
assert isinstance(steps[0].frame, np.ndarray)
def test_reconstruct_phases_and_terminal_flag():
trace = _make_trace()
steps = reconstruct(trace)
assert steps[0].meta.phase == "cut"
assert steps[-1].meta.phase == "play"
# The final frame carries the session outcome iff the game actually ended.
if trace.outcome in ("eliminated", "survived") and (
len(trace.turns) < 5 or trace.outcome == "survived"
):
assert steps[-1].meta.terminal == trace.outcome
def test_reconstruct_carries_turn_metadata():
trace = _make_trace()
steps = reconstruct(trace)
play = [s for s in steps if s.meta.phase == "play"]
assert len(play) == len(trace.turns)
first = play[0].meta
assert first.turn_idx == trace.turns[0].turn_idx
assert first.action == trace.turns[0].action
assert first.motive_action == trace.turns[0].motive_action
assert first.habit_action == trace.turns[0].habit_action
def test_reconstruct_raises_on_corrupt_positions():
trace = _make_trace()
bad_turn = trace.turns[0].model_copy(update={"focal_pos": (99, 99)})
corrupt = trace.model_copy(update={"turns": [bad_turn] + list(trace.turns[1:])})
with pytest.raises(TraceReconstructionError):
reconstruct(corrupt)
def test_reconstruct_raises_on_corrupt_cut_frame():
trace = _make_trace()
corrupt = trace.model_copy(
update={"cut_frames": ["GARBAGE"] + list(trace.cut_frames[1:])}
)
with pytest.raises(TraceReconstructionError):
reconstruct(corrupt)