prakharg24 commited on
Commit
e95241f
·
verified ·
1 Parent(s): c62c72d

Update my_pages/rashomon_effect.py

Browse files
Files changed (1) hide show
  1. my_pages/rashomon_effect.py +14 -8
my_pages/rashomon_effect.py CHANGED
@@ -3,12 +3,14 @@ import matplotlib.pyplot as plt
3
  import numpy as np
4
  from utils import go_to
5
 
 
 
6
  def render():
7
  st.markdown(
8
  """
9
  <div style='text-align: center; font-size:18px; color:gray;'>
10
- Consider data about individuals who paid back their loans (green) and those who defaulted (red). <br>
11
- Which model out of the two will you choose as your final model to give loan applications? <br><br>
12
  </div>
13
  """,
14
  unsafe_allow_html=True
@@ -17,17 +19,21 @@ def render():
17
  st.markdown("---")
18
 
19
  # Generate synthetic data
20
- np.random.seed(42)
21
- n_points = 100
22
- income = np.random.normal(50, 15, n_points)
23
- credit = np.random.normal(50, 15, n_points)
24
- labels = (income + credit > 100).astype(int) # 1 = paid back, 0 = default
 
 
 
 
25
 
26
  colors = ['green' if label == 1 else 'red' for label in labels]
27
 
28
  # Function to plot scatter
29
  def plot_scatter(x, y, colors, title="", decision_boundary=None, boundary_type=None, highlight_point=None):
30
- fig, ax = plt.subplots(figsize=(3, 3))
31
  ax.scatter(x, y, c=colors, alpha=0.6)
32
  ax.set_xlabel("Annual Income")
33
  ax.set_ylabel("Credit Score")
 
3
  import numpy as np
4
  from utils import go_to
5
 
6
+ plt.style.use('dark_background')
7
+
8
  def render():
9
  st.markdown(
10
  """
11
  <div style='text-align: center; font-size:18px; color:gray;'>
12
+ Consider data about individuals who either paid their loans (green) or defaulted (red). <br>
13
+ Which model out of the two will you choose to give loan applications? <br><br>
14
  </div>
15
  """,
16
  unsafe_allow_html=True
 
19
  st.markdown("---")
20
 
21
  # Generate synthetic data
22
+ # np.random.seed(42)
23
+ # n_points = 100
24
+ # income = np.random.normal(50, 15, n_points)
25
+ # credit = np.random.normal(50, 15, n_points)
26
+ # labels = (income + credit > 100).astype(int) # 1 = paid back, 0 = default
27
+
28
+ income = [80, 85, 97, 91, 78, 102, 84, 88, 45, 51, 34, 47, 38, 39, 97, 91, 38, 32]
29
+ credit = [800, 805, 970, 910, 708, 1020, 804, 880, 450, 501, 304, 470, 380, 309, 370, 301, 1080, 902]
30
+ labels = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1]
31
 
32
  colors = ['green' if label == 1 else 'red' for label in labels]
33
 
34
  # Function to plot scatter
35
  def plot_scatter(x, y, colors, title="", decision_boundary=None, boundary_type=None, highlight_point=None):
36
+ fig, ax = plt.subplots(figsize=(2, 2))
37
  ax.scatter(x, y, c=colors, alpha=0.6)
38
  ax.set_xlabel("Annual Income")
39
  ax.set_ylabel("Credit Score")