| """Tests for dashboard summary-stats helpers in utils/utils.py.""" |
|
|
| import numpy as np |
| import pytest |
|
|
| from utils.utils import mean_rt_per_choice |
|
|
|
|
| def _make_sim_output(rts, choices): |
| """Build a sim-output-shaped dict with rts/choices arrays.""" |
| rts = np.asarray(rts, dtype=np.float32).reshape(-1, 1, 1) |
| choices = np.asarray(choices, dtype=np.float32).reshape(-1, 1, 1) |
| return {"rts": rts, "choices": choices} |
|
|
|
|
| def test_mean_rt_per_choice_basic_two_choice(): |
| sim = _make_sim_output( |
| rts=[0.4, 0.5, 0.6, 0.8, 1.0], |
| choices=[1, 1, 1, -1, -1], |
| ) |
| out = mean_rt_per_choice(sim) |
| assert set(out.keys()) == {-1, 1} |
| assert out[1] == pytest.approx(0.5) |
| assert out[-1] == pytest.approx(0.9) |
|
|
|
|
| def test_mean_rt_per_choice_filters_deadline_timeouts(): |
| sim = _make_sim_output( |
| rts=[0.4, -999, 0.6, -999, 1.0], |
| choices=[1, -999, 1, -999, -1], |
| ) |
| out = mean_rt_per_choice(sim) |
| |
| assert out[1] == pytest.approx(0.5) |
| assert out[-1] == pytest.approx(1.0) |
|
|
|
|
| def test_mean_rt_per_choice_one_sided(): |
| """All responses on one boundary — only that choice appears in the dict.""" |
| sim = _make_sim_output( |
| rts=[0.3, 0.4, 0.5], |
| choices=[1, 1, 1], |
| ) |
| out = mean_rt_per_choice(sim) |
| assert out == {1: pytest.approx(0.4)} |
|
|
|
|
| def test_mean_rt_per_choice_all_invalid(): |
| sim = _make_sim_output( |
| rts=[-999, -999, -999], |
| choices=[-999, -999, -999], |
| ) |
| assert mean_rt_per_choice(sim) == {} |
|
|
|
|
| def test_mean_rt_per_choice_n_choice(): |
| sim = _make_sim_output( |
| rts=[0.4, 0.5, 0.6, 0.7, 0.8], |
| choices=[0, 1, 2, 0, 3], |
| ) |
| out = mean_rt_per_choice(sim) |
| assert set(out.keys()) == {0, 1, 2, 3} |
| assert out[0] == pytest.approx(0.55) |
| assert out[1] == pytest.approx(0.5) |
| assert out[2] == pytest.approx(0.6) |
| assert out[3] == pytest.approx(0.8) |
|
|
|
|
| def test_mean_rt_per_choice_keys_are_int_when_integer_valued(): |
| """Choice keys should be int when integer-valued so DataFrame column names are clean.""" |
| sim = _make_sim_output(rts=[0.4, 0.5], choices=[-1, 1]) |
| out = mean_rt_per_choice(sim) |
| for key in out: |
| assert isinstance(key, int), f"expected int key, got {type(key)}" |
|
|