prakharg24 commited on
Commit
979710a
·
verified ·
1 Parent(s): a5e0476

Update my_pages/rashomon_effect.py

Browse files
Files changed (1) hide show
  1. my_pages/rashomon_effect.py +23 -30
my_pages/rashomon_effect.py CHANGED
@@ -18,13 +18,6 @@ def render():
18
 
19
  st.markdown("---")
20
 
21
- # Generate synthetic data
22
- # np.random.seed(42)
23
- # n_points = 100
24
- # income = np.random.normal(50, 15, n_points)
25
- # credit = np.random.normal(50, 15, n_points)
26
- # labels = (income + credit > 100).astype(int) # 1 = paid back, 0 = default
27
-
28
  income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
29
  credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
30
  labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
@@ -63,7 +56,7 @@ def render():
63
 
64
  # Highlight specific point
65
  if highlight_point is not None:
66
- ax.scatter(*highlight_point, c='yellow', edgecolors='black', s=200, zorder=5)
67
 
68
  ax.spines['right'].set_visible(False)
69
  ax.spines['top'].set_visible(False)
@@ -72,37 +65,37 @@ def render():
72
 
73
  return fig
74
 
75
- # Top scatter plot (centered to match smaller width)
 
 
 
 
 
 
 
 
 
 
76
  col1, col2, col3 = st.columns([1.5, 1, 1.5])
77
  with col2:
78
  st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
79
 
80
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
81
  with col2:
82
- st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical"))
83
  vertical_selected = st.button("Choose Model 1")
84
  with col3:
85
- st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant"))
86
  slant_selected = st.button("Choose Model 2")
87
  with col4:
88
- st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal"))
89
  horizontal_selected = st.button("Choose Model 3")
90
 
91
- # col1, col2, col3 = st.columns([1.5, 1, 1.5])
92
- # Show new individual based on selection
93
- # if left_selected:
94
- # new_point = (40, 80) # High credit score, low income
95
- # with col2:
96
- # st.pyplot(plot_scatter(income, credit, colors,
97
- # title="Vertical Boundary + New Individual",
98
- # boundary_type="vertical",
99
- # highlight_point=new_point))
100
- # st.warning("This individual was rejected by your chosen model. Why not choose a model that helps them?")
101
- # elif right_selected:
102
- # new_point = (80, 40) # Low credit score, high income
103
- # with col2:
104
- # st.pyplot(plot_scatter(income, credit, colors,
105
- # title="Horizontal Boundary + New Individual",
106
- # boundary_type="horizontal",
107
- # highlight_point=new_point))
108
- # st.warning("This individual was rejected by your chosen model. Why not choose a model that helps them?")
 
18
 
19
  st.markdown("---")
20
 
 
 
 
 
 
 
 
21
  income = np.array([80, 85, 97, 91, 78, 102, 84, 88, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32])
22
  credit = np.array([970, 880, 1020, 910, 805, 800, 804, 708, 470, 309, 450, 304, 380, 501, 370, 301, 1080, 902])
23
  labels = np.array([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1])
 
56
 
57
  # Highlight specific point
58
  if highlight_point is not None:
59
+ ax.scatter(*highlight_point, c='yellow', edgecolors='black', s=200, zorder=5, alpha=0)
60
 
61
  ax.spines['right'].set_visible(False)
62
  ax.spines['top'].set_visible(False)
 
65
 
66
  return fig
67
 
68
+
69
+ # Store highlight point based on selection
70
+ highlight_point = None
71
+ if vertical_selected:
72
+ highlight_point = (32, 902)
73
+ elif horizontal_selected:
74
+ highlight_point = (97, 370) # Low credit score, high income
75
+ elif slant_selected:
76
+ highlight_point = (32, 902) # Example point for slant model
77
+
78
+ # Top scatter plot (centered to match smaller width)
79
  col1, col2, col3 = st.columns([1.5, 1, 1.5])
80
  with col2:
81
  st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
82
 
83
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
84
  with col2:
85
+ st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
86
  vertical_selected = st.button("Choose Model 1")
87
  with col3:
88
+ st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
89
  slant_selected = st.button("Choose Model 2")
90
  with col4:
91
+ st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
92
  horizontal_selected = st.button("Choose Model 3")
93
 
94
+ st.markdown("---")
95
+ col1, col2, col3, col4 = st.columns([2, 1, 1, 1])
96
+ with col3:
97
+ if st.button("Go Home"):
98
+ go_to("home")
99
+ with col4:
100
+ if st.button("Next"):
101
+ go_to("developer_decisions")