File size: 5,186 Bytes
f46f2d9 a733d39 804b2fa f46f2d9 e95241f da13d3a 7acc3dc bb39a0f da13d3a 693650f 7acc3dc 693650f 7acc3dc 9ca7fce dffbc09 1b5de32 9ca7fce 804b2fa a733d39 7acc3dc 6726747 6fbbcf2 979710a c337048 0c434b3 979710a c82f407 6726747 6fbbcf2 6726747 6fbbcf2 c62c72d 979710a c82f407 6726747 6fbbcf2 0ec6255 979710a c82f407 6726747 6fbbcf2 a733d39 7acc3dc 91d4827 9ca7fce c82f407 dffbc09 1b5de32 804b2fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | import streamlit as st
import matplotlib.pyplot as plt
import numpy as np
from utils import add_navigation, add_instruction_text, add_red_text
plt.style.use('dark_background')
#### Setup data to plot
income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 81, 40, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 810, 370, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
colors = ['green' if label == 1 else 'red' for label in labels]
def plot_scatter(x, y, colors, title="", boundary_type=None, highlight_point=None):
fig, ax = plt.subplots(figsize=(2, 2))
ax.scatter(x, y, c=colors, alpha=0.6)
ax.set_xlabel("Annual Income")
ax.set_ylabel("Credit Score")
# ax.set_title(title)
fig.patch.set_alpha(0)
ax.patch.set_alpha(0)
# Decision boundary
if boundary_type is not None:
if boundary_type == "vertical":
ax.axvline(65, color='blue')
ax.fill_betweenx(np.arange(min(y), max(y)), 65, max(x), alpha=0.1, color='green')
ax.fill_betweenx(np.arange(min(y), max(y)), min(x), 65, alpha=0.1, color='red')
elif boundary_type == "horizontal":
ax.axhline(650, color='blue')
ax.fill_between(np.arange(min(x), max(x)), 650, max(y), alpha=0.1, color='green')
ax.fill_between(np.arange(min(x), max(x)), min(y), 650, alpha=0.1, color='red')
elif boundary_type == "slant":
slope = -10.677966 # From (94, 350) and (35, 980)
intercept = 1353.7288
x_sorted = np.sort(x)
y_line = slope * x_sorted + intercept
ax.plot(x_sorted, y_line, color='blue')
ax.fill_between(x_sorted, y_line, max(y), alpha=0.1, color='green')
ax.fill_between(x_sorted, min(y), y_line, alpha=0.1, color='red')
# Highlight specific point
if highlight_point is not None:
ax.scatter(*highlight_point, c='green', edgecolors='yellow', s=200, zorder=5, linewidths=4)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
return fig
def render():
add_navigation("txt_rashomon_effect", "txt_rashomon_developer")
add_instruction_text(
"""
Consider the following data about individuals who did (green) or didn't (red) repay their loans. <br>
Which model out of these three will you choose to judge loan applications?
"""
)
#### Rashomon Set Definition
rashomon_set_message = """
Multiple models achieving similar accuracy, i.e., multiple interpretations of the data, is known as the Rashomon effect.
We call the models below part of a 'Rashomon set'.
"""
add_red_text(rashomon_set_message)
#### Plot three graphs to represent three models
graph_selected, highlight_point = None, None
if "graph_selected" in st.session_state:
graph_selected = st.session_state.graph_selected
highlight_point = st.session_state.highlight_point
col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
with col2:
st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
st.markdown("Accuracy: 90%")
if graph_selected=="vertical":
button_click_v = st.button("Choose Model 1", type="primary")
else:
button_click_v = st.button("Choose Model 1")
if button_click_v:
st.session_state.highlight_point = (32, 902)
st.session_state.graph_selected = "vertical"
st.rerun()
with col3:
st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
st.markdown("Accuracy: 90%")
if graph_selected=="slant":
button_click_s = st.button("Choose Model 2", type="primary")
else:
button_click_s = st.button("Choose Model 2")
if button_click_s:
st.session_state.highlight_point = (32, 902)
st.session_state.graph_selected = "slant"
st.rerun()
with col4:
st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
st.markdown("Accuracy: 90%")
if graph_selected=="horizontal":
button_click_h = st.button("Choose Model 3", type="primary")
else:
button_click_h = st.button("Choose Model 3")
if button_click_h:
st.session_state.highlight_point = (97, 370)
st.session_state.graph_selected = "horizontal"
st.rerun()
#### Multiplicity Definition
if "graph_selected" in st.session_state:
multiplicity_message = """
Because of your choice, the highlighted individual was rejected, but would have gotten loan under a different model.
These conflicting predictions is multiplicity.<br><br>
<b>Clearly, the choice of model directly impacts individuals!</b>
"""
add_red_text(multiplicity_message) |