|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import seaborn as sns |
|
|
import streamlit as st |
|
|
|
|
|
|
|
|
from ssms.config import model_config |
|
|
from ssms.basic_simulators.simulator import simulator |
|
|
import pandas as pd |
|
|
import utils |
|
|
|
|
|
|
|
|
def create_param_selectors(model_name: str, model_num: int = 1): |
|
|
|
|
|
d_config = model_config[model_name] |
|
|
params = d_config["params"] |
|
|
param_bounds_low = d_config["param_bounds"][0] |
|
|
param_bounds_high = d_config["param_bounds"][1] |
|
|
param_defaults = d_config["default_params"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "v" in params: |
|
|
v_index = params.index("v") |
|
|
param_defaults[v_index] = 0.5 |
|
|
|
|
|
if "t" in params: |
|
|
t_index = params.index("t") |
|
|
param_defaults[t_index] = 0.2 |
|
|
|
|
|
d_param_slider = {} |
|
|
for i, (name, low, high, default) in enumerate( |
|
|
zip( |
|
|
params, |
|
|
param_bounds_low, |
|
|
param_bounds_high, |
|
|
param_defaults, |
|
|
) |
|
|
): |
|
|
d_param_slider[i] = st.slider( |
|
|
label=name, |
|
|
min_value=float(low), |
|
|
max_value=float(high), |
|
|
value=float(default), |
|
|
key=f"param{i}" |
|
|
f"_{model_name}" |
|
|
f"_{model_num}" |
|
|
f'_{st.session_state["param_version"]}', |
|
|
) |
|
|
return d_param_slider |
|
|
|
|
|
|
|
|
def create_styling_selectors(model_num: int = 1): |
|
|
""" |
|
|
Create styling configuration widgets for plot customization. |
|
|
|
|
|
This function creates Streamlit widgets that allow users to customize |
|
|
various visual aspects of the plots including colors, line widths, |
|
|
alpha, and which model components to display. |
|
|
|
|
|
Note: This version is designed to work in the sidebar without using st.columns() |
|
|
|
|
|
Args: |
|
|
model_num: Integer identifier for the model (1 or 2) |
|
|
|
|
|
Returns: |
|
|
dict: Dictionary containing all styling parameters with their user-selected values |
|
|
""" |
|
|
|
|
|
|
|
|
color_options = ["blue", "red", "green", "orange", "purple", "black", "gray", "brown"] |
|
|
|
|
|
|
|
|
legend_locations = ["upper right", "upper left", "lower left", "lower right", |
|
|
"center", "upper center", "lower center", "center left", "center right"] |
|
|
|
|
|
|
|
|
marker_options = { "Diamond": "D", "Square": "s", "Line": 0, "Circle": "o", "Star": "*", "Triangle": "^", |
|
|
"Plus": "+", "X": "x"} |
|
|
|
|
|
styling_config = {} |
|
|
|
|
|
|
|
|
with st.expander(f"🎨 Styling", expanded=False): |
|
|
|
|
|
|
|
|
st.markdown("**Colors**") |
|
|
styling_config["data_color"] = st.selectbox( |
|
|
"Data Color", |
|
|
color_options, |
|
|
index=color_options.index("blue" if model_num == 1 else "red"), |
|
|
key=f"data_color_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
styling_config["posterior_uncertainty_color"] = st.selectbox( |
|
|
"Model Color", |
|
|
color_options, |
|
|
index=color_options.index("black"), |
|
|
key=f"model_color_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("**Lines**") |
|
|
styling_config["linewidth_histogram"] = st.slider( |
|
|
"Histogram Line Width", |
|
|
min_value=0.1, |
|
|
max_value=3.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
key=f"hist_lw_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
styling_config["linewidth_model"] = st.slider( |
|
|
"Model Line Width", |
|
|
min_value=0.1, |
|
|
max_value=3.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
key=f"model_lw_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("**Histograms**") |
|
|
styling_config["bin_size"] = st.slider( |
|
|
"Bin Size", |
|
|
min_value=0.01, |
|
|
max_value=0.2, |
|
|
value=0.05, |
|
|
step=0.01, |
|
|
key=f"bin_size_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
styling_config["alpha"] = st.slider( |
|
|
"alpha", |
|
|
min_value=0.0, |
|
|
max_value=1.0, |
|
|
value=1.0, |
|
|
step=0.05, |
|
|
key=f"alpha_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("**Model Components**") |
|
|
styling_config["add_data_model_keep_boundary"] = st.checkbox( |
|
|
"Show Boundaries", |
|
|
value=True, |
|
|
key=f"show_boundary_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
styling_config["add_data_model_keep_slope"] = st.checkbox( |
|
|
"Show Slope/Trajectory", |
|
|
value=True, |
|
|
key=f"show_slope_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
styling_config["add_data_model_keep_ndt"] = st.checkbox( |
|
|
"Show Non-Decision Time", |
|
|
value=True, |
|
|
key=f"show_ndt_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
styling_config["add_data_model_keep_starting_point"] = st.checkbox( |
|
|
"Show Starting Point", |
|
|
value=True, |
|
|
key=f"show_start_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown("**Axis Limits**") |
|
|
styling_config["xlim_min"] = st.number_input( |
|
|
"x-axis min", |
|
|
value=-0.1, |
|
|
step=0.1, |
|
|
key=f"xlim_min_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
styling_config["xlim_max"] = st.number_input( |
|
|
"x-axis max", |
|
|
value=5.0, |
|
|
step=0.1, |
|
|
key=f"xlim_max_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
styling_config["ylim_max"] = st.number_input( |
|
|
"y-axis max", |
|
|
value=3.75, |
|
|
step=0.25, |
|
|
key=f"ylim_max_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
|
|
|
if styling_config["add_data_model_keep_starting_point"]: |
|
|
st.markdown("**Starting Point**") |
|
|
styling_config["add_data_model_markersize_starting_point"] = st.slider( |
|
|
"Marker Size", |
|
|
min_value=10, |
|
|
max_value=100, |
|
|
value=35, |
|
|
key=f"marker_size_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
|
|
|
styling_config["add_data_model_markertype_starting_point"] = st.selectbox( |
|
|
"Marker Type", |
|
|
list(marker_options.keys()), |
|
|
index=0, |
|
|
key=f"marker_type_{model_num}_{st.session_state['styling_version']}" |
|
|
) |
|
|
else: |
|
|
|
|
|
styling_config["add_data_model_markersize_starting_point"] = 35 |
|
|
styling_config["add_data_model_markertype_starting_point"] = "Diamond" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
marker_type_map = {k: v for k, v in marker_options.items()} |
|
|
styling_config["add_data_model_markertype_starting_point"] = marker_type_map.get( |
|
|
styling_config["add_data_model_markertype_starting_point"], "D" |
|
|
) |
|
|
|
|
|
return styling_config |
|
|
|
|
|
|
|
|
def get_filtered_styling_config(styling_config, plot_type="plot_func_model"): |
|
|
""" |
|
|
Filter styling configuration based on plot type compatibility. |
|
|
|
|
|
Different plotting functions accept different parameters, so this function |
|
|
filters the styling configuration to only include parameters that are |
|
|
relevant for the specific plot type. |
|
|
|
|
|
Args: |
|
|
styling_config: Dictionary of styling parameters |
|
|
plot_type: String indicating which plot function will be used |
|
|
("plot_func_model" or "plot_func_model_n") |
|
|
|
|
|
Returns: |
|
|
dict: Filtered styling configuration appropriate for the plot type |
|
|
""" |
|
|
|
|
|
if plot_type == "plot_func_model": |
|
|
|
|
|
return styling_config |
|
|
|
|
|
elif plot_type == "plot_func_model_n": |
|
|
|
|
|
allowed_params = { |
|
|
'linewidth_histogram', 'linewidth_model', 'bin_size', |
|
|
'alpha', 'legend_fontsize', 'legend_location', 'legend_shadow', |
|
|
'add_legend', 'add_data_model_markersize_starting_point', |
|
|
'add_data_model_markertype_starting_point', |
|
|
'add_data_model_keep_starting_point', |
|
|
'add_data_model_keep_boundary', |
|
|
'add_data_model_keep_slope', |
|
|
'add_data_model_keep_ndt' |
|
|
} |
|
|
return {k: v for k, v in styling_config.items() if k in allowed_params} |
|
|
|
|
|
else: |
|
|
|
|
|
return styling_config |
|
|
|
|
|
def add_model(): |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_parameters(): |
|
|
"""Reset only model parameters to defaults""" |
|
|
st.session_state["param_version"] += 1 |
|
|
|
|
|
def reset_styling(): |
|
|
"""Reset only styling options to defaults""" |
|
|
st.session_state["styling_version"] += 1 |
|
|
|
|
|
def reset_all(): |
|
|
"""Reset both parameters and styling to defaults""" |
|
|
st.session_state["param_version"] += 1 |
|
|
st.session_state["styling_version"] += 1 |
|
|
st.session_state["slider_version"] += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def draw_model_configurator(model_num=1): |
|
|
|
|
|
|
|
|
|
|
|
model_select = st.selectbox("Model " + str(model_num), l_model_names, key="model_selector_" + str(model_num)) |
|
|
|
|
|
return model_select |
|
|
|
|
|
def draw_simulation_settings(model_num=1): |
|
|
|
|
|
nsamples = st.number_input("NSamples", value=5000, key="size" + str(model_num)) |
|
|
|
|
|
|
|
|
ntrajectories = st.number_input( |
|
|
"NTrajectories", value=5, key="ntraj" + str(model_num) |
|
|
) |
|
|
|
|
|
|
|
|
randomseed = st.number_input("RandomSeed", value=41 + model_num, key="seed_" + str(model_num)) |
|
|
|
|
|
return nsamples, ntrajectories, randomseed |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
|
l_model_names = list(model_config.keys()) |
|
|
|
|
|
|
|
|
if "param_version" not in st.session_state: |
|
|
st.session_state["param_version"] = 1 |
|
|
|
|
|
if "styling_version" not in st.session_state: |
|
|
st.session_state["styling_version"] = 1 |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.empty() |
|
|
st.markdown("**Model Selection**") |
|
|
with st.container(): |
|
|
st.markdown('<div style="margin-top: -1rem;">', unsafe_allow_html=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
model_select_1 = draw_model_configurator(model_num=1) |
|
|
|
|
|
|
|
|
styling_config_1 = create_styling_selectors(model_num=1) |
|
|
|
|
|
with col2: |
|
|
model_select_2 = draw_model_configurator(model_num=2) |
|
|
|
|
|
|
|
|
styling_config_2 = create_styling_selectors(model_num=2) |
|
|
|
|
|
st.markdown('</div>', unsafe_allow_html=True) |
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("**Parameters**") |
|
|
col1_2, col2_2 = st.columns(2) |
|
|
with col1_2: |
|
|
d_slider_1 = create_param_selectors(model_select_1, model_num=1) |
|
|
with col2_2: |
|
|
d_slider_2 = create_param_selectors(model_select_2, model_num=2) |
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("**Simulation Settings**") |
|
|
col1_3, col2_3 = st.columns(2) |
|
|
with col1_3: |
|
|
nsamples_1, ntrajectories_1, randomseed_1 = draw_simulation_settings(model_num=1) |
|
|
with col2_3: |
|
|
nsamples_2, ntrajectories_2, randomseed_2 = draw_simulation_settings(model_num=2) |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("**Reset Options**") |
|
|
|
|
|
|
|
|
reset_col1, reset_col2, reset_col3 = st.columns(3) |
|
|
|
|
|
with reset_col1: |
|
|
st.button( |
|
|
"Reset Params", |
|
|
help="Reset model parameters to defaults", |
|
|
key="reset_params", |
|
|
on_click=reset_parameters, |
|
|
) |
|
|
|
|
|
with reset_col2: |
|
|
st.button( |
|
|
"Reset Styling", |
|
|
help="Reset styling options to defaults", |
|
|
key="reset_styling", |
|
|
on_click=reset_styling, |
|
|
) |
|
|
|
|
|
with reset_col3: |
|
|
st.button( |
|
|
"Reset Full", |
|
|
help="Reset both parameters and styling to defaults", |
|
|
key="reset_all", |
|
|
on_click=reset_all, |
|
|
) |
|
|
|
|
|
st.markdown( |
|
|
"<h1 style='text-align: center; color: black;'>HSSM Model Plots</h1>", |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<div style='background-color: #f0f2f6; padding: 15px; border-radius: 10px; margin: 20px 0; border-left: 5px solid #1f77b4;'> |
|
|
<p style='margin: 0; font-size: 16px;'> |
|
|
<strong>📊 Fit to your own data</strong><br> |
|
|
This dashboard provides interactive visualization of several Sequential Sampling Models available for fitting to data in the <a href='https://github.com/lnccbrown/HSSM' target='_blank' style='color: #1f77b4; text-decoration: none;'><strong>HSSM</strong></a> toolbox. |
|
|
</p> |
|
|
</div> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
fig1, ax1 = plt.subplots() |
|
|
if model_config[model_select_1]["nchoices"] == 2 and not ("race" in model_select_1): |
|
|
|
|
|
|
|
|
filtered_styling_1 = get_filtered_styling_config(styling_config_1, "plot_func_model") |
|
|
ax1 = utils.utils.plot_func_model( |
|
|
model_name=model_select_1, |
|
|
theta=[list(d_slider_1.values())], |
|
|
axis=ax1, |
|
|
value_range=[styling_config_1["xlim_min"], styling_config_1["xlim_max"]], |
|
|
n_samples=nsamples_1, |
|
|
ylim=styling_config_1["ylim_max"], |
|
|
n_trajectories=ntrajectories_1, |
|
|
random_state=randomseed_1, |
|
|
**filtered_styling_1 |
|
|
) |
|
|
else: |
|
|
|
|
|
filtered_styling_1 = get_filtered_styling_config(styling_config_1, "plot_func_model_n") |
|
|
ax1 = utils.utils.plot_func_model_n( |
|
|
model_name=model_select_1, |
|
|
theta=[list(d_slider_1.values())], |
|
|
axis=ax1, |
|
|
value_range=[styling_config_1["xlim_min"], styling_config_1["xlim_max"]], |
|
|
n_samples=nsamples_1, |
|
|
n_trajectories=ntrajectories_1, |
|
|
random_state=randomseed_1, |
|
|
**filtered_styling_1 |
|
|
) |
|
|
ax1.set_title(model_select_1.upper()) |
|
|
ax1.set_xlabel("rt in seconds") |
|
|
|
|
|
fig2, ax2 = plt.subplots() |
|
|
if model_config[model_select_2]["nchoices"] == 2 and not ("race" in model_select_2): |
|
|
|
|
|
filtered_styling_2 = get_filtered_styling_config(styling_config_2, "plot_func_model") |
|
|
ax2 = utils.utils.plot_func_model( |
|
|
model_name=model_select_2, |
|
|
theta=[list(d_slider_2.values())], |
|
|
axis=ax2, |
|
|
value_range=[styling_config_2["xlim_min"], styling_config_2["xlim_max"]], |
|
|
n_samples=nsamples_2, |
|
|
ylim=styling_config_2["ylim_max"], |
|
|
n_trajectories=ntrajectories_2, |
|
|
random_state=randomseed_2, |
|
|
**filtered_styling_2 |
|
|
) |
|
|
else: |
|
|
|
|
|
filtered_styling_2 = get_filtered_styling_config(styling_config_2, "plot_func_model_n") |
|
|
ax2 = utils.utils.plot_func_model_n( |
|
|
model_name=model_select_2, |
|
|
theta=[list(d_slider_2.values())], |
|
|
axis=ax2, |
|
|
value_range=[styling_config_2["xlim_min"], styling_config_2["xlim_max"]], |
|
|
n_samples=nsamples_2, |
|
|
n_trajectories=ntrajectories_2, |
|
|
random_state=randomseed_2, |
|
|
**filtered_styling_2 |
|
|
) |
|
|
|
|
|
ax2.set_title(model_select_2.upper()) |
|
|
ax2.set_xlabel("rt in seconds") |
|
|
|
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
figure_placeholder_1 = st.empty() |
|
|
figure_placeholder_1.pyplot(fig1) |
|
|
with col2: |
|
|
figure_placeholder_2 = st.empty() |
|
|
figure_placeholder_2.pyplot(fig2) |
|
|
|
|
|
|
|
|
sim_output_1 = simulator( |
|
|
model=model_select_1, |
|
|
theta=[list(d_slider_1.values())], |
|
|
n_samples=nsamples_1, |
|
|
random_state=randomseed_1, |
|
|
) |
|
|
sim_output_2 = simulator( |
|
|
model=model_select_2, |
|
|
theta=[list(d_slider_2.values())], |
|
|
n_samples=nsamples_2, |
|
|
random_state=randomseed_2, |
|
|
) |
|
|
|
|
|
|
|
|
metadata = pd.DataFrame( |
|
|
{ |
|
|
"Model": [ |
|
|
str(sim_output_1["metadata"]["model"]), |
|
|
str(sim_output_2["metadata"]["model"]), |
|
|
], |
|
|
"Choice Probability": [ |
|
|
float(sim_output_1["choice_p"][0, 0]), |
|
|
float(sim_output_2["choice_p"][0, 0]), |
|
|
], |
|
|
"Mean RT": [ |
|
|
float(sim_output_1["rts"].mean()), |
|
|
float(sim_output_2["rts"].mean()), |
|
|
], |
|
|
"Noise SD": [ |
|
|
float(sim_output_1["metadata"]["s"]), |
|
|
float(sim_output_2["metadata"]["s"]), |
|
|
], |
|
|
}, |
|
|
index=["Model 1", "Model 2"], |
|
|
) |
|
|
|
|
|
col3, col4 = st.columns(2) |
|
|
with col3: |
|
|
if ( |
|
|
len(sim_output_1["metadata"]["possible_choices"]) |
|
|
== 2 | len(sim_output_2["metadata"]["possible_choices"]) |
|
|
== 2 |
|
|
): |
|
|
figure_placeholder_3 = st.empty() |
|
|
|
|
|
|
|
|
fig3, ax3 = plt.subplots() |
|
|
ax3.hist( |
|
|
sim_output_1["rts"][np.abs(sim_output_1["rts"]) != 999] |
|
|
* sim_output_1["choices"][np.abs(sim_output_1["rts"] != 999)], |
|
|
histtype="step", |
|
|
bins=50, |
|
|
density=True, |
|
|
color=styling_config_1["data_color"], |
|
|
fill=None, |
|
|
label=model_select_1.upper(), |
|
|
) |
|
|
ax3.hist( |
|
|
sim_output_2["rts"][np.abs(sim_output_2["rts"]) != 999] |
|
|
* sim_output_2["choices"][np.abs(sim_output_2["rts"] != 999)], |
|
|
histtype="step", |
|
|
bins=50, |
|
|
density=True, |
|
|
color=styling_config_2["data_color"], |
|
|
fill=None, |
|
|
label=model_select_2.upper(), |
|
|
) |
|
|
ax3.legend() |
|
|
ax3.set_xlabel("rt") |
|
|
ax3.set_xlim(-5, 5) |
|
|
figure_placeholder_3.pyplot(fig3) |
|
|
else: |
|
|
|
|
|
|
|
|
pass |
|
|
with col4: |
|
|
st.dataframe(metadata) |