Update my_pages/rashomon_effect.py
Browse files- my_pages/rashomon_effect.py +23 -30
my_pages/rashomon_effect.py
CHANGED
|
@@ -18,13 +18,6 @@ def render():
|
|
| 18 |
|
| 19 |
st.markdown("---")
|
| 20 |
|
| 21 |
-
# Generate synthetic data
|
| 22 |
-
# np.random.seed(42)
|
| 23 |
-
# n_points = 100
|
| 24 |
-
# income = np.random.normal(50, 15, n_points)
|
| 25 |
-
# credit = np.random.normal(50, 15, n_points)
|
| 26 |
-
# labels = (income + credit > 100).astype(int) # 1 = paid back, 0 = default
|
| 27 |
-
|
| 28 |
income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
|
| 29 |
credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
|
| 30 |
labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
|
|
@@ -63,7 +56,7 @@ def render():
|
|
| 63 |
|
| 64 |
# Highlight specific point
|
| 65 |
if highlight_point is not None:
|
| 66 |
-
ax.scatter(*highlight_point, c='yellow', edgecolors='black', s=200, zorder=5)
|
| 67 |
|
| 68 |
ax.spines['right'].set_visible(False)
|
| 69 |
ax.spines['top'].set_visible(False)
|
|
@@ -72,37 +65,37 @@ def render():
|
|
| 72 |
|
| 73 |
return fig
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
col1, col2, col3 = st.columns([1.5, 1, 1.5])
|
| 77 |
with col2:
|
| 78 |
st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
|
| 79 |
|
| 80 |
col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
|
| 81 |
with col2:
|
| 82 |
-
st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical"))
|
| 83 |
vertical_selected = st.button("Choose Model 1")
|
| 84 |
with col3:
|
| 85 |
-
st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant"))
|
| 86 |
slant_selected = st.button("Choose Model 2")
|
| 87 |
with col4:
|
| 88 |
-
st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal"))
|
| 89 |
horizontal_selected = st.button("Choose Model 3")
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
# highlight_point=new_point))
|
| 100 |
-
# st.warning("This individual was rejected by your chosen model. Why not choose a model that helps them?")
|
| 101 |
-
# elif right_selected:
|
| 102 |
-
# new_point = (80, 40) # Low credit score, high income
|
| 103 |
-
# with col2:
|
| 104 |
-
# st.pyplot(plot_scatter(income, credit, colors,
|
| 105 |
-
# title="Horizontal Boundary + New Individual",
|
| 106 |
-
# boundary_type="horizontal",
|
| 107 |
-
# highlight_point=new_point))
|
| 108 |
-
# st.warning("This individual was rejected by your chosen model. Why not choose a model that helps them?")
|
|
|
|
| 18 |
|
| 19 |
st.markdown("---")
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
|
| 22 |
credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
|
| 23 |
labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
|
|
|
|
| 56 |
|
| 57 |
# Highlight specific point
|
| 58 |
if highlight_point is not None:
|
| 59 |
+
ax.scatter(*highlight_point, c='yellow', edgecolors='black', s=200, zorder=5, alpha=0)
|
| 60 |
|
| 61 |
ax.spines['right'].set_visible(False)
|
| 62 |
ax.spines['top'].set_visible(False)
|
|
|
|
| 65 |
|
| 66 |
return fig
|
| 67 |
|
| 68 |
+
|
| 69 |
+
# Store highlight point based on selection
|
| 70 |
+
highlight_point = None
|
| 71 |
+
if vertical_selected:
|
| 72 |
+
highlight_point = (32, 902)
|
| 73 |
+
elif horizontal_selected:
|
| 74 |
+
highlight_point = (97, 370) # Low credit score, high income
|
| 75 |
+
elif slant_selected:
|
| 76 |
+
highlight_point = (32, 902) # Example point for slant model
|
| 77 |
+
|
| 78 |
+
# Top scatter plot (centered to match smaller width)
|
| 79 |
col1, col2, col3 = st.columns([1.5, 1, 1.5])
|
| 80 |
with col2:
|
| 81 |
st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
|
| 82 |
|
| 83 |
col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
|
| 84 |
with col2:
|
| 85 |
+
st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
|
| 86 |
vertical_selected = st.button("Choose Model 1")
|
| 87 |
with col3:
|
| 88 |
+
st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
|
| 89 |
slant_selected = st.button("Choose Model 2")
|
| 90 |
with col4:
|
| 91 |
+
st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
|
| 92 |
horizontal_selected = st.button("Choose Model 3")
|
| 93 |
|
| 94 |
+
st.markdown("---")
|
| 95 |
+
col1, col2, col3, col4 = st.columns([2, 1, 1, 1])
|
| 96 |
+
with col3:
|
| 97 |
+
if st.button("Go Home"):
|
| 98 |
+
go_to("home")
|
| 99 |
+
with col4:
|
| 100 |
+
if st.button("Next"):
|
| 101 |
+
go_to("developer_decisions")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|