ssms_gui / tests /test_variability_collapse.py
Alexander
Render `_s*` model plots at expected variability + per-sample NDT
ecf42d4
"""Regression tests for the cartoon-variability collapse helpers in utils/utils.py.
These guard against the misleading behavior where `no_noise=True` cartoon plots in
ssms-gui rendered one random realization of across-trial variability (e.g. a slope
of `v + epsilon` for ddm_sdv) instead of the model's deterministic skeleton.
"""
import numpy as np
import pytest
from ssms.basic_simulators import Simulator
from utils.utils import (
_VARIABILITY_COLLAPSE,
_EXPECTED_T_POST_SHIFT,
_apply_expected_t_shift,
_collapse_theta_for_cartoon,
_patch_trajectory_t_with_actual_ndt,
)
# (model_name, theta_inner, sv_or_st_index, collapsed_value)
COLLAPSE_CASES = [
("ddm_sdv", [0.5, 1.0, 0.5, 0.3, 1.5], 4, 0.0),
("ddm_st", [0.5, 1.0, 0.5, 0.3, 0.2], 4, 0.0),
("ddm_truncnormt", [0.5, 1.0, 0.5, 0.4, 0.1], 4, 1e-9),
("ddm_rayleight", [0.5, 1.0, 0.5, 0.3], 3, 0.0),
]
@pytest.mark.parametrize("model_name,theta_inner,idx,collapsed", COLLAPSE_CASES)
def test_collapse_replaces_only_variability_slot(model_name, theta_inner, idx, collapsed):
out = _collapse_theta_for_cartoon(model_name, [list(theta_inner)])
assert len(out) == 1
assert len(out[0]) == len(theta_inner)
for i, val in enumerate(theta_inner):
if i == idx:
assert out[0][i] == collapsed
else:
assert out[0][i] == val
def test_collapse_passthrough_for_unmapped_model():
theta = [[0.5, 1.0, 0.5, 0.3]]
assert _collapse_theta_for_cartoon("ddm", theta) is theta
def test_collapse_does_not_mutate_input():
theta = [[0.5, 1.0, 0.5, 0.3, 1.5]]
original = [list(theta[0])]
_collapse_theta_for_cartoon("ddm_sdv", theta)
assert theta[0] == original[0]
def test_apply_expected_t_shift_rayleigh():
st_val = 0.3
sim_out = {"metadata": {"t": np.array([0.0], dtype=np.float32)}}
_apply_expected_t_shift("ddm_rayleight", {"st": st_val}, sim_out)
expected = st_val * np.sqrt(np.pi / 2.0)
assert sim_out["metadata"]["t"][0] == pytest.approx(expected, rel=1e-5)
def test_apply_expected_t_shift_noop_for_unmapped_model():
sim_out = {"metadata": {"t": np.array([0.123], dtype=np.float32)}}
_apply_expected_t_shift("ddm_sdv", {"sv": 1.0}, sim_out)
assert sim_out["metadata"]["t"][0] == pytest.approx(0.123)
# ---------------------------------------------------------------------------
# End-to-end: simulate with collapsed theta and `no_noise=True`. The trajectory's
# slope must equal `v` (no spurious tilt from a residual variability draw).
# Smoke test: with collapsed `sv`/`st`, a `no_noise=True` simulation should be
# fully deterministic up to integer truncation. We sanity-check the slope of
# `metadata['trajectory']` against `v`.
# ---------------------------------------------------------------------------
E2E_CASES = [
("ddm_sdv", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "sv": 1.5}),
("ddm_st", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "st": 0.3}),
("ddm_truncnormt", {"v": 0.8, "a": 1.2, "z": 0.5, "mt": 0.4, "st": 0.1}),
("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}),
]
def _theta_list(model_name, theta_dict):
from ssms.config import get_model_config
params = get_model_config()[model_name]["params"]
return [[theta_dict[p] for p in params]]
@pytest.mark.parametrize("model_name,theta_dict", E2E_CASES)
def test_collapsed_cartoon_slope_matches_v(model_name, theta_dict):
theta = _theta_list(model_name, theta_dict)
theta_cartoon = _collapse_theta_for_cartoon(model_name, theta)
sim = Simulator(model=model_name)
out = sim.simulate(
theta=theta_cartoon, n_samples=1,
no_noise=True, delta_t=0.001, smooth_unif=False,
random_state=42,
)
traj = np.asarray(out["metadata"]["trajectory"]).flatten()
valid = traj[traj > -999]
# Need at least a handful of points to fit a line; constant-boundary DDMs reach this easily.
assert len(valid) >= 10, f"too few trajectory points for {model_name}: {len(valid)}"
# Drop the last point (boundary crossing — slightly past the bound), fit a line.
delta_t = 0.001
t_axis = np.arange(len(valid)) * delta_t
slope, _ = np.polyfit(t_axis[:-1], valid[:-1], 1)
assert slope == pytest.approx(theta_dict["v"], abs=0.05), (
f"{model_name}: cartoon slope {slope:.4f} != v {theta_dict['v']}"
)
def test_rayleight_metadata_t_post_shift_e2e():
theta_dict = {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}
theta = _theta_list("ddm_rayleight", theta_dict)
theta_cartoon = _collapse_theta_for_cartoon("ddm_rayleight", theta)
sim = Simulator(model="ddm_rayleight")
out = sim.simulate(
theta=theta_cartoon, n_samples=1,
no_noise=True, delta_t=0.001, smooth_unif=False,
random_state=42,
)
_apply_expected_t_shift("ddm_rayleight", theta_dict, out)
expected = theta_dict["st"] * np.sqrt(np.pi / 2.0)
assert out["metadata"]["t"].flat[0] == pytest.approx(expected, rel=1e-4)
def test_collapse_table_keys_match_post_shift_keys_subset():
# Every model in the post-shift table must also be in the collapse table.
assert set(_EXPECTED_T_POST_SHIFT.keys()).issubset(_VARIABILITY_COLLAPSE.keys())
# ---------------------------------------------------------------------------
# Trajectory NDT patch: metadata['t'] is the *input* t, not the per-sample NDT
# actually drawn. The patch back-derives the actual NDT from rt - decision_time
# so the trajectory plotter starts each line at its own NDT.
# ---------------------------------------------------------------------------
def test_patch_trajectory_t_recovers_known_ndt():
delta_t = 0.001
# Trajectory: 50 valid steps then -999. decision_time = 50 * delta_t = 0.05s.
traj = np.full(200, -999.0, dtype=np.float32)
traj[:51] = np.linspace(0.0, 1.0, 51) # indices 0..50 valid
rt = 0.05 + 0.42 # decision_time + arbitrary NDT
sim_out = {
"rts": np.array([[[rt]]], dtype=np.float32),
"metadata": {
"t": np.array([0.0], dtype=np.float32),
"trajectory": traj,
},
}
_patch_trajectory_t_with_actual_ndt(sim_out, delta_t=delta_t)
assert sim_out["metadata"]["t"][0] == pytest.approx(0.42, abs=1e-5)
def test_patch_trajectory_t_e2e_rayleigh_varies_across_calls():
"""Across n trajectory calls with different seeds, the patched NDTs must vary
(Rayleigh-distributed) — not all be identical to the input t=0."""
sim = Simulator(model="ddm_rayleight")
theta = _theta_list("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3})
ndts = []
for seed in range(20):
out = sim.simulate(
theta=theta, n_samples=1,
no_noise=False, delta_t=0.001,
random_state=seed, smooth_unif=False,
)
_patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
ndts.append(float(out["metadata"]["t"].flat[0]))
ndts = np.array(ndts)
assert (ndts > 0).all(), "all patched NDTs should be strictly positive for Rayleigh"
assert ndts.std() > 0.05, f"patched NDTs should vary across seeds, got std={ndts.std():.4f}"
# Mean should be in the right ballpark of E[Rayleigh(0.3)] = 0.3 * sqrt(pi/2) ≈ 0.376.
# 20 samples is noisy, so a wide tolerance.
assert 0.15 < ndts.mean() < 0.65, f"unexpected mean NDT {ndts.mean():.4f}"
def test_patch_trajectory_t_noop_for_constant_ndt_model():
"""For plain ddm (no NDT variability), patched NDT should match input t."""
sim = Simulator(model="ddm")
input_t = 0.3
theta = _theta_list("ddm", {"v": 0.8, "a": 1.2, "z": 0.5, "t": input_t})
out = sim.simulate(
theta=theta, n_samples=1,
no_noise=False, delta_t=0.001,
random_state=42, smooth_unif=False,
)
_patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
# Allow a single delta_t of slop from integer truncation of decision time.
assert out["metadata"]["t"][0] == pytest.approx(input_t, abs=2e-3)