Alexander Claude Opus 4.7 (1M context) commited on
Commit
ecf42d4
·
1 Parent(s): 84e2ec8

Render `_s*` model plots at expected variability + per-sample NDT

Browse files

The cartoon and trajectory plots for DDM variants with random across-trial
parameters (ddm_sdv, ddm_st, ddm_truncnormt, ddm_rayleight) were showing
misleading information:

1. The noiseless "model cartoon" line rendered at slope `v + epsilon` for
one random epsilon ~ Normal(0, sv) instead of the deterministic skeleton
at slope `v`. Reason: ssm-simulators `no_noise=True` only zeros Brownian
noise, not across-trial variability draws.

2. Sample trajectories all started at the input `t` even when each actual
per-sample NDT was drawn from a random distribution (Uniform/TruncNorm/
Rayleigh). Reason: simulator metadata records the input `t`, not the
per-sample NDT actually drawn from `t_dist`.

GUI-side fixes:
- `_VARIABILITY_COLLAPSE` overrides each `_s*` model's variability params
before the cartoon-only simulator call.
- `_EXPECTED_T_POST_SHIFT` post-adjusts `metadata['t']` for distributions
whose param=0 collapse != expectation (ddm_rayleight: st * sqrt(pi/2)).
- `_patch_trajectory_t_with_actual_ndt` back-derives per-sample NDT from
`rt - decision_time` and overrides `metadata['t']` per trajectory.
- `expected_random_params` flag (default True) on plot_func_model{,_n},
surfaced as a sidebar checkbox in app.py.

Both fixes are GUI-scoped, no ssm-simulators changes needed. TODO
comments mark the lift points for eventual upstream relocation.

Build infra: migrate to uv (pyproject.toml + uv.lock), update Dockerfile
to install via `uv sync --frozen --no-dev` and entrypoint via `uv run`,
drop requirements.txt. Local workflow: `uv sync && uv run pytest`.

Tests: 17 new tests covering helpers and end-to-end behavior on each
affected model.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ .pytest_cache/
5
+ .DS_Store
Dockerfile CHANGED
@@ -8,13 +8,17 @@ RUN apt-get update && apt-get install -y \
8
  git \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
 
 
 
 
13
 
14
- RUN pip3 install -r requirements.txt
15
 
16
  EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
8
  git \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
+ # Install uv (multi-stage copy from the official image — no curl/install scripts).
12
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
13
+
14
+ # Resolve and install dependencies from the lockfile (production: no dev group).
15
+ COPY pyproject.toml uv.lock ./
16
+ RUN uv sync --frozen --no-dev
17
 
18
+ COPY src/ ./src/
19
 
20
  EXPOSE 8501
21
 
22
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
23
 
24
+ ENTRYPOINT ["uv", "run", "--no-sync", "streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ssms-gui"
3
+ version = "0.1.0"
4
+ description = "Streamlit dashboard for visualizing sequential sampling models from ssm-simulators."
5
+ requires-python = ">=3.13"
6
+ dependencies = [
7
+ "altair",
8
+ "cloudpickle",
9
+ "matplotlib",
10
+ "pandas",
11
+ "seaborn",
12
+ "ssm-simulators>=0.12.2",
13
+ "streamlit>1.30.0",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ "pytest",
19
+ ]
20
+
21
+ [tool.pytest.ini_options]
22
+ testpaths = ["tests"]
requirements.txt DELETED
@@ -1,7 +0,0 @@
1
- altair
2
- pandas
3
- ssm-simulators>=0.12.2
4
- cloudpickle
5
- streamlit>1.30.0
6
- matplotlib
7
- seaborn
 
 
 
 
 
 
 
 
src/app.py CHANGED
@@ -170,11 +170,22 @@ def create_styling_selectors(model_num: int = 1):
170
  key=f"show_ndt_{model_num}_{st.session_state['styling_version']}"
171
  )
172
  styling_config["add_data_model_keep_starting_point"] = st.checkbox(
173
- "Show Starting Point",
174
  value=True,
175
  key=f"show_start_{model_num}_{st.session_state['styling_version']}"
176
  )
177
 
 
 
 
 
 
 
 
 
 
 
 
178
  # Axis Limits Section
179
  st.markdown("**Axis Limits**")
