prakharg24 commited on
Commit
6726747
·
verified ·
1 Parent(s): b0b4d09

Update my_pages/rashomon_effect.py

Browse files
Files changed (1) hide show
  1. my_pages/rashomon_effect.py +34 -7
my_pages/rashomon_effect.py CHANGED
@@ -6,6 +6,17 @@ from utils import go_to
6
  plt.style.use('dark_background')
7
 
8
  def render():
 
 
 
 
 
 
 
 
 
 
 
9
  st.markdown(
10
  """
11
  <div style='text-align: center; font-size:18px; color:gray;'>
@@ -66,8 +77,9 @@ def render():
66
  return fig
67
 
68
 
69
- highlight_point = None
70
- if "highlight_point" in st.session_state:
 
71
  highlight_point = st.session_state.highlight_point
72
 
73
  # Top scatter plot (centered to match smaller width)
@@ -78,18 +90,33 @@ def render():
78
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
79
  with col2:
80
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
81
- if st.button("Choose Model 1"):
 
 
 
 
82
  st.session_state.highlight_point = (32, 902)
 
83
  st.rerun()
84
  with col3:
85
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
86
- if st.button("Choose Model 2"):
87
- st.session_state.highlight_point = (97, 370)
 
 
 
 
 
88
  st.rerun()
89
  with col4:
90
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
91
- if st.button("Choose Model 3"):
92
- st.session_state.highlight_point = (32, 902)
 
 
 
 
 
93
  st.rerun()
94
 
95
  st.markdown("---")
 
6
  plt.style.use('dark_background')
7
 
8
  def render():
9
+ st.markdown(
10
+ """
11
+ <style>
12
+ button[kind="primary"] {
13
+ background: green!important;
14
+ }
15
+ </style>
16
+ """,
17
+ unsafe_allow_html=True,
18
+ )
19
+
20
  st.markdown(
21
  """
22
  <div style='text-align: center; font-size:18px; color:gray;'>
 
77
  return fig
78
 
79
 
80
+ graph_selected, highlight_point = None, None
81
+ if "graph_selected" in st.session_state:
82
+ graph_selected = st.session_state.graph_selected
83
  highlight_point = st.session_state.highlight_point
84
 
85
  # Top scatter plot (centered to match smaller width)
 
90
  col1, col2, col3, col4, col5 = st.columns([0.5, 1, 1, 1, 0.5])
91
  with col2:
92
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="vertical", highlight_point=highlight_point))
93
+ if graph_selected=="vertical":
94
+ button_click_v = st.button("Choose Model 1", type="primary")
95
+ else:
96
+ button_click_v = st.button("Choose Model 1")
97
+ if button_click_v:
98
  st.session_state.highlight_point = (32, 902)
99
+ st.session_state.graph_selected = "vertical"
100
  st.rerun()
101
  with col3:
102
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="slant", highlight_point=highlight_point))
103
+ if graph_selected=="slant":
104
+ button_click_s = st.button("Choose Model 2", type="primary")
105
+ else:
106
+ button_click_s = st.button("Choose Model 2")
107
+ if button_click_s:
108
+ st.session_state.highlight_point = (32, 902)
109
+ st.session_state.graph_selected = "slant"
110
  st.rerun()
111
  with col4:
112
  st.pyplot(plot_scatter(income, credit, colors, boundary_type="horizontal", highlight_point=highlight_point))
113
+ if graph_selected=="horizontal":
114
+ button_click_h = st.button("Choose Model 3", type="primary")
115
+ else:
116
+ button_click_h = st.button("Choose Model 3")
117
+ if button_click_h:
118
+ st.session_state.highlight_point = (97, 370)
119
+ st.session_state.graph_selected = "horizontal"
120
  st.rerun()
121
 
122
  st.markdown("---")