File size: 3,752 Bytes
221e0dd 91db57a d946758 221e0dd 1050453 221e0dd a0e4212 221e0dd 1050453 05504e2 d2a9289 a0e4212 221e0dd 1050453 221e0dd 1050453 221e0dd 1050453 221e0dd 1050453 221e0dd 1050453 d2a9289 56960a5 53b8c7c d2a9289 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 56960a5 1050453 221e0dd 1050453 221e0dd 91db57a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
from my_pages.rashomon_effect import plot_scatter
from my_pages.rashomon_effect import income, credit, labels, colors
plt.style.use('dark_background')
def render():
add_navigation("txt_rashomon_developer", "txt_developer_decisions")
add_instruction_text(
"""
Consider the same data as before. <br>
Instead of directly choosing a model, you make development choices now.
"""
)
#### Choosing regularization
st.markdown("""
**Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize.
Choose a regularization method:
- L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features.
- L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature.
"""
)
regularization_method = None
if "regularization_method" in st.session_state:
regularization_method = st.session_state.regularization_method
col1, col2 = st.columns([1, 1])
with col1:
if regularization_method=="l1":
button_click_l1 = st.button("L1 Regularization", type="primary")
else:
button_click_l1 = st.button("L1 Regularization")
if button_click_l1:
st.session_state.regularization_method = "l1"
st.rerun()
with col2:
if regularization_method=="l2":
button_click_l2 = st.button("L2 Regularization", type="primary")
else:
button_click_l2 = st.button("L2 Regularization")
if button_click_l2:
st.session_state.regularization_method = "l2"
st.rerun()
#### Choosing random seed
if regularization_method=="l1":
st.markdown("""
**Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin
(You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret).
"""
)
random_seed = None
if "random_seed" in st.session_state:
random_seed = st.session_state.random_seed
col1, col2 = st.columns([1, 1])
with col1:
if random_seed=="Heads":
button_click_1 = st.button("Heads", type="primary")
else:
button_click_1 = st.button("Heads")
if button_click_1:
st.session_state.random_seed = "Heads"
st.rerun()
with col2:
if random_seed=="Tails":
button_click_2 = st.button("Tails", type="primary")
else:
button_click_2 = st.button("Tails")
if button_click_2:
st.session_state.random_seed = "Tails"
st.rerun()
#### Plot the final figure
plot_chosen = None
if regularization_method=="l2":
plot_chosen = "slant"
if regularization_method=="l1":
if random_seed=="Heads":
plot_chosen = "vertical"
elif random_seed=="Tails":
plot_chosen = "horizontal"
if plot_chosen is not None:
col1, col2, col3 = st.columns([1.5, 1, 1.5])
with col2:
st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None))
multiplicity_message = """
Your choices during model development lead you to this model.
"""
add_red_text(multiplicity_message) |