180
  styling_config["xlim_min"] = st.number_input(
@@ -288,14 +299,15 @@ def get_filtered_styling_config(styling_config, plot_type="plot_func_model"):
288
  elif plot_type == "plot_func_model_n":
289
  # plot_func_model_n only accepts a subset of parameters
290
  allowed_params = {
291
- 'linewidth_histogram', 'linewidth_model', 'bin_size',
292
- 'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow',
293
  'add_legend', 'add_data_model_markersize_starting_point',
294
  'add_data_model_markertype_starting_point',
295
  'add_data_model_keep_starting_point',
296
  'add_data_model_keep_boundary',
297
- 'add_data_model_keep_slope',
298
- 'add_data_model_keep_ndt'
 
299
  }
300
  return {k: v for k, v in styling_config.items() if k in allowed_params}
301
 
 
170
  key=f"show_ndt_{model_num}_{st.session_state['styling_version']}"
171
  )
172
  styling_config["add_data_model_keep_starting_point"] = st.checkbox(
173
+ "Show Starting Point",
174
  value=True,
175
  key=f"show_start_{model_num}_{st.session_state['styling_version']}"
176
  )
177
 
178
+ styling_config["expected_random_params"] = st.checkbox(
179
+ "Cartoon at expected variability params",
180
+ value=True,
181
+ help=(
182
+ "When on, the noiseless model cartoon reflects the expected slope/NDT "
183
+ "instead of one random realization of across-trial variability "
184
+ "(e.g. sv, st in ddm_sdv, ddm_st, ddm_truncnormt, ddm_rayleight)."
185
+ ),
186
+ key=f"expected_random_params_{model_num}_{st.session_state['styling_version']}",
187
+ )
188
+
189
  # Axis Limits Section
190
  st.markdown("**Axis Limits**")
