Render `_s*` model plots at expected variability + per-sample NDT
Browse filesThe cartoon and trajectory plots for DDM variants with random across-trial
parameters (ddm_sdv, ddm_st, ddm_truncnormt, ddm_rayleight) were showing
misleading information:
1. The noiseless "model cartoon" line rendered at slope `v + epsilon` for
one random epsilon ~ Normal(0, sv) instead of the deterministic skeleton
at slope `v`. Reason: ssm-simulators `no_noise=True` only zeros Brownian
noise, not across-trial variability draws.
2. Sample trajectories all started at the input `t` even when each actual
per-sample NDT was drawn from a random distribution (Uniform/TruncNorm/
Rayleigh). Reason: simulator metadata records the input `t`, not the
per-sample NDT actually drawn from `t_dist`.
GUI-side fixes:
- `_VARIABILITY_COLLAPSE` overrides each `_s*` model's variability params
before the cartoon-only simulator call.
- `_EXPECTED_T_POST_SHIFT` post-adjusts `metadata['t']` for distributions
whose param=0 collapse != expectation (ddm_rayleight: st * sqrt(pi/2)).
- `_patch_trajectory_t_with_actual_ndt` back-derives per-sample NDT from
`rt - decision_time` and overrides `metadata['t']` per trajectory.
- `expected_random_params` flag (default True) on plot_func_model{,_n},
surfaced as a sidebar checkbox in app.py.
Both fixes are GUI-scoped, no ssm-simulators changes needed. TODO
comments mark the lift points for eventual upstream relocation.
Build infra: migrate to uv (pyproject.toml + uv.lock), update Dockerfile
to install via `uv sync --frozen --no-dev` and entrypoint via `uv run`,
drop requirements.txt. Local workflow: `uv sync && uv run pytest`.
Tests: 17 new tests covering helpers and end-to-end behavior on each
affected model.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- .gitignore +5 -0
- Dockerfile +8 -4
- pyproject.toml +22 -0
- requirements.txt +0 -7
- src/app.py +17 -5
- src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- src/utils/__pycache__/utils.cpython-311.pyc +0 -0
- src/utils/utils.py +91 -2
- tests/conftest.py +4 -0
- tests/test_variability_collapse.py +200 -0
- uv.lock +0 -0
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
.DS_Store
|
|
@@ -8,13 +8,17 @@ RUN apt-get update && apt-get install -y \
|
|
| 8 |
git \
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
| 11 |
-
|
| 12 |
-
COPY
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
EXPOSE 8501
|
| 17 |
|
| 18 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 19 |
|
| 20 |
-
ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
| 8 |
git \
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
| 11 |
+
# Install uv (multi-stage copy from the official image — no curl/install scripts).
|
| 12 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 13 |
+
|
| 14 |
+
# Resolve and install dependencies from the lockfile (production: no dev group).
|
| 15 |
+
COPY pyproject.toml uv.lock ./
|
| 16 |
+
RUN uv sync --frozen --no-dev
|
| 17 |
|
| 18 |
+
COPY src/ ./src/
|
| 19 |
|
| 20 |
EXPOSE 8501
|
| 21 |
|
| 22 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 23 |
|
| 24 |
+
ENTRYPOINT ["uv", "run", "--no-sync", "streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ssms-gui"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Streamlit dashboard for visualizing sequential sampling models from ssm-simulators."
|
| 5 |
+
requires-python = ">=3.13"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"altair",
|
| 8 |
+
"cloudpickle",
|
| 9 |
+
"matplotlib",
|
| 10 |
+
"pandas",
|
| 11 |
+
"seaborn",
|
| 12 |
+
"ssm-simulators>=0.12.2",
|
| 13 |
+
"streamlit>1.30.0",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[dependency-groups]
|
| 17 |
+
dev = [
|
| 18 |
+
"pytest",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[tool.pytest.ini_options]
|
| 22 |
+
testpaths = ["tests"]
|
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
altair
|
| 2 |
-
pandas
|
| 3 |
-
ssm-simulators>=0.12.2
|
| 4 |
-
cloudpickle
|
| 5 |
-
streamlit>1.30.0
|
| 6 |
-
matplotlib
|
| 7 |
-
seaborn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -170,11 +170,22 @@ def create_styling_selectors(model_num: int = 1):
|
|
| 170 |
key=f"show_ndt_{model_num}_{st.session_state['styling_version']}"
|
| 171 |
)
|
| 172 |
styling_config["add_data_model_keep_starting_point"] = st.checkbox(
|
| 173 |
-
"Show Starting Point",
|
| 174 |
value=True,
|
| 175 |
key=f"show_start_{model_num}_{st.session_state['styling_version']}"
|
| 176 |
)
|
| 177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
# Axis Limits Section
|
| 179 |
st.markdown("**Axis Limits**")
|
| 180 |
styling_config["xlim_min"] = st.number_input(
|
|
@@ -288,14 +299,15 @@ def get_filtered_styling_config(styling_config, plot_type="plot_func_model"):
|
|
| 288 |
elif plot_type == "plot_func_model_n":
|
| 289 |
# plot_func_model_n only accepts a subset of parameters
|
| 290 |
allowed_params = {
|
| 291 |
-
'linewidth_histogram', 'linewidth_model', 'bin_size',
|
| 292 |
-
'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow',
|
| 293 |
'add_legend', 'add_data_model_markersize_starting_point',
|
| 294 |
'add_data_model_markertype_starting_point',
|
| 295 |
'add_data_model_keep_starting_point',
|
| 296 |
'add_data_model_keep_boundary',
|
| 297 |
-
'add_data_model_keep_slope',
|
| 298 |
-
'add_data_model_keep_ndt'
|
|
|
|
| 299 |
}
|
| 300 |
return {k: v for k, v in styling_config.items() if k in allowed_params}
|
| 301 |
|
|
|
|
| 170 |
key=f"show_ndt_{model_num}_{st.session_state['styling_version']}"
|
| 171 |
)
|
| 172 |
styling_config["add_data_model_keep_starting_point"] = st.checkbox(
|
| 173 |
+
"Show Starting Point",
|
| 174 |
value=True,
|
| 175 |
key=f"show_start_{model_num}_{st.session_state['styling_version']}"
|
| 176 |
)
|
| 177 |
|
| 178 |
+
styling_config["expected_random_params"] = st.checkbox(
|
| 179 |
+
"Cartoon at expected variability params",
|
| 180 |
+
value=True,
|
| 181 |
+
help=(
|
| 182 |
+
"When on, the noiseless model cartoon reflects the expected slope/NDT "
|
| 183 |
+
"instead of one random realization of across-trial variability "
|
| 184 |
+
"(e.g. sv, st in ddm_sdv, ddm_st, ddm_truncnormt, ddm_rayleight)."
|
| 185 |
+
),
|
| 186 |
+
key=f"expected_random_params_{model_num}_{st.session_state['styling_version']}",
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
# Axis Limits Section
|
| 190 |
st.markdown("**Axis Limits**")
|
| 191 |
styling_config["xlim_min"] = st.number_input(
|
|
|
|
| 299 |
elif plot_type == "plot_func_model_n":
|
| 300 |
# plot_func_model_n only accepts a subset of parameters
|
| 301 |
allowed_params = {
|
| 302 |
+
'linewidth_histogram', 'linewidth_model', 'bin_size',
|
| 303 |
+
'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow',
|
| 304 |
'add_legend', 'add_data_model_markersize_starting_point',
|
| 305 |
'add_data_model_markertype_starting_point',
|
| 306 |
'add_data_model_keep_starting_point',
|
| 307 |
'add_data_model_keep_boundary',
|
| 308 |
+
'add_data_model_keep_slope',
|
| 309 |
+
'add_data_model_keep_ndt',
|
| 310 |
+
'expected_random_params',
|
| 311 |
}
|
| 312 |
return {k: v for k, v in styling_config.items() if k in allowed_params}
|
| 313 |
|
|
Binary file (276 Bytes)
|
|
|
|
Binary file (26 kB)
|
|
|
|
@@ -8,6 +8,81 @@ _model_config = get_model_config()
|
|
| 8 |
from matplotlib.lines import Line2D
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def plot_func_model(
|
| 12 |
model_name,
|
| 13 |
theta,
|
|
@@ -35,6 +110,7 @@ def plot_func_model(
|
|
| 35 |
delta_t_model=0.001,
|
| 36 |
random_state=None,
|
| 37 |
add_legend=True, # keep_frame=False,
|
|
|
|
| 38 |
**kwargs,
|
| 39 |
):
|
| 40 |
"""Calculate posterior predictive for a certain bottom node.
|
|
@@ -134,10 +210,16 @@ def plot_func_model(
|
|
| 134 |
sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
|
| 135 |
no_noise=False, delta_t=0.001,
|
| 136 |
random_state=rand_int, smooth_unif=False)
|
|
|
|
| 137 |
|
| 138 |
-
|
|
|
|
| 139 |
no_noise=True, delta_t=0.001,
|
| 140 |
smooth_unif=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
# ADD DATA HISTOGRAMS
|
| 143 |
weights_up = np.tile(
|
|
@@ -453,6 +535,7 @@ def plot_func_model_n(
|
|
| 453 |
alpha=1,
|
| 454 |
keep_frame=False,
|
| 455 |
random_state=None,
|
|
|
|
| 456 |
**kwargs,
|
| 457 |
):
|
| 458 |
"""Calculate posterior predictive for a certain bottom node.
|
|
@@ -565,10 +648,16 @@ def plot_func_model_n(
|
|
| 565 |
sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
|
| 566 |
no_noise=False, delta_t=0.001,
|
| 567 |
random_state=rand_int, smooth_unif=False)
|
|
|
|
| 568 |
|
| 569 |
-
|
|
|
|
| 570 |
no_noise=True, delta_t=0.001,
|
| 571 |
smooth_unif=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
| 573 |
# ADD HISTOGRAMS
|
| 574 |
# -------------------------------
|
|
|
|
| 8 |
from matplotlib.lines import Line2D
|
| 9 |
|
| 10 |
|
| 11 |
+
# TODO(ssm-simulators): upstream this as Simulator.simulate(variability_at_expectation=True).
|
| 12 |
+
# For models with across-trial random parameters (e.g. ddm_sdv: v ~ N(v, sv)), `no_noise=True`
|
| 13 |
+
# only zeros Brownian noise — the variability draw still happens, so the cartoon shows one
|
| 14 |
+
# random realization of the variability term instead of the model's deterministic skeleton.
|
| 15 |
+
# These tables let us collapse those draws to their expected value for the cartoon-only path.
|
| 16 |
+
#
|
| 17 |
+
# Maps model_name -> {variability_param_name: collapsed_value}.
|
| 18 |
+
_VARIABILITY_COLLAPSE = {
|
| 19 |
+
"ddm_sdv": {"sv": 0.0}, # N(0, 0) -> 0 = E[N(0, sv)]
|
| 20 |
+
"ddm_st": {"st": 0.0}, # U(0, 0) -> 0 = E[U(-st, st)]
|
| 21 |
+
"ddm_truncnormt": {"st": 1e-9}, # TruncN(mt, ~0) -> mt ≈ E[TruncN(mt, st)]; eps avoids div-by-zero in the config lambda (a=-mt/st)
|
| 22 |
+
"ddm_rayleight": {"st": 0.0}, # Rayleigh(0) -> 0; metadata['t'] is then post-shifted to st·sqrt(pi/2)
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# Models where collapsing variability does NOT yield the distribution's expectation.
|
| 26 |
+
# Shift `metadata['t']` after simulation by the expected NDT contribution.
|
| 27 |
+
_EXPECTED_T_POST_SHIFT = {
|
| 28 |
+
"ddm_rayleight": lambda theta_d: theta_d["st"] * np.sqrt(np.pi / 2.0),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _collapse_theta_for_cartoon(model_name, theta):
|
| 33 |
+
"""Return a copy of `theta` with random-variability params replaced for the cartoon-only path.
|
| 34 |
+
|
| 35 |
+
`theta` is the list-of-lists shape that `app.py` passes to `Simulator.simulate`: positional
|
| 36 |
+
in `_model_config[model_name]['params']` order. Models without an entry in the table are
|
| 37 |
+
returned unchanged.
|
| 38 |
+
"""
|
| 39 |
+
overrides = _VARIABILITY_COLLAPSE.get(model_name)
|
| 40 |
+
if not overrides:
|
| 41 |
+
return theta
|
| 42 |
+
params = _model_config[model_name]["params"]
|
| 43 |
+
inner = list(theta[0])
|
| 44 |
+
for name, value in overrides.items():
|
| 45 |
+
inner[params.index(name)] = value
|
| 46 |
+
return [inner]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _apply_expected_t_shift(model_name, theta_dict, sim_out):
|
| 50 |
+
"""Mutate `sim_out['metadata']['t']` in place to reflect E[NDT] for models that need it.
|
| 51 |
+
|
| 52 |
+
Required for distributions whose param=0 collapse does not equal the expectation
|
| 53 |
+
(currently only ddm_rayleight: Rayleigh(0) = 0 but E[Rayleigh(scale=st)] = st·sqrt(pi/2)).
|
| 54 |
+
"""
|
| 55 |
+
shift_fn = _EXPECTED_T_POST_SHIFT.get(model_name)
|
| 56 |
+
if shift_fn is None:
|
| 57 |
+
return
|
| 58 |
+
shift = shift_fn(theta_dict)
|
| 59 |
+
t_arr = np.asarray(sim_out["metadata"]["t"])
|
| 60 |
+
sim_out["metadata"]["t"] = (t_arr + shift).astype(t_arr.dtype, copy=False)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# TODO(ssm-simulators): expose t_samplewise (and v_samplewise, z_samplewise) in metadata
|
| 64 |
+
# so consumers don't have to back-derive per-sample NDT from RT and the trajectory.
|
| 65 |
+
def _patch_trajectory_t_with_actual_ndt(sim_out, delta_t=0.001):
|
| 66 |
+
"""Override `sim_out['metadata']['t']` with the per-sample NDT actually used for this trajectory.
|
| 67 |
+
|
| 68 |
+
The simulator records the *input* `t` in metadata, not the per-sample draw from `t_dist`
|
| 69 |
+
(e.g. Uniform/TruncNormal/Rayleigh for ddm_st / ddm_truncnormt / ddm_rayleight).
|
| 70 |
+
Without this patch, every trajectory plotter starts at the input `t` and ignores NDT
|
| 71 |
+
variability. We back-derive the per-sample NDT from `RT - decision_time`, where
|
| 72 |
+
`decision_time = (last valid trajectory index) * delta_t`. Trajectory calls in this
|
| 73 |
+
module use `smooth_unif=False`, so `RT = decision_time + NDT` exactly.
|
| 74 |
+
|
| 75 |
+
No-op for models without NDT variability (actual_ndt == input t).
|
| 76 |
+
"""
|
| 77 |
+
traj = np.asarray(sim_out["metadata"]["trajectory"]).flatten()
|
| 78 |
+
valid_idx = np.where(traj > -999)[0]
|
| 79 |
+
decision_time = float(valid_idx[-1] * delta_t) if len(valid_idx) else 0.0
|
| 80 |
+
rt = float(np.asarray(sim_out["rts"]).flat[0])
|
| 81 |
+
actual_ndt = max(0.0, rt - decision_time)
|
| 82 |
+
t_arr = np.asarray(sim_out["metadata"]["t"])
|
| 83 |
+
sim_out["metadata"]["t"] = np.full_like(t_arr, actual_ndt, dtype=t_arr.dtype)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
def plot_func_model(
|
| 87 |
model_name,
|
| 88 |
theta,
|
|
|
|
| 110 |
delta_t_model=0.001,
|
| 111 |
random_state=None,
|
| 112 |
add_legend=True, # keep_frame=False,
|
| 113 |
+
expected_random_params=True,
|
| 114 |
**kwargs,
|
| 115 |
):
|
| 116 |
"""Calculate posterior predictive for a certain bottom node.
|
|
|
|
| 210 |
sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
|
| 211 |
no_noise=False, delta_t=0.001,
|
| 212 |
random_state=rand_int, smooth_unif=False)
|
| 213 |
+
_patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001)
|
| 214 |
|
| 215 |
+
theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta
|
| 216 |
+
sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1,
|
| 217 |
no_noise=True, delta_t=0.001,
|
| 218 |
smooth_unif=False)
|
| 219 |
+
if expected_random_params:
|
| 220 |
+
params = _model_config[model_name]["params"]
|
| 221 |
+
theta_dict = dict(zip(params, theta[0]))
|
| 222 |
+
_apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise)
|
| 223 |
|
| 224 |
# ADD DATA HISTOGRAMS
|
| 225 |
weights_up = np.tile(
|
|
|
|
| 535 |
alpha=1,
|
| 536 |
keep_frame=False,
|
| 537 |
random_state=None,
|
| 538 |
+
expected_random_params=True,
|
| 539 |
**kwargs,
|
| 540 |
):
|
| 541 |
"""Calculate posterior predictive for a certain bottom node.
|
|
|
|
| 648 |
sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
|
| 649 |
no_noise=False, delta_t=0.001,
|
| 650 |
random_state=rand_int, smooth_unif=False)
|
| 651 |
+
_patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001)
|
| 652 |
|
| 653 |
+
theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta
|
| 654 |
+
sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1,
|
| 655 |
no_noise=True, delta_t=0.001,
|
| 656 |
smooth_unif=False)
|
| 657 |
+
if expected_random_params:
|
| 658 |
+
params = _model_config[model_name]["params"]
|
| 659 |
+
theta_dict = dict(zip(params, theta[0]))
|
| 660 |
+
_apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise)
|
| 661 |
|
| 662 |
# ADD HISTOGRAMS
|
| 663 |
# -------------------------------
|
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
|
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Regression tests for the cartoon-variability collapse helpers in utils/utils.py.
|
| 2 |
+
|
| 3 |
+
These guard against the misleading behavior where `no_noise=True` cartoon plots in
|
| 4 |
+
ssms-gui rendered one random realization of across-trial variability (e.g. a slope
|
| 5 |
+
of `v + epsilon` for ddm_sdv) instead of the model's deterministic skeleton.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
from ssms.basic_simulators import Simulator
|
| 12 |
+
|
| 13 |
+
from utils.utils import (
|
| 14 |
+
_VARIABILITY_COLLAPSE,
|
| 15 |
+
_EXPECTED_T_POST_SHIFT,
|
| 16 |
+
_apply_expected_t_shift,
|
| 17 |
+
_collapse_theta_for_cartoon,
|
| 18 |
+
_patch_trajectory_t_with_actual_ndt,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# (model_name, theta_inner, sv_or_st_index, collapsed_value)
|
| 23 |
+
COLLAPSE_CASES = [
|
| 24 |
+
("ddm_sdv", [0.5, 1.0, 0.5, 0.3, 1.5], 4, 0.0),
|
| 25 |
+
("ddm_st", [0.5, 1.0, 0.5, 0.3, 0.2], 4, 0.0),
|
| 26 |
+
("ddm_truncnormt", [0.5, 1.0, 0.5, 0.4, 0.1], 4, 1e-9),
|
| 27 |
+
("ddm_rayleight", [0.5, 1.0, 0.5, 0.3], 3, 0.0),
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@pytest.mark.parametrize("model_name,theta_inner,idx,collapsed", COLLAPSE_CASES)
|
| 32 |
+
def test_collapse_replaces_only_variability_slot(model_name, theta_inner, idx, collapsed):
|
| 33 |
+
out = _collapse_theta_for_cartoon(model_name, [list(theta_inner)])
|
| 34 |
+
assert len(out) == 1
|
| 35 |
+
assert len(out[0]) == len(theta_inner)
|
| 36 |
+
for i, val in enumerate(theta_inner):
|
| 37 |
+
if i == idx:
|
| 38 |
+
assert out[0][i] == collapsed
|
| 39 |
+
else:
|
| 40 |
+
assert out[0][i] == val
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_collapse_passthrough_for_unmapped_model():
|
| 44 |
+
theta = [[0.5, 1.0, 0.5, 0.3]]
|
| 45 |
+
assert _collapse_theta_for_cartoon("ddm", theta) is theta
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def test_collapse_does_not_mutate_input():
|
| 49 |
+
theta = [[0.5, 1.0, 0.5, 0.3, 1.5]]
|
| 50 |
+
original = [list(theta[0])]
|
| 51 |
+
_collapse_theta_for_cartoon("ddm_sdv", theta)
|
| 52 |
+
assert theta[0] == original[0]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_apply_expected_t_shift_rayleigh():
|
| 56 |
+
st_val = 0.3
|
| 57 |
+
sim_out = {"metadata": {"t": np.array([0.0], dtype=np.float32)}}
|
| 58 |
+
_apply_expected_t_shift("ddm_rayleight", {"st": st_val}, sim_out)
|
| 59 |
+
expected = st_val * np.sqrt(np.pi / 2.0)
|
| 60 |
+
assert sim_out["metadata"]["t"][0] == pytest.approx(expected, rel=1e-5)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def test_apply_expected_t_shift_noop_for_unmapped_model():
|
| 64 |
+
sim_out = {"metadata": {"t": np.array([0.123], dtype=np.float32)}}
|
| 65 |
+
_apply_expected_t_shift("ddm_sdv", {"sv": 1.0}, sim_out)
|
| 66 |
+
assert sim_out["metadata"]["t"][0] == pytest.approx(0.123)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# End-to-end: simulate with collapsed theta and `no_noise=True`. The trajectory's
|
| 71 |
+
# slope must equal `v` (no spurious tilt from a residual variability draw).
|
| 72 |
+
# Smoke test: with collapsed `sv`/`st`, a `no_noise=True` simulation should be
|
| 73 |
+
# fully deterministic up to integer truncation. We sanity-check the slope of
|
| 74 |
+
# `metadata['trajectory']` against `v`.
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
E2E_CASES = [
|
| 78 |
+
("ddm_sdv", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "sv": 1.5}),
|
| 79 |
+
("ddm_st", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "st": 0.3}),
|
| 80 |
+
("ddm_truncnormt", {"v": 0.8, "a": 1.2, "z": 0.5, "mt": 0.4, "st": 0.1}),
|
| 81 |
+
("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}),
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def _theta_list(model_name, theta_dict):
|
| 86 |
+
from ssms.config import get_model_config
|
| 87 |
+
params = get_model_config()[model_name]["params"]
|
| 88 |
+
return [[theta_dict[p] for p in params]]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@pytest.mark.parametrize("model_name,theta_dict", E2E_CASES)
|
| 92 |
+
def test_collapsed_cartoon_slope_matches_v(model_name, theta_dict):
|
| 93 |
+
theta = _theta_list(model_name, theta_dict)
|
| 94 |
+
theta_cartoon = _collapse_theta_for_cartoon(model_name, theta)
|
| 95 |
+
|
| 96 |
+
sim = Simulator(model=model_name)
|
| 97 |
+
out = sim.simulate(
|
| 98 |
+
theta=theta_cartoon, n_samples=1,
|
| 99 |
+
no_noise=True, delta_t=0.001, smooth_unif=False,
|
| 100 |
+
random_state=42,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
traj = np.asarray(out["metadata"]["trajectory"]).flatten()
|
| 104 |
+
valid = traj[traj > -999]
|
| 105 |
+
# Need at least a handful of points to fit a line; constant-boundary DDMs reach this easily.
|
| 106 |
+
assert len(valid) >= 10, f"too few trajectory points for {model_name}: {len(valid)}"
|
| 107 |
+
|
| 108 |
+
# Drop the last point (boundary crossing — slightly past the bound), fit a line.
|
| 109 |
+
delta_t = 0.001
|
| 110 |
+
t_axis = np.arange(len(valid)) * delta_t
|
| 111 |
+
slope, _ = np.polyfit(t_axis[:-1], valid[:-1], 1)
|
| 112 |
+
assert slope == pytest.approx(theta_dict["v"], abs=0.05), (
|
| 113 |
+
f"{model_name}: cartoon slope {slope:.4f} != v {theta_dict['v']}"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def test_rayleight_metadata_t_post_shift_e2e():
|
| 118 |
+
theta_dict = {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}
|
| 119 |
+
theta = _theta_list("ddm_rayleight", theta_dict)
|
| 120 |
+
theta_cartoon = _collapse_theta_for_cartoon("ddm_rayleight", theta)
|
| 121 |
+
|
| 122 |
+
sim = Simulator(model="ddm_rayleight")
|
| 123 |
+
out = sim.simulate(
|
| 124 |
+
theta=theta_cartoon, n_samples=1,
|
| 125 |
+
no_noise=True, delta_t=0.001, smooth_unif=False,
|
| 126 |
+
random_state=42,
|
| 127 |
+
)
|
| 128 |
+
_apply_expected_t_shift("ddm_rayleight", theta_dict, out)
|
| 129 |
+
|
| 130 |
+
expected = theta_dict["st"] * np.sqrt(np.pi / 2.0)
|
| 131 |
+
assert out["metadata"]["t"].flat[0] == pytest.approx(expected, rel=1e-4)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def test_collapse_table_keys_match_post_shift_keys_subset():
|
| 135 |
+
# Every model in the post-shift table must also be in the collapse table.
|
| 136 |
+
assert set(_EXPECTED_T_POST_SHIFT.keys()).issubset(_VARIABILITY_COLLAPSE.keys())
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Trajectory NDT patch: metadata['t'] is the *input* t, not the per-sample NDT
|
| 141 |
+
# actually drawn. The patch back-derives the actual NDT from rt - decision_time
|
| 142 |
+
# so the trajectory plotter starts each line at its own NDT.
|
| 143 |
+
# ---------------------------------------------------------------------------
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_patch_trajectory_t_recovers_known_ndt():
|
| 147 |
+
delta_t = 0.001
|
| 148 |
+
# Trajectory: 50 valid steps then -999. decision_time = 50 * delta_t = 0.05s.
|
| 149 |
+
traj = np.full(200, -999.0, dtype=np.float32)
|
| 150 |
+
traj[:51] = np.linspace(0.0, 1.0, 51) # indices 0..50 valid
|
| 151 |
+
rt = 0.05 + 0.42 # decision_time + arbitrary NDT
|
| 152 |
+
sim_out = {
|
| 153 |
+
"rts": np.array([[[rt]]], dtype=np.float32),
|
| 154 |
+
"metadata": {
|
| 155 |
+
"t": np.array([0.0], dtype=np.float32),
|
| 156 |
+
"trajectory": traj,
|
| 157 |
+
},
|
| 158 |
+
}
|
| 159 |
+
_patch_trajectory_t_with_actual_ndt(sim_out, delta_t=delta_t)
|
| 160 |
+
assert sim_out["metadata"]["t"][0] == pytest.approx(0.42, abs=1e-5)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def test_patch_trajectory_t_e2e_rayleigh_varies_across_calls():
|
| 164 |
+
"""Across n trajectory calls with different seeds, the patched NDTs must vary
|
| 165 |
+
(Rayleigh-distributed) — not all be identical to the input t=0."""
|
| 166 |
+
sim = Simulator(model="ddm_rayleight")
|
| 167 |
+
theta = _theta_list("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3})
|
| 168 |
+
|
| 169 |
+
ndts = []
|
| 170 |
+
for seed in range(20):
|
| 171 |
+
out = sim.simulate(
|
| 172 |
+
theta=theta, n_samples=1,
|
| 173 |
+
no_noise=False, delta_t=0.001,
|
| 174 |
+
random_state=seed, smooth_unif=False,
|
| 175 |
+
)
|
| 176 |
+
_patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
|
| 177 |
+
ndts.append(float(out["metadata"]["t"].flat[0]))
|
| 178 |
+
|
| 179 |
+
ndts = np.array(ndts)
|
| 180 |
+
assert (ndts > 0).all(), "all patched NDTs should be strictly positive for Rayleigh"
|
| 181 |
+
assert ndts.std() > 0.05, f"patched NDTs should vary across seeds, got std={ndts.std():.4f}"
|
| 182 |
+
# Mean should be in the right ballpark of E[Rayleigh(0.3)] = 0.3 * sqrt(pi/2) ≈ 0.376.
|
| 183 |
+
# 20 samples is noisy, so a wide tolerance.
|
| 184 |
+
assert 0.15 < ndts.mean() < 0.65, f"unexpected mean NDT {ndts.mean():.4f}"
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def test_patch_trajectory_t_noop_for_constant_ndt_model():
|
| 188 |
+
"""For plain ddm (no NDT variability), patched NDT should match input t."""
|
| 189 |
+
sim = Simulator(model="ddm")
|
| 190 |
+
input_t = 0.3
|
| 191 |
+
theta = _theta_list("ddm", {"v": 0.8, "a": 1.2, "z": 0.5, "t": input_t})
|
| 192 |
+
|
| 193 |
+
out = sim.simulate(
|
| 194 |
+
theta=theta, n_samples=1,
|
| 195 |
+
no_noise=False, delta_t=0.001,
|
| 196 |
+
random_state=42, smooth_unif=False,
|
| 197 |
+
)
|
| 198 |
+
_patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
|
| 199 |
+
# Allow a single delta_t of slop from integer truncation of decision time.
|
| 200 |
+
assert out["metadata"]["t"][0] == pytest.approx(input_t, abs=2e-3)
|
|
The diff for this file is too large to render.
See raw diff
|
|
|