Alexander commited on
Commit
36bf2f6
·
1 Parent(s): 5e1b1b7

initial test

Browse files
Dockerfile CHANGED
@@ -17,4 +17,4 @@ EXPOSE 8501
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
17
 
18
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
19
 
20
+ ENTRYPOINT ["streamlit", "run", "src/app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  altair
2
  pandas
3
- streamlit
 
 
 
 
1
  altair
2
  pandas
3
+ ssm-simulators
4
+ streamlit>1.30.0
5
+ matplotlib
6
+ seaborn
src/app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ import seaborn as sns
4
+ import streamlit as st
5
+
6
+ # from hssm import simulate_data
7
+ from ssms.config import model_config
8
+ from ssms.basic_simulators.simulator import simulator
9
+ import pandas as pd
10
+ import utils
11
+
12
+
13
+ # Function to create input select widgets
14
+ def create_param_selectors(model_name: str, model_num: int = 1):
15
+ d_config = model_config[model_name]
16
+ params = d_config["params"]
17
+ param_bounds_low = d_config["param_bounds"][0]
18
+ param_bounds_high = d_config["param_bounds"][1]
19
+ param_defaults = d_config["default_params"]
20
+
21
+ d_param_slider = {}
22
+ for i, (name, low, high, default) in enumerate(
23
+ zip(
24
+ params,
25
+ param_bounds_low,
26
+ param_bounds_high,
27
+ param_defaults,
28
+ )
29
+ ):
30
+ d_param_slider[i] = st.slider(
31
+ label=name,
32
+ min_value=float(low),
33
+ max_value=float(high),
34
+ value=float(default),
35
+ key=f"param{i}"
36
+ f"_{model_name}"
37
+ f"_{model_num}"
38
+ f'_{st.session_state["slider_version"]}',
39
+ )
40
+ return d_param_slider
41
+
42
+
43
+ def add_model():
44
+ pass
45
+
46
+
47
+ def reset_sliders():
48
+ st.session_state["slider_version"] += 1
49
+
50
+
51
+ # Initialize a slider version attribute state. Is used for resetting values
52
+ if "slider_version" not in st.session_state:
53
+ st.session_state["slider_version"] = 1
54
+
55
+
56
+ def draw_model_configurator(model_num=1):
57
+ # Create widgets for the sidebar
58
+ # st.markdown("<h2 style='text-align: center; color: black;'>Model Configurator</h1>",
59
+ # unsafe_allow_html=True)
60
+ # Dropdown selection of model name
61
+ model_select = st.selectbox(
62
+ "Model " + str(model_num), l_model_names, key="modelname" + str(model_num)
63
+ )
64
+ # Sliders for param values
65
+ d_slider = create_param_selectors(model_select, model_num=model_num)
66
+ # Number of data points to simulate
67
+ nsamples = st.number_input("NSamples", value=5000, key="size" + str(model_num))
68
+ # Number of trajectories to show
69
+ ntrajectories = st.number_input(
70
+ "NTrajectories", value=0, key="ntraj" + str(model_num)
71
+ )
72
+ return model_select, d_slider, nsamples, ntrajectories
73
+
74
+
75
+ st.set_page_config(layout="wide")
76
+
77
+ # Get list of model names
78
+ l_model_names = list(model_config.keys())
79
+
80
+ with st.sidebar:
81
+ col1, col2 = st.columns(2)
82
+ with col1:
83
+ model_select_1, d_slider_1, nsamples_1, ntrajectories_1 = (
84
+ draw_model_configurator(model_num=1)
85
+ )
86
+ with col2:
87
+ model_select_2, d_slider_2, nsamples_2, ntrajectories_2 = (
88
+ draw_model_configurator(model_num=2)
89
+ )
90
+
91
+ # Button to reset sliders to default values
92
+ randomseed = st.number_input("RandomSeed", value=0, key="seed")
93
+ st.button(
94
+ "Reset",
95
+ help="Reset parameters to defaults",
96
+ key="reset",
97
+ on_click=reset_sliders,
98
+ )
99
+
100
+ # st.title("SSM Model Plots", )
101
+ st.markdown(
102
+ "<h1 style='text-align: center; color: black;'>SSM Model Plots</h1>",
103
+ unsafe_allow_html=True,
104
+ )
105
+
106
+ # Display components for main panel
107
+ fig1, ax1 = plt.subplots()
108
+ if model_config[model_select_1]["nchoices"] == 2 and not ("race" in model_select_1):
109
+ ax1 = utils.utils.plot_func_model(
110
+ model_name=model_select_1,
111
+ theta=[list(d_slider_1.values())],
112
+ axis=ax1,
113
+ value_range=[-0.1, 5],
114
+ n_samples=nsamples_1,
115
+ ylim=5,
116
+ data_color="blue",
117
+ add_trajectories=True,
118
+ n_trajectories=ntrajectories_1,
119
+ linewidth_model=1,
120
+ linewidth_histogram=1,
121
+ random_state=randomseed,
122
+ )
123
+ else:
124
+ ax1 = utils.utils.plot_func_model_n(
125
+ model_name=model_select_1,
126
+ theta=[list(d_slider_1.values())],
127
+ axis=ax1,
128
+ value_range=[-0.1, 5],
129
+ n_samples=nsamples_1,
130
+ data_color="blue",
131
+ add_trajectories=True,
132
+ n_trajectories=ntrajectories_1,
133
+ linewidth_model=1,
134
+ linewidth_histogram=1,
135
+ random_state=randomseed,
136
+ )
137
+ ax1.set_title("Model 1")
138
+ ax1.set_xlabel("RT in seconds")
139
+
140
+ fig2, ax2 = plt.subplots()
141
+ if model_config[model_select_2]["nchoices"] == 2 and not ("race" in model_select_2):
142
+ ax2 = utils.utils.plot_func_model(
143
+ model_name=model_select_2,
144
+ theta=[list(d_slider_2.values())],
145
+ axis=ax2,
146
+ value_range=[-0.1, 5],
147
+ n_samples=nsamples_2,
148
+ ylim=5,
149
+ data_color="red",
150
+ add_trajectories=True,
151
+ n_trajectories=ntrajectories_2,
152
+ linewidth_model=1,
153
+ linewidth_histogram=1,
154
+ random_state=randomseed,
155
+ )
156
+ else:
157
+ ax2 = utils.utils.plot_func_model_n(
158
+ model_name=model_select_2,
159
+ theta=[list(d_slider_2.values())],
160
+ axis=ax2,
161
+ value_range=[-0.1, 5],
162
+ n_samples=nsamples_2,
163
+ data_color="red",
164
+ add_trajectories=True,
165
+ n_trajectories=ntrajectories_2,
166
+ linewidth_model=1,
167
+ linewidth_histogram=1,
168
+ random_state=randomseed,
169
+ )
170
+
171
+ ax2.set_title("Model 2")
172
+ ax2.set_xlabel("RT in seconds")
173
+
174
+ # Place figure in placeholder
175
+ col1, col2 = st.columns(2)
176
+ with col1:
177
+ figure_placeholder_1 = st.empty() # Placeholder for figure render
178
+ figure_placeholder_1.pyplot(fig1)
179
+ with col2:
180
+ figure_placeholder_2 = st.empty() # Placeholder for figure render
181
+ figure_placeholder_2.pyplot(fig2)
182
+
183
+ # Simulate two datasets:
184
+ sim_output_1 = simulator(
185
+ model=model_select_1,
186
+ theta=[list(d_slider_1.values())],
187
+ n_samples=nsamples_1,
188
+ random_state=randomseed,
189
+ )
190
+ sim_output_2 = simulator(
191
+ model=model_select_2,
192
+ theta=[list(d_slider_2.values())],
193
+ n_samples=nsamples_2,
194
+ random_state=randomseed,
195
+ )
196
+
197
+ # Make metadata dataframe
198
+ metadata = pd.DataFrame(
199
+ {
200
+ "Model 1": [
201
+ sim_output_1["metadata"]["model"],
202
+ sim_output_1["choice_p"][0, 0],
203
+ sim_output_1["rts"].mean(),
204
+ sim_output_1["metadata"]["s"],
205
+ ],
206
+ "Model 2": [
207
+ sim_output_2["metadata"]["model"],
208
+ sim_output_2["choice_p"][0, 0],
209
+ sim_output_2["rts"].mean(),
210
+ sim_output_2["metadata"]["s"],
211
+ ],
212
+ },
213
+ index=["Model", "Choice Probability", "Mean RT", "Noise SD"],
214
+ )
215
+
216
+ col3, col4 = st.columns(2)
217
+ with col3:
218
+ if (
219
+ len(sim_output_1["metadata"]["possible_choices"])
220
+ == 2 | len(sim_output_2["metadata"]["possible_choices"])
221
+ == 2
222
+ ):
223
+ figure_placeholder_3 = st.empty()
224
+
225
+ # Plot the simulated data
226
+ fig3, ax3 = plt.subplots()
227
+ ax3.hist(
228
+ sim_output_1["rts"][np.abs(sim_output_1["rts"]) != 999]
229
+ * sim_output_1["choices"][np.abs(sim_output_1["rts"] != 999)],
230
+ histtype="step",
231
+ bins=50,
232
+ density=True,
233
+ color="blue",
234
+ fill=None,
235
+ label="Model 1",
236
+ )
237
+ ax3.hist(
238
+ sim_output_2["rts"][np.abs(sim_output_2["rts"]) != 999]
239
+ * sim_output_2["choices"][np.abs(sim_output_2["rts"] != 999)],
240
+ histtype="step",
241
+ bins=50,
242
+ density=True,
243
+ color="red",
244
+ fill=None,
245
+ label="Model 2",
246
+ )
247
+ ax3.legend()
248
+ ax3.set_xlabel("RT")
249
+ ax3.set_xlim(-5, 5)
250
+ figure_placeholder_3.pyplot(fig3)
251
+ else:
252
+ # TODO: Implement better comparison plot
253
+ # for models with more than 2 choice options
254
+ pass
255
+ with col4:
256
+ st.dataframe(metadata)
src/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import utils
2
+
3
+ __all__ = ["utils"]
src/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (276 Bytes). View file
 
src/utils/__pycache__/utils.cpython-311.pyc ADDED
Binary file (26 kB). View file
 
src/utils/old_plots.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def _plot_func_model(
2
+ bottom_node,
3
+ axis,
4
+ value_range=None,
5
+ samples=10,
6
+ bin_size=0.05,
7
+ add_data_rts=True,
8
+ add_data_model=True,
9
+ add_data_model_keep_slope=True,
10
+ add_data_model_keep_boundary=True,
11
+ add_data_model_keep_ndt=True,
12
+ add_data_model_keep_starting_point=True,
13
+ add_data_model_markersize_starting_point=50,
14
+ add_data_model_markertype_starting_point=0,
15
+ add_data_model_markershift_starting_point=0,
16
+ add_posterior_uncertainty_model=False,
17
+ add_posterior_uncertainty_rts=False,
18
+ add_posterior_mean_model=True,
19
+ add_posterior_mean_rts=True,
20
+ add_trajectories=False,
21
+ data_label="Data",
22
+ secondary_data=None,
23
+ secondary_data_label=None,
24
+ secondary_data_color="blue",
25
+ linewidth_histogram=0.5,
26
+ linewidth_model=0.5,
27
+ legend_fontsize=12,
28
+ legend_shadow=True,
29
+ legend_location="upper right",
30
+ data_color="blue",
31
+ posterior_mean_color="red",
32
+ posterior_uncertainty_color="black",
33
+ alpha=0.05,
34
+ delta_t_model=0.01,
35
+ add_legend=True, # keep_frame=False,
36
+ **kwargs,
37
+ ):
38
+ """Calculate posterior predictive for a certain bottom node.
39
+
40
+ Arguments:
41
+ bottom_node: pymc.stochastic
42
+ Bottom node to compute posterior over.
43
+
44
+ axis: matplotlib.axis
45
+ Axis to plot into.
46
+
47
+ value_range: numpy.ndarray
48
+ Range over which to evaluate the likelihood.
49
+
50
+ Optional:
51
+ samples: int <default=10>
52
+ Number of posterior samples to use.
53
+
54
+ bin_size: float <default=0.05>
55
+ Size of bins used for histograms
56
+
57
+ alpha: float <default=0.05>
58
+ alpha (transparency) level for the sample-wise elements of the plot
59
+
60
+ add_data_rts: bool <default=True>
61
+ Add data histogram of rts ?
62
+
63
+ add_data_model: bool <default=True>
64
+ Add model cartoon for data
65
+
66
+ add_posterior_uncertainty_rts: bool <default=True>
67
+ Add sample by sample histograms?
68
+
69
+ add_posterior_mean_rts: bool <default=True>
70
+ Add a mean posterior?
71
+
72
+ add_model: bool <default=True>
73
+ Whether to add model cartoons to the plot.
74
+
75
+ linewidth_histogram: float <default=0.5>
76
+ linewdith of histrogram plot elements.
77
+
78
+ linewidth_model: float <default=0.5>
79
+ linewidth of plot elements concerning the model cartoons.
80
+
81
+ legend_location: str <default='upper right'>
82
+ string defining legend position. Find the rest of the options in the matplotlib documentation.
83
+
84
+ legend_shadow: bool <default=True>
85
+ Add shadow to legend box?
86
+
87
+ legend_fontsize: float <default=12>
88
+ Fontsize of legend.
89
+
90
+ data_color : str <default="blue">
91
+ Color for the data part of the plot.
92
+
93
+ posterior_mean_color : str <default="red">
94
+ Color for the posterior mean part of the plot.
95
+
96
+ posterior_uncertainty_color : str <default="black">
97
+ Color for the posterior uncertainty part of the plot.
98
+
99
+ delta_t_model:
100
+ specifies plotting intervals for model cartoon elements of the graphs.
101
+ """
102
+
103
+ # AF-TODO: Add a mean version of this!
104
+ if value_range is None:
105
+ # Infer from data by finding the min and max from the nodes
106
+ raise NotImplementedError("value_range keyword argument must be supplied.")
107
+
108
+ if len(value_range) > 2:
109
+ value_range = (value_range[0], value_range[-1])
110
+
111
+ # Extract some parameters from kwargs
112
+ bins = np.arange(value_range[0], value_range[-1], bin_size)
113
+
114
+ # If bottom_node is a DataFrame we know that we are just plotting real data
115
+ if type(bottom_node) == pd.DataFrame:
116
+ samples_tmp = [bottom_node]
117
+ data_tmp = None
118
+ else:
119
+ samples_tmp = _post_pred_generate(
120
+ bottom_node,
121
+ samples=samples,
122
+ data=None,
123
+ append_data=False,
124
+ add_model_parameters=True,
125
+ )
126
+ data_tmp = bottom_node.value.copy()
127
+
128
+ # Relevant for recovery mode
129
+ node_data_full = kwargs.pop("node_data", None)
130
+
131
+ tmp_model = kwargs.pop("model_", "angle")
132
+ if len(model_config[tmp_model]["choices"]) > 2:
133
+ raise ValueError("The model plot works only for 2 choice models at the moment")
134
+
135
+ # ---------------------------
136
+
137
+ ylim = kwargs.pop("ylim", 3)
138
+ hist_bottom = kwargs.pop("hist_bottom", 2)
139
+ hist_histtype = kwargs.pop("hist_histtype", "step")
140
+
141
+ if ("ylim_high" in kwargs) and ("ylim_low" in kwargs):
142
+ ylim_high = kwargs["ylim_high"]
143
+ ylim_low = kwargs["ylim_low"]
144
+ else:
145
+ ylim_high = ylim
146
+ ylim_low = -ylim
147
+
148
+ if ("hist_bottom_high" in kwargs) and ("hist_bottom_low" in kwargs):
149
+ hist_bottom_high = kwargs["hist_bottom_high"]
150
+ hist_bottom_low = kwargs["hist_bottom_low"]
151
+ else:
152
+ hist_bottom_high = hist_bottom
153
+ hist_bottom_low = hist_bottom
154
+
155
+ axis.set_xlim(value_range[0], value_range[-1])
156
+ axis.set_ylim(ylim_low, ylim_high)
157
+ axis_twin_up = axis.twinx()
158
+ axis_twin_down = axis.twinx()
159
+ axis_twin_up.set_ylim(ylim_low, ylim_high)
160
+ axis_twin_up.set_yticks([])
161
+ axis_twin_down.set_ylim(ylim_high, ylim_low)
162
+ axis_twin_down.set_yticks([])
163
+ axis_twin_down.set_axis_off()
164
+ axis_twin_up.set_axis_off()
165
+
166
+ # ADD HISTOGRAMS
167
+ # -------------------------------
168
+ # POSTERIOR BASED HISTOGRAM
169
+ if add_posterior_uncertainty_rts: # add_uc_rts:
170
+ j = 0
171
+ for sample in samples_tmp:
172
+ tmp_label = None
173
+
174
+ if add_legend and j == 0:
175
+ tmp_label = "PostPred"
176
+
177
+ weights_up = np.tile(
178
+ (1 / bin_size) / sample.shape[0],
179
+ reps=sample.loc[sample.response == 1, :].shape[0],
180
+ )
181
+ weights_down = np.tile(
182
+ (1 / bin_size) / sample.shape[0],
183
+ reps=sample.loc[(sample.response != 1), :].shape[0],
184
+ )
185
+
186
+ axis_twin_up.hist(
187
+ np.abs(sample.rt[sample.response == 1]),
188
+ bins=bins,
189
+ weights=weights_up,
190
+ histtype=hist_histtype,
191
+ bottom=hist_bottom_high,
192
+ alpha=alpha,
193
+ color=posterior_uncertainty_color,
194
+ edgecolor=posterior_uncertainty_color,
195
+ zorder=-1,
196
+ label=tmp_label,
197
+ linewidth=linewidth_histogram,
198
+ )
199
+
200
+ axis_twin_down.hist(
201
+ np.abs(sample.loc[(sample.response != 1), :].rt),
202
+ bins=bins,
203
+ weights=weights_down,
204
+ histtype=hist_histtype,
205
+ bottom=hist_bottom_low,
206
+ alpha=alpha,
207
+ color=posterior_uncertainty_color,
208
+ edgecolor=posterior_uncertainty_color,
209
+ linewidth=linewidth_histogram,
210
+ zorder=-1,
211
+ )
212
+ j += 1
213
+
214
+ if add_posterior_mean_rts: # add_mean_rts:
215
+ concat_data = pd.concat(samples_tmp)
216
+ tmp_label = None
217
+
218
+ if add_legend:
219
+ tmp_label = "PostPred Mean"
220
+
221
+ weights_up = np.tile(
222
+ (1 / bin_size) / concat_data.shape[0],
223
+ reps=concat_data.loc[concat_data.response == 1, :].shape[0],
224
+ )
225
+ weights_down = np.tile(
226
+ (1 / bin_size) / concat_data.shape[0],
227
+ reps=concat_data.loc[(concat_data.response != 1), :].shape[0],
228
+ )
229
+
230
+ axis_twin_up.hist(
231
+ np.abs(concat_data.rt[concat_data.response == 1]),
232
+ bins=bins,
233
+ weights=weights_up,
234
+ histtype=hist_histtype,
235
+ bottom=hist_bottom_high,
236
+ alpha=1.0,
237
+ color=posterior_mean_color,
238
+ edgecolor=posterior_mean_color,
239
+ zorder=-1,
240
+ label=tmp_label,
241
+ linewidth=linewidth_histogram,
242
+ )
243
+
244
+ axis_twin_down.hist(
245
+ np.abs(concat_data.loc[(concat_data.response != 1), :].rt),
246
+ bins=bins,
247
+ weights=weights_down,
248
+ histtype=hist_histtype,
249
+ bottom=hist_bottom_low,
250
+ alpha=1.0,
251
+ color=posterior_mean_color,
252
+ edgecolor=posterior_mean_color,
253
+ linewidth=linewidth_histogram,
254
+ zorder=-1,
255
+ )
256
+
257
+ # DATA HISTOGRAM
258
+ if (data_tmp is not None) and add_data_rts:
259
+ tmp_label = None
260
+ if add_legend:
261
+ tmp_label = data_label
262
+
263
+ weights_up = np.tile(
264
+ (1 / bin_size) / data_tmp.shape[0],
265
+ reps=data_tmp[data_tmp.response == 1].shape[0],
266
+ )
267
+ weights_down = np.tile(
268
+ (1 / bin_size) / data_tmp.shape[0],
269
+ reps=data_tmp[(data_tmp.response != 1)].shape[0],
270
+ )
271
+
272
+ axis_twin_up.hist(
273
+ np.abs(data_tmp[data_tmp.response == 1].rt),
274
+ bins=bins,
275
+ weights=weights_up,
276
+ histtype=hist_histtype,
277
+ bottom=hist_bottom_high,
278
+ alpha=1,
279
+ color=data_color,
280
+ edgecolor=data_color,
281
+ label=tmp_label,
282
+ zorder=-1,
283
+ linewidth=linewidth_histogram,
284
+ )
285
+
286
+ axis_twin_down.hist(
287
+ np.abs(data_tmp[(data_tmp.response != 1)].rt),
288
+ bins=bins,
289
+ weights=weights_down,
290
+ histtype=hist_histtype,
291
+ bottom=hist_bottom_low,
292
+ alpha=1,
293
+ color=data_color,
294
+ edgecolor=data_color,
295
+ linewidth=linewidth_histogram,
296
+ zorder=-1,
297
+ )
298
+
299
+ # SECONDARY DATA HISTOGRAM
300
+ if secondary_data is not None:
301
+ tmp_label = None
302
+ if add_legend:
303
+ if secondary_data_label is not None:
304
+ tmp_label = secondary_data_label
305
+
306
+ weights_up = np.tile(
307
+ (1 / bin_size) / secondary_data.shape[0],
308
+ reps=secondary_data[secondary_data.response == 1].shape[0],
309
+ )
310
+ weights_down = np.tile(
311
+ (1 / bin_size) / secondary_data.shape[0],
312
+ reps=secondary_data[(secondary_data.response != 1)].shape[0],
313
+ )
314
+
315
+ axis_twin_up.hist(
316
+ np.abs(secondary_data[secondary_data.response == 1].rt),
317
+ bins=bins,
318
+ weights=weights_up,
319
+ histtype=hist_histtype,
320
+ bottom=hist_bottom_high,
321
+ alpha=1,
322
+ color=secondary_data_color,
323
+ edgecolor=secondary_data_color,
324
+ label=tmp_label,
325
+ zorder=-100,
326
+ linewidth=linewidth_histogram,
327
+ )
328
+
329
+ axis_twin_down.hist(
330
+ np.abs(secondary_data[(secondary_data.response != 1)].rt),
331
+ bins=bins,
332
+ weights=weights_down,
333
+ histtype=hist_histtype,
334
+ bottom=hist_bottom_low,
335
+ alpha=1,
336
+ color=secondary_data_color,
337
+ edgecolor=secondary_data_color,
338
+ linewidth=linewidth_histogram,
339
+ zorder=-100,
340
+ )
341
+ # -------------------------------
342
+
343
+ if add_legend:
344
+ if data_tmp is not None:
345
+ axis_twin_up.legend(
346
+ fontsize=legend_fontsize, shadow=legend_shadow, loc=legend_location
347
+ )
348
+
349
+ # ADD MODEL:
350
+ j = 0
351
+ t_s = np.arange(0, value_range[-1], delta_t_model)
352
+
353
+ # MAKE BOUNDS (FROM MODEL CONFIG) !
354
+ if add_posterior_uncertainty_model: # add_uc_model:
355
+ for sample in samples_tmp:
356
+ _add_model_cartoon_to_ax(
357
+ sample=sample,
358
+ axis=axis,
359
+ tmp_model=tmp_model,
360
+ keep_slope=add_data_model_keep_slope,
361
+ keep_boundary=add_data_model_keep_boundary,
362
+ keep_ndt=add_data_model_keep_ndt,
363
+ keep_starting_point=add_data_model_keep_starting_point,
364
+ markersize_starting_point=add_data_model_markersize_starting_point,
365
+ markertype_starting_point=add_data_model_markertype_starting_point,
366
+ markershift_starting_point=add_data_model_markershift_starting_point,
367
+ delta_t_graph=delta_t_model,
368
+ sample_hist_alpha=alpha,
369
+ lw_m=linewidth_model,
370
+ tmp_label=tmp_label,
371
+ ylim_low=ylim_low,
372
+ ylim_high=ylim_high,
373
+ t_s=t_s,
374
+ color=posterior_uncertainty_color,
375
+ zorder_cnt=j,
376
+ )
377
+
378
+ if (node_data_full is not None) and add_data_model:
379
+ _add_model_cartoon_to_ax(
380
+ sample=node_data_full,
381
+ axis=axis,
382
+ tmp_model=tmp_model,
383
+ keep_slope=add_data_model_keep_slope,
384
+ keep_boundary=add_data_model_keep_boundary,
385
+ keep_ndt=add_data_model_keep_ndt,
386
+ keep_starting_point=add_data_model_keep_starting_point,
387
+ markersize_starting_point=add_data_model_markersize_starting_point,
388
+ markertype_starting_point=add_data_model_markertype_starting_point,
389
+ markershift_starting_point=add_data_model_markershift_starting_point,
390
+ delta_t_graph=delta_t_model,
391
+ sample_hist_alpha=1.0,
392
+ lw_m=linewidth_model + 0.5,
393
+ tmp_label=None,
394
+ ylim_low=ylim_low,
395
+ ylim_high=ylim_high,
396
+ t_s=t_s,
397
+ color=data_color,
398
+ zorder_cnt=j + 1,
399
+ )
400
+
401
+ if add_posterior_mean_model: # add_mean_model:
402
+ tmp_label = None
403
+ if add_legend:
404
+ tmp_label = "PostPred Mean"
405
+
406
+ _add_model_cartoon_to_ax(
407
+ sample=pd.DataFrame(pd.concat(samples_tmp).mean().astype(np.float32)).T,
408
+ axis=axis,
409
+ tmp_model=tmp_model,
410
+ keep_slope=add_data_model_keep_slope,
411
+ keep_boundary=add_data_model_keep_boundary,
412
+ keep_ndt=add_data_model_keep_ndt,
413
+ keep_starting_point=add_data_model_keep_starting_point,
414
+ markersize_starting_point=add_data_model_markersize_starting_point,
415
+ markertype_starting_point=add_data_model_markertype_starting_point,
416
+ markershift_starting_point=add_data_model_markershift_starting_point,
417
+ delta_t_graph=delta_t_model,
418
+ sample_hist_alpha=1.0,
419
+ lw_m=linewidth_model + 0.5,
420
+ tmp_label=None,
421
+ ylim_low=ylim_low,
422
+ ylim_high=ylim_high,
423
+ t_s=t_s,
424
+ color=posterior_mean_color,
425
+ zorder_cnt=j + 2,
426
+ )
427
+
428
+ if add_trajectories:
429
+ _add_trajectories(
430
+ axis=axis,
431
+ sample=samples_tmp[0],
432
+ tmp_model=tmp_model,
433
+ t_s=t_s,
434
+ delta_t_graph=delta_t_model,
435
+ **kwargs,
436
+ )
437
+
438
+
439
+ # AF-TODO: Add documentation for this function
440
+ def _add_trajectories(
441
+ axis=None,
442
+ sample=None,
443
+ t_s=None,
444
+ delta_t_graph=0.01,
445
+ tmp_model=None,
446
+ n_trajectories=10,
447
+ supplied_trajectory=None,
448
+ maxid_supplied_trajectory=1, # useful for gifs
449
+ highlight_trajectory_rt_choice=True,
450
+ markersize_trajectory_rt_choice=50,
451
+ markertype_trajectory_rt_choice="*",
452
+ markercolor_trajectory_rt_choice="red",
453
+ linewidth_trajectories=1,
454
+ alpha_trajectories=0.5,
455
+ color_trajectories="black",
456
+ **kwargs,
457
+ ):
458
+ # Check markercolor type
459
+ if type(markercolor_trajectory_rt_choice) == str:
460
+ markercolor_trajectory_rt_choice_dict = {}
461
+ for value_ in model_config[tmp_model]["choices"]:
462
+ markercolor_trajectory_rt_choice_dict[
463
+ value_
464
+ ] = markercolor_trajectory_rt_choice
465
+ elif type(markercolor_trajectory_rt_choice) == list:
466
+ cnt = 0
467
+ for value_ in model_config[tmp_model]["choices"]:
468
+ markercolor_trajectory_rt_choice_dict[
469
+ value_
470
+ ] = markercolor_trajectory_rt_choice[cnt]
471
+ cnt += 1
472
+ elif type(markercolor_trajectory_rt_choice) == dict:
473
+ markercolor_trajectory_rt_choice_dict = markercolor_trajectory_rt_choice
474
+ else:
475
+ pass
476
+
477
+ # Check trajectory color type
478
+ if type(color_trajectories) == str:
479
+ color_trajectories_dict = {}
480
+ for value_ in model_config[tmp_model]["choices"]:
481
+ color_trajectories_dict[value_] = color_trajectories
482
+ elif type(color_trajectories) == list:
483
+ cnt = 0
484
+ for value_ in model_config[tmp_model]["choices"]:
485
+ color_trajectories_dict[value_] = color_trajectories[cnt]
486
+ cnt += 1
487
+ elif type(color_trajectories) == dict:
488
+ color_trajectories_dict = color_trajectories
489
+ else:
490
+ pass
491
+
492
+ # Make bounds
493
+ (b_low, b_high) = _make_bounds(
494
+ tmp_model=tmp_model,
495
+ sample=sample,
496
+ delta_t_graph=delta_t_graph,
497
+ t_s=t_s,
498
+ return_shifted_by_ndt=False,
499
+ )
500
+
501
+ # Trajectories
502
+ if supplied_trajectory is None:
503
+ for i in range(n_trajectories):
504
+ rand_int = np.random.choice(400000000)
505
+ out_traj = simulator(
506
+ theta=sample[model_config[tmp_model]["params"]].values[0],
507
+ model=tmp_model,
508
+ n_samples=1,
509
+ no_noise=False,
510
+ delta_t=delta_t_graph,
511
+ bin_dim=None,
512
+ random_state=rand_int,
513
+ )
514
+
515
+ tmp_traj = out_traj[2]["trajectory"]
516
+ tmp_traj_choice = float(out_traj[1].flatten())
517
+ maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0])
518
+
519
+ # Identify boundary value at timepoint of crossing
520
+ b_tmp = b_high[maxid] if tmp_traj_choice > 0 else b_low[maxid]
521
+
522
+ axis.plot(
523
+ t_s[:maxid] + sample.t.values[0],
524
+ tmp_traj[:maxid],
525
+ color=color_trajectories_dict[tmp_traj_choice],
526
+ alpha=alpha_trajectories,
527
+ linewidth=linewidth_trajectories,
528
+ zorder=2000 + i,
529
+ )
530
+
531
+ if highlight_trajectory_rt_choice:
532
+ axis.scatter(
533
+ t_s[maxid] + sample.t.values[0],
534
+ b_tmp,
535
+ # tmp_traj[maxid],
536
+ markersize_trajectory_rt_choice,
537
+ color=markercolor_trajectory_rt_choice_dict[tmp_traj_choice],
538
+ alpha=1,
539
+ marker=markertype_trajectory_rt_choice,
540
+ zorder=2000 + i,
541
+ )
542
+
543
+ else:
544
+ if len(supplied_trajectory["trajectories"].shape) == 1:
545
+ supplied_trajectory["trajectories"] = np.expand_dims(
546
+ supplied_trajectory["trajectories"], axis=0
547
+ )
548
+
549
+ for j in range(supplied_trajectory["trajectories"].shape[0]):
550
+ maxid = np.minimum(
551
+ np.argmax(np.where(supplied_trajectory["trajectories"][j, :] > -999)),
552
+ t_s.shape[0],
553
+ )
554
+ if j == (supplied_trajectory["trajectories"].shape[0] - 1):
555
+ maxid_traj = min(maxid, maxid_supplied_trajectory)
556
+ else:
557
+ maxid_traj = maxid
558
+
559
+ axis.plot(
560
+ t_s[:maxid_traj] + sample.t.values[0],
561
+ supplied_trajectory["trajectories"][j, :maxid_traj],
562
+ color=color_trajectories_dict[
563
+ supplied_trajectory["trajectory_choices"][j]
564
+ ], # color_trajectories,
565
+ alpha=alpha_trajectories,
566
+ linewidth=linewidth_trajectories,
567
+ zorder=2000 + j,
568
+ )
569
+
570
+ # Identify boundary value at timepoint of crossing
571
+ b_tmp = (
572
+ b_high[maxid_traj]
573
+ if supplied_trajectory["trajectory_choices"][j] > 0
574
+ else b_low[maxid_traj]
575
+ )
576
+
577
+ if maxid_traj == maxid:
578
+ if highlight_trajectory_rt_choice:
579
+ axis.scatter(
580
+ t_s[maxid_traj] + sample.t.values[0],
581
+ b_tmp,
582
+ # supplied_trajectory['trajectories'][j, maxid_traj],
583
+ markersize_trajectory_rt_choice,
584
+ color=markercolor_trajectory_rt_choice_dict[
585
+ supplied_trajectory["trajectory_choices"][j]
586
+ ], # markercolor_trajectory_rt_choice,
587
+ alpha=1,
588
+ marker=markertype_trajectory_rt_choice,
589
+ zorder=2000 + j,
590
+ )
591
+
592
+
593
+ # AF-TODO: Add documentation to this function
594
+ def _add_model_cartoon_to_ax(
595
+ sample=None,
596
+ axis=None,
597
+ tmp_model=None,
598
+ keep_slope=True,
599
+ keep_boundary=True,
600
+ keep_ndt=True,
601
+ keep_starting_point=True,
602
+ markersize_starting_point=80,
603
+ markertype_starting_point=1,
604
+ markershift_starting_point=-0.05,
605
+ delta_t_graph=None,
606
+ sample_hist_alpha=None,
607
+ lw_m=None,
608
+ tmp_label=None,
609
+ ylim_low=None,
610
+ ylim_high=None,
611
+ t_s=None,
612
+ zorder_cnt=1,
613
+ color="black",
614
+ ):
615
+ # Make bounds
616
+ b_low, b_high = _make_bounds(
617
+ tmp_model=tmp_model,
618
+ sample=sample,
619
+ delta_t_graph=delta_t_graph,
620
+ t_s=t_s,
621
+ return_shifted_by_ndt=True,
622
+ )
623
+
624
+ # MAKE SLOPES (VIA TRAJECTORIES HERE --> RUN NOISE FREE SIMULATIONS)!
625
+ out = simulator(
626
+ theta=sample[model_config[tmp_model]["params"]].values[0],
627
+ model=tmp_model,
628
+ n_samples=1,
629
+ no_noise=True,
630
+ delta_t=delta_t_graph,
631
+ bin_dim=None,
632
+ )
633
+
634
+ tmp_traj = out[2]["trajectory"]
635
+ maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0])
636
+
637
+ if "hddm_base" in tmp_model:
638
+ a_tmp = sample.a.values[0] / 2
639
+ tmp_traj = tmp_traj - a_tmp
640
+
641
+ if keep_boundary:
642
+ # Upper bound
643
+ axis.plot(
644
+ t_s, # + sample.t.values[0],
645
+ b_high,
646
+ color=color,
647
+ alpha=sample_hist_alpha,
648
+ zorder=1000 + zorder_cnt,
649
+ linewidth=lw_m,
650
+ label=tmp_label,
651
+ )
652
+
653
+ # Lower bound
654
+ axis.plot(
655
+ t_s, # + sample.t.values[0],
656
+ b_low,
657
+ color=color,
658
+ alpha=sample_hist_alpha,
659
+ zorder=1000 + zorder_cnt,
660
+ linewidth=lw_m,
661
+ )
662
+
663
+ # Slope
664
+ if keep_slope:
665
+ axis.plot(
666
+ t_s[:maxid] + sample.t.values[0],
667
+ tmp_traj[:maxid],
668
+ color=color,
669
+ alpha=sample_hist_alpha,
670
+ zorder=1000 + zorder_cnt,
671
+ linewidth=lw_m,
672
+ ) # TOOK AWAY LABEL
673
+
674
+ # Non-decision time
675
+ if keep_ndt:
676
+ axis.axvline(
677
+ x=sample.t.values[0],
678
+ ymin=ylim_low,
679
+ ymax=ylim_high,
680
+ color=color,
681
+ linestyle="--",
682
+ linewidth=lw_m,
683
+ zorder=1000 + zorder_cnt,
684
+ alpha=sample_hist_alpha,
685
+ )
686
+ # Starting point
687
+ if keep_starting_point:
688
+ axis.scatter(
689
+ sample.t.values[0] + markershift_starting_point,
690
+ b_low[0] + (sample.z.values[0] * (b_high[0] - b_low[0])),
691
+ markersize_starting_point,
692
+ marker=markertype_starting_point,
693
+ color=color,
694
+ alpha=sample_hist_alpha,
695
+ zorder=1000 + zorder_cnt,
696
+ )
697
+
698
+
699
+ def _make_bounds(
700
+ tmp_model=None,
701
+ sample=None,
702
+ delta_t_graph=None,
703
+ t_s=None,
704
+ return_shifted_by_ndt=True,
705
+ ):
706
+ # MULTIPLICATIVE BOUND
707
+ if tmp_model == "weibull" or tmp_model == "weibull_cdf":
708
+ b = np.maximum(
709
+ sample.a.values[0]
710
+ * model_config[tmp_model]["boundary"](
711
+ t=t_s, alpha=sample.alpha.values[0], beta=sample.beta.values[0]
712
+ ),
713
+ 0,
714
+ )
715
+
716
+ # Move boundary forward by the non-decision time
717
+ b_raw_high = deepcopy(b)
718
+ b_raw_low = deepcopy(-b)
719
+ b_init_val = b[0]
720
+ t_shift = np.arange(0, sample.t.values[0], delta_t_graph).shape[0]
721
+ b = np.roll(b, t_shift)
722
+ b[:t_shift] = b_init_val
723
+
724
+ # ADDITIVE BOUND
725
+ elif tmp_model == "angle":
726
+ b = np.maximum(
727
+ sample.a.values[0]
728
+ + model_config[tmp_model]["boundary"](t=t_s, theta=sample.theta.values[0]),
729
+ 0,
730
+ )
731
+
732
+ b_raw_high = deepcopy(b)
733
+ b_raw_low = deepcopy(-b)
734
+ # Move boundary forward by the non-decision time
735
+ b_init_val = b[0]
736
+ t_shift = np.arange(0, sample.t.values[0], delta_t_graph).shape[0]
737
+ b = np.roll(b, t_shift)
738
+ b[:t_shift] = b_init_val
739
+
740
+ # CONSTANT BOUND
741
+ elif (
742
+ tmp_model == "ddm"
743
+ or tmp_model == "ornstein"
744
+ or tmp_model == "levy"
745
+ or tmp_model == "full_ddm"
746
+ or tmp_model == "ddm_hddm_base"
747
+ or tmp_model == "full_ddm_hddm_base"
748
+ ):
749
+ b = sample.a.values[0] * np.ones(t_s.shape[0])
750
+
751
+ if "hddm_base" in tmp_model:
752
+ b = (sample.a.values[0] / 2) * np.ones(t_s.shape[0])
753
+
754
+ b_raw_high = b
755
+ b_raw_low = -b
756
+
757
+ # Separate out upper and lower bound:
758
+ b_high = b
759
+ b_low = -b
760
+
761
+ if return_shifted_by_ndt:
762
+ return (b_low, b_high)
763
+ else:
764
+ return (b_raw_low, b_raw_high)
765
+
766
+
767
+ def _plot_func_model_n(
768
+ bottom_node,
769
+ axis,
770
+ value_range=None,
771
+ samples=10,
772
+ bin_size=0.05,
773
+ add_posterior_uncertainty_model=False,
774
+ add_posterior_uncertainty_rts=False,
775
+ add_posterior_mean_model=True,
776
+ add_posterior_mean_rts=True,
777
+ linewidth_histogram=0.5,
778
+ linewidth_model=0.5,
779
+ legend_fontsize=7,
780
+ legend_shadow=True,
781
+ legend_location="upper right",
782
+ delta_t_model=0.01,
783
+ add_legend=True,
784
+ alpha=0.01,
785
+ keep_frame=False,
786
+ **kwargs,
787
+ ):
788
+ """Calculate posterior predictive for a certain bottom node.
789
+
790
+ Arguments:
791
+ bottom_node: pymc.stochastic
792
+ Bottom node to compute posterior over.
793
+
794
+ axis: matplotlib.axis
795
+ Axis to plot into.
796
+
797
+ value_range: numpy.ndarray
798
+ Range over which to evaluate the likelihood.
799
+
800
+ Optional:
801
+ samples: int <default=10>
802
+ Number of posterior samples to use.
803
+
804
+ bin_size: float <default=0.05>
805
+ Size of bins used for histograms
806
+
807
+ alpha: float <default=0.05>
808
+ alpha (transparency) level for the sample-wise elements of the plot
809
+
810
+ add_posterior_uncertainty_rts: bool <default=True>
811
+ Add sample by sample histograms?
812
+
813
+ add_posterior_mean_rts: bool <default=True>
814
+ Add a mean posterior?
815
+
816
+ add_model: bool <default=True>
817
+ Whether to add model cartoons to the plot.
818
+
819
+ linewidth_histogram: float <default=0.5>
820
+ linewdith of histrogram plot elements.
821
+
822
+ linewidth_model: float <default=0.5>
823
+ linewidth of plot elements concerning the model cartoons.
824
+
825
+ legend_loc: str <default='upper right'>
826
+ string defining legend position. Find the rest of the options in the matplotlib documentation.
827
+
828
+ legend_shadow: bool <default=True>
829
+ Add shadow to legend box?
830
+
831
+ legend_fontsize: float <default=12>
832
+ Fontsize of legend.
833
+
834
+ data_color : str <default="blue">
835
+ Color for the data part of the plot.
836
+
837
+ posterior_mean_color : str <default="red">
838
+ Color for the posterior mean part of the plot.
839
+
840
+ posterior_uncertainty_color : str <default="black">
841
+ Color for the posterior uncertainty part of the plot.
842
+
843
+
844
+ delta_t_model:
845
+ specifies plotting intervals for model cartoon elements of the graphs.
846
+ """
847
+
848
+ color_dict = {
849
+ -1: "black",
850
+ 0: "black",
851
+ 1: "green",
852
+ 2: "blue",
853
+ 3: "red",
854
+ 4: "orange",
855
+ 5: "purple",
856
+ 6: "brown",
857
+ }
858
+
859
+ # AF-TODO: Add a mean version of this !
860
+ if value_range is None:
861
+ # Infer from data by finding the min and max from the nodes
862
+ raise NotImplementedError("value_range keyword argument must be supplied.")
863
+
864
+ if len(value_range) > 2:
865
+ value_range = (value_range[0], value_range[-1])
866
+
867
+ # Extract some parameters from kwargs
868
+ bins = np.arange(value_range[0], value_range[-1], bin_size)
869
+
870
+ # Relevant for recovery mode
871
+ node_data_full = kwargs.pop("node_data", None)
872
+ tmp_model = kwargs.pop("model_", "angle")
873
+
874
+ bottom = 0
875
+ # ------------
876
+ ylim = kwargs.pop("ylim", 3)
877
+
878
+ choices = model_config[tmp_model]["choices"]
879
+
880
+ # If bottom_node is a DataFrame we know that we are just plotting real data
881
+ if type(bottom_node) == pd.DataFrame:
882
+ samples_tmp = [bottom_node]
883
+ data_tmp = None
884
+ else:
885
+ samples_tmp = _post_pred_generate(
886
+ bottom_node,
887
+ samples=samples,
888
+ data=None,
889
+ append_data=False,
890
+ add_model_parameters=True,
891
+ )
892
+ data_tmp = bottom_node.value.copy()
893
+
894
+ axis.set_xlim(value_range[0], value_range[-1])
895
+ axis.set_ylim(0, ylim)
896
+
897
+ # ADD MODEL:
898
+ j = 0
899
+ t_s = np.arange(0, value_range[-1], delta_t_model)
900
+
901
+ # # MAKE BOUNDS (FROM MODEL CONFIG) !
902
+ if add_posterior_uncertainty_model: # add_uc_model:
903
+ for sample in samples_tmp:
904
+ tmp_label = None
905
+
906
+ if add_legend and (j == 0):
907
+ tmp_label = "PostPred"
908
+
909
+ _add_model_n_cartoon_to_ax(
910
+ sample=sample,
911
+ axis=axis,
912
+ tmp_model=tmp_model,
913
+ delta_t_graph=delta_t_model,
914
+ sample_hist_alpha=alpha,
915
+ lw_m=linewidth_model,
916
+ tmp_label=tmp_label,
917
+ linestyle="-",
918
+ ylim=ylim,
919
+ t_s=t_s,
920
+ color_dict=color_dict,
921
+ zorder_cnt=j,
922
+ )
923
+
924
+ j += 1
925
+
926
+ if add_posterior_mean_model: # add_mean_model:
927
+ tmp_label = None
928
+ if add_legend:
929
+ tmp_label = "PostPred Mean"
930
+
931
+ bottom = _add_model_n_cartoon_to_ax(
932
+ sample=pd.DataFrame(pd.concat(samples_tmp).mean().astype(np.float32)).T,
933
+ axis=axis,
934
+ tmp_model=tmp_model,
935
+ delta_t_graph=delta_t_model,
936
+ sample_hist_alpha=1.0,
937
+ lw_m=linewidth_model + 0.5,
938
+ linestyle="-",
939
+ tmp_label=None,
940
+ ylim=ylim,
941
+ t_s=t_s,
942
+ color_dict=color_dict,
943
+ zorder_cnt=j + 2,
944
+ )
945
+
946
+ if node_data_full is not None:
947
+ _add_model_n_cartoon_to_ax(
948
+ sample=node_data_full,
949
+ axis=axis,
950
+ tmp_model=tmp_model,
951
+ delta_t_graph=delta_t_model,
952
+ sample_hist_alpha=1.0,
953
+ lw_m=linewidth_model + 0.5,
954
+ linestyle="dashed",
955
+ tmp_label=None,
956
+ ylim=ylim,
957
+ t_s=t_s,
958
+ color_dict=color_dict,
959
+ zorder_cnt=j + 1,
960
+ )
961
+
962
+ # ADD HISTOGRAMS
963
+ # -------------------------------
964
+
965
+ # POSTERIOR BASED HISTOGRAM
966
+ if add_posterior_uncertainty_rts: # add_uc_rts:
967
+ j = 0
968
+ for sample in samples_tmp:
969
+ for choice in choices:
970
+ tmp_label = None
971
+
972
+ if add_legend and j == 0:
973
+ tmp_label = "PostPred"
974
+
975
+ weights = np.tile(
976
+ (1 / bin_size) / sample.shape[0],
977
+ reps=sample.loc[sample.response == choice, :].shape[0],
978
+ )
979
+
980
+ axis.hist(
981
+ np.abs(sample.rt[sample.response == choice]),
982
+ bins=bins,
983
+ bottom=bottom,
984
+ weights=weights,
985
+ histtype="step",
986
+ alpha=alpha,
987
+ color=color_dict[choice],
988
+ zorder=-1,
989
+ label=tmp_label,
990
+ linewidth=linewidth_histogram,
991
+ )
992
+ j += 1
993
+
994
+ if add_posterior_mean_rts:
995
+ concat_data = pd.concat(samples_tmp)
996
+ for choice in choices:
997
+ tmp_label = None
998
+ if add_legend and (choice == choices[0]):
999
+ tmp_label = "PostPred Mean"
1000
+
1001
+ weights = np.tile(
1002
+ (1 / bin_size) / concat_data.shape[0],
1003
+ reps=concat_data.loc[concat_data.response == choice, :].shape[0],
1004
+ )
1005
+
1006
+ axis.hist(
1007
+ np.abs(concat_data.rt[concat_data.response == choice]),
1008
+ bins=bins,
1009
+ bottom=bottom,
1010
+ weights=weights,
1011
+ histtype="step",
1012
+ alpha=1.0,
1013
+ color=color_dict[choice],
1014
+ zorder=-1,
1015
+ label=tmp_label,
1016
+ linewidth=linewidth_histogram,
1017
+ )
1018
+
1019
+ # DATA HISTOGRAM
1020
+ if data_tmp is not None:
1021
+ for choice in choices:
1022
+ tmp_label = None
1023
+ if add_legend and (choice == choices[0]):
1024
+ tmp_label = "Data"
1025
+
1026
+ weights = np.tile(
1027
+ (1 / bin_size) / data_tmp.shape[0],
1028
+ reps=data_tmp.loc[data_tmp.response == choice, :].shape[0],
1029
+ )
1030
+
1031
+ axis.hist(
1032
+ np.abs(data_tmp.rt[data_tmp.response == choice]),
1033
+ bins=bins,
1034
+ bottom=bottom,
1035
+ weights=weights,
1036
+ histtype="step",
1037
+ linestyle="dashed",
1038
+ alpha=1.0,
1039
+ color=color_dict[choice],
1040
+ edgecolor=color_dict[choice],
1041
+ zorder=-1,
1042
+ label=tmp_label,
1043
+ linewidth=linewidth_histogram,
1044
+ )
1045
+ # -------------------------------
1046
+
1047
+ if add_legend:
1048
+ if data_tmp is not None:
1049
+ custom_elems = [
1050
+ Line2D([0], [0], color=color_dict[choice], lw=1) for choice in choices
1051
+ ]
1052
+ custom_titles = ["response: " + str(choice) for choice in choices]
1053
+
1054
+ custom_elems.append(
1055
+ Line2D([0], [0], color="black", lw=1.0, linestyle="dashed")
1056
+ )
1057
+ custom_elems.append(Line2D([0], [0], color="black", lw=1.0, linestyle="-"))
1058
+ custom_titles.append("Data")
1059
+ custom_titles.append("Posterior")
1060
+
1061
+ axis.legend(
1062
+ custom_elems,
1063
+ custom_titles,
1064
+ fontsize=legend_fontsize,
1065
+ shadow=legend_shadow,
1066
+ loc=legend_location,
1067
+ )
1068
+
1069
+ # FRAME
1070
+ if not keep_frame:
1071
+ axis.set_frame_on(False)
1072
+
1073
+
1074
+ def _add_model_n_cartoon_to_ax(
1075
+ sample=None,
1076
+ axis=None,
1077
+ tmp_model=None,
1078
+ delta_t_graph=None,
1079
+ sample_hist_alpha=None,
1080
+ lw_m=None,
1081
+ linestyle="-",
1082
+ tmp_label=None,
1083
+ ylim=None,
1084
+ t_s=None,
1085
+ zorder_cnt=1,
1086
+ color_dict=None,
1087
+ ):
1088
+ if "weibull" in tmp_model:
1089
+ b = np.maximum(
1090
+ sample.a.values[0]
1091
+ * model_config[tmp_model]["boundary"](
1092
+ t=t_s, alpha=sample.alpha.values[0], beta=sample.beta.values[0]
1093
+ ),
1094
+ 0,
1095
+ )
1096
+
1097
+ elif "angle" in tmp_model:
1098
+ b = np.maximum(
1099
+ sample.a.values[0]
1100
+ + model_config[tmp_model]["boundary"](t=t_s, theta=sample.theta.values[0]),
1101
+ 0,
1102
+ )
1103
+
1104
+ else:
1105
+ b = sample.a.values[0] * np.ones(t_s.shape[0])
1106
+
1107
+ # Upper bound
1108
+ axis.plot(
1109
+ t_s + sample.t.values[0],
1110
+ b,
1111
+ color="black",
1112
+ alpha=sample_hist_alpha,
1113
+ zorder=1000 + zorder_cnt,
1114
+ linewidth=lw_m,
1115
+ linestyle=linestyle,
1116
+ label=tmp_label,
1117
+ )
1118
+
1119
+ # Starting point
1120
+ axis.axvline(
1121
+ x=sample.t.values[0],
1122
+ ymin=-ylim,
1123
+ ymax=ylim,
1124
+ color="black",
1125
+ linestyle=linestyle,
1126
+ linewidth=lw_m,
1127
+ alpha=sample_hist_alpha,
1128
+ )
1129
+
1130
+ # # MAKE SLOPES (VIA TRAJECTORIES HERE --> RUN NOISE FREE SIMULATIONS)!
1131
+ out = simulator(
1132
+ theta=sample[model_config[tmp_model]["params"]].values[0],
1133
+ model=tmp_model,
1134
+ n_samples=1,
1135
+ no_noise=True,
1136
+ delta_t=delta_t_graph,
1137
+ bin_dim=None,
1138
+ )
1139
+
1140
+ # # AF-TODO: Add trajectories
1141
+ tmp_traj = out[2]["trajectory"]
1142
+
1143
+ for i in range(len(model_config[tmp_model]["choices"])):
1144
+ tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, i] > -999)), t_s.shape[0])
1145
+
1146
+ # Slope
1147
+ axis.plot(
1148
+ t_s[:tmp_maxid] + sample.t.values[0],
1149
+ tmp_traj[:tmp_maxid, i],
1150
+ color=color_dict[i],
1151
+ linestyle=linestyle,
1152
+ alpha=sample_hist_alpha,
1153
+ zorder=1000 + zorder_cnt,
1154
+ linewidth=lw_m,
1155
+ ) # TOOK AWAY LABEL
1156
+
1157
+ return b[0]
src/utils/utils.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from ssms.config import model_config
5
+ from ssms.basic_simulators.simulator import simulator
6
+ from matplotlib.lines import Line2D
7
+
8
+
9
+ def plot_func_model(
10
+ model_name,
11
+ theta,
12
+ axis,
13
+ value_range=None,
14
+ n_samples=10,
15
+ bin_size=0.05,
16
+ add_data_rts=True,
17
+ add_data_model_keep_slope=True,
18
+ add_data_model_keep_boundary=True,
19
+ add_data_model_keep_ndt=True,
20
+ add_data_model_keep_starting_point=True,
21
+ add_data_model_markersize_starting_point=50,
22
+ add_data_model_markertype_starting_point=0,
23
+ add_data_model_markershift_starting_point=0,
24
+ n_trajectories = 0,
25
+ linewidth_histogram=0.5,
26
+ linewidth_model=0.5,
27
+ legend_fontsize=12,
28
+ legend_shadow=True,
29
+ legend_location="upper right",
30
+ data_color="blue",
31
+ posterior_uncertainty_color="black",
32
+ alpha=0.05,
33
+ delta_t_model=0.001,
34
+ random_state=None,
35
+ add_legend=True, # keep_frame=False,
36
+ **kwargs,
37
+ ):
38
+ """Calculate posterior predictive for a certain bottom node.
39
+
40
+ Arguments:
41
+ bottom_node: pymc.stochastic
42
+ Bottom node to compute posterior over.
43
+
44
+ axis: matplotlib.axis
45
+ Axis to plot into.
46
+
47
+ value_range: numpy.ndarray
48
+ Range over which to evaluate the likelihood.
49
+
50
+ Optional:
51
+ samples: int <default=10>
52
+ Number of posterior samples to use.
53
+
54
+ bin_size: float <default=0.05>
55
+ Size of bins used for histograms
56
+
57
+ alpha: float <default=0.05>
58
+ alpha (transparency) level for the sample-wise elements of the plot
59
+
60
+ add_data_rts: bool <default=True>
61
+ Add data histogram of rts ?
62
+
63
+ add_data_model: bool <default=True>
64
+ Add model cartoon for data
65
+
66
+ add_posterior_uncertainty_rts: bool <default=True>
67
+ Add sample by sample histograms?
68
+
69
+ add_posterior_mean_rts: bool <default=True>
70
+ Add a mean posterior?
71
+
72
+ add_model: bool <default=True>
73
+ Whether to add model cartoons to the plot.
74
+
75
+ linewidth_histogram: float <default=0.5>
76
+ linewdith of histrogram plot elements.
77
+
78
+ linewidth_model: float <default=0.5>
79
+ linewidth of plot elements concerning the model cartoons.
80
+
81
+ legend_location: str <default='upper right'>
82
+ string defining legend position. Find the rest of the options in the matplotlib documentation.
83
+
84
+ legend_shadow: bool <default=True>
85
+ Add shadow to legend box?
86
+
87
+ legend_fontsize: float <default=12>
88
+ Fontsize of legend.
89
+
90
+ data_color : str <default="blue">
91
+ Color for the data part of the plot.
92
+
93
+ posterior_mean_color : str <default="red">
94
+ Color for the posterior mean part of the plot.
95
+
96
+ posterior_uncertainty_color : str <default="black">
97
+ Color for the posterior uncertainty part of the plot.
98
+
99
+ delta_t_model:
100
+ specifies plotting intervals for model cartoon elements of the graphs.
101
+ """
102
+
103
+ if value_range is None:
104
+ # Infer from data by finding the min and max from the nodes
105
+ raise NotImplementedError("value_range keyword argument must be supplied.")
106
+
107
+ if len(value_range) > 2:
108
+ value_range = (value_range[0], value_range[-1])
109
+
110
+ # Extract some parameters from kwargs
111
+ bins = np.arange(value_range[0], value_range[-1], bin_size)
112
+
113
+ if model_config[model_name]["nchoices"] > 2:
114
+ raise ValueError("The model plot works only for 2 choice models at the moment")
115
+
116
+ # RUN SIMULATIONS
117
+ # -------------------------------
118
+
119
+ # Simulator Data
120
+ if random_state is not None:
121
+ np.random.seed(random_state)
122
+
123
+ rand_int = np.random.choice(400000000)
124
+ sim_out = simulator(model = model_name, theta = theta, n_samples = n_samples,
125
+ no_noise = False, delta_t = 0.001,
126
+ bin_dim = None, random_state = rand_int)
127
+
128
+ sim_out_traj = {}
129
+ for i in range(n_trajectories):
130
+ rand_int = np.random.choice(400000000)
131
+ sim_out_traj[i] = simulator(model = model_name, theta = theta, n_samples = 1,
132
+ no_noise = False, delta_t = 0.001,
133
+ bin_dim = None, random_state = rand_int, smooth_unif = False)
134
+
135
+ sim_out_no_noise = simulator(model = model_name, theta = theta, n_samples = 1,
136
+ no_noise = True, delta_t = 0.001,
137
+ bin_dim = None, smooth_unif = False)
138
+
139
+ # ADD DATA HISTOGRAMS
140
+ weights_up = np.tile(
141
+ (1 / bin_size) / sim_out['rts'][(sim_out['rts'] != -999)].shape[0],
142
+ reps=sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] == 1)].shape[0],
143
+ )
144
+ weights_down = np.tile(
145
+ (1 / bin_size) / sim_out['rts'][(sim_out['rts'] != -999)].shape[0],
146
+ reps=sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] != 1)].shape[0],
147
+ )
148
+
149
+ (b_high, b_low) = (np.maximum(sim_out['metadata']['boundary'], 0),
150
+ np.minimum((-1) * sim_out['metadata']['boundary'], 0))
151
+
152
+ # ADD HISTOGRAMS
153
+ # -------------------------------
154
+
155
+ ylim = kwargs.pop("ylim", 3)
156
+ #hist_bottom = kwargs.pop("hist_bottom", 2)
157
+ hist_histtype = kwargs.pop("hist_histtype", "step")
158
+
159
+ if ("ylim_high" in kwargs) and ("ylim_low" in kwargs):
160
+ ylim_high = kwargs["ylim_high"]
161
+ ylim_low = kwargs["ylim_low"]
162
+ else:
163
+ ylim_high = ylim
164
+ ylim_low = -ylim
165
+
166
+ if ("hist_bottom_high" in kwargs) and ("hist_bottom_low" in kwargs):
167
+ hist_bottom_high = kwargs["hist_bottom_high"]
168
+ hist_bottom_low = kwargs["hist_bottom_low"]
169
+ else:
170
+ hist_bottom_high = b_high[0] #hist_bottom
171
+ hist_bottom_low = -b_low[0] #hist_bottom
172
+
173
+ axis.set_xlim(value_range[0], value_range[-1])
174
+ axis.set_ylim(ylim_low, ylim_high)
175
+ axis_twin_up = axis.twinx()
176
+ axis_twin_down = axis.twinx()
177
+ axis_twin_up.set_ylim(ylim_low, ylim_high)
178
+ axis_twin_up.set_yticks([])
179
+ axis_twin_down.set_ylim(ylim_high, ylim_low)
180
+ axis_twin_down.set_yticks([])
181
+ axis_twin_down.set_axis_off()
182
+ axis_twin_up.set_axis_off()
183
+
184
+ axis_twin_up.hist(
185
+ np.abs(sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] == 1)]),
186
+ bins=bins,
187
+ weights=weights_up,
188
+ histtype=hist_histtype,
189
+ bottom=hist_bottom_high,
190
+ alpha=1,
191
+ color=data_color,
192
+ edgecolor=data_color,
193
+ linewidth=linewidth_histogram,
194
+ zorder=-1,
195
+ )
196
+
197
+ axis_twin_down.hist(
198
+ np.abs(sim_out['rts'][(sim_out['rts'] != -999) & (sim_out['choices'] != 1)]),
199
+ bins=bins,
200
+ weights=weights_down,
201
+ histtype=hist_histtype,
202
+ bottom=hist_bottom_low,
203
+ alpha=1,
204
+ color=data_color,
205
+ edgecolor=data_color,
206
+ linewidth=linewidth_histogram,
207
+ zorder=-1,
208
+ )
209
+
210
+ # ADD MODEL:
211
+ j = 0
212
+ t_s = np.arange(0, sim_out['metadata']['max_t'], delta_t_model) #value_range[-1], delta_t_model)
213
+
214
+ _add_model_cartoon_to_ax(
215
+ sample=sim_out_no_noise,
216
+ axis=axis,
217
+ keep_slope=add_data_model_keep_slope,
218
+ keep_boundary=add_data_model_keep_boundary,
219
+ keep_ndt=add_data_model_keep_ndt,
220
+ keep_starting_point=add_data_model_keep_starting_point,
221
+ markersize_starting_point=add_data_model_markersize_starting_point,
222
+ markertype_starting_point=add_data_model_markertype_starting_point,
223
+ markershift_starting_point=add_data_model_markershift_starting_point,
224
+ delta_t_graph=delta_t_model,
225
+ sample_hist_alpha=alpha,
226
+ lw_m=linewidth_model,
227
+ ylim_low=ylim_low,
228
+ ylim_high=ylim_high,
229
+ t_s=t_s,
230
+ color=posterior_uncertainty_color,
231
+ zorder_cnt=j,
232
+ )
233
+
234
+ if n_trajectories > 0:
235
+ _add_trajectories(
236
+ axis=axis,
237
+ sample=sim_out_traj,
238
+ t_s=t_s,
239
+ delta_t_graph=delta_t_model,
240
+ n_trajectories=n_trajectories,
241
+ **kwargs,
242
+ )
243
+
244
+ return axis
245
+
246
+ # AF-TODO: Add documentation for this function
247
+ def _add_trajectories(
248
+ axis=None,
249
+ sample=None,
250
+ t_s=None,
251
+ delta_t_graph=0.01,
252
+ n_trajectories=10,
253
+ supplied_trajectory=None,
254
+ maxid_supplied_trajectory=1, # useful for gifs
255
+ highlight_trajectory_rt_choice=True,
256
+ markersize_trajectory_rt_choice=50,
257
+ markertype_trajectory_rt_choice="*",
258
+ markercolor_trajectory_rt_choice="red",
259
+ linewidth_trajectories=1,
260
+ alpha_trajectories=0.5,
261
+ color_trajectories="black",
262
+ **kwargs,
263
+ ):
264
+ """Add trajectories to a given axis."""
265
+ # Check markercolor type
266
+ if isinstance(markercolor_trajectory_rt_choice, str):
267
+ markercolor_trajectory_rt_choice_dict = {}
268
+ for value_ in sample[0]['metadata']['possible_choices']:
269
+ markercolor_trajectory_rt_choice_dict[
270
+ value_
271
+ ] = markercolor_trajectory_rt_choice
272
+ elif isinstance(markercolor_trajectory_rt_choice, list):
273
+ cnt = 0
274
+ for value_ in sample[0]['metadata']['possible_choices']:
275
+ markercolor_trajectory_rt_choice_dict[
276
+ value_
277
+ ] = markercolor_trajectory_rt_choice[cnt]
278
+ cnt += 1
279
+ elif isinstance(markercolor_trajectory_rt_choice, dict):
280
+ markercolor_trajectory_rt_choice_dict = markercolor_trajectory_rt_choice
281
+ else:
282
+ pass
283
+
284
+ # Check trajectory color type
285
+ if isinstance(color_trajectories, str):
286
+ color_trajectories_dict = {}
287
+ for value_ in sample[0]['metadata']['possible_choices']:
288
+ color_trajectories_dict[value_] = color_trajectories
289
+ elif isinstance(color_trajectories, list):
290
+ cnt = 0
291
+ for value_ in sample[0]['metadata']['possible_choices']:
292
+ color_trajectories_dict[value_] = color_trajectories[cnt]
293
+ cnt += 1
294
+ elif isinstance(color_trajectories, dict):
295
+ color_trajectories_dict = color_trajectories
296
+ else:
297
+ pass
298
+
299
+ # Make bounds
300
+ (b_high, b_low) = (np.maximum(sample[0]['metadata']['boundary'], 0),
301
+ np.minimum((-1) * sample[0]['metadata']['boundary'], 0))
302
+
303
+ b_h_init = b_high[0]
304
+ b_l_init = b_low[0]
305
+ n_roll = int((sample[0]['metadata']['t'][0] / delta_t_graph) + 1)
306
+ b_high = np.roll(b_high, n_roll)
307
+ b_high[:n_roll] = b_h_init
308
+ b_low = np.roll(b_low, n_roll)
309
+ b_low[:n_roll] = b_l_init
310
+
311
+ print(n_trajectories)
312
+ # Trajectories
313
+ for i in range(n_trajectories):
314
+ tmp_traj = sample[i]['metadata']['trajectory']
315
+ tmp_traj_choice = float(sample[i]['choices'].flatten())
316
+ maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)), t_s.shape[0])
317
+
318
+ # Identify boundary value at timepoint of crossing
319
+ b_tmp = b_high[maxid + n_roll] if tmp_traj_choice > 0 else b_low[maxid + n_roll]
320
+
321
+ axis.plot(
322
+ t_s[:maxid] + sample[i]['metadata']['t'][0], #sample.t.values[0],
323
+ tmp_traj[:maxid],
324
+ color=color_trajectories_dict[tmp_traj_choice],
325
+ alpha=alpha_trajectories,
326
+ linewidth=linewidth_trajectories,
327
+ zorder=2000 + i,
328
+ )
329
+
330
+ if highlight_trajectory_rt_choice:
331
+ axis.scatter(
332
+ t_s[maxid] + sample[i]['metadata']['t'][0], #sample.t.values[0],
333
+ b_tmp,
334
+ # tmp_traj[maxid],
335
+ markersize_trajectory_rt_choice,
336
+ color=markercolor_trajectory_rt_choice_dict[tmp_traj_choice],
337
+ alpha=1,
338
+ marker=markertype_trajectory_rt_choice,
339
+ zorder=2000 + i,
340
+ )
341
+
342
+ # AF-TODO: Add documentation to this function
343
+ def _add_model_cartoon_to_ax(
344
+ sample=None,
345
+ axis=None,
346
+ keep_slope=True,
347
+ keep_boundary=True,
348
+ keep_ndt=True,
349
+ keep_starting_point=True,
350
+ markersize_starting_point=80,
351
+ markertype_starting_point=1,
352
+ markershift_starting_point=-0.05,
353
+ delta_t_graph=None,
354
+ sample_hist_alpha=None,
355
+ lw_m=None,
356
+ tmp_label=None,
357
+ ylim_low=None,
358
+ ylim_high=None,
359
+ t_s=None,
360
+ zorder_cnt=1,
361
+ color="black",
362
+ ):
363
+ # Make bounds
364
+ (b_high, b_low) = (np.maximum(sample['metadata']['boundary'], 0),
365
+ np.minimum((-1) * sample['metadata']['boundary'], 0))
366
+
367
+ b_h_init = b_high[0]
368
+ b_l_init = b_low[0]
369
+ n_roll = int((sample['metadata']['t'][0] / delta_t_graph) + 1)
370
+ b_high = np.roll(b_high, n_roll)
371
+ b_high[:n_roll] = b_h_init
372
+ b_low = np.roll(b_low, n_roll)
373
+ b_low[:n_roll] = b_l_init
374
+
375
+ tmp_traj = sample["metadata"]["trajectory"]
376
+ maxid = np.minimum(np.argmax(np.where(tmp_traj > -999)),
377
+ t_s.shape[0])
378
+
379
+ if keep_boundary:
380
+ # Upper bound
381
+ axis.plot(
382
+ t_s, # + sample.t.values[0],
383
+ b_high[:t_s.shape[0]],
384
+ color=color,
385
+ alpha=1,
386
+ zorder=1000 + zorder_cnt,
387
+ linewidth=lw_m,
388
+ label=tmp_label,
389
+ )
390
+
391
+ # Lower bound
392
+ axis.plot(
393
+ t_s, # + sample.t.values[0],
394
+ b_low[:t_s.shape[0]],
395
+ color=color,
396
+ alpha=1,
397
+ zorder=1000 + zorder_cnt,
398
+ linewidth=lw_m,
399
+ )
400
+
401
+ # Slope
402
+ if keep_slope:
403
+ axis.plot(
404
+ t_s[:maxid] + sample['metadata']['t'][0],
405
+ tmp_traj[:maxid],
406
+ color=color,
407
+ alpha=1,
408
+ zorder=1000 + zorder_cnt,
409
+ linewidth=lw_m,
410
+ ) # TOOK AWAY LABEL
411
+
412
+ # Non-decision time
413
+ if keep_ndt:
414
+ axis.axvline(
415
+ x=sample['metadata']['t'][0],
416
+ ymin=ylim_low,
417
+ ymax=ylim_high,
418
+ color=color,
419
+ linestyle="--",
420
+ linewidth=lw_m,
421
+ zorder=1000 + zorder_cnt,
422
+ alpha=1,
423
+ )
424
+ # Starting point
425
+ if keep_starting_point:
426
+ axis.scatter(
427
+ sample['metadata']['t'][0] + markershift_starting_point,
428
+ b_low[0] + (sample['metadata']['z'][0] * (b_high[0] - b_low[0])),
429
+ markersize_starting_point,
430
+ marker=markertype_starting_point,
431
+ color=color,
432
+ alpha=1,
433
+ zorder=1000 + zorder_cnt,
434
+ )
435
+
436
+ def plot_func_model_n(
437
+ model_name,
438
+ theta,
439
+ axis,
440
+ n_trajectories=10,
441
+ value_range=None,
442
+ bin_size=0.05,
443
+ n_samples=10,
444
+ linewidth_histogram=0.5,
445
+ linewidth_model=0.5,
446
+ legend_fontsize=7,
447
+ legend_shadow=True,
448
+ legend_location="upper right",
449
+ delta_t_model=0.001,
450
+ add_legend=True,
451
+ alpha=1.0,
452
+ keep_frame=False,
453
+ random_state=None,
454
+ **kwargs,
455
+ ):
456
+ """Calculate posterior predictive for a certain bottom node.
457
+
458
+ Arguments:
459
+ bottom_node: pymc.stochastic
460
+ Bottom node to compute posterior over.
461
+
462
+ axis: matplotlib.axis
463
+ Axis to plot into.
464
+
465
+ value_range: numpy.ndarray
466
+ Range over which to evaluate the likelihood.
467
+
468
+ Optional:
469
+ samples: int <default=10>
470
+ Number of posterior samples to use.
471
+
472
+ bin_size: float <default=0.05>
473
+ Size of bins used for histograms
474
+
475
+ alpha: float <default=0.05>
476
+ alpha (transparency) level for the sample-wise elements of the plot
477
+
478
+ add_posterior_uncertainty_rts: bool <default=True>
479
+ Add sample by sample histograms?
480
+
481
+ add_posterior_mean_rts: bool <default=True>
482
+ Add a mean posterior?
483
+
484
+ add_model: bool <default=True>
485
+ Whether to add model cartoons to the plot.
486
+
487
+ linewidth_histogram: float <default=0.5>
488
+ linewdith of histrogram plot elements.
489
+
490
+ linewidth_model: float <default=0.5>
491
+ linewidth of plot elements concerning the model cartoons.
492
+
493
+ legend_loc: str <default='upper right'>
494
+ string defining legend position. Find the rest of the options in the matplotlib documentation.
495
+
496
+ legend_shadow: bool <default=True>
497
+ Add shadow to legend box?
498
+
499
+ legend_fontsize: float <default=12>
500
+ Fontsize of legend.
501
+
502
+ data_color : str <default="blue">
503
+ Color for the data part of the plot.
504
+
505
+ posterior_mean_color : str <default="red">
506
+ Color for the posterior mean part of the plot.
507
+
508
+ posterior_uncertainty_color : str <default="black">
509
+ Color for the posterior uncertainty part of the plot.
510
+
511
+
512
+ delta_t_model:
513
+ specifies plotting intervals for model cartoon elements of the graphs.
514
+ """
515
+
516
+ color_dict = {
517
+ -1: "black",
518
+ 0: "black",
519
+ 1: "green",
520
+ 2: "blue",
521
+ 3: "red",
522
+ 4: "orange",
523
+ 5: "purple",
524
+ 6: "brown",
525
+ }
526
+
527
+ # AF-TODO: Add a mean version of this !
528
+ if value_range is None:
529
+ # Infer from data by finding the min and max from the nodes
530
+ raise NotImplementedError("value_range keyword argument must be supplied.")
531
+
532
+ if len(value_range) > 2:
533
+ value_range = (value_range[0], value_range[-1])
534
+
535
+ # Extract some parameters from kwargs
536
+ bins = np.arange(value_range[0], value_range[-1], bin_size)
537
+ # ------------
538
+ ylim = kwargs.pop("ylim", 4)
539
+
540
+ axis.set_xlim(value_range[0], value_range[-1])
541
+ axis.set_ylim(0, ylim)
542
+
543
+ # ADD MODEL:
544
+
545
+ # RUN SIMULATIONS
546
+ # -------------------------------
547
+
548
+ # Simulator Data
549
+ if random_state is not None:
550
+ np.random.seed(random_state)
551
+
552
+ rand_int = np.random.choice(400000000)
553
+ sim_out = simulator(model = model_name, theta = theta, n_samples = n_samples,
554
+ no_noise = False, delta_t = 0.001,
555
+ bin_dim = None, random_state = rand_int)
556
+
557
+ choices = sim_out['metadata']['possible_choices']
558
+
559
+ sim_out_traj = {}
560
+ for i in range(n_trajectories):
561
+ rand_int = np.random.choice(400000000)
562
+ sim_out_traj[i] = simulator(model = model_name, theta = theta, n_samples = 1,
563
+ no_noise = False, delta_t = 0.001,
564
+ bin_dim = None, random_state = rand_int, smooth_unif = False)
565
+
566
+ sim_out_no_noise = simulator(model = model_name, theta = theta, n_samples = 1,
567
+ no_noise = True, delta_t = 0.001,
568
+ bin_dim = None, smooth_unif = False)
569
+
570
+ # ADD HISTOGRAMS
571
+ # -------------------------------
572
+
573
+ # POSTERIOR BASED HISTOGRAM
574
+ j = 0
575
+ b = np.maximum(sim_out['metadata']['boundary'], 0)
576
+ bottom = b[0]
577
+ for choice in choices:
578
+ tmp_label = None
579
+
580
+ if add_legend and j == 0:
581
+ tmp_label = "PostPred"
582
+
583
+ weights = np.tile(
584
+ (1 / bin_size) / sim_out['rts'].shape[0],
585
+ reps=sim_out['rts'][(sim_out['choices'] == choice) & (sim_out['rts'] != -999)].shape[0],
586
+ )
587
+
588
+ axis.hist(
589
+ np.abs(sim_out['rts'][(sim_out['choices'] == choice) & (sim_out['rts'] != -999)]),
590
+ bins=bins,
591
+ bottom=bottom,
592
+ weights=weights,
593
+ histtype="step",
594
+ alpha=alpha,
595
+ color=color_dict[choice],
596
+ zorder=-1,
597
+ label=tmp_label,
598
+ linewidth=linewidth_histogram,
599
+ )
600
+ j += 1
601
+
602
+ # ADD MODEL:
603
+ tmp_label = None
604
+ j = 0
605
+ t_s = np.arange(0, sim_out['metadata']['max_t'], delta_t_model)
606
+
607
+ if add_legend and (j == 0):
608
+ tmp_label = "PostPred"
609
+
610
+ _add_model_n_cartoon_to_ax(
611
+ sample=sim_out_no_noise,
612
+ axis=axis,
613
+ delta_t_graph=delta_t_model,
614
+ sample_hist_alpha=alpha,
615
+ lw_m=linewidth_model,
616
+ tmp_label=tmp_label,
617
+ linestyle="-",
618
+ ylim=ylim,
619
+ t_s=t_s,
620
+ color_dict=color_dict,
621
+ zorder_cnt=j,
622
+ )
623
+
624
+ if n_trajectories > 0:
625
+ _add_trajectories_n(
626
+ axis=axis,
627
+ sample=sim_out_traj,
628
+ t_s=t_s,
629
+ delta_t_graph=delta_t_model,
630
+ n_trajectories=n_trajectories,
631
+ **kwargs,
632
+ )
633
+
634
+ if add_legend:
635
+ custom_elems = [
636
+ Line2D([0], [0], color=color_dict[choice], lw=1) for choice in choices
637
+ ]
638
+ custom_titles = ["response: " + str(choice) for choice in choices]
639
+
640
+ custom_elems.append(
641
+ Line2D([0], [0], color="black", lw=1.0, linestyle="dashed")
642
+ )
643
+ # custom_elems.append(Line2D([0], [0], color="black", lw=1.0, linestyle="-"))
644
+ # custom_titles.append("Data")
645
+ # custom_titles.append("Posterior")
646
+
647
+ axis.legend(
648
+ custom_elems,
649
+ custom_titles,
650
+ fontsize=legend_fontsize,
651
+ shadow=legend_shadow,
652
+ loc=legend_location,
653
+ )
654
+
655
+ # FRAME
656
+ if not keep_frame:
657
+ axis.set_frame_on(False)
658
+
659
+ return axis
660
+
661
+ def _add_trajectories_n(axis=None,
662
+ sample=None,
663
+ t_s=None,
664
+ delta_t_graph=0.01,
665
+ n_trajectories=10,
666
+ highlight_trajectory_rt_choice=True,
667
+ markersize_trajectory_rt_choice=50,
668
+ markertype_trajectory_rt_choice="*",
669
+ markercolor_trajectory_rt_choice="black",
670
+ linewidth_trajectories=1,
671
+ alpha_trajectories=0.5,
672
+ color_trajectories="black",
673
+ **kwargs,
674
+ ):
675
+
676
+ """Add trajectories to a given axis."""
677
+ color_dict = {
678
+ -1: "black",
679
+ 0: "black",
680
+ 1: "green",
681
+ 2: "blue",
682
+ 3: "red",
683
+ 4: "orange",
684
+ 5: "purple",
685
+ 6: "brown",
686
+ }
687
+
688
+ # Check trajectory color type
689
+ if isinstance(color_trajectories, str):
690
+ color_trajectories_dict = {}
691
+ for value_ in sample[0]['metadata']['possible_choices']:
692
+ color_trajectories_dict[value_] = color_trajectories
693
+ elif isinstance(color_trajectories, list):
694
+ cnt = 0
695
+ for value_ in sample[0]['metadata']['possible_choices']:
696
+ color_trajectories_dict[value_] = color_trajectories[cnt]
697
+ cnt += 1
698
+ elif isinstance(color_trajectories, dict):
699
+ color_trajectories_dict = color_trajectories
700
+ else:
701
+ pass
702
+
703
+ # Make bounds
704
+ b = np.maximum(sample[0]['metadata']['boundary'], 0)
705
+ b_init = b[0]
706
+ n_roll = int((sample[0]['metadata']['t'][0] / delta_t_graph) + 1)
707
+ b = np.roll(b, n_roll)
708
+ b[:n_roll] = b_init
709
+
710
+ # Trajectories
711
+ for i in range(n_trajectories):
712
+ tmp_traj = sample[i]['metadata']['trajectory']
713
+ tmp_traj_choice = float(sample[i]['choices'].flatten())
714
+
715
+ for j in range(len(sample[i]['metadata']['possible_choices'])):
716
+ tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, j] > -999)), t_s.shape[0])
717
+
718
+ # Identify boundary value at timepoint of crossing
719
+ b_tmp = b[tmp_maxid + n_roll]
720
+
721
+ axis.plot(
722
+ t_s[:tmp_maxid] + sample[i]['metadata']['t'][0], #sample.t.values[0],
723
+ tmp_traj[:tmp_maxid, j],
724
+ color=color_dict[j],
725
+ alpha=alpha_trajectories,
726
+ linewidth=linewidth_trajectories,
727
+ zorder=2000 + i,
728
+ )
729
+
730
+ if highlight_trajectory_rt_choice and tmp_traj_choice == j:
731
+ axis.scatter(
732
+ t_s[tmp_maxid] + sample[i]['metadata']['t'][0], #sample.t.values[0],
733
+ b_tmp,
734
+ # tmp_traj[maxid],
735
+ markersize_trajectory_rt_choice,
736
+ color=color_dict[tmp_traj_choice],
737
+ alpha=1,
738
+ marker=markertype_trajectory_rt_choice,
739
+ zorder=2000 + i,
740
+ )
741
+ elif highlight_trajectory_rt_choice and tmp_traj_choice != j:
742
+ axis.scatter(
743
+ t_s[tmp_maxid] + sample[i]['metadata']['t'][0] + 0.05, #sample.t.values[0],
744
+ tmp_traj[tmp_maxid, j],
745
+ # tmp_traj[maxid],
746
+ markersize_trajectory_rt_choice,
747
+ color=color_dict[j],
748
+ alpha=1,
749
+ marker=5,
750
+ zorder=2000 + i,
751
+ )
752
+
753
+ def _add_model_n_cartoon_to_ax(
754
+ sample=None,
755
+ axis=None,
756
+ delta_t_graph=None,
757
+ sample_hist_alpha=None,
758
+ keep_boundary=True,
759
+ keep_ndt=True,
760
+ keep_slope=True,
761
+ keep_starting_point=True,
762
+ lw_m=None,
763
+ linestyle="-",
764
+ tmp_label=None,
765
+ ylim=None,
766
+ t_s=None,
767
+ zorder_cnt=1,
768
+ color_dict=None,
769
+ ):
770
+
771
+ b = np.maximum(sample['metadata']['boundary'], 0)
772
+ b_init = b[0]
773
+ n_roll = int((sample['metadata']['t'][0] / delta_t_graph) + 1)
774
+ b = np.roll(b, n_roll)
775
+ b[:n_roll] = b_init
776
+
777
+ # Upper bound
778
+ if keep_boundary:
779
+ axis.plot(
780
+ t_s,
781
+ b[:t_s.shape[0]],
782
+ color="black",
783
+ alpha=sample_hist_alpha,
784
+ zorder=1000 + zorder_cnt,
785
+ linewidth=lw_m,
786
+ linestyle=linestyle,
787
+ label=tmp_label,
788
+ )
789
+
790
+ # Starting point
791
+ if keep_starting_point:
792
+ axis.axvline(
793
+ x=sample['metadata']['t'][0],
794
+ ymin=-ylim,
795
+ ymax=ylim,
796
+ color="black",
797
+ linestyle=linestyle,
798
+ linewidth=lw_m,
799
+ alpha=sample_hist_alpha,
800
+ )
801
+
802
+ # # MAKE SLOPES (VIA TRAJECTORIES HERE --> RUN NOISE FREE SIMULATIONS)!
803
+ if keep_slope:
804
+ tmp_traj = sample["metadata"]["trajectory"]
805
+
806
+ for i in range(len(sample["metadata"]["possible_choices"])):
807
+ tmp_maxid = np.minimum(np.argmax(np.where(tmp_traj[:, i] > -999)), t_s.shape[0])
808
+
809
+ # Slope
810
+ axis.plot(
811
+ t_s[:tmp_maxid] + sample['metadata']['t'][0],
812
+ tmp_traj[:tmp_maxid, i],
813
+ color=color_dict[i],
814
+ linestyle=linestyle,
815
+ alpha=sample_hist_alpha,
816
+ zorder=1000 + zorder_cnt,
817
+ linewidth=lw_m,
818
+ ) # TOOK AWAY LABEL
819
+
820
+ return b[0]