191
  styling_config["xlim_min"] = st.number_input(
 
299
  elif plot_type == "plot_func_model_n":
300
  # plot_func_model_n only accepts a subset of parameters
301
  allowed_params = {
302
+ 'linewidth_histogram', 'linewidth_model', 'bin_size',
303
+ 'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow',
304
  'add_legend', 'add_data_model_markersize_starting_point',
305
  'add_data_model_markertype_starting_point',
306
  'add_data_model_keep_starting_point',
307
  'add_data_model_keep_boundary',
308
+ 'add_data_model_keep_slope',
309
+ 'add_data_model_keep_ndt',
310
+ 'expected_random_params',
311
  }
312
  return {k: v for k, v in styling_config.items() if k in allowed_params}
313
 
src/utils/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (276 Bytes)
 
src/utils/__pycache__/utils.cpython-311.pyc DELETED
Binary file (26 kB)
 
src/utils/utils.py CHANGED
@@ -8,6 +8,81 @@ _model_config = get_model_config()
8
  from matplotlib.lines import Line2D
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def plot_func_model(
12
  model_name,
13
  theta,
@@ -35,6 +110,7 @@ def plot_func_model(
35
  delta_t_model=0.001,
36
  random_state=None,
37
  add_legend=True, # keep_frame=False,
 
38
  **kwargs,
39
  ):
40
  """Calculate posterior predictive for a certain bottom node.
@@ -134,10 +210,16 @@ def plot_func_model(
134
  sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
135
  no_noise=False, delta_t=0.001,
136
  random_state=rand_int, smooth_unif=False)
 
137
 
138
- sim_out_no_noise = sim.simulate(theta=theta, n_samples=1,
 
139
  no_noise=True, delta_t=0.001,
140
  smooth_unif=False)
 
 
 
 
141
 
142
  # ADD DATA HISTOGRAMS
143
  weights_up = np.tile(
@@ -453,6 +535,7 @@ def plot_func_model_n(
453
  alpha=1,
454
  keep_frame=False,
455
  random_state=None,
 
456
  **kwargs,
457
  ):
458
  """Calculate posterior predictive for a certain bottom node.
@@ -565,10 +648,16 @@ def plot_func_model_n(
565
  sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
566
  no_noise=False, delta_t=0.001,
567
  random_state=rand_int, smooth_unif=False)
 
568
 
569
- sim_out_no_noise = sim.simulate(theta=theta, n_samples=1,
 
570
  no_noise=True, delta_t=0.001,
571
  smooth_unif=False)
 
 
 
 
572
 
573
  # ADD HISTOGRAMS
574
  # -------------------------------
 
8
  from matplotlib.lines import Line2D
9
 
10
 
11
+ # TODO(ssm-simulators): upstream this as Simulator.simulate(variability_at_expectation=True).
12
+ # For models with across-trial random parameters (e.g. ddm_sdv: v ~ N(v, sv)), `no_noise=True`
13
+ # only zeros Brownian noise — the variability draw still happens, so the cartoon shows one
14
+ # random realization of the variability term instead of the model's deterministic skeleton.
15
+ # These tables let us collapse those draws to their expected value for the cartoon-only path.
16
+ #
17
+ # Maps model_name -> {variability_param_name: collapsed_value}.
18
+ _VARIABILITY_COLLAPSE = {
19
+ "ddm_sdv": {"sv": 0.0}, # N(0, 0) -> 0 = E[N(0, sv)]
20
+ "ddm_st": {"st": 0.0}, # U(0, 0) -> 0 = E[U(-st, st)]
21
+ "ddm_truncnormt": {"st": 1e-9}, # TruncN(mt, ~0) -> mt ≈ E[TruncN(mt, st)]; eps avoids div-by-zero in the config lambda (a=-mt/st)
22
+ "ddm_rayleight": {"st": 0.0}, # Rayleigh(0) -> 0; metadata['t'] is then post-shifted to st·sqrt(pi/2)
23
+ }
24
+
25
+ # Models where collapsing variability does NOT yield the distribution's expectation.
26
+ # Shift `metadata['t']` after simulation by the expected NDT contribution.
27
+ _EXPECTED_T_POST_SHIFT = {
28
+ "ddm_rayleight": lambda theta_d: theta_d["st"] * np.sqrt(np.pi / 2.0),
29
+ }
30
+
31
+
32
+ def _collapse_theta_for_cartoon(model_name, theta):
33
+ """Return a copy of `theta` with random-variability params replaced for the cartoon-only path.
34
+
35
+ `theta` is the list-of-lists shape that `app.py` passes to `Simulator.simulate`: positional
36
+ in `_model_config[model_name]['params']` order. Models without an entry in the table are
37
+ returned unchanged.
38
+ """
39
+ overrides = _VARIABILITY_COLLAPSE.get(model_name)
40
+ if not overrides:
41
+ return theta
42
+ params = _model_config[model_name]["params"]
43
+ inner = list(theta[0])
44
+ for name, value in overrides.items():
45
+ inner[params.index(name)] = value
46
+ return [inner]
47
+
48
+
49
+ def _apply_expected_t_shift(model_name, theta_dict, sim_out):
50
+ """Mutate `sim_out['metadata']['t']` in place to reflect E[NDT] for models that need it.
51
+
52
+ Required for distributions whose param=0 collapse does not equal the expectation
53
+ (currently only ddm_rayleight: Rayleigh(0) = 0 but E[Rayleigh(scale=st)] = st·sqrt(pi/2)).
54
+ """
55
+ shift_fn = _EXPECTED_T_POST_SHIFT.get(model_name)
56
+ if shift_fn is None:
57
+ return
58
+ shift = shift_fn(theta_dict)
59
+ t_arr = np.asarray(sim_out["metadata"]["t"])
60
+ sim_out["metadata"]["t"] = (t_arr + shift).astype(t_arr.dtype, copy=False)
61
+
62
+
63
+ # TODO(ssm-simulators): expose t_samplewise (and v_samplewise, z_samplewise) in metadata
64
+ # so consumers don't have to back-derive per-sample NDT from RT and the trajectory.
65
+ def _patch_trajectory_t_with_actual_ndt(sim_out, delta_t=0.001):
66
+ """Override `sim_out['metadata']['t']` with the per-sample NDT actually used for this trajectory.
67
+
68
+ The simulator records the *input* `t` in metadata, not the per-sample draw from `t_dist`
69
+ (e.g. Uniform/TruncNormal/Rayleigh for ddm_st / ddm_truncnormt / ddm_rayleight).
70
+ Without this patch, every trajectory plotter starts at the input `t` and ignores NDT
71
+ variability. We back-derive the per-sample NDT from `RT - decision_time`, where
72
+ `decision_time = (last valid trajectory index) * delta_t`. Trajectory calls in this
73
+ module use `smooth_unif=False`, so `RT = decision_time + NDT` exactly.
74
+
75
+ No-op for models without NDT variability (actual_ndt == input t).
76
+ """
77
+ traj = np.asarray(sim_out["metadata"]["trajectory"]).flatten()
78
+ valid_idx = np.where(traj > -999)[0]
79
+ decision_time = float(valid_idx[-1] * delta_t) if len(valid_idx) else 0.0
80
+ rt = float(np.asarray(sim_out["rts"]).flat[0])
81
+ actual_ndt = max(0.0, rt - decision_time)
82
+ t_arr = np.asarray(sim_out["metadata"]["t"])
83
+ sim_out["metadata"]["t"] = np.full_like(t_arr, actual_ndt, dtype=t_arr.dtype)
84
+
85
+
86
  def plot_func_model(
87
  model_name,
88
  theta,
 
110
  delta_t_model=0.001,
111
  random_state=None,
112
  add_legend=True, # keep_frame=False,
113
+ expected_random_params=True,
114
  **kwargs,
115
  ):
116
  """Calculate posterior predictive for a certain bottom node.
 
210
  sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
211
  no_noise=False, delta_t=0.001,
212
  random_state=rand_int, smooth_unif=False)
213
+ _patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001)
214
 
215
+ theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta
216
+ sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1,
217
  no_noise=True, delta_t=0.001,
218
  smooth_unif=False)
219
+ if expected_random_params:
220
+ params = _model_config[model_name]["params"]
221
+ theta_dict = dict(zip(params, theta[0]))
222
+ _apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise)
223
 
224
  # ADD DATA HISTOGRAMS
225
  weights_up = np.tile(
 
535
  alpha=1,
536
  keep_frame=False,
537
  random_state=None,
538
+ expected_random_params=True,
539
  **kwargs,
540
  ):
541
  """Calculate posterior predictive for a certain bottom node.
 
648
  sim_out_traj[i] = sim.simulate(theta=theta, n_samples=1,
649
  no_noise=False, delta_t=0.001,
650
  random_state=rand_int, smooth_unif=False)
651
+ _patch_trajectory_t_with_actual_ndt(sim_out_traj[i], delta_t=0.001)
652
 
653
+ theta_cartoon = _collapse_theta_for_cartoon(model_name, theta) if expected_random_params else theta
654
+ sim_out_no_noise = sim.simulate(theta=theta_cartoon, n_samples=1,
655
  no_noise=True, delta_t=0.001,
656
  smooth_unif=False)
657
+ if expected_random_params:
658
+ params = _model_config[model_name]["params"]
659
+ theta_dict = dict(zip(params, theta[0]))
660
+ _apply_expected_t_shift(model_name, theta_dict, sim_out_no_noise)
661
 
662
  # ADD HISTOGRAMS
663
  # -------------------------------
tests/conftest.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src"))
tests/test_variability_collapse.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Regression tests for the cartoon-variability collapse helpers in utils/utils.py.
2
+
3
+ These guard against the misleading behavior where `no_noise=True` cartoon plots in
4
+ ssms-gui rendered one random realization of across-trial variability (e.g. a slope
5
+ of `v + epsilon` for ddm_sdv) instead of the model's deterministic skeleton.
6
+ """
7
+
8
+ import numpy as np
9
+ import pytest
10
+
11
+ from ssms.basic_simulators import Simulator
12
+
13
+ from utils.utils import (
14
+ _VARIABILITY_COLLAPSE,
15
+ _EXPECTED_T_POST_SHIFT,
16
+ _apply_expected_t_shift,
17
+ _collapse_theta_for_cartoon,
18
+ _patch_trajectory_t_with_actual_ndt,
19
+ )
20
+
21
+
22
+ # (model_name, theta_inner, sv_or_st_index, collapsed_value)
23
+ COLLAPSE_CASES = [
24
+ ("ddm_sdv", [0.5, 1.0, 0.5, 0.3, 1.5], 4, 0.0),
25
+ ("ddm_st", [0.5, 1.0, 0.5, 0.3, 0.2], 4, 0.0),
26
+ ("ddm_truncnormt", [0.5, 1.0, 0.5, 0.4, 0.1], 4, 1e-9),
27
+ ("ddm_rayleight", [0.5, 1.0, 0.5, 0.3], 3, 0.0),
28
+ ]
29
+
30
+
31
+ @pytest.mark.parametrize("model_name,theta_inner,idx,collapsed", COLLAPSE_CASES)
32
+ def test_collapse_replaces_only_variability_slot(model_name, theta_inner, idx, collapsed):
33
+ out = _collapse_theta_for_cartoon(model_name, [list(theta_inner)])
34
+ assert len(out) == 1
35
+ assert len(out[0]) == len(theta_inner)
36
+ for i, val in enumerate(theta_inner):
37
+ if i == idx:
38
+ assert out[0][i] == collapsed
39
+ else:
40
+ assert out[0][i] == val
41
+
42
+
43
+ def test_collapse_passthrough_for_unmapped_model():
44
+ theta = [[0.5, 1.0, 0.5, 0.3]]
45
+ assert _collapse_theta_for_cartoon("ddm", theta) is theta
46
+
47
+
48
+ def test_collapse_does_not_mutate_input():
49
+ theta = [[0.5, 1.0, 0.5, 0.3, 1.5]]
50
+ original = [list(theta[0])]
51
+ _collapse_theta_for_cartoon("ddm_sdv", theta)
52
+ assert theta[0] == original[0]
53
+
54
+
55
+ def test_apply_expected_t_shift_rayleigh():
56
+ st_val = 0.3
57
+ sim_out = {"metadata": {"t": np.array([0.0], dtype=np.float32)}}
58
+ _apply_expected_t_shift("ddm_rayleight", {"st": st_val}, sim_out)
59
+ expected = st_val * np.sqrt(np.pi / 2.0)
60
+ assert sim_out["metadata"]["t"][0] == pytest.approx(expected, rel=1e-5)
61
+
62
+
63
+ def test_apply_expected_t_shift_noop_for_unmapped_model():
64
+ sim_out = {"metadata": {"t": np.array([0.123], dtype=np.float32)}}
65
+ _apply_expected_t_shift("ddm_sdv", {"sv": 1.0}, sim_out)
66
+ assert sim_out["metadata"]["t"][0] == pytest.approx(0.123)
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # End-to-end: simulate with collapsed theta and `no_noise=True`. The trajectory's
71
+ # slope must equal `v` (no spurious tilt from a residual variability draw).
72
+ # Smoke test: with collapsed `sv`/`st`, a `no_noise=True` simulation should be
73
+ # fully deterministic up to integer truncation. We sanity-check the slope of
74
+ # `metadata['trajectory']` against `v`.
75
+ # ---------------------------------------------------------------------------
76
+
77
+ E2E_CASES = [
78
+ ("ddm_sdv", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "sv": 1.5}),
79
+ ("ddm_st", {"v": 0.8, "a": 1.2, "z": 0.5, "t": 0.1, "st": 0.3}),
80
+ ("ddm_truncnormt", {"v": 0.8, "a": 1.2, "z": 0.5, "mt": 0.4, "st": 0.1}),
81
+ ("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}),
82
+ ]
83
+
84
+
85
+ def _theta_list(model_name, theta_dict):
86
+ from ssms.config import get_model_config
87
+ params = get_model_config()[model_name]["params"]
88
+ return [[theta_dict[p] for p in params]]
89
+
90
+
91
+ @pytest.mark.parametrize("model_name,theta_dict", E2E_CASES)
92
+ def test_collapsed_cartoon_slope_matches_v(model_name, theta_dict):
93
+ theta = _theta_list(model_name, theta_dict)
94
+ theta_cartoon = _collapse_theta_for_cartoon(model_name, theta)
95
+
96
+ sim = Simulator(model=model_name)
97
+ out = sim.simulate(
98
+ theta=theta_cartoon, n_samples=1,
99
+ no_noise=True, delta_t=0.001, smooth_unif=False,
100
+ random_state=42,
101
+ )
102
+
103
+ traj = np.asarray(out["metadata"]["trajectory"]).flatten()
104
+ valid = traj[traj > -999]
105
+ # Need at least a handful of points to fit a line; constant-boundary DDMs reach this easily.
106
+ assert len(valid) >= 10, f"too few trajectory points for {model_name}: {len(valid)}"
107
+
108
+ # Drop the last point (boundary crossing — slightly past the bound), fit a line.
109
+ delta_t = 0.001
110
+ t_axis = np.arange(len(valid)) * delta_t
111
+ slope, _ = np.polyfit(t_axis[:-1], valid[:-1], 1)
112
+ assert slope == pytest.approx(theta_dict["v"], abs=0.05), (
113
+ f"{model_name}: cartoon slope {slope:.4f} != v {theta_dict['v']}"
114
+ )
115
+
116
+
117
+ def test_rayleight_metadata_t_post_shift_e2e():
118
+ theta_dict = {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3}
119
+ theta = _theta_list("ddm_rayleight", theta_dict)
120
+ theta_cartoon = _collapse_theta_for_cartoon("ddm_rayleight", theta)
121
+
122
+ sim = Simulator(model="ddm_rayleight")
123
+ out = sim.simulate(
124
+ theta=theta_cartoon, n_samples=1,
125
+ no_noise=True, delta_t=0.001, smooth_unif=False,
126
+ random_state=42,
127
+ )
128
+ _apply_expected_t_shift("ddm_rayleight", theta_dict, out)
129
+
130
+ expected = theta_dict["st"] * np.sqrt(np.pi / 2.0)
131
+ assert out["metadata"]["t"].flat[0] == pytest.approx(expected, rel=1e-4)
132
+
133
+
134
+ def test_collapse_table_keys_match_post_shift_keys_subset():
135
+ # Every model in the post-shift table must also be in the collapse table.
136
+ assert set(_EXPECTED_T_POST_SHIFT.keys()).issubset(_VARIABILITY_COLLAPSE.keys())
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Trajectory NDT patch: metadata['t'] is the *input* t, not the per-sample NDT
141
+ # actually drawn. The patch back-derives the actual NDT from rt - decision_time
142
+ # so the trajectory plotter starts each line at its own NDT.
143
+ # ---------------------------------------------------------------------------
144
+
145
+
146
+ def test_patch_trajectory_t_recovers_known_ndt():
147
+ delta_t = 0.001
148
+ # Trajectory: 50 valid steps then -999. decision_time = 50 * delta_t = 0.05s.
149
+ traj = np.full(200, -999.0, dtype=np.float32)
150
+ traj[:51] = np.linspace(0.0, 1.0, 51) # indices 0..50 valid
151
+ rt = 0.05 + 0.42 # decision_time + arbitrary NDT
152
+ sim_out = {
153
+ "rts": np.array([[[rt]]], dtype=np.float32),
154
+ "metadata": {
155
+ "t": np.array([0.0], dtype=np.float32),
156
+ "trajectory": traj,
157
+ },
158
+ }
159
+ _patch_trajectory_t_with_actual_ndt(sim_out, delta_t=delta_t)
160
+ assert sim_out["metadata"]["t"][0] == pytest.approx(0.42, abs=1e-5)
161
+
162
+
163
+ def test_patch_trajectory_t_e2e_rayleigh_varies_across_calls():
164
+ """Across n trajectory calls with different seeds, the patched NDTs must vary
165
+ (Rayleigh-distributed) — not all be identical to the input t=0."""
166
+ sim = Simulator(model="ddm_rayleight")
167
+ theta = _theta_list("ddm_rayleight", {"v": 0.8, "a": 1.2, "z": 0.5, "st": 0.3})
168
+
169
+ ndts = []
170
+ for seed in range(20):
171
+ out = sim.simulate(
172
+ theta=theta, n_samples=1,
173
+ no_noise=False, delta_t=0.001,
174
+ random_state=seed, smooth_unif=False,
175
+ )
176
+ _patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
177
+ ndts.append(float(out["metadata"]["t"].flat[0]))
178
+
179
+ ndts = np.array(ndts)
180
+ assert (ndts > 0).all(), "all patched NDTs should be strictly positive for Rayleigh"
181
+ assert ndts.std() > 0.05, f"patched NDTs should vary across seeds, got std={ndts.std():.4f}"
182
+ # Mean should be in the right ballpark of E[Rayleigh(0.3)] = 0.3 * sqrt(pi/2) ≈ 0.376.
183
+ # 20 samples is noisy, so a wide tolerance.
184
+ assert 0.15 < ndts.mean() < 0.65, f"unexpected mean NDT {ndts.mean():.4f}"
185
+
186
+
187
+ def test_patch_trajectory_t_noop_for_constant_ndt_model():
188
+ """For plain ddm (no NDT variability), patched NDT should match input t."""
189
+ sim = Simulator(model="ddm")
190
+ input_t = 0.3
191
+ theta = _theta_list("ddm", {"v": 0.8, "a": 1.2, "z": 0.5, "t": input_t})
192
+
193
+ out = sim.simulate(
194
+ theta=theta, n_samples=1,
195
+ no_noise=False, delta_t=0.001,
196
+ random_state=42, smooth_unif=False,
197
+ )
198
+ _patch_trajectory_t_with_actual_ndt(out, delta_t=0.001)
199
+ # Allow a single delta_t of slop from integer truncation of decision time.
200
+ assert out["metadata"]["t"][0] == pytest.approx(input_t, abs=2e-3)
uv.lock ADDED
The diff for this file is too large to render. See raw diff