prakharg24 commited on
Commit
6fbbcf2
·
verified ·
1 Parent(s): 979710a

Update my_pages/rashomon_effect.py

Browse files
Files changed (1) hide show
  1. my_pages/rashomon_effect.py +13 -12
my_pages/rashomon_effect.py CHANGED
@@ -66,16 +66,11 @@ def render():
66
  return fig
67
 
68
 
69
- # Store highlight point based on selection
70
- highlight_point = None
71
- if vertical_selected:
72
- highlight_point = (32, 902)
73
- elif horizontal_selected:
74
- highlight_point = (97, 370) # Low credit score, high income
75
- elif slant_selected:
76
- highlight_point = (32, 902) # Example point for slant model
77
 
78
- # Top scatter plot (centered to match smaller width)
79
  col1, col2, col3 = st.columns([1.5, 1, 1.5])
80
  with col2:
81
  st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
@@ -83,13 +78,19 @@ def render():
83
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
84
  with col2:
85
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
86
- vertical_selected = st.button("Choose Model 1")
 
 
87
  with col3:
88
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
89
- slant_selected = st.button("Choose Model 2")
 
 
90
  with col4:
91
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
92
- horizontal_selected = st.button("Choose Model 3")
 
 
93
 
94
  st.markdown("---")
95
  col1, col2, col3, col4 = st.columns([2, 1, 1, 1])
 
66
  return fig
67
 
68
 
69
+ higlight_point = None
70
+ if "highlight_point" in st.session_state:
71
+ highlight_point = st.session_state.highlight_point
 
 
 
 
 
72
 
73
+ # Top scatter plot (centered to match smaller width)
74
  col1, col2, col3 = st.columns([1.5, 1, 1.5])
75
  with col2:
76
  st.pyplot(plot_scatter(income, credit, colors, title="Original Data"))
 
78
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
79
  with col2:
80
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
81
+ if st.button("Choose Model 1"):
82
+ st.session_state.highlight_point = (32, 902)
83
+ st.rerun()
84
  with col3:
85
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
86
+ if st.button("Choose Model 2"):
87
+ st.session_state.highlight_point = (97, 370)
88
+ st.rerun()
89
  with col4:
90
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
91
+ if st.button("Choose Model 3"):
92
+ st.session_state.highlight_point = (32, 902)
93
+ st.rerun()
94
 
95
  st.markdown("---")
96
  col1, col2, col3, col4 = st.columns([2, 1, 1, 1])