File size: 8,028 Bytes
ecf42d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""Regression tests for the cartoon-variability collapse helpers in utils/utils.py.

These guard against the misleading behavior where `no_noise=True` cartoon plots in
ssms-gui rendered one random realization of across-trial variability (e.g. a slope
of `v + epsilon` for ddm_sdv) instead of the model's deterministic skeleton.
"""

import numpy as np
import pytest

from ssms.basic_simulators import Simulator

from utils.utils import (
    _VARIABILITY_COLLAPSE,
    _EXPECTED_T_POST_SHIFT,
    _apply_expected_t_shift,
    _collapse_theta_for_cartoon,
    _patch_trajectory_t_with_actual_ndt,
)


# (model_name, theta_inner, sv_or_st_index, collapsed_value)
COLLAPSE_CASES = [
    ("ddm_sdv",        [0.5, 1.0, 0.5, 0.3, 1.5], 4, 0.0),
    ("ddm_st",         [0.5, 1.0, 0.5, 0.3, 0.2], 4, 0.0),
    ("ddm_truncnormt", [0.5, 1.0, 0.5, 0.4, 0.1], 4, 1e-9),
    ("ddm_rayleight",  [0.5, 1.0, 0.5, 0.3],      3, 0.0),
]


@pytest.mark.parametrize("model_name,theta_inner,idx,collapsed", COLLAPSE_CASES)
def test_collapse_replaces_only_variability_slot(model_name, theta_inner, idx, collapsed):
    out = _collapse_theta_for_cartoon(model_name, [list(theta_inner)])
    assert len(out) == 1
    assert len(out[0]) == len(theta_inner)
    for i, val in enumerate(theta_inner):
        if i == idx:
            assert out[0][i] == collapsed
        else:
            assert out[0][i] == val


def test_collapse_passthrough_for_unmapped_model():
    theta = [[0.5, 1.0, 0.5, 0.3]]
    assert _collapse_theta_for_cartoon("ddm", theta) is theta


def test_collapse_does_not_mutate_input():
    theta = [[0.5, 1.0, 0.5, 0.3, 1.5]]
    original = [list(theta[0])]
    _collapse_theta_for_cartoon("ddm_sdv", theta)
    assert theta[0] == original[0]


def test_apply_expected_t_shift_rayleigh():
    st_val = 0.3
    sim_out = {"metadata": {"t": np.array([0.0], dtype=np.float32)}}
    _apply_expected_t_shift("ddm_rayleight", {"st": st_val}, sim_out)
    expected = st_val * np.sqrt(np.pi / 2.0)
    assert sim_out["metadata"]["t"][0] == pytest.approx(expected, rel=1e-5)


def test_apply_expected_t_shift_noop_for_unmapped_model():
    sim_out = {"metadata": {"t": np.array([0.123], dtype=np.float32)}}
    _apply_expected_t_shift("ddm_sdv", {"sv": 1.0}, sim_out)
    assert sim_out["metadata"]["t"][0] == pytest.approx(0.123)


# ---------------------------------------------------------------------------
# End-to-end: simulate with collapsed theta and `no_noise=True`. The trajectory's
# slope must equal `v` (no spurious tilt from a residual variability draw).
# Smoke test: with collapsed `sv`/`st`, a `no_noise=True` simulation should be
# fully deterministic up to integer truncation. We sanity-check the slope of
# `metadata['trajectory']` against `v`.
# ---------------------------------------------------------------------------

E2E_CASES = [
    ("ddm_sdv",        {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "sv": 1.5}),
    ("ddm_st",         {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "st": 0.3}),
    ("ddm_truncnormt", {"v": 0.8, "a": 1.2, "z": 0.5, "mt": 0.4, "st": 0.1}),
    ("ddm_rayleight",  {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}),
]


def _theta_list(model_name, theta_dict):
    from ssms.config import get_model_config
    params = get_model_config()[model_name]["params"]
    return [[theta_dict[p] for p in params]]


@pytest.mark.parametrize("model_name,theta_dict", E2E_CASES)
def test_collapsed_cartoon_slope_matches_v(model_name, theta_dict):
    theta = _theta_list(model_name, theta_dict)
    theta_cartoon = _collapse_theta_for_cartoon(model_name, theta)

    sim = Simulator(model=model_name)
    out = sim.simulate(
        theta=theta_cartoon, n_samples=1,
        no_noise=True, delta_t=0.001, smooth_unif=False,
        random_state=42,
    )

    traj = np.asarray(out["metadata"]["trajectory"]).flatten()
    valid = traj[traj > -999]
    # Need at least a handful of points to fit a line; constant-boundary DDMs reach this easily.
    assert len(valid) >= 10, f"too few trajectory points for {model_name}: {len(valid)}"

    # Drop the last point (boundary crossing — slightly past the bound), fit a line.
    delta_t = 0.001
    t_axis = np.arange(len(valid)) * delta_t
    slope, _ = np.polyfit(t_axis[:-1], valid[:-1], 1)
    assert slope == pytest.approx(theta_dict["v"], abs=0.05), (
        f"{model_name}: cartoon slope {slope:.4f} != v {theta_dict['v']}"
    )


def test_rayleight_metadata_t_post_shift_e2e():
    theta_dict = {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}
    theta = _theta_list("ddm_rayleight", theta_dict)
    theta_cartoon = _collapse_theta_for_cartoon("ddm_rayleight", theta)

    sim = Simulator(model="ddm_rayleight")
    out = sim.simulate(
        theta=theta_cartoon, n_samples=1,
        no_noise=True, delta_t=0.001, smooth_unif=False,
        random_state=42,
    )
    _apply_expected_t_shift("ddm_rayleight", theta_dict, out)

    expected = theta_dict["st"] * np.sqrt(np.pi / 2.0)
    assert out["metadata"]["t"].flat[0] == pytest.approx(expected, rel=1e-4)


def test_collapse_table_keys_match_post_shift_keys_subset():
    # Every model in the post-shift table must also be in the collapse table.
    assert set(_EXPECTED_T_POST_SHIFT.keys()).issubset(_VARIABILITY_COLLAPSE.keys())


# ---------------------------------------------------------------------------
# Trajectory NDT patch: metadata['t'] is the *input* t, not the per-sample NDT
# actually drawn. The patch back-derives the actual NDT from rt - decision_time
# so the trajectory plotter starts each line at its own NDT.
# ---------------------------------------------------------------------------


def test_patch_trajectory_t_recovers_known_ndt():
    delta_t = 0.001
    # Trajectory: 50 valid steps then -999. decision_time = 50 * delta_t = 0.05s.
    traj = np.full(200, -999.0, dtype=np.float32)
    traj[:51] = np.linspace(0.0, 1.0, 51)  # indices 0..50 valid
    rt = 0.05 + 0.42  # decision_time + arbitrary NDT
    sim_out = {
        "rts": np.array([[[rt]]], dtype=np.float32),
        "metadata": {
            "t": np.array([0.0], dtype=np.float32),
            "trajectory": traj,
        },
    }
    _patch_trajectory_t_with_actual_ndt(sim_out, delta_t=delta_t)
    assert sim_out["metadata"]["t"][0] == pytest.approx(0.42, abs=1e-5)


def test_patch_trajectory_t_e2e_rayleigh_varies_across_calls():
    """Across n trajectory calls with different seeds, the patched NDTs must vary
    (Rayleigh-distributed) — not all be identical to the input t=0."""
    sim = Simulator(model="ddm_rayleight")
    theta = _theta_list("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3})

    ndts = []
    for seed in range(20):
        out = sim.simulate(
            theta=theta, n_samples=1,
            no_noise=False, delta_t=0.001,
            random_state=seed, smooth_unif=False,
        )
        _patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
        ndts.append(float(out["metadata"]["t"].flat[0]))

    ndts = np.array(ndts)
    assert (ndts > 0).all(), "all patched NDTs should be strictly positive for Rayleigh"
    assert ndts.std() > 0.05, f"patched NDTs should vary across seeds, got std={ndts.std():.4f}"
    # Mean should be in the right ballpark of E[Rayleigh(0.3)] = 0.3 * sqrt(pi/2) ≈ 0.376.
    # 20 samples is noisy, so a wide tolerance.
    assert 0.15 < ndts.mean() < 0.65, f"unexpected mean NDT {ndts.mean():.4f}"


def test_patch_trajectory_t_noop_for_constant_ndt_model():
    """For plain ddm (no NDT variability), patched NDT should match input t."""
    sim = Simulator(model="ddm")
    input_t = 0.3
    theta = _theta_list("ddm", {"v": 0.8, "a": 1.2, "z": 0.5, "t": input_t})

    out = sim.simulate(
        theta=theta, n_samples=1,
        no_noise=False, delta_t=0.001,
        random_state=42, smooth_unif=False,
    )
    _patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
    # Allow a single delta_t of slop from integer truncation of decision time.
    assert out["metadata"]["t"][0] == pytest.approx(input_t, abs=2e-3)