File size: 5,186 Bytes
f46f2d9
 
a733d39
804b2fa
f46f2d9
e95241f
 
da13d3a
 
 
 
 
 
7acc3dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb39a0f
da13d3a
693650f
 
 
 
7acc3dc
693650f
 
7acc3dc
 
9ca7fce
dffbc09
1b5de32
9ca7fce
804b2fa
a733d39
7acc3dc
6726747
 
 
6fbbcf2
979710a
c337048
0c434b3
979710a
c82f407
6726747
 
 
 
 
6fbbcf2
6726747
6fbbcf2
c62c72d
979710a
c82f407
6726747
 
 
 
 
 
 
6fbbcf2
0ec6255
979710a
c82f407
6726747
 
 
 
 
 
 
6fbbcf2
a733d39
7acc3dc
91d4827
9ca7fce
c82f407
 
dffbc09
1b5de32
804b2fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text

plt.style.use('dark_background')

#### Setup data to plot
income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 81, 40, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 810, 370, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
colors = ['green' if label == 1 else 'red' for label in labels]

def plot_scatter(x, y, colors, title="", boundary_type=None, highlight_point=None):
    fig, ax = plt.subplots(figsize=(2, 2))
    ax.scatter(x, y, c=colors, alpha=0.6)
    ax.set_xlabel("Annual Income")
    ax.set_ylabel("Credit Score")
    # ax.set_title(title)

    fig.patch.set_alpha(0)
    ax.patch.set_alpha(0)

    # Decision boundary
    if boundary_type is not None:
        if boundary_type == "vertical":
            ax.axvline(65, color='blue')
            ax.fill_betweenx(np.arange(min(y), max(y)), 65, max(x), alpha=0.1, color='green')
            ax.fill_betweenx(np.arange(min(y), max(y)), min(x), 65, alpha=0.1, color='red')
        elif boundary_type == "horizontal":
            ax.axhline(650, color='blue')
            ax.fill_between(np.arange(min(x), max(x)), 650, max(y), alpha=0.1, color='green')
            ax.fill_between(np.arange(min(x), max(x)), min(y), 650, alpha=0.1, color='red')
        elif boundary_type == "slant":
            slope = -10.677966       # From (94, 350) and (35, 980)
            intercept = 1353.7288
            x_sorted = np.sort(x)
            y_line = slope * x_sorted + intercept
            ax.plot(x_sorted, y_line, color='blue')
            ax.fill_between(x_sorted, y_line, max(y), alpha=0.1, color='green')
            ax.fill_between(x_sorted, min(y), y_line, alpha=0.1, color='red')

    # Highlight specific point
    if highlight_point is not None:
        ax.scatter(*highlight_point, c='green', edgecolors='yellow', s=200, zorder=5, linewidths=4)

    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.set_xticks([])
    ax.set_yticks([])
    
    return fig

def render():
    add_navigation("txt_rashomon_effect", "txt_rashomon_developer")

    add_instruction_text(
        """
        Consider the following data about individuals who did (green) or didn't (red) repay their loans. <br>
        Which model out of these three will you choose to judge loan applications?
        """
    )

    #### Rashomon Set Definition
    rashomon_set_message = """
        Multiple models achieving similar accuracy, i.e., multiple interpretations of the data, is known as the Rashomon effect.
        We call the models below part of a 'Rashomon set'.
    """
    add_red_text(rashomon_set_message)

    #### Plot three graphs to represent three models
    graph_selected, highlight_point = None, None
    if "graph_selected" in st.session_state:
        graph_selected = st.session_state.graph_selected
        highlight_point = st.session_state.highlight_point

    col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
    with col2:
        st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
        st.markdown("Accuracy: 90%")
        if graph_selected=="vertical":
            button_click_v = st.button("Choose Model 1", type="primary")
        else:
            button_click_v = st.button("Choose Model 1")
        if button_click_v:
            st.session_state.highlight_point = (32, 902)
            st.session_state.graph_selected = "vertical"
            st.rerun()
    with col3:
        st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
        st.markdown("Accuracy: 90%")
        if graph_selected=="slant":
            button_click_s = st.button("Choose Model 2", type="primary")
        else:
            button_click_s = st.button("Choose Model 2")
        if button_click_s:
            st.session_state.highlight_point = (32, 902)
            st.session_state.graph_selected = "slant"
            st.rerun()
    with col4:
        st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
        st.markdown("Accuracy: 90%")
        if graph_selected=="horizontal":
            button_click_h = st.button("Choose Model 3", type="primary")
        else:
            button_click_h = st.button("Choose Model 3")
        if button_click_h:
            st.session_state.highlight_point = (97, 370)
            st.session_state.graph_selected = "horizontal"
            st.rerun()

    #### Multiplicity Definition    
    if "graph_selected" in st.session_state:
        multiplicity_message = """
            Because of your choice, the highlighted individual was rejected, but would have gotten loan under a different model.
            These conflicting predictions is multiplicity.<br><br>
            <b>Clearly, the choice of model directly impacts individuals!</b>
        """
        add_red_text(multiplicity_message)