| import streamlit as st |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from utils import add_navigation, add_instruction_text, add_red_text |
| from my_pages.rashomon_effect import plot_scatter |
| from my_pages.rashomon_effect import income, credit, labels, colors |
|
|
| plt.style.use('dark_background') |
|
|
| def render(): |
| add_navigation("txt_rashomon_developer", "txt_developer_decisions") |
|
|
| add_instruction_text( |
| """ |
| Consider the same data as before. <br> |
| Instead of directly choosing a model, you make development choices now. |
| """ |
| ) |
|
|
| |
| st.markdown(""" |
| **Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize. |
| |
| Choose a regularization method: |
| - L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features. |
| - L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature. |
| """ |
| ) |
|
|
| regularization_method = None |
| if "regularization_method" in st.session_state: |
| regularization_method = st.session_state.regularization_method |
| col1, col2 = st.columns([1, 1]) |
| with col1: |
| if regularization_method=="l1": |
| button_click_l1 = st.button("L1 Regularization", type="primary") |
| else: |
| button_click_l1 = st.button("L1 Regularization") |
| if button_click_l1: |
| st.session_state.regularization_method = "l1" |
| st.rerun() |
| with col2: |
| if regularization_method=="l2": |
| button_click_l2 = st.button("L2 Regularization", type="primary") |
| else: |
| button_click_l2 = st.button("L2 Regularization") |
| if button_click_l2: |
| st.session_state.regularization_method = "l2" |
| st.rerun() |
|
|
| |
| if regularization_method=="l1": |
| st.markdown(""" |
| **Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin |
| (You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret). |
| """ |
| ) |
| random_seed = None |
| if "random_seed" in st.session_state: |
| random_seed = st.session_state.random_seed |
| col1, col2 = st.columns([1, 1]) |
| with col1: |
| if random_seed=="Heads": |
| button_click_1 = st.button("Heads", type="primary") |
| else: |
| button_click_1 = st.button("Heads") |
| if button_click_1: |
| st.session_state.random_seed = "Heads" |
| st.rerun() |
| with col2: |
| if random_seed=="Tails": |
| button_click_2 = st.button("Tails", type="primary") |
| else: |
| button_click_2 = st.button("Tails") |
| if button_click_2: |
| st.session_state.random_seed = "Tails" |
| st.rerun() |
|
|
| |
| plot_chosen = None |
| if regularization_method=="l2": |
| plot_chosen = "slant" |
| if regularization_method=="l1": |
| if random_seed=="Heads": |
| plot_chosen = "vertical" |
| elif random_seed=="Tails": |
| plot_chosen = "horizontal" |
|
|
| if plot_chosen is not None: |
| col1, col2, col3 = st.columns([1.5, 1, 1.5]) |
| with col2: |
| st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None)) |
|
|
| multiplicity_message = """ |
| Your choices during model development lead you to this model. |
| """ |
| add_red_text(multiplicity_message) |