File size: 1,207 Bytes
426093b
fbc6f0f
426093b
 
fbc6f0f
 
 
 
 
93cd78f
fbc6f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from proteus.game.agents import VanillaAgent
from proteus.providers import FakeProvider
from proteus.game.metrics.aggregate import aggregate_traces
from proteus.game.runtime.session import SessionRunner


def _trace(seed):
    agent = VanillaAgent(FakeProvider(["ACTION: up"]))
    return SessionRunner(
        "template", agent, seed=seed, play_turns=4, use_probe=False,
    ).run()


def test_aggregate_groups_by_model_and_difficulty():
    traces = [_trace(1), _trace(2)]
    groups = aggregate_traces(traces)
    # One group: (model="fake", difficulty="easy").
    assert ("fake", "easy") in groups
    g = groups[("fake", "easy")]
    assert g["n"] == 2
    # Every metric key is aggregated to a mean (float).
    assert "motive_reading_accuracy" in g["metrics"]
    assert isinstance(g["metrics"]["motive_reading_accuracy"], float)


def test_aggregate_means_are_correct():
    traces = [_trace(1), _trace(2)]
    groups = aggregate_traces(traces)
    g = groups[("fake", "easy")]
    expected = sum(t.metrics["mean_step_reward"] for t in traces) / 2
    assert abs(g["metrics"]["mean_step_reward"] - expected) < 1e-9


def test_aggregate_empty_returns_empty():
    assert aggregate_traces([]) == {}