| import streamlit as st |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from utils import add_navigation, add_instruction_text, add_red_text |
|
|
| plt.style.use('dark_background') |
|
|
| |
| income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 81, 40, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32]) |
| credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 810, 370, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902]) |
| labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1]) |
| colors = ['green' if label == 1 else 'red' for label in labels] |
|
|
| def plot_scatter(x, y, colors, title="", boundary_type=None, highlight_point=None): |
| fig, ax = plt.subplots(figsize=(2, 2)) |
| ax.scatter(x, y, c=colors, alpha=0.6) |
| ax.set_xlabel("Annual Income") |
| ax.set_ylabel("Credit Score") |
| |
|
|
| fig.patch.set_alpha(0) |
| ax.patch.set_alpha(0) |
|
|
| |
| if boundary_type is not None: |
| if boundary_type == "vertical": |
| ax.axvline(65, color='blue') |
| ax.fill_betweenx(np.arange(min(y), max(y)), 65, max(x), alpha=0.1, color='green') |
| ax.fill_betweenx(np.arange(min(y), max(y)), min(x), 65, alpha=0.1, color='red') |
| elif boundary_type == "horizontal": |
| ax.axhline(650, color='blue') |
| ax.fill_between(np.arange(min(x), max(x)), 650, max(y), alpha=0.1, color='green') |
| ax.fill_between(np.arange(min(x), max(x)), min(y), 650, alpha=0.1, color='red') |
| elif boundary_type == "slant": |
| slope = -10.677966 |
| intercept = 1353.7288 |
| x_sorted = np.sort(x) |
| y_line = slope * x_sorted + intercept |
| ax.plot(x_sorted, y_line, color='blue') |
| ax.fill_between(x_sorted, y_line, max(y), alpha=0.1, color='green') |
| ax.fill_between(x_sorted, min(y), y_line, alpha=0.1, color='red') |
|
|
| |
| if highlight_point is not None: |
| ax.scatter(*highlight_point, c='green', edgecolors='yellow', s=200, zorder=5, linewidths=4) |
|
|
| ax.spines['right'].set_visible(False) |
| ax.spines['top'].set_visible(False) |
| ax.set_xticks([]) |
| ax.set_yticks([]) |
| |
| return fig |
|
|
| def render(): |
| add_navigation("txt_rashomon_effect", "txt_rashomon_developer") |
|
|
| add_instruction_text( |
| """ |
| Consider the following data about individuals who did (green) or didn't (red) repay their loans. <br> |
| Which model out of these three will you choose to judge loan applications? |
| """ |
| ) |
|
|
| |
| rashomon_set_message = """ |
| Multiple models achieving similar accuracy, i.e., multiple interpretations of the data, is known as the Rashomon effect. |
| We call the models below part of a 'Rashomon set'. |
| """ |
| add_red_text(rashomon_set_message) |
|
|
| |
| graph_selected, highlight_point = None, None |
| if "graph_selected" in st.session_state: |
| graph_selected = st.session_state.graph_selected |
| highlight_point = st.session_state.highlight_point |
|
|
| col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5]) |
| with col2: |
| st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point)) |
| st.markdown("Accuracy: 90%") |
| if graph_selected=="vertical": |
| button_click_v = st.button("Choose Model 1", type="primary") |
| else: |
| button_click_v = st.button("Choose Model 1") |
| if button_click_v: |
| st.session_state.highlight_point = (32, 902) |
| st.session_state.graph_selected = "vertical" |
| st.rerun() |
| with col3: |
| st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point)) |
| st.markdown("Accuracy: 90%") |
| if graph_selected=="slant": |
| button_click_s = st.button("Choose Model 2", type="primary") |
| else: |
| button_click_s = st.button("Choose Model 2") |
| if button_click_s: |
| st.session_state.highlight_point = (32, 902) |
| st.session_state.graph_selected = "slant" |
| st.rerun() |
| with col4: |
| st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point)) |
| st.markdown("Accuracy: 90%") |
| if graph_selected=="horizontal": |
| button_click_h = st.button("Choose Model 3", type="primary") |
| else: |
| button_click_h = st.button("Choose Model 3") |
| if button_click_h: |
| st.session_state.highlight_point = (97, 370) |
| st.session_state.graph_selected = "horizontal" |
| st.rerun() |
|
|
| |
| if "graph_selected" in st.session_state: |
| multiplicity_message = """ |
| Because of your choice, the highlighted individual was rejected, but would have gotten loan under a different model. |
| These conflicting predictions is multiplicity.<br><br> |
| <b>Clearly, the choice of model directly impacts individuals!</b> |
| """ |
| add_red_text(multiplicity_message) |