verifiability / my_pages /rashomon_developer.py
prakharg24's picture
Update my_pages/rashomon_developer.py
53b8c7c verified
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
from my_pages.rashomon_effect import plot_scatter
from my_pages.rashomon_effect import income, credit, labels, colors
plt.style.use('dark_background')
def render():
add_navigation("txt_rashomon_developer", "txt_developer_decisions")
add_instruction_text(
"""
Consider the same data as before. <br>
Instead of directly choosing a model, you make development choices now.
"""
)
#### Choosing regularization
st.markdown("""
**Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize.
Choose a regularization method:
- L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features.
- L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature.
"""
)
regularization_method = None
if "regularization_method" in st.session_state:
regularization_method = st.session_state.regularization_method
col1, col2 = st.columns([1, 1])
with col1:
if regularization_method=="l1":
button_click_l1 = st.button("L1 Regularization", type="primary")
else:
button_click_l1 = st.button("L1 Regularization")
if button_click_l1:
st.session_state.regularization_method = "l1"
st.rerun()
with col2:
if regularization_method=="l2":
button_click_l2 = st.button("L2 Regularization", type="primary")
else:
button_click_l2 = st.button("L2 Regularization")
if button_click_l2:
st.session_state.regularization_method = "l2"
st.rerun()
#### Choosing random seed
if regularization_method=="l1":
st.markdown("""
**Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin
(You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret).
"""
)
random_seed = None
if "random_seed" in st.session_state:
random_seed = st.session_state.random_seed
col1, col2 = st.columns([1, 1])
with col1:
if random_seed=="Heads":
button_click_1 = st.button("Heads", type="primary")
else:
button_click_1 = st.button("Heads")
if button_click_1:
st.session_state.random_seed = "Heads"
st.rerun()
with col2:
if random_seed=="Tails":
button_click_2 = st.button("Tails", type="primary")
else:
button_click_2 = st.button("Tails")
if button_click_2:
st.session_state.random_seed = "Tails"
st.rerun()
#### Plot the final figure
plot_chosen = None
if regularization_method=="l2":
plot_chosen = "slant"
if regularization_method=="l1":
if random_seed=="Heads":
plot_chosen = "vertical"
elif random_seed=="Tails":
plot_chosen = "horizontal"
if plot_chosen is not None:
col1, col2, col3 = st.columns([1.5, 1, 1.5])
with col2:
st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None))
multiplicity_message = """
Your choices during model development lead you to this model.
"""
add_red_text(multiplicity_message)