File size: 3,752 Bytes
221e0dd
 
 
91db57a
d946758
 
221e0dd
 
 
 
1050453
221e0dd
 
 
a0e4212
 
221e0dd
 
 
1050453
 
05504e2
d2a9289
a0e4212
 
 
221e0dd
 
 
1050453
 
 
 
 
 
 
221e0dd
1050453
 
 
221e0dd
1050453
 
 
221e0dd
1050453
 
 
221e0dd
 
1050453
 
d2a9289
56960a5
53b8c7c
d2a9289
 
1050453
 
 
 
 
56960a5
 
1050453
56960a5
1050453
56960a5
1050453
 
56960a5
 
1050453
56960a5
1050453
56960a5
1050453
 
 
 
 
 
 
56960a5
1050453
56960a5
1050453
 
 
 
 
 
 
221e0dd
1050453
221e0dd
91db57a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
from my_pages.rashomon_effect import plot_scatter
from my_pages.rashomon_effect import income, credit, labels, colors

plt.style.use('dark_background')

def render():
    add_navigation("txt_rashomon_developer", "txt_developer_decisions")

    add_instruction_text(
        """
        Consider the same data as before. <br>
        Instead of directly choosing a model, you make development choices now.
        """
    )

    #### Choosing regularization
    st.markdown("""
        **Regularization:** Regularization is a technique commonly used to stop AI models from learning the noise or small quirks in the data that might not generalize. 
        
        Choose a regularization method:
        - L1 Regularization: Force your AI model to use less number of features, thus avoiding irrelevant features.
        - L2 Regularization: Force your AI model to rely less on each feature, even though you use all features, thus avoiding noisy dominance of any single feature.
    """
    )

    regularization_method = None
    if "regularization_method" in st.session_state:
        regularization_method = st.session_state.regularization_method
    col1, col2 = st.columns([1, 1])
    with col1:
        if regularization_method=="l1":
            button_click_l1 = st.button("L1 Regularization", type="primary")
        else:
            button_click_l1 = st.button("L1 Regularization")
        if button_click_l1:
            st.session_state.regularization_method = "l1"
            st.rerun()
    with col2:
        if regularization_method=="l2":
            button_click_l2 = st.button("L2 Regularization", type="primary")
        else:
            button_click_l2 = st.button("L2 Regularization")
        if button_click_l2:
            st.session_state.regularization_method = "l2"
            st.rerun()

    #### Choosing random seed
    if regularization_method=="l1":
        st.markdown("""
            **Randomness:** Sometimes there is randomness in the learning process. Let's flip a coin 
            (You can just choose Heads or Tails, and we will assume we flipped a coin. It'll be our little secret).
        """
        )
        random_seed = None
        if "random_seed" in st.session_state:
            random_seed = st.session_state.random_seed
        col1, col2 = st.columns([1, 1])
        with col1:
            if random_seed=="Heads":
                button_click_1 = st.button("Heads", type="primary")
            else:
                button_click_1 = st.button("Heads")
            if button_click_1:
                st.session_state.random_seed = "Heads"
                st.rerun()
        with col2:
            if random_seed=="Tails":
                button_click_2 = st.button("Tails", type="primary")
            else:
                button_click_2 = st.button("Tails")
            if button_click_2:
                st.session_state.random_seed = "Tails"
                st.rerun()

    #### Plot the final figure
    plot_chosen = None
    if regularization_method=="l2":
        plot_chosen = "slant"
    if regularization_method=="l1":
        if random_seed=="Heads":
            plot_chosen = "vertical"
        elif random_seed=="Tails":
            plot_chosen = "horizontal"

    if plot_chosen is not None:
        col1, col2, col3 = st.columns([1.5, 1, 1.5])
        with col2:
            st.pyplot(plot_scatter(income, credit, colors, boundary_type=plot_chosen, highlight_point=None))

        multiplicity_message = """
            Your choices during model development lead you to this model.
        """
        add_red_text(multiplicity_message)