import streamlit as st import matplotlib.pyplot as plt import numpy as np from utils import add_navigation, add_instruction_text, add_red_text from my_pages.rashomon_effect import plot_scatter from my_pages.rashomon_effect import income, credit, labels, colors plt.style.use('dark_background') def render(): add_navigation("txt_rashomon_developer", "txt_developer_decisions") add_instruction_text( """ Consider the same data as before.
Instead of directly choosing a model, you make development choices now. """ ) #### Choosing regularization st.markdown(""" **Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize. Choose a regularization method: - L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features. - L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature. """ ) regularization_method = None if "regularization_method" in st.session_state: regularization_method = st.session_state.regularization_method col1, col2 = st.columns([1, 1]) with col1: if regularization_method=="l1": button_click_l1 = st.button("L1 Regularization", type="primary") else: button_click_l1 = st.button("L1 Regularization") if button_click_l1: st.session_state.regularization_method = "l1" st.rerun() with col2: if regularization_method=="l2": button_click_l2 = st.button("L2 Regularization", type="primary") else: button_click_l2 = st.button("L2 Regularization") if button_click_l2: st.session_state.regularization_method = "l2" st.rerun() #### Choosing random seed if regularization_method=="l1": st.markdown(""" **Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin (You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret). """ ) random_seed = None if "random_seed" in st.session_state: random_seed = st.session_state.random_seed col1, col2 = st.columns([1, 1]) with col1: if random_seed=="Heads": button_click_1 = st.button("Heads", type="primary") else: button_click_1 = st.button("Heads") if button_click_1: st.session_state.random_seed = "Heads" st.rerun() with col2: if random_seed=="Tails": button_click_2 = st.button("Tails", type="primary") else: button_click_2 = st.button("Tails") if button_click_2: st.session_state.random_seed = "Tails" st.rerun() #### Plot the final figure plot_chosen = None if regularization_method=="l2": plot_chosen = "slant" if regularization_method=="l1": if random_seed=="Heads": plot_chosen = "vertical" elif random_seed=="Tails": plot_chosen = "horizontal" if plot_chosen is not None: col1, col2, col3 = st.columns([1.5, 1, 1.5]) with col2: st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None)) multiplicity_message = """ Your choices during model development lead you to this model. """ add_red_text(multiplicity_message)