Alexander commited on
Commit
500dcd7
·
1 Parent(s): 36bf2f6

bunch of visual improvements and first round of exposing styling options to users

Browse files
src/app.py CHANGED
@@ -9,9 +9,9 @@ from ssms.basic_simulators.simulator import simulator
9
  import pandas as pd
10
  import utils
11
 
12
-
13
  # Function to create input select widgets
14
  def create_param_selectors(model_name: str, model_num: int = 1):
 
15
  d_config = model_config[model_name]
16
  params = d_config["params"]
17
  param_bounds_low = d_config["param_bounds"][0]
@@ -35,67 +35,379 @@ def create_param_selectors(model_name: str, model_num: int = 1):
35
  key=f"param{i}"
36
  f"_{model_name}"
37
  f"_{model_num}"
38
- f'_{st.session_state["slider_version"]}',
39
  )
40
  return d_param_slider
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def add_model():
44
  pass
45
 
 
 
 
 
 
 
46
 
47
- def reset_sliders():
48
- st.session_state["slider_version"] += 1
 
49
 
 
 
 
 
 
50
 
51
  # Initialize a slider version attribute state. Is used for resetting values
52
- if "slider_version" not in st.session_state:
53
- st.session_state["slider_version"] = 1
54
 
55
 
56
  def draw_model_configurator(model_num=1):
57
  # Create widgets for the sidebar
58
- # st.markdown("<h2 style='text-align: center; color: black;'>Model Configurator</h1>",
59
- # unsafe_allow_html=True)
60
- # Dropdown selection of model name
61
- model_select = st.selectbox(
62
- "Model " + str(model_num), l_model_names, key="modelname" + str(model_num)
63
- )
64
- # Sliders for param values
65
- d_slider = create_param_selectors(model_select, model_num=model_num)
66
  # Number of data points to simulate
67
  nsamples = st.number_input("NSamples", value=5000, key="size" + str(model_num))
 
68
  # Number of trajectories to show
69
  ntrajectories = st.number_input(
70
- "NTrajectories", value=0, key="ntraj" + str(model_num)
71
  )
72
- return model_select, d_slider, nsamples, ntrajectories
73
 
 
 
 
 
74
 
75
  st.set_page_config(layout="wide")
76
 
77
  # Get list of model names
78
  l_model_names = list(model_config.keys())
79
 
 
 
 
 
 
 
 
 
80
  with st.sidebar:
