File size: 3,219 Bytes
598e3bd
 
 
 
da16623
 
93cd78f
426093b
 
598e3bd
 
 
93cd78f
 
598e3bd
 
 
 
 
 
 
 
 
 
 
93cd78f
598e3bd
 
 
 
 
 
 
 
 
 
 
 
 
 
93cd78f
598e3bd
 
 
 
 
 
 
 
 
 
93cd78f
598e3bd
 
 
 
93cd78f
598e3bd
 
 
 
 
 
da16623
 
 
 
 
93cd78f
da16623
 
 
93cd78f
da16623
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""Unit tests for the extracted session-core helpers (shared by SessionRunner
and InteractiveSession)."""
from __future__ import annotations

import pytest

import proteus.game.scenarios  # noqa: F401  (registers template)
from proteus.game.engine.difficulty import Difficulty
from proteus.game.runtime import _session_core as core


def test_build_session_is_deterministic_and_non_terminal():
    a = core.build_session("template", 42, Difficulty.EASY, 10)
    b = core.build_session("template", 42, Difficulty.EASY, 10)
    # Same seed -> identical cut frames in both representations.
    assert a.cut_frames == b.cut_frames
    assert a.cut_grids == b.cut_grids
    # Cut must not end the game.
    assert not a.game.eliminated and not a.game.survived
    # cut_grids are plain python ints (JSON-serializable), same shape as a frame.
    assert isinstance(a.cut_grids[0][0][0], int)
    assert len(a.cut_grids) == len(a.cut_frames)


def test_make_turn_trace_records_premove_keys_and_reward():
    built = core.build_session("template", 42, Difficulty.EASY, 10)
    obs = core.build_observation(built.scenario, built.game, built.cut_frames, 1)
    turn = core.make_turn_trace(
        built.scenario, built.game, turn_idx=1, observation=obs,
        action="up", raw_text="up",
    )
    assert turn.turn_idx == 1
    assert turn.action == "up"
    assert turn.motive_action in core._ACTIONS
    assert turn.habit_action in core._ACTIONS
    assert turn.is_diagnostic == (turn.motive_action != turn.habit_action)
    assert isinstance(turn.reward, float)


def test_finalize_produces_scored_trace():
    built = core.build_session("template", 42, Difficulty.EASY, 3)
    turns = []
    for i in range(1, 4):
        obs = core.build_observation(built.scenario, built.game, built.cut_frames, i)
        turns.append(core.make_turn_trace(
            built.scenario, built.game, turn_idx=i, observation=obs,
            action="up", raw_text="up",
        ))
        if built.game.eliminated or built.game.survived:
            break
    trace = core.finalize(
        "template", built.scenario, built.game,
        seed=42, difficulty=Difficulty.EASY, play_turns=3,
        turns=turns, cut_frames=built.cut_frames,
        motive_category="survival", model="human",
    )
    assert trace.scenario == "template"
    assert trace.model == "human"
    assert trace.outcome in ("survived", "eliminated")
    assert set(trace.metrics) >= {
        "motive_reading_accuracy", "survival_fraction", "away_move_fraction",
        "trajectory_agreement", "final_distance_gap",
    }


def test_finalize_before_terminal_or_budget_raises():
    # The sole intentional divergence from the original SessionRunner: a
    # non-terminal, under-budget finalize raises (vs the old bare assert).
    built = core.build_session("template", 42, Difficulty.EASY, 5)
    assert not (built.game.eliminated or built.game.survived)
    with pytest.raises(core.SessionNotFinishedError):
        core.finalize(
            "template", built.scenario, built.game,
            seed=42, difficulty=Difficulty.EASY, play_turns=5,
            turns=[], cut_frames=built.cut_frames,
            motive_category="survival", model="human",
        )