from proteus.game.agents import VanillaAgent from proteus.providers import FakeProvider from proteus.game.metrics.aggregate import aggregate_traces from proteus.game.runtime.session import SessionRunner def _trace(seed): agent = VanillaAgent(FakeProvider(["ACTION: up"])) return SessionRunner( "template", agent, seed=seed, play_turns=4, use_probe=False, ).run() def test_aggregate_groups_by_model_and_difficulty(): traces = [_trace(1), _trace(2)] groups = aggregate_traces(traces) # One group: (model="fake", difficulty="easy"). assert ("fake", "easy") in groups g = groups[("fake", "easy")] assert g["n"] == 2 # Every metric key is aggregated to a mean (float). assert "motive_reading_accuracy" in g["metrics"] assert isinstance(g["metrics"]["motive_reading_accuracy"], float) def test_aggregate_means_are_correct(): traces = [_trace(1), _trace(2)] groups = aggregate_traces(traces) g = groups[("fake", "easy")] expected = sum(t.metrics["mean_step_reward"] for t in traces) / 2 assert abs(g["metrics"]["mean_step_reward"] - expected) < 1e-9 def test_aggregate_empty_returns_empty(): assert aggregate_traces([]) == {}