File size: 3,550 Bytes
426093b
d53ff08
 
 
 
93cd78f
d53ff08
 
 
 
 
 
 
93cd78f
d53ff08
 
 
 
1195808
 
426093b
 
1195808
426093b
 
1195808
 
 
 
93cd78f
1195808
 
 
 
 
 
 
 
 
 
 
 
 
 
93cd78f
1195808
 
 
 
 
93cd78f
 
 
 
1195808
 
 
 
93cd78f
1195808
 
 
93cd78f
 
1195808
 
 
 
 
 
 
93cd78f
1195808
 
 
 
90b2ce8
1195808
93cd78f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from proteus.game.runtime.trace import SessionTrace


def test_session_trace_memory_ref_defaults_none_and_round_trips():
    t = SessionTrace(
        scenario="template", motive_category="survival", seed=42,
        difficulty="easy", model="demo", outcome="survived",
    )
    assert t.memory_ref is None
    t2 = SessionTrace.model_validate_json(t.model_dump_json())
    assert t2.memory_ref is None

    t3 = SessionTrace(
        scenario="template", motive_category="survival", seed=42,
        difficulty="easy", model="demo", outcome="survived",
        memory_ref="demo@FIXED",
    )
    assert SessionTrace.model_validate_json(t3.model_dump_json()).memory_ref == "demo@FIXED"


from proteus.game.agents import VanillaAgent
from proteus.game.engine.difficulty import Difficulty
from proteus.providers import FakeProvider
from proteus.game.runtime import SessionRunner
from proteus.game.runtime.memory import MemoryCheckpoint, MemoryTurn


def _memory() -> MemoryCheckpoint:
    return MemoryCheckpoint(
        model="demo", scenario="template", difficulty="easy", seed=42,
        created_at="FIXED",
        memory_turns=[
            MemoryTurn(turn_idx=1, frame_ascii="MEMFRAME-1", action="up",
                       reasoning="", focal_pos=(5, 3), predator_pos=(7, 3)),
            MemoryTurn(turn_idx=2, frame_ascii="MEMFRAME-2", action="down",
                       reasoning="", focal_pos=(5, 2), predator_pos=(6, 3)),
        ],
        outcome="survived", transparent_prompt="brief",
    )


def _runner(memory=None, memory_ref=None) -> SessionRunner:
    prov = FakeProvider(["ACTION: stay"] * 20, model_name="demo")
    return SessionRunner(
        "template", VanillaAgent(prov), difficulty=Difficulty.EASY,
        seed=42, play_turns=3, use_probe=False,
        memory=memory, memory_ref=memory_ref,
    )


def test_memory_injection_shows_explicit_memory_and_preserves_measurement():
    # NOTE: template provides a default (persona) memory, so the no-explicit-memory
    # baseline already carries a MEMORY block. The contrast here is therefore
    # "explicit fixture memory" vs "default memory", not "memory" vs "no memory".
    base = _runner().run()
    withmem = _runner(memory=_memory(), memory_ref="demo@FIXED").run()

    obs1 = withmem.turns[0].observation
    # the explicit memory block is present at turn 1
    assert "MEMORY" in obs1
    assert "MEMFRAME-1" in obs1 and "MEMFRAME-2" in obs1
    assert "you chose: up" in obs1
    # the explicit memory overrides the scenario default, so the observation differs
    assert obs1 != base.turns[0].observation

    # the scored game is identical: same answer keys, diagnostic, metrics
    assert [t.motive_action for t in withmem.turns] == [t.motive_action for t in base.turns]
    assert [t.habit_action for t in withmem.turns] == [t.habit_action for t in base.turns]
    assert [t.is_diagnostic for t in withmem.turns] == [t.is_diagnostic for t in base.turns]
    assert withmem.metrics == base.metrics

    # memory_ref recorded with explicit memory, None when falling back to default
    assert withmem.memory_ref == "demo@FIXED"
    assert base.memory_ref is None


def test_memory_is_shown_every_turn_for_auto_regressive_play():
    withmem = _runner(memory=_memory(), memory_ref="x").run()
    # Auto-regressive play: the handover memory is carried on turn 2+ as well
    # (not only turn 1), so a stateless agent never loses it mid-episode.
    assert "MEMFRAME-1" in withmem.turns[0].observation
    assert "MEMFRAME-1" in withmem.turns[1].observation