81
- col1, col2 = st.columns(2)
82
- with col1:
83
- model_select_1, d_slider_1, nsamples_1, ntrajectories_1 = (
84
- draw_model_configurator(model_num=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  )
86
- with col2:
87
- model_select_2, d_slider_2, nsamples_2, ntrajectories_2 = (
88
- draw_model_configurator(model_num=2)
 
 
 
 
89
  )
90
 
91
- # Button to reset sliders to default values
92
- randomseed = st.number_input("RandomSeed", value=0, key="seed")
93
- st.button(
94
- "Reset",
95
- help="Reset parameters to defaults",
96
- key="reset",
97
- on_click=reset_sliders,
98
- )
99
 
100
  # st.title("SSM Model Plots", )
101
  st.markdown(
@@ -106,70 +418,67 @@ st.markdown(
106
  # Display components for main panel
107
  fig1, ax1 = plt.subplots()
108
  if model_config[model_select_1]["nchoices"] == 2 and not ("race" in model_select_1):
 
 
 
109
  ax1 = utils.utils.plot_func_model(
110
  model_name=model_select_1,
111
  theta=[list(d_slider_1.values())],
112
  axis=ax1,
113
- value_range=[-0.1, 5],
114
  n_samples=nsamples_1,
115
- ylim=5,
116
- data_color="blue",
117
- add_trajectories=True,
118
  n_trajectories=ntrajectories_1,
119
- linewidth_model=1,
120
- linewidth_histogram=1,
121
- random_state=randomseed,
122
  )
123
  else:
 
 
124
  ax1 = utils.utils.plot_func_model_n(
125
  model_name=model_select_1,
126
  theta=[list(d_slider_1.values())],
127
  axis=ax1,
128
- value_range=[-0.1, 5],
129
  n_samples=nsamples_1,
130
- data_color="blue",
131
- add_trajectories=True,
132
  n_trajectories=ntrajectories_1,
133
- linewidth_model=1,
134
- linewidth_histogram=1,
135
- random_state=randomseed,
136
  )
137
- ax1.set_title("Model 1")
138
- ax1.set_xlabel("RT in seconds")
139
 
140
  fig2, ax2 = plt.subplots()
141
  if model_config[model_select_2]["nchoices"] == 2 and not ("race" in model_select_2):
 
 
142
  ax2 = utils.utils.plot_func_model(
143
  model_name=model_select_2,
144
  theta=[list(d_slider_2.values())],
145
  axis=ax2,
146
- value_range=[-0.1, 5],
147
  n_samples=nsamples_2,
148
- ylim=5,
149
- data_color="red",
150
- add_trajectories=True,
151
  n_trajectories=ntrajectories_2,
152
- linewidth_model=1,
153
- linewidth_histogram=1,
154
- random_state=randomseed,
155
  )
156
  else:
 
 
157
  ax2 = utils.utils.plot_func_model_n(
158
  model_name=model_select_2,
159
  theta=[list(d_slider_2.values())],
160
  axis=ax2,
161
- value_range=[-0.1, 5],
162
  n_samples=nsamples_2,
163
- data_color="red",
164
- add_trajectories=True,
165
  n_trajectories=ntrajectories_2,
166
- linewidth_model=1,
167
- linewidth_histogram=1,
168
- random_state=randomseed,
169
  )
170
 
171
- ax2.set_title("Model 2")
172
- ax2.set_xlabel("RT in seconds")
173
 
174
  # Place figure in placeholder
175
  col1, col2 = st.columns(2)
@@ -185,29 +494,32 @@ sim_output_1 = simulator(
185
  model=model_select_1,
186
  theta=[list(d_slider_1.values())],
187
  n_samples=nsamples_1,
188
- random_state=randomseed,
189
  )
190
  sim_output_2 = simulator(
191
  model=model_select_2,
192
  theta=[list(d_slider_2.values())],
193
  n_samples=nsamples_2,
194
- random_state=randomseed,
195
  )
196
 
197
  # Make metadata dataframe
 
 
 
198
  metadata = pd.DataFrame(
199
  {
200
  "Model 1": [
201
- sim_output_1["metadata"]["model"],
202
- sim_output_1["choice_p"][0, 0],
203
- sim_output_1["rts"].mean(),
204
- sim_output_1["metadata"]["s"],
205
  ],
206
  "Model 2": [
207
- sim_output_2["metadata"]["model"],
208
- sim_output_2["choice_p"][0, 0],
209
- sim_output_2["rts"].mean(),
210
- sim_output_2["metadata"]["s"],
211
  ],
212
  },
213
  index=["Model", "Choice Probability", "Mean RT", "Noise SD"],
@@ -230,9 +542,9 @@ with col3:
230
  histtype="step",
231
  bins=50,
232
  density=True,
233
- color="blue",
234
  fill=None,
235
- label="Model 1",
236
  )
237
  ax3.hist(
238
  sim_output_2["rts"][np.abs(sim_output_2["rts"]) != 999]
@@ -240,12 +552,12 @@ with col3:
240
  histtype="step",
241
  bins=50,
242
  density=True,
243
- color="red",
244
  fill=None,
245
- label="Model 2",
246
  )
247
  ax3.legend()
248
- ax3.set_xlabel("RT")
249
  ax3.set_xlim(-5, 5)
250
  figure_placeholder_3.pyplot(fig3)
251
  else:
@@ -253,4 +565,4 @@ with col3:
253
  # for models with more than 2 choice options
254
  pass
255
  with col4:
256
- st.dataframe(metadata)
 
9
  import pandas as pd
10
  import utils
11
 
 
12
  # Function to create input select widgets
13
  def create_param_selectors(model_name: str, model_num: int = 1):
14
+
15
  d_config = model_config[model_name]
16
  params = d_config["params"]
17
  param_bounds_low = d_config["param_bounds"][0]
 
35
  key=f"param{i}"
36
  f"_{model_name}"
37
  f"_{model_num}"
38
+ f'_{st.session_state["param_version"]}',
39
  )
40
  return d_param_slider
41
 
42
 
43
+ def create_styling_selectors(model_num: int = 1):
44
+ """
45
+ Create styling configuration widgets for plot customization.
46
+
47
+ This function creates Streamlit widgets that allow users to customize
48
+ various visual aspects of the plots including colors, line widths,
49
+ alpha, and which model components to display.
50
+
51
+ Note: This version is designed to work in the sidebar without using st.columns()
52
+
53
+ Args:
54
+ model_num: Integer identifier for the model (1 or 2)
55
+
56
+ Returns:
57
+ dict: Dictionary containing all styling parameters with their user-selected values
58
+ """
59
+
60
+ # Color options for different plot elements
61
+ color_options = ["blue", "red", "green", "orange", "purple", "black", "gray", "brown"]
62
+
63
+ # Legend location options (matplotlib standard locations)
64
+ legend_locations = ["upper right", "upper left", "lower left", "lower right",
65
+ "center", "upper center", "lower center", "center left", "center right"]
66
+
67
+ # Marker type options for trajectories
68
+ marker_options = { "Diamond": "D", "Square": "s", "Line": 0, "Circle": "o", "Star": "*", "Triangle": "^",
69
+ "Plus": "+", "X": "x"}
70
+
71
+ styling_config = {}
72
+
73
+ # Create an expander for styling options to keep the interface clean
74
+ with st.expander(f"🎨 Styling", expanded=False):
75
+
76
+ # Color Settings Section
77
+ st.markdown("**Colors**")
78
+ styling_config["data_color"] = st.selectbox(
79
+ "Data Color",
80
+ color_options,
81
+ index=color_options.index("blue" if model_num == 1 else "red"),
82
+ key=f"data_color_{model_num}_{st.session_state['styling_version']}"
83
+ )
84
+
85
+ styling_config["posterior_uncertainty_color"] = st.selectbox(
86
+ "Model Color",
87
+ color_options,
88
+ index=color_options.index("black"),
89
+ key=f"model_color_{model_num}_{st.session_state['styling_version']}"
90
+ )
91
+
92
+ # Line Width Settings Section
93
+ st.markdown("**Lines**")
94
+ styling_config["linewidth_histogram"] = st.slider(
95
+ "Histogram Line Width",
96
+ min_value=0.1,
97
+ max_value=3.0,
98
+ value=1.0,
99
+ step=0.1,
100
+ key=f"hist_lw_{model_num}_{st.session_state['styling_version']}"
101
+ )
102
+
103
+ styling_config["linewidth_model"] = st.slider(
104
+ "Model Line Width",
105
+ min_value=0.1,
106
+ max_value=3.0,
107
+ value=1.0,
108
+ step=0.1,
109
+ key=f"model_lw_{model_num}_{st.session_state['styling_version']}"
110
+ )
111
+
112
+ # Histogram Settings Section
113
+ st.markdown("**Histograms**")
114
+ styling_config["bin_size"] = st.slider(
115
+ "Bin Size",
116
+ min_value=0.01,
117
+ max_value=0.2,
118
+ value=0.05,
119
+ step=0.01,
120
+ key=f"bin_size_{model_num}_{st.session_state['styling_version']}"
121
+ )
122
+
123
+ styling_config["alpha"] = st.slider(
124
+ "alpha",
125
+ min_value=0.0,
126
+ max_value=1.0,
127
+ value=1.0,
128
+ step=0.05,
129
+ key=f"alpha_{model_num}_{st.session_state['styling_version']}"
130
+ )
131
+
132
+ # Model Components Section - Toggle which parts of the model to show
133
+ st.markdown("**Model Components**")
134
+ styling_config["add_data_model_keep_boundary"] = st.checkbox(
135
+ "Show Boundaries",
136
+ value=True,
137
+ key=f"show_boundary_{model_num}_{st.session_state['styling_version']}"
138
+ )
139
+ styling_config["add_data_model_keep_slope"] = st.checkbox(
140
+ "Show Slope/Trajectory",
141
+ value=True,
142
+ key=f"show_slope_{model_num}_{st.session_state['styling_version']}"
143
+ )
144
+
145
+ styling_config["add_data_model_keep_ndt"] = st.checkbox(
146
+ "Show Non-Decision Time",
147
+ value=True,
148
+ key=f"show_ndt_{model_num}_{st.session_state['styling_version']}"
149
+ )
150
+ styling_config["add_data_model_keep_starting_point"] = st.checkbox(
151
+ "Show Starting Point",
152
+ value=True,
153
+ key=f"show_start_{model_num}_{st.session_state['styling_version']}"
154
+ )
155
+
156
+ # Axis Limits Section
157
+ st.markdown("**Axis Limits**")
158
+ styling_config["xlim_min"] = st.number_input(
159
+ "x-axis min",
160
+ value=-0.1,
161
+ step=0.1,
162
+ key=f"xlim_min_{model_num}_{st.session_state['styling_version']}"
163
+ )
164
+ styling_config["xlim_max"] = st.number_input(
165
+ "x-axis max",
166
+ value=5.0,
167
+ step=0.1,
168
+ key=f"xlim_max_{model_num}_{st.session_state['styling_version']}"
169
+ )
170
+ styling_config["ylim_max"] = st.number_input(
171
+ "y-axis max",
172
+ value=3.75,
173
+ step=0.25,
174
+ key=f"ylim_max_{model_num}_{st.session_state['styling_version']}"
175
+ )
176
+
177
+ # Starting Point Marker Settings (only show if starting point is enabled)
178
+ if styling_config["add_data_model_keep_starting_point"]:
179
+ st.markdown("**Starting Point**")
180
+ styling_config["add_data_model_markersize_starting_point"] = st.slider(
181
+ "Marker Size",
182
+ min_value=10,
183
+ max_value=100,
184
+ value=35,
185
+ key=f"marker_size_{model_num}_{st.session_state['styling_version']}"
186
+ )
187
+
188
+ styling_config["add_data_model_markertype_starting_point"] = st.selectbox(
189
+ "Marker Type",
190
+ list(marker_options.keys()),
191
+ index=0, # defaulting to first entry in marker options dictionary
192
+ key=f"marker_type_{model_num}_{st.session_state['styling_version']}"
193
+ )
194
+ else:
195
+ # Set defaults when starting point is not shown
196
+ styling_config["add_data_model_markersize_starting_point"] = 35
197
+ styling_config["add_data_model_markertype_starting_point"] = "Diamond"
198
+
199
+ # AF-TODO: Legend settings weren't working yet
200
+
201
+ # # Legend Settings Section
202
+ # st.markdown("**Legend Settings**")
203
+ # styling_config["add_legend"] = st.checkbox(
204
+ # "Show Legend",
205
+ # value=True,
206
+ # key=f"show_legend_{model_num}_{st.session_state['styling_version']}"
207
+ # )
208
+
209
+ # if styling_config["add_legend"]:
210
+ # styling_config["legend_fontsize"] = st.slider(
211
+ # "Legend Font Size",
212
+ # min_value=6,
213
+ # max_value=20,
214
+ # value=12,
215
+ # key=f"legend_font_{model_num}_{st.session_state['styling_version']}"
216
+ # )
217
+
218
+ # styling_config["legend_location"] = st.selectbox(
219
+ # "Legend Location",
220
+ # legend_locations,
221
+ # index=0, # "upper right"
222
+ # key=f"legend_loc_{model_num}_{st.session_state['styling_version']}"
223
+ # )
224
+
225
+ # styling_config["legend_shadow"] = st.checkbox(
226
+ # "Legend Shadow",
227
+ # value=True,
228
+ # key=f"legend_shadow_{model_num}_{st.session_state['styling_version']}"
229
+ # )
230
+ # else:
231
+ # # Set defaults when legend is not shown
232
+ # styling_config["legend_fontsize"] = 12
233
+ # styling_config["legend_location"] = "upper right"
234
+ # styling_config["legend_shadow"] = True
235
+
236
+ # Convert marker type from display name to matplotlib code
237
+ marker_type_map = {k: v for k, v in marker_options.items()}
238
+ styling_config["add_data_model_markertype_starting_point"] = marker_type_map.get(
239
+ styling_config["add_data_model_markertype_starting_point"], "D"
240
+ )
241
+
242
+ return styling_config
243
+
244
+
245
+ def get_filtered_styling_config(styling_config, plot_type="plot_func_model"):
246
+ """
247
+ Filter styling configuration based on plot type compatibility.
248
+
249
+ Different plotting functions accept different parameters, so this function
250
+ filters the styling configuration to only include parameters that are
251
+ relevant for the specific plot type.
252
+
253
+ Args:
254
+ styling_config: Dictionary of styling parameters
255
+ plot_type: String indicating which plot function will be used
256
+ ("plot_func_model" or "plot_func_model_n")
257
+
258
+ Returns:
259
+ dict: Filtered styling configuration appropriate for the plot type
260
+ """
261
+
262
+ if plot_type == "plot_func_model":
263
+ # plot_func_model accepts all styling parameters
264
+ return styling_config
265
+
266
+ elif plot_type == "plot_func_model_n":
267
+ # plot_func_model_n only accepts a subset of parameters
268
+ allowed_params = {
269
+ 'linewidth_histogram', 'linewidth_model', 'bin_size',
270
+ 'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow',
271
+ 'add_legend', 'add_data_model_markersize_starting_point',
272
+ 'add_data_model_markertype_starting_point',
273
+ 'add_data_model_keep_starting_point',
274
+ 'add_data_model_keep_boundary',
275
+ 'add_data_model_keep_slope',
276
+ 'add_data_model_keep_ndt'
277
+ }
278
+ return {k: v for k, v in styling_config.items() if k in allowed_params}
279
+
280
+ else:
281
+ # Default: return all parameters
282
+ return styling_config
283
+
284
  def add_model():
285
  pass
286
 
287
+ # def reset_sliders():
288
+ # st.session_state["slider_version"] += 1
289
+
290
+ def reset_parameters():
291
+ """Reset only model parameters to defaults"""
292
+ st.session_state["param_version"] += 1
293
 
294
+ def reset_styling():
295
+ """Reset only styling options to defaults"""
296
+ st.session_state["styling_version"] += 1
297
 
298
+ def reset_all():
299
+ """Reset both parameters and styling to defaults"""
300
+ st.session_state["param_version"] += 1
301
+ st.session_state["styling_version"] += 1
302
+ st.session_state["slider_version"] += 1 # Keep for any remaining widgets
303
 
304
  # Initialize a slider version attribute state. Is used for resetting values
305
+ # if "slider_version" not in st.session_state:
306
+ # st.session_state["slider_version"] = 1
307
 
308
 
309
  def draw_model_configurator(model_num=1):
310
  # Create widgets for the sidebar
311
+
312
+ # 1. Dropdown selection of model name
313
+ model_select = st.selectbox("Model " + str(model_num), l_model_names, key="model_selector_" + str(model_num))
314
+
315
+ return model_select
316
+
317
+ def draw_simulation_settings(model_num=1):
 
318
  # Number of data points to simulate
319
  nsamples = st.number_input("NSamples", value=5000, key="size" + str(model_num))
320
+
321
  # Number of trajectories to show
322
  ntrajectories = st.number_input(
323
+ "NTrajectories", value=5, key="ntraj" + str(model_num)
324
  )
 
325
 
326
+ # Random seed setting
327
+ randomseed = st.number_input("RandomSeed", value=41 + model_num, key="seed_" + str(model_num))
328
+
329
+ return nsamples, ntrajectories, randomseed
330
 
331
  st.set_page_config(layout="wide")
332
 
333
  # Get list of model names
334
  l_model_names = list(model_config.keys())
335
 
336
+ # Initialize separate version attributes for parameters and styling
337
+ if "param_version" not in st.session_state:
338
+ st.session_state["param_version"] = 1
339
+
340
+ if "styling_version" not in st.session_state:
341
+ st.session_state["styling_version"] = 1
342
+
343
+
344
  with st.sidebar:
345
+ st.empty()
346
+ st.markdown("**Model Selection**")
347
+ with st.container():
348
+ st.markdown('<div style="margin-top: -1rem;">', unsafe_allow_html=True)
349
+
350
+ col1, col2 = st.columns(2)
351
+ with col1:
352
+ model_select_1 = draw_model_configurator(model_num=1)
353
+
354
+ # Styling configuration
355
+ styling_config_1 = create_styling_selectors(model_num=1)
356
+
357
+ with col2:
358
+ model_select_2 = draw_model_configurator(model_num=2)
359
+
360
+ # Styling configuration
361
+ styling_config_2 = create_styling_selectors(model_num=2)
362
+
363
+ st.markdown('</div>', unsafe_allow_html=True)
364
+
365
+ st.markdown("---")
366
+ st.markdown("**Parameters**")
367
+ col1_2, col2_2 = st.columns(2)
368
+ with col1_2:
369
+ d_slider_1 = create_param_selectors(model_select_1, model_num=1)
370
+ with col2_2:
371
+ d_slider_2 = create_param_selectors(model_select_2, model_num=2)
372
+
373
+ st.markdown("---")
374
+ st.markdown("**Simulation Settings**")
375
+ col1_3, col2_3 = st.columns(2)
376
+ with col1_3:
377
+ nsamples_1, ntrajectories_1, randomseed_1 = draw_simulation_settings(model_num=1)
378
+ with col2_3:
379
+ nsamples_2, ntrajectories_2, randomseed_2 = draw_simulation_settings(model_num=2)
380
+
381
+ # Button to reset sliders to default values
382
+ st.markdown("---")
383
+ st.markdown("**Reset Options**")
384
+
385
+ # Create three columns for the reset buttons
386
+ reset_col1, reset_col2, reset_col3 = st.columns(3)
387
+
388
+ with reset_col1:
389
+ st.button(
390
+ "Reset Params",
391
+ help="Reset model parameters to defaults",
392
+ key="reset_params",
393
+ on_click=reset_parameters,
394
  )
395
+
396
+ with reset_col2:
397
+ st.button(
398
+ "Reset Styling",
399
+ help="Reset styling options to defaults",
400
+ key="reset_styling",
401
+ on_click=reset_styling,
402
  )
403
 
404
+ with reset_col3:
405
+ st.button(
406
+ "Reset Full",
407
+ help="Reset both parameters and styling to defaults",
408
+ key="reset_all",
409
+ on_click=reset_all,
410
+ )
 
411
 
412
  # st.title("SSM Model Plots", )
413
  st.markdown(
 
418
  # Display components for main panel
419
  fig1, ax1 = plt.subplots()
420
  if model_config[model_select_1]["nchoices"] == 2 and not ("race" in model_select_1):
421
+ # Use filtered styling parameters for plot_func_model
422
+
423
+ filtered_styling_1 = get_filtered_styling_config(styling_config_1, "plot_func_model")
424
  ax1 = utils.utils.plot_func_model(
425
  model_name=model_select_1,
426
  theta=[list(d_slider_1.values())],
427
  axis=ax1,
428
+ value_range=[styling_config_1["xlim_min"], styling_config_1["xlim_max"]],
429
  n_samples=nsamples_1,
430
+ ylim=styling_config_1["ylim_max"],
 
 
431
  n_trajectories=ntrajectories_1,
432
+ random_state=randomseed_1,
433
+ **filtered_styling_1
 
434
  )
435
  else:
436
+ # Use filtered styling parameters for plot_func_model_n
437
+ filtered_styling_1 = get_filtered_styling_config(styling_config_1, "plot_func_model_n")
438
  ax1 = utils.utils.plot_func_model_n(
439
  model_name=model_select_1,
440
  theta=[list(d_slider_1.values())],
441
  axis=ax1,
442
+ value_range=[styling_config_1["xlim_min"], styling_config_1["xlim_max"]],
443
  n_samples=nsamples_1,
 
 
444
  n_trajectories=ntrajectories_1,
445
+ random_state=randomseed_1,
446
+ **filtered_styling_1
 
447
  )
448
+ ax1.set_title(model_select_1.upper())
449
+ ax1.set_xlabel("rt in seconds")
450
 
451
  fig2, ax2 = plt.subplots()
452
  if model_config[model_select_2]["nchoices"] == 2 and not ("race" in model_select_2):
453
+ # Use filtered styling parameters for plot_func_model
454
+ filtered_styling_2 = get_filtered_styling_config(styling_config_2, "plot_func_model")
455
  ax2 = utils.utils.plot_func_model(
456
  model_name=model_select_2,
457
  theta=[list(d_slider_2.values())],
458
  axis=ax2,
459
+ value_range=[styling_config_2["xlim_min"], styling_config_2["xlim_max"]],
460
  n_samples=nsamples_2,
461
+ ylim=styling_config_2["ylim_max"],
 
 
462
  n_trajectories=ntrajectories_2,
463
+ random_state=randomseed_2,
464
+ **filtered_styling_2
 
465
  )
466
  else:
467
+ # Use filtered styling parameters for plot_func_model_n
468
+ filtered_styling_2 = get_filtered_styling_config(styling_config_2, "plot_func_model_n")
469
  ax2 = utils.utils.plot_func_model_n(
470
  model_name=model_select_2,
471
  theta=[list(d_slider_2.values())],
472
  axis=ax2,
473
+ value_range=[styling_config_2["xlim_min"], styling_config_2["xlim_max"]],
474
  n_samples=nsamples_2,
 
 
475
  n_trajectories=ntrajectories_2,
476
+ random_state=randomseed_2,
477
+ **filtered_styling_2
 
478
  )
479
 
480
+ ax2.set_title(model_select_2.upper())
481
+ ax2.set_xlabel("rt in seconds")
482
 
483
  # Place figure in placeholder
484
  col1, col2 = st.columns(2)
 
494
  model=model_select_1,
495
  theta=[list(d_slider_1.values())],
496
  n_samples=nsamples_1,
497
+ random_state=randomseed_1,
498
  )
499
  sim_output_2 = simulator(
500
  model=model_select_2,
501
  theta=[list(d_slider_2.values())],
502
  n_samples=nsamples_2,
503
+ random_state=randomseed_2,
504
  )
505
 
506
  # Make metadata dataframe
507
+ # AF-TODO: Should be transposed and then resolve consequences
508
+ # because right now the dataframe has mixed data types in each columns
509
+ # which leads streamlit to complain about arrow incompatibility
510
  metadata = pd.DataFrame(
511
  {
512
  "Model 1": [
513
+ str(sim_output_1["metadata"]["model"]),
514
+ float(sim_output_1["choice_p"][0, 0]),
515
+ float(sim_output_1["rts"].mean()),
516
+ float(sim_output_1["metadata"]["s"]),
517
  ],
518
  "Model 2": [
519
+ str(sim_output_2["metadata"]["model"]),
520
+ float(sim_output_2["choice_p"][0, 0]),
521
+ float(sim_output_2["rts"].mean()),
522
+ float(sim_output_2["metadata"]["s"]),
523
  ],
524
  },
525
  index=["Model", "Choice Probability", "Mean RT", "Noise SD"],
 
542
  histtype="step",
543
  bins=50,
544
  density=True,
545
+ color=styling_config_1["data_color"], # Use user-selected color
546
  fill=None,
547
+ label=model_select_1.upper(),
548
  )
549
  ax3.hist(
550
  sim_output_2["rts"][np.abs(sim_output_2["rts"]) != 999]
 
552
  histtype="step",
553
  bins=50,
554
  density=True,
555
+ color=styling_config_2["data_color"], # Use user-selected color
556
  fill=None,
557
+ label=model_select_2.upper(),
558
  )
559
  ax3.legend()
560
+ ax3.set_xlabel("rt")
561
  ax3.set_xlim(-5, 5)
562
  figure_placeholder_3.pyplot(fig3)
563
  else:
 
565
  # for models with more than 2 choice options
566
  pass
567
  with col4:
568
+ st.dataframe(metadata)
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/src/utils/__pycache__/__init__.cpython-311.pyc and b/src/utils/__pycache__/__init__.cpython-311.pyc differ
 
src/utils/__pycache__/utils.cpython-311.pyc CHANGED
Binary files a/src/utils/__pycache__/utils.cpython-311.pyc and b/src/utils/__pycache__/utils.cpython-311.pyc differ
 
src/utils/old_plots.py DELETED
@@ -1,1157 +0,0 @@
1
- def _plot_func_model(
2
- bottom_node,
3
- axis,
4
- value_range=None,
5
- samples=10,
6
- bin_size=0.05,
7
- add_data_rts=True,
8
- add_data_model=True,
9
- add_data_model_keep_slope=True,
10
- add_data_model_keep_boundary=True,
11
- add_data_model_keep_ndt=True,
12
- add_data_model_keep_starting_point=True,
13
- add_data_model_markersize_starting_point=50,
14
- add_data_model_markertype_starting_point=0,
15
- add_data_model_markershift_starting_point=0,
16
- add_posterior_uncertainty_model=False,
17
- add_posterior_uncertainty_rts=False,
18
- add_posterior_mean_model=True,
19
- add_posterior_mean_rts=True,
20
- add_trajectories=False,
21
- data_label="Data",
22
- secondary_data=None,
23
- secondary_data_label=None,
24
- secondary_data_color="blue",
25
- linewidth_histogram=0.5,
26
- linewidth_model=0.5,
27
- legend_fontsize=12,
28
- legend_shadow=True,
29
- legend_location="upper right",
30
- data_color="blue",
31
- posterior_mean_color="red",
32
- posterior_uncertainty_color="black",
33
- alpha=0.05,
34
- delta_t_model=0.01,
35
- add_legend=True, # keep_frame=False,
36
- **kwargs,
37
- ):
38
- """Calculate posterior predictive for a certain bottom node.
39
-
40
- Arguments:
41
- bottom_node: pymc.stochastic
42
- Bottom node to compute posterior over.
43
-
44
- axis: matplotlib.axis
45
- Axis to plot into.
46
-
47
- value_range: numpy.ndarray
48
- Range over which to evaluate the likelihood.
49
-
50
- Optional:
51
- samples: int <default=10>
52
- Number of posterior samples to use.
53
-
54
- bin_size: float <default=0.05>
55
- Size of bins used for histograms
56
-
57
- alpha: float <default=0.05>
58
- alpha (transparency) level for the sample-wise elements of the plot
59
-
60
- add_data_rts: bool <default=True>
61
- Add data histogram of rts ?
62
-
63
- add_data_model: bool <default=True>
64
- Add model cartoon for data
65
-
66
- add_posterior_uncertainty_rts: bool <default=True>
67
- Add sample by sample histograms?
68
-
69
- add_posterior_mean_rts: bool <default=True>
70
- Add a mean posterior?
71
-
72
- add_model: bool <default=True>
73
- Whether to add model cartoons to the plot.
74
-
75
- linewidth_histogram: float <default=0.5>
76
- linewdith of histrogram plot elements.
77
-
78
- linewidth_model: float <default=0.5>
79
- linewidth of plot elements concerning the model cartoons.
80
-
81
- legend_location: str <default='upper right'>
82
- string defining legend position. Find the rest of the options in the matplotlib documentation.
83
-
84
- legend_shadow: bool <default=True>
85
- Add shadow to legend box?
86
-
87
- legend_fontsize: float <default=12>
88
- Fontsize of legend.
89
-
90
- data_color : str <default="blue">
91
- Color for the data part of the plot.
92
-
93
- posterior_mean_color : str <default="red">
94
- Color for the posterior mean part of the plot.
95
-
96
- posterior_uncertainty_color : str <default="black">
97
- Color for the posterior uncertainty part of the plot.
98
-
99
- delta_t_model:
100
- specifies plotting intervals for model cartoon elements of the graphs.
101
- """
102
-
103
- # AF-TODO: Add a mean version of this!
104
- if value_range is None:
105
- # Infer from data by finding the min and max from the nodes
106
- raise NotImplementedError("value_range keyword argument must be supplied.")
107
-
108
- if len(value_range) > 2:
109
- value_range = (value_range[0], value_range[-1])
110
-
111
- # Extract some parameters from kwargs
112
- bins = np.arange(value_range[0], value_range[-1], bin_size)
113
-
114
- # If bottom_node is a DataFrame we know that we are just plotting real data
115
- if type(bottom_node) == pd.DataFrame:
116
- samples_tmp = [bottom_node]
117
- data_tmp = None
118
- else:
119
- samples_tmp = _post_pred_generate(
120
- bottom_node,
121
- samples=samples,
122
- data=None,
123
- append_data=False,
124
- add_model_parameters=True,
125
- )
126
- data_tmp = bottom_node.value.copy()
127
-
128
- # Relevant for recovery mode
129
- node_data_full = kwargs.pop("node_data", None)
130
-
131
- tmp_model = kwargs.pop("model_", "angle")
132
- if len(model_config[tmp_model]["choices"]) > 2:
133
- raise ValueError("The model plot works only for 2 choice models at the moment")
134
-
135
- # ---------------------------
136
-
137
- ylim = kwargs.pop("ylim", 3)
138
- hist_bottom = kwargs.pop("hist_bottom", 2)
139
- hist_histtype = kwargs.pop("hist_histtype", "step")
140
-
141
- if ("ylim_high" in kwargs) and ("ylim_low" in kwargs):
142
- ylim_high = kwargs["ylim_high"]
143
- ylim_low = kwargs["ylim_low"]
144
- else:
145
- ylim_high = ylim
146
- ylim_low = -ylim
147
-
148
- if ("hist_bottom_high" in kwargs) and ("hist_bottom_low" in kwargs):
149
- hist_bottom_high = kwargs["hist_bottom_high"]
150
- hist_bottom_low = kwargs["hist_bottom_low"]
151
- else:
152
- hist_bottom_high = hist_bottom
153
- hist_bottom_low = hist_bottom
154
-
155
- axis.set_xlim(value_range[0], value_range[-1])
156
- axis.set_ylim(ylim_low, ylim_high)
157
- axis_twin_up = axis.twinx()
158
- axis_twin_down = axis.twinx()
159
- axis_twin_up.set_ylim(ylim_low, ylim_high)
160
- axis_twin_up.set_yticks([])
161
- axis_twin_down.set_ylim(ylim_high, ylim_low)
162
- axis_twin_down.set_yticks([])
163
- axis_twin_down.set_axis_off()
164
- axis_twin_up.set_axis_off()
165
-
166
- # ADD HISTOGRAMS
167
- # -------------------------------
168
- # POSTERIOR BASED HISTOGRAM
169
- if add_posterior_uncertainty_rts: # add_uc_rts:
170
- j = 0
171
- for sample in samples_tmp:
172
- tmp_label = None
173
-
174
- if add_legend and j == 0:
175
- tmp_label = "PostPred"
176
-
177
- weights_up = np.tile(
178
- (1 / bin_size) / sample.shape[0],
179
- reps=sample.loc[sample.response == 1, :].shape[0],
180
- )
181
- weights_down = np.tile(
182
- (1 / bin_size) / sample.shape[0],
183
- reps=sample.loc[(sample.response != 1), :].shape[0],
184
- )
185
-
186
- axis_twin_up.hist(
187
- np.abs(sample.rt[sample.response == 1]),
188
- bins=bins,
189
- weights=weights_up,
190
- histtype=hist_histtype,
191
- bottom=hist_bottom_high,
192
- alpha=alpha,
193
- color=posterior_uncertainty_color,
194
- edgecolor=posterior_uncertainty_color,
195
- zorder=-1,
196
- label=tmp_label,
197
- linewidth=linewidth_histogram,
198
- )
199
-
200
- axis_twin_down.hist(
201
- np.abs(sample.loc[(sample.response != 1), :].rt),
202
- bins=bins,
203
- weights=weights_down,
204
- histtype=hist_histtype,
205
- bottom=hist_bottom_low,
206
- alpha=alpha,
207
- color=posterior_uncertainty_color,
208
- edgecolor=posterior_uncertainty_color,
209
- linewidth=linewidth_histogram,
210
- zorder=-1,
211
- )
212
- j += 1
213
-
214
- if add_posterior_mean_rts: # add_mean_rts:
215
- concat_data = pd.concat(samples_tmp)
216
- tmp_label = None
217
-
218
- if add_legend:
219
- tmp_label = "PostPred Mean"
220
-
221
- weights_up = np.tile(
222
- (1 / bin_size) / concat_data.shape[0],
223
- reps=concat_data.loc[concat_data.response == 1, :].shape[0],
224
- )
225
- weights_down = np.tile(
226
- (1 / bin_size) / concat_data.shape[0],
227
- reps=concat_data.loc[(concat_data.response != 1), :].shape[0],
228
- )
229
-
230
- axis_twin_up.hist(
231
- np.abs(concat_data.rt[concat_data.response == 1]),
232
- bins=bins,
233
- weights=weights_up,
234
- histtype=hist_histtype,
235
- bottom=hist_bottom_high,
236
- alpha=1.0,
237
- color=posterior_mean_color,
238
- edgecolor=posterior_mean_color,
239
- zorder=-1,
240
- label=tmp_label,
241
- linewidth=linewidth_histogram,
242
- )
243
-
244
- axis_twin_down.hist(
245
- np.abs(concat_data.loc[(concat_data.response != 1), :].rt),
246
- bins=bins,
247
- weights=weights_down,
248
- histtype=hist_histtype,
249
- bottom=hist_bottom_low,
250
- alpha=1.0,
251
- color=posterior_mean_color,
252
- edgecolor=posterior_mean_color,
253
- linewidth=linewidth_histogram,
254
- zorder=-1,
255
- )
256
-
257
- # DATA HISTOGRAM
258
- if (data_tmp is not None) and add_data_rts:
259
- tmp_label = None
260
- if add_legend:
261
- tmp_label = data_label
262
-
263
- weights_up = np.tile(
264
- (1 / bin_size) / data_tmp.shape[0],
265
- reps=data_tmp[data_tmp.response == 1].shape[0],
266
- )
267
- weights_down = np.tile(
268
- (1 / bin_size) / data_tmp.shape[0],
269
- reps=data_tmp[(data_tmp.response != 1)].shape[0],
270
- )
271
-
272
- axis_twin_up.hist(
273
- np.abs(data_tmp[data_tmp.response == 1].rt),
274
- bins=bins,
275
- weights=weights_up,
276
- histtype=hist_histtype,
277
- bottom=hist_bottom_high,
278
- alpha=1,
279
- color=data_color,
280
- edgecolor=data_color,
281
- label=tmp_label,
282
- zorder=-1,
283
- linewidth=linewidth_histogram,
284
- )
285
-
286
- axis_twin_down.hist(
287
- np.abs(data_tmp[(data_tmp.response != 1)].rt),
288
- bins=bins,
289
- weights=weights_down,
290
- histtype=hist_histtype,
291
- bottom=hist_bottom_low,
292
- alpha=1,
293
- color=data_color,
294
- edgecolor=data_color,
295
- linewidth=linewidth_histogram,
296
- zorder=-1,
297
- )
298
-
299
- # SECONDARY DATA HISTOGRAM
300
- if secondary_data is not None:
301
- tmp_label = None
302
- if add_legend:
303
- if secondary_data_label is not None:
304
- tmp_label = secondary_data_label
305
-
306
- weights_up = np.tile(
307
- (1 / bin_size) / secondary_data.shape[0],
308
- reps=secondary_data[secondary_data.response == 1].shape[0],
309
- )
310
- weights_down = np.tile(
311
- (1 / bin_size) / secondary_data.shape[0],
312
- reps=secondary_data[(secondary_data.response != 1)].shape[0],
313
- )
314
-
315
- axis_twin_up.hist(
316
- np.abs(secondary_data[secondary_data.response == 1].rt),
317
- bins=bins,
318
- weights=weights_up,
319
- histtype=hist_histtype,
320
- bottom=hist_bottom_high,
321
- alpha=1,
322
- color=secondary_data_color,
323
- edgecolor=secondary_data_color,
324
- label=tmp_label,
325
- zorder=-100,
326
- linewidth=linewidth_histogram,
327
- )
328
-
329
- axis_twin_down.hist(
330
- np.abs(secondary_data[(secondary_data.response != 1)].rt),
331
- bins=bins,
332
- weights=weights_down,
333
- histtype=hist_histtype,
334
- bottom=hist_bottom_low,
335
- alpha=1,
336
- color=secondary_data_color,
337
- edgecolor=secondary_data_color,
338
- linewidth=linewidth_histogram,
339
- zorder=-100,
340
- )
341
- # -------------------------------
342
-
343
- if add_legend:
344
- if data_tmp is not None:
345
- axis_twin_up.legend(
346
- fontsize=legend_fontsize, shadow=legend_shadow, loc=legend_location
347
- )
348
-
349
- # ADD MODEL:
350
- j = 0
351
- t_s = np.arange(0, value_range[-1], delta_t_model)
352
-
353
- # MAKE BOUNDS (FROM MODEL CONFIG) !
354
- if add_posterior_uncertainty_model: # add_uc_model:
355
- for sample in samples_tmp:
356
- _add_model_cartoon_to_ax(
357
- sample=sample,
358
- axis=axis,
359
- tmp_model=tmp_model,
360
- keep_slope=add_data_model_keep_slope,
361
- keep_boundary=add_data_model_keep_boundary,
362
- keep_ndt=add_data_model_keep_ndt,
363
- keep_starting_point=add_data_model_keep_starting_point,
364
- markersize_starting_point=add_data_model_markersize_starting_point,
365
- markertype_starting_point=add_data_model_markertype_starting_point,
366
- markershift_starting_point=add_data_model_markershift_starting_point,
367
- delta_t_graph=delta_t_model,
368
- sample_hist_alpha=alpha,
369
- lw_m=linewidth_model,
370
- tmp_label=tmp_label,
371
- ylim_low=ylim_low,
372
- ylim_high=ylim_high,
373
- t_s=t_s,
374
- color=posterior_uncertainty_color,
375
- zorder_cnt=j,
376
- )
377
-
378
- if (node_data_full is not None) and add_data_model:
379
- _add_model_cartoon_to_ax(
380
- sample=node_data_full,
381
- axis=axis,
382
- tmp_model=tmp_model,
383
- keep_slope=add_data_model_keep_slope,
384
- keep_boundary=add_data_model_keep_boundary,
385
- keep_ndt=add_data_model_keep_ndt,
386
- keep_starting_point=add_data_model_keep_starting_point,
387
- markersize_starting_point=add_data_model_markersize_starting_point,
388
- markertype_starting_point=add_data_model_markertype_starting_point,
389
- markershift_starting_point=add_data_model_markershift_starting_point,
390
- delta_t_graph=delta_t_model,
391
- sample_hist_alpha=1.0,
392
- lw_m=linewidth_model + 0.5,
393
- tmp_label=None,
394
- ylim_low=ylim_low,
395
- ylim_high=ylim_high,
396
- t_s=t_s,
397
- color=data_color,
398
- zorder_cnt=j + 1,
399
- )
400
-
401
- if add_posterior_mean_model: # add_mean_model:
402
- tmp_label = None
403
- if add_legend:
404
- tmp_label = "PostPred Mean"
405
-
406
- _add_model_cartoon_to_ax(
407
- sample=pd.DataFrame(pd.concat(samples_tmp).mean().astype(np.float32)).T,
408
- axis=axis,
409
- tmp_model=tmp_model,
410
- keep_slope=add_data_model_keep_slope,
411
- keep_boundary=add_data_model_keep_boundary,
412
- keep_ndt=add_data_model_keep_ndt,
413
- keep_starting_point=add_data_model_keep_starting_point,
414
- markersize_starting_point=add_data_model_markersize_starting_point,
415
- markertype_starting_point=add_data_model_markertype_starting_point,
416
- markershift_starting_point=add_data_model_markershift_starting_point,
417
- delta_t_graph=delta_t_model,
418
- sample_hist_alpha=1.0,
419
- lw_m=linewidth_model + 0.5,
420
- tmp_label=None,
421
- ylim_low=ylim_low,
422
- ylim_high=ylim_high,
423
- t_s=t_s,
424
- color=posterior_mean_color,
425
- zorder_cnt=j + 2,
426
- )
427
-
428
- if add_trajectories:
429
- _add_trajectories(
430
- axis=axis,
431
- sample=samples_tmp[0],
432
- tmp_model=tmp_model,
433
- t_s=t_s,
434
- delta_t_graph=delta_t_model,
435
- **kwargs,
436
- )
437
-
438
-
439
- # AF-TODO: Add documentation for this function
440
- def _add_trajectories(
441
- axis=None,
442
- sample=None,
443
- t_s=None,
444
- delta_t_graph=0.01,
445
- tmp_model=None,
446
- n_trajectories=10,
447
- supplied_trajectory=None,
448
- maxid_supplied_trajectory=1, # useful for gifs
449
- highlight_trajectory_rt_choice=True,
450
- markersize_trajectory_rt_choice=50,
451
- markertype_trajectory_rt_choice="*",
452
- markercolor_trajectory_rt_choice="red",
453
- linewidth_trajectories=1,
454
- alpha_trajectories=0.5,
455
- color_trajectories="black",
456
- **kwargs,
457
- ):
458
- # Check markercolor type
459
- if type(markercolor_trajectory_rt_choice) == str:
460
- markercolor_trajectory_rt_choice_dict = {}
461
- for value_ in model_config[tmp_model]["choices"]:
462
- markercolor_trajectory_rt_choice_dict[
463
- value_
464
- ] = markercolor_trajectory_rt_choice
465
- elif type(markercolor_trajectory_rt_choice) == list:
466
- cnt = 0
467
- for value_ in model_config[tmp_model]["choices"]:
468
- markercolor_trajectory_rt_choice_dict[
469
- value_
470
- ] = markercolor_trajectory_rt_choice[cnt]
471
- cnt += 1
472
- elif type(markercolor_trajectory_rt_choice) == dict:
473
- markercolor_trajectory_rt_choice_dict = markercolor_trajectory_rt_choice
474
- else:
475
- pass
476
-
477
- # Check trajectory color type
478
- if type(color_trajectories) == str:
479
- color_trajectories_dict = {}
480
- for value_ in model_config[tmp_model]["choices"]:
481
- color_trajectories_dict[value_] = color_trajectories
482
- elif type(color_trajectories) == list:
483
- cnt = 0
484
- for value_ in model_config[tmp_model]["choices"]:
485
- color_trajectories_dict[value_] = color_trajectories[cnt]
486
- cnt += 1
487
- elif type(color_trajectories) == dict:
488
- color_trajectories_dict = color_trajectories
489
- else:
490
- pass
491
-
492
- # Make bounds
493
- (b_low, b_high) = _make_bounds(
494
- tmp_model=tmp_model,
495
- sample=sample,
496
- delta_t_graph=delta_t_graph,
497
- t_s=t_s,
498
- return_shifted_by_ndt=False,
499
- )
500
-
501
- # Trajectories
502
- if supplied_trajectory is None:
503
- for i in range(n_trajectories):
504
- rand_int = np.random.choice(400000000)
505
- out_traj = simulator(
506
- theta=sample[model_config[tmp_model]["params"]].values[0],
507
- model=tmp_model,
508
- n_samples=1,
509
- no_noise=False,
510
- delta_t=delta_t_graph,
511
- bin_dim=None,
512
- random_state=rand_int,
513
- )
514
-
515
- tmp_traj = out_traj[2]["trajectory"]
516
- tmp_traj_choice = float(out_traj[1].flatten())
517
- maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0])
518
-
519
- # Identify boundary value at timepoint of crossing
520
- b_tmp = b_high[maxid] if tmp_traj_choice > 0 else b_low[maxid]
521
-
522
- axis.plot(
523
- t_s[:maxid] + sample.t.values[0],
524
- tmp_traj[:maxid],
525
- color=color_trajectories_dict[tmp_traj_choice],
526
- alpha=alpha_trajectories,
527
- linewidth=linewidth_trajectories,
528
- zorder=2000 + i,
529
- )
530
-
531
- if highlight_trajectory_rt_choice:
532
- axis.scatter(
533
- t_s[maxid] + sample.t.values[0],
534
- b_tmp,
535
- # tmp_traj[maxid],
536
- markersize_trajectory_rt_choice,
537
- color=markercolor_trajectory_rt_choice_dict[tmp_traj_choice],
538
- alpha=1,
539
- marker=markertype_trajectory_rt_choice,
540
- zorder=2000 + i,
541
- )
542
-
543
- else:
544
- if len(supplied_trajectory["trajectories"].shape) == 1:
545
- supplied_trajectory["trajectories"] = np.expand_dims(
546
- supplied_trajectory["trajectories"], axis=0
547
- )
548
-
549
- for j in range(supplied_trajectory["trajectories"].shape[0]):
550
- maxid = np.minimum(
551
- np.argmax(np.where(supplied_trajectory["trajectories"][j, :] > -999)),
552
- t_s.shape[0],
553
- )
554
- if j == (supplied_trajectory["trajectories"].shape[0] - 1):
555
- maxid_traj = min(maxid, maxid_supplied_trajectory)
556
- else:
557
- maxid_traj = maxid
558
-
559
- axis.plot(
560
- t_s[:maxid_traj] + sample.t.values[0],
561
- supplied_trajectory["trajectories"][j, :maxid_traj],
562
- color=color_trajectories_dict[
563
- supplied_trajectory["trajectory_choices"][j]
564
- ], # color_trajectories,
565
- alpha=alpha_trajectories,
566
- linewidth=linewidth_trajectories,
567
- zorder=2000 + j,
568
- )
569
-
570
- # Identify boundary value at timepoint of crossing
571
- b_tmp = (
572
- b_high[maxid_traj]
573
- if supplied_trajectory["trajectory_choices"][j] > 0
574
- else b_low[maxid_traj]
575
- )
576
-
577
- if maxid_traj == maxid:
578
- if highlight_trajectory_rt_choice:
579
- axis.scatter(
580
- t_s[maxid_traj] + sample.t.values[0],
581
- b_tmp,
582
- # supplied_trajectory['trajectories'][j, maxid_traj],
583
- markersize_trajectory_rt_choice,
584
- color=markercolor_trajectory_rt_choice_dict[
585
- supplied_trajectory["trajectory_choices"][j]
586
- ], # markercolor_trajectory_rt_choice,
587
- alpha=1,
588
- marker=markertype_trajectory_rt_choice,
589
- zorder=2000 + j,
590
- )
591
-
592
-
593
- # AF-TODO: Add documentation to this function
594
- def _add_model_cartoon_to_ax(
595
- sample=None,
596
- axis=None,
597
- tmp_model=None,
598
- keep_slope=True,
599
- keep_boundary=True,
600
- keep_ndt=True,
601
- keep_starting_point=True,
602
- markersize_starting_point=80,
603
- markertype_starting_point=1,
604
- markershift_starting_point=-0.05,
605
- delta_t_graph=None,
606
- sample_hist_alpha=None,
607
- lw_m=None,
608
- tmp_label=None,
609
- ylim_low=None,
610
- ylim_high=None,
611
- t_s=None,
612
- zorder_cnt=1,
613
- color="black",
614
- ):
615
- # Make bounds
616
- b_low, b_high = _make_bounds(
617
- tmp_model=tmp_model,
618
- sample=sample,
619
- delta_t_graph=delta_t_graph,
620
- t_s=t_s,
621
- return_shifted_by_ndt=True,
622
- )
623
-
624
- # MAKE SLOPES (VIA TRAJECTORIES HERE --> RUN NOISE FREE SIMULATIONS)!
625
- out = simulator(
626
- theta=sample[model_config[tmp_model]["params"]].values[0],
627
- model=tmp_model,
628
- n_samples=1,
629
- no_noise=True,
630
- delta_t=delta_t_graph,
631
- bin_dim=None,
632
- )
633
-
634
- tmp_traj = out[2]["trajectory"]
635
- maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0])
636
-
637
- if "hddm_base" in tmp_model:
638
- a_tmp = sample.a.values[0] / 2
639
- tmp_traj = tmp_traj - a_tmp
640
-
641
- if keep_boundary:
642
- # Upper bound
643
- axis.plot(
644
- t_s, # + sample.t.values[0],
645
- b_high,
646
- color=color,
647
- alpha=sample_hist_alpha,
648
- zorder=1000 + zorder_cnt,
649
- linewidth=lw_m,
650
- label=tmp_label,
651
- )
652
-
653
- # Lower bound
654
- axis.plot(
655
- t_s, # + sample.t.values[0],
656
- b_low,
657
- color=color,
658
- alpha=sample_hist_alpha,
659
- zorder=1000 + zorder_cnt,
660
- linewidth=lw_m,
661
- )
662
-
663
- # Slope
664
- if keep_slope:
665
- axis.plot(
666
- t_s[:maxid] + sample.t.values[0],
667
- tmp_traj[:maxid],
668
- color=color,
669
- alpha=sample_hist_alpha,
670
- zorder=1000 + zorder_cnt,
671
- linewidth=lw_m,
672
- ) # TOOK AWAY LABEL
673
-
674
- # Non-decision time
675
- if keep_ndt:
676
- axis.axvline(
677
- x=sample.t.values[0],
678
- ymin=ylim_low,
679
- ymax=ylim_high,
680
- color=color,
681
- linestyle="--",
682
- linewidth=lw_m,
683
- zorder=1000 + zorder_cnt,
684
- alpha=sample_hist_alpha,
685
- )
686
- # Starting point
687
- if keep_starting_point:
688
- axis.scatter(
689
- sample.t.values[0] + markershift_starting_point,
690
- b_low[0] + (sample.z.values[0] * (b_high[0] - b_low[0])),
691
- markersize_starting_point,
692
- marker=markertype_starting_point,
693
- color=color,
694
- alpha=sample_hist_alpha,
695
- zorder=1000 + zorder_cnt,
696
- )
697
-
698
-
699
- def _make_bounds(
700
- tmp_model=None,
701
- sample=None,
702
- delta_t_graph=None,
703
- t_s=None,
704
- return_shifted_by_ndt=True,
705
- ):
706
- # MULTIPLICATIVE BOUND
707
- if tmp_model == "weibull" or tmp_model == "weibull_cdf":
708
- b = np.maximum(
709
- sample.a.values[0]
710
- * model_config[tmp_model]["boundary"](
711
- t=t_s, alpha=sample.alpha.values[0], beta=sample.beta.values[0]
712
- ),
713
- 0,
714
- )
715
-
716
- # Move boundary forward by the non-decision time
717
- b_raw_high = deepcopy(b)
718
- b_raw_low = deepcopy(-b)
719
- b_init_val = b[0]
720
- t_shift = np.arange(0, sample.t.values[0], delta_t_graph).shape[0]
721
- b = np.roll(b, t_shift)
722
- b[:t_shift] = b_init_val
723
-
724
- # ADDITIVE BOUND
725
- elif tmp_model == "angle":
726
- b = np.maximum(
727
- sample.a.values[0]
728
- + model_config[tmp_model]["boundary"](t=t_s, theta=sample.theta.values[0]),
729
- 0,
730
- )
731
-
732
- b_raw_high = deepcopy(b)
733
- b_raw_low = deepcopy(-b)
734
- # Move boundary forward by the non-decision time
735
- b_init_val = b[0]
736
- t_shift = np.arange(0, sample.t.values[0], delta_t_graph).shape[0]
737
- b = np.roll(b, t_shift)
738
- b[:t_shift] = b_init_val
739
-
740
- # CONSTANT BOUND
741
- elif (
742
- tmp_model == "ddm"
743
- or tmp_model == "ornstein"
744
- or tmp_model == "levy"
745
- or tmp_model == "full_ddm"
746
- or tmp_model == "ddm_hddm_base"
747
- or tmp_model == "full_ddm_hddm_base"
748
- ):
749
- b = sample.a.values[0] * np.ones(t_s.shape[0])
750
-
751
- if "hddm_base" in tmp_model:
752
- b = (sample.a.values[0] / 2) * np.ones(t_s.shape[0])
753
-
754
- b_raw_high = b
755
- b_raw_low = -b
756
-
757
- # Separate out upper and lower bound:
758
- b_high = b
759
- b_low = -b
760
-
761
- if return_shifted_by_ndt:
762
- return (b_low, b_high)
763
- else:
764
- return (b_raw_low, b_raw_high)
765
-
766
-
767
- def _plot_func_model_n(
768
- bottom_node,
769
- axis,
770
- value_range=None,
771
- samples=10,
772
- bin_size=0.05,
773
- add_posterior_uncertainty_model=False,
774
- add_posterior_uncertainty_rts=False,
775
- add_posterior_mean_model=True,
776
- add_posterior_mean_rts=True,
777
- linewidth_histogram=0.5,
778
- linewidth_model=0.5,
779
- legend_fontsize=7,
780
- legend_shadow=True,
781
- legend_location="upper right",
782
- delta_t_model=0.01,
783
- add_legend=True,
784
- alpha=0.01,
785
- keep_frame=False,
786
- **kwargs,
787
- ):
788
- """Calculate posterior predictive for a certain bottom node.
789
-
790
- Arguments:
791
- bottom_node: pymc.stochastic
792
- Bottom node to compute posterior over.
793
-
794
- axis: matplotlib.axis
795
- Axis to plot into.
796
-
797
- value_range: numpy.ndarray
798
- Range over which to evaluate the likelihood.
799
-
800
- Optional:
801
- samples: int <default=10>
802
- Number of posterior samples to use.
803
-
804
- bin_size: float <default=0.05>
805
- Size of bins used for histograms
806
-
807
- alpha: float <default=0.05>
808
- alpha (transparency) level for the sample-wise elements of the plot
809
-
810
- add_posterior_uncertainty_rts: bool <default=True>
811
- Add sample by sample histograms?
812
-
813
- add_posterior_mean_rts: bool <default=True>
814
- Add a mean posterior?
815
-
816
- add_model: bool <default=True>
817
- Whether to add model cartoons to the plot.
818
-
819
- linewidth_histogram: float <default=0.5>
820
- linewdith of histrogram plot elements.
821
-
822
- linewidth_model: float <default=0.5>
823
- linewidth of plot elements concerning the model cartoons.
824
-
825
- legend_loc: str <default='upper right'>
826
- string defining legend position. Find the rest of the options in the matplotlib documentation.
827
-
828
- legend_shadow: bool <default=True>
829
- Add shadow to legend box?
830
-
831
- legend_fontsize: float <default=12>
832
- Fontsize of legend.
833
-
834
- data_color : str <default="blue">
835
- Color for the data part of the plot.
836
-
837
- posterior_mean_color : str <default="red">
838
- Color for the posterior mean part of the plot.
839
-
840
- posterior_uncertainty_color : str <default="black">
841
- Color for the posterior uncertainty part of the plot.
842
-
843
-
844
- delta_t_model:
845
- specifies plotting intervals for model cartoon elements of the graphs.
846
- """
847
-
848
- color_dict = {
849
- -1: "black",
850
- 0: "black",
851
- 1: "green",
852
- 2: "blue",
853
- 3: "red",
854
- 4: "orange",
855
- 5: "purple",
856
- 6: "brown",
857
- }
858
-
859
- # AF-TODO: Add a mean version of this !
860
- if value_range is None:
861
- # Infer from data by finding the min and max from the nodes
862
- raise NotImplementedError("value_range keyword argument must be supplied.")
863
-
864
- if len(value_range) > 2:
865
- value_range = (value_range[0], value_range[-1])
866
-
867
- # Extract some parameters from kwargs
868
- bins = np.arange(value_range[0], value_range[-1], bin_size)
869
-
870
- # Relevant for recovery mode
871
- node_data_full = kwargs.pop("node_data", None)
872
- tmp_model = kwargs.pop("model_", "angle")
873
-
874
- bottom = 0
875
- # ------------
876
- ylim = kwargs.pop("ylim", 3)
877
-
878
- choices = model_config[tmp_model]["choices"]
879
-
880
- # If bottom_node is a DataFrame we know that we are just plotting real data
881
- if type(bottom_node) == pd.DataFrame:
882
- samples_tmp = [bottom_node]
883
- data_tmp = None
884
- else:
885
- samples_tmp = _post_pred_generate(
886
- bottom_node,
887
- samples=samples,
888
- data=None,
889
- append_data=False,
890
- add_model_parameters=True,
891
- )
892
- data_tmp = bottom_node.value.copy()
893
-
894
- axis.set_xlim(value_range[0], value_range[-1])
895
- axis.set_ylim(0, ylim)
896
-
897
- # ADD MODEL:
898
- j = 0
899
- t_s = np.arange(0, value_range[-1], delta_t_model)
900
-
901
- # # MAKE BOUNDS (FROM MODEL CONFIG) !
902
- if add_posterior_uncertainty_model: # add_uc_model:
903
- for sample in samples_tmp:
904
- tmp_label = None
905
-
906
- if add_legend and (j == 0):
907
- tmp_label = "PostPred"
908
-
909
- _add_model_n_cartoon_to_ax(
910
- sample=sample,
911
- axis=axis,
912
- tmp_model=tmp_model,
913
- delta_t_graph=delta_t_model,
914
- sample_hist_alpha=alpha,
915
- lw_m=linewidth_model,
916
- tmp_label=tmp_label,
917
- linestyle="-",
918
- ylim=ylim,
919
- t_s=t_s,
920
- color_dict=color_dict,
921
- zorder_cnt=j,
922
- )
923
-
924
- j += 1
925
-
926
- if add_posterior_mean_model: # add_mean_model:
927
- tmp_label = None
928
- if add_legend:
929
- tmp_label = "PostPred Mean"
930
-
931
- bottom = _add_model_n_cartoon_to_ax(
932
- sample=pd.DataFrame(pd.concat(samples_tmp).mean().astype(np.float32)).T,
933
- axis=axis,
934
- tmp_model=tmp_model,
935
- delta_t_graph=delta_t_model,
936
- sample_hist_alpha=1.0,
937
- lw_m=linewidth_model + 0.5,
938
- linestyle="-",
939
- tmp_label=None,
940
- ylim=ylim,
941
- t_s=t_s,
942
- color_dict=color_dict,
943
- zorder_cnt=j + 2,
944
- )
945
-
946
- if node_data_full is not None:
947
- _add_model_n_cartoon_to_ax(
948
- sample=node_data_full,
949
- axis=axis,
950
- tmp_model=tmp_model,
951
- delta_t_graph=delta_t_model,
952
- sample_hist_alpha=1.0,
953
- lw_m=linewidth_model + 0.5,
954
- linestyle="dashed",
955
- tmp_label=None,
956
- ylim=ylim,
957
- t_s=t_s,
958
- color_dict=color_dict,
959
- zorder_cnt=j + 1,
960
- )
961
-
962
- # ADD HISTOGRAMS
963
- # -------------------------------
964
-
965
- # POSTERIOR BASED HISTOGRAM
966
- if add_posterior_uncertainty_rts: # add_uc_rts:
967
- j = 0
968
- for sample in samples_tmp:
969
- for choice in choices:
970
- tmp_label = None
971
-
972
- if add_legend and j == 0:
973
- tmp_label = "PostPred"
974
-
975
- weights = np.tile(
976
- (1 / bin_size) / sample.shape[0],
977
- reps=sample.loc[sample.response == choice, :].shape[0],
978
- )
979
-
980
- axis.hist(
981
- np.abs(sample.rt[sample.response == choice]),
982
- bins=bins,
983
- bottom=bottom,
984
- weights=weights,
985
- histtype="step",
986
- alpha=alpha,
987
- color=color_dict[choice],
988
- zorder=-1,
989
- label=tmp_label,
990
- linewidth=linewidth_histogram,
991
- )
992
- j += 1
993
-
994
- if add_posterior_mean_rts:
995
- concat_data = pd.concat(samples_tmp)
996
- for choice in choices:
997
- tmp_label = None
998
- if add_legend and (choice == choices[0]):
999
- tmp_label = "PostPred Mean"
1000
-
1001
- weights = np.tile(
1002
- (1 / bin_size) / concat_data.shape[0],
1003
- reps=concat_data.loc[concat_data.response == choice, :].shape[0],
1004
- )
1005
-
1006
- axis.hist(
1007
- np.abs(concat_data.rt[concat_data.response == choice]),
1008
- bins=bins,
1009
- bottom=bottom,
1010
- weights=weights,
1011
- histtype="step",
1012
- alpha=1.0,
1013
- color=color_dict[choice],
1014
- zorder=-1,
1015
- label=tmp_label,
1016
- linewidth=linewidth_histogram,
1017
- )
1018
-
1019
- # DATA HISTOGRAM
1020
- if data_tmp is not None:
1021
- for choice in choices:
1022
- tmp_label = None
1023
- if add_legend and (choice == choices[0]):
1024
- tmp_label = "Data"
1025
-
1026
- weights = np.tile(
1027
- (1 / bin_size) / data_tmp.shape[0],
1028
- reps=data_tmp.loc[data_tmp.response == choice, :].shape[0],
1029
- )
1030
-
1031
- axis.hist(
1032
- np.abs(data_tmp.rt[data_tmp.response == choice]),
1033
- bins=bins,
1034
- bottom=bottom,
1035
- weights=weights,
1036
- histtype="step",
1037
- linestyle="dashed",
1038
- alpha=1.0,
1039
- color=color_dict[choice],
1040
- edgecolor=color_dict[choice],
1041
- zorder=-1,
1042
- label=tmp_label,
1043
- linewidth=linewidth_histogram,
1044
- )
1045
- # -------------------------------
1046
-
1047
- if add_legend:
1048
- if data_tmp is not None:
1049
- custom_elems = [
1050
- Line2D([0], [0], color=color_dict[choice], lw=1) for choice in choices
1051
- ]
1052
- custom_titles = ["response: " + str(choice) for choice in choices]
1053
-
1054
- custom_elems.append(
1055
- Line2D([0], [0], color="black", lw=1.0, linestyle="dashed")
1056
- )
1057
- custom_elems.append(Line2D([0], [0], color="black", lw=1.0, linestyle="-"))
1058
- custom_titles.append("Data")
1059
- custom_titles.append("Posterior")
1060
-
1061
- axis.legend(
1062
- custom_elems,
1063
- custom_titles,
1064
- fontsize=legend_fontsize,
1065
- shadow=legend_shadow,
1066
- loc=legend_location,
1067
- )
1068
-
1069
- # FRAME
1070
- if not keep_frame:
1071
- axis.set_frame_on(False)
1072
-
1073
-
1074
- def _add_model_n_cartoon_to_ax(
1075
- sample=None,
1076
- axis=None,
1077
- tmp_model=None,
1078
- delta_t_graph=None,
1079
- sample_hist_alpha=None,
1080
- lw_m=None,
1081
- linestyle="-",
1082
- tmp_label=None,
1083
- ylim=None,
1084
- t_s=None,
1085
- zorder_cnt=1,
1086
- color_dict=None,
1087
- ):
1088
- if "weibull" in tmp_model:
1089
- b = np.maximum(
1090
- sample.a.values[0]
1091
- * model_config[tmp_model]["boundary"](
1092
- t=t_s, alpha=sample.alpha.values[0], beta=sample.beta.values[0]
1093
- ),
1094
- 0,
1095
- )
1096
-
1097
- elif "angle" in tmp_model:
1098
- b = np.maximum(
1099
- sample.a.values[0]
1100
- + model_config[tmp_model]["boundary"](t=t_s, theta=sample.theta.values[0]),
1101
- 0,
1102
- )
1103
-
1104
- else:
1105
- b = sample.a.values[0] * np.ones(t_s.shape[0])
1106
-
1107
- # Upper bound
1108
- axis.plot(
1109
- t_s + sample.t.values[0],
1110
- b,
1111
- color="black",
1112
- alpha=sample_hist_alpha,
1113
- zorder=1000 + zorder_cnt,
1114
- linewidth=lw_m,
1115
- linestyle=linestyle,
1116
- label=tmp_label,
1117
- )
1118
-
1119
- # Starting point
1120
- axis.axvline(
1121
- x=sample.t.values[0],
1122
- ymin=-ylim,
1123
- ymax=ylim,
1124
- color="black",
1125
- linestyle=linestyle,
1126
- linewidth=lw_m,
1127
- alpha=sample_hist_alpha,
1128
- )
1129
-
1130
- # # MAKE SLOPES (VIA TRAJECTORIES HERE --> RUN NOISE FREE SIMULATIONS)!
1131
- out = simulator(
1132
- theta=sample[model_config[tmp_model]["params"]].values[0],
1133
- model=tmp_model,
1134
- n_samples=1,
1135
- no_noise=True,
1136
- delta_t=delta_t_graph,
1137
- bin_dim=None,
1138
- )
1139
-
1140
- # # AF-TODO: Add trajectories
1141
- tmp_traj = out[2]["trajectory"]
1142
-
1143
- for i in range(len(model_config[tmp_model]["choices"])):
1144
- tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, i] > -999)), t_s.shape[0])
1145
-
1146
- # Slope
1147
- axis.plot(
1148
- t_s[:tmp_maxid] + sample.t.values[0],
1149
- tmp_traj[:tmp_maxid, i],
1150
- color=color_dict[i],
1151
- linestyle=linestyle,
1152
- alpha=sample_hist_alpha,
1153
- zorder=1000 + zorder_cnt,
1154
- linewidth=lw_m,
1155
- ) # TOOK AWAY LABEL
1156
-
1157
- return b[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/utils.py CHANGED
@@ -187,7 +187,7 @@ def plot_func_model(
187
  weights=weights_up,
188
  histtype=hist_histtype,
189
  bottom=hist_bottom_high,
190
- alpha=1,
191
  color=data_color,
192
  edgecolor=data_color,
193
  linewidth=linewidth_histogram,
@@ -200,7 +200,7 @@ def plot_func_model(
200
  weights=weights_down,
201
  histtype=hist_histtype,
202
  bottom=hist_bottom_low,
203
- alpha=1,
204
  color=data_color,
205
  edgecolor=data_color,
206
  linewidth=linewidth_histogram,
@@ -308,7 +308,6 @@ def _add_trajectories(
308
  b_low = np.roll(b_low, n_roll)
309
  b_low[:n_roll] = b_l_init
310
 
311
- print(n_trajectories)
312
  # Trajectories
313
  for i in range(n_trajectories):
314
  tmp_traj = sample[i]['metadata']['trajectory']
@@ -426,7 +425,7 @@ def _add_model_cartoon_to_ax(
426
  axis.scatter(
427
  sample['metadata']['t'][0] + markershift_starting_point,
428
  b_low[0] + (sample['metadata']['z'][0] * (b_high[0] - b_low[0])),
429
- markersize_starting_point,
430
  marker=markertype_starting_point,
431
  color=color,
432
  alpha=1,
@@ -448,7 +447,7 @@ def plot_func_model_n(
448
  legend_location="upper right",
449
  delta_t_model=0.001,
450
  add_legend=True,
451
- alpha=1.0,
452
  keep_frame=False,
453
  random_state=None,
454
  **kwargs,
@@ -472,7 +471,7 @@ def plot_func_model_n(
472
  bin_size: float <default=0.05>
473
  Size of bins used for histograms
474
 
475
- alpha: float <default=0.05>
476
  alpha (transparency) level for the sample-wise elements of the plot
477
 
478
  add_posterior_uncertainty_rts: bool <default=True>
@@ -560,7 +559,7 @@ def plot_func_model_n(
560
  for i in range(n_trajectories):
561
  rand_int = np.random.choice(400000000)
562
  sim_out_traj[i] = simulator(model = model_name, theta = theta, n_samples = 1,
563
- no_noise = False, delta_t = 0.001,
564
  bin_dim = None, random_state = rand_int, smooth_unif = False)
565
 
566
  sim_out_no_noise = simulator(model = model_name, theta = theta, n_samples = 1,
 
187
  weights=weights_up,
188
  histtype=hist_histtype,
189
  bottom=hist_bottom_high,
190
+ alpha=alpha,
191
  color=data_color,
192
  edgecolor=data_color,
193
  linewidth=linewidth_histogram,
 
200
  weights=weights_down,
201
  histtype=hist_histtype,
202
  bottom=hist_bottom_low,
203
+ alpha=alpha,
204
  color=data_color,
205
  edgecolor=data_color,
206
  linewidth=linewidth_histogram,
 
308
  b_low = np.roll(b_low, n_roll)
309
  b_low[:n_roll] = b_l_init
310
 
 
311
  # Trajectories
312
  for i in range(n_trajectories):
313
  tmp_traj = sample[i]['metadata']['trajectory']
 
425
  axis.scatter(
426
  sample['metadata']['t'][0] + markershift_starting_point,
427
  b_low[0] + (sample['metadata']['z'][0] * (b_high[0] - b_low[0])),
428
+ s=markersize_starting_point,
429
  marker=markertype_starting_point,
430
  color=color,
431
  alpha=1,
 
447
  legend_location="upper right",
448
  delta_t_model=0.001,
449
  add_legend=True,
450
+ alpha=1,
451
  keep_frame=False,
452
  random_state=None,
453
  **kwargs,
 
471
  bin_size: float <default=0.05>
472
  Size of bins used for histograms
473
 
474
+ alpha: float <default=1.0>
475
  alpha (transparency) level for the sample-wise elements of the plot
476
 
477
  add_posterior_uncertainty_rts: bool <default=True>
 
559
  for i in range(n_trajectories):
560
  rand_int = np.random.choice(400000000)
561
  sim_out_traj[i] = simulator(model = model_name, theta = theta, n_samples = 1,
562
+ no_noise = False, delta_t = 0.001,
563
  bin_dim = None, random_state = rand_int, smooth_unif = False)
564
 
565
  sim_out_no_noise = simulator(model = model_name, theta = theta, n_samples = 1,