"""Tests for dashboard summary-stats helpers in utils/utils.py.""" import numpy as np import pytest from utils.utils import mean_rt_per_choice def _make_sim_output(rts, choices): """Build a sim-output-shaped dict with rts/choices arrays.""" rts = np.asarray(rts, dtype=np.float32).reshape(-1, 1, 1) choices = np.asarray(choices, dtype=np.float32).reshape(-1, 1, 1) return {"rts": rts, "choices": choices} def test_mean_rt_per_choice_basic_two_choice(): sim = _make_sim_output( rts=[0.4, 0.5, 0.6, 0.8, 1.0], choices=[1, 1, 1, -1, -1], ) out = mean_rt_per_choice(sim) assert set(out.keys()) == {-1, 1} assert out[1] == pytest.approx(0.5) assert out[-1] == pytest.approx(0.9) def test_mean_rt_per_choice_filters_deadline_timeouts(): sim = _make_sim_output( rts=[0.4, -999, 0.6, -999, 1.0], choices=[1, -999, 1, -999, -1], ) out = mean_rt_per_choice(sim) # Only valid samples count. assert out[1] == pytest.approx(0.5) assert out[-1] == pytest.approx(1.0) def test_mean_rt_per_choice_one_sided(): """All responses on one boundary — only that choice appears in the dict.""" sim = _make_sim_output( rts=[0.3, 0.4, 0.5], choices=[1, 1, 1], ) out = mean_rt_per_choice(sim) assert out == {1: pytest.approx(0.4)} def test_mean_rt_per_choice_all_invalid(): sim = _make_sim_output( rts=[-999, -999, -999], choices=[-999, -999, -999], ) assert mean_rt_per_choice(sim) == {} def test_mean_rt_per_choice_n_choice(): sim = _make_sim_output( rts=[0.4, 0.5, 0.6, 0.7, 0.8], choices=[0, 1, 2, 0, 3], ) out = mean_rt_per_choice(sim) assert set(out.keys()) == {0, 1, 2, 3} assert out[0] == pytest.approx(0.55) # mean of 0.4 and 0.7 assert out[1] == pytest.approx(0.5) assert out[2] == pytest.approx(0.6) assert out[3] == pytest.approx(0.8) def test_mean_rt_per_choice_keys_are_int_when_integer_valued(): """Choice keys should be int when integer-valued so DataFrame column names are clean.""" sim = _make_sim_output(rts=[0.4, 0.5], choices=[-1, 1]) out = mean_rt_per_choice(sim) for key in out: assert isinstance(key, int), f"expected int key, got {type(key)}"