selva1909 commited on
Commit
aaf1da0
·
verified ·
1 Parent(s): b469be3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +9 -16
src/streamlit_app.py CHANGED
@@ -7,19 +7,23 @@ from sklearn.metrics import mean_squared_error
7
 
8
  st.set_page_config(page_title="Linear Regression Playground", layout="centered")
9
 
10
- # FIX: make equation always fully visible + box formatting
11
  st.markdown("""
12
  <style>
13
  .eq-box {
14
- border: 2px solid #333;
15
  border-radius: 8px;
16
- background: #ffffff;
17
  padding: 14px;
18
  width: 100%;
19
  font-size: 22px;
 
20
  text-align: center;
21
  margin-top: 14px;
22
  }
 
 
 
23
  </style>
24
  """, unsafe_allow_html=True)
25
 
@@ -49,7 +53,6 @@ if mode != st.session_state.current_mode:
49
  st.session_state.trained = False
50
  st.session_state.current_mode = mode
51
 
52
-
53
  # ------------------------------------
54
  # Generate dataset
55
  # ------------------------------------
@@ -87,7 +90,6 @@ if train_btn:
87
 
88
  st.session_state.trained = True
89
 
90
-
91
  # ------------------------------------
92
  # Visualization
93
  # ------------------------------------
@@ -112,11 +114,9 @@ if st.session_state.trained:
112
  with col2:
113
  st.metric("MSE", f"{mse:.4f}")
114
 
115
- # FIXED FULL EQUATION BOX
116
  equation = rf"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}"
117
  st.markdown(f"<div class='eq-box'>${equation}$</div>", unsafe_allow_html=True)
118
 
119
-
120
  # ----------------- 3D Regression -----------------
121
  else:
122
  X1, X2, Z, Z_pred, mse, model = st.session_state.data
@@ -124,24 +124,19 @@ if st.session_state.trained:
124
  col1, col2 = st.columns([2, 1])
125
 
126
  with col1:
127
-
128
  if not rotate_3d:
129
  fig = plt.figure(figsize=(4.5, 4))
130
  ax = fig.add_subplot(111, projection="3d")
131
 
132
  idx = np.random.choice(len(Z.ravel()), min(350, len(Z.ravel())), replace=False)
133
- ax.scatter(
134
- X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
135
- color="orange", alpha=0.25, s=8
136
- )
137
 
138
  ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
139
  ax.set_title("3D Linear Regression")
140
  st.pyplot(fig, clear_figure=True)
141
-
142
  else:
143
  placeholder = st.empty()
144
-
145
  for angle in range(0, 360, 5):
146
  fig = plt.figure(figsize=(4.5, 4))
147
  ax = fig.add_subplot(111, projection="3d")
@@ -165,8 +160,6 @@ if st.session_state.trained:
165
  c = model.intercept_
166
 
167
  equation3d = rf"z = {a:.3f}x_1 + {b:.3f}x_2 + {c:.3f}"
168
-
169
- # FIXED FULL EQUATION BOX
170
  st.markdown(f"<div class='eq-box'>${equation3d}$</div>", unsafe_allow_html=True)
171
 
172
  else:
 
7
 
8
  st.set_page_config(page_title="Linear Regression Playground", layout="centered")
9
 
10
+ # === FIX: fully visible equation with dark box ===
11
  st.markdown("""
12
  <style>
13
  .eq-box {
14
+ border: 2px solid #444;
15
  border-radius: 8px;
16
+ background: #222; /* DARK background */
17
  padding: 14px;
18
  width: 100%;
19
  font-size: 22px;
20
+ color: white !important; /* WHITE text */
21
  text-align: center;
22
  margin-top: 14px;
23
  }
24
+ .mathjax-chtml, .MathJax {
25
+ color: white !important; /* Force formula text white */
26
+ }
27
  </style>
28
  """, unsafe_allow_html=True)
29
 
 
53
  st.session_state.trained = False
54
  st.session_state.current_mode = mode
55
 
 
56
  # ------------------------------------
57
  # Generate dataset
58
  # ------------------------------------
 
90
 
91
  st.session_state.trained = True
92
 
 
93
  # ------------------------------------
94
  # Visualization
95
  # ------------------------------------
 
114
  with col2:
115
  st.metric("MSE", f"{mse:.4f}")
116
 
 
117
  equation = rf"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}"
118
  st.markdown(f"<div class='eq-box'>${equation}$</div>", unsafe_allow_html=True)
119
 
 
120
  # ----------------- 3D Regression -----------------
121
  else:
122
  X1, X2, Z, Z_pred, mse, model = st.session_state.data
 
124
  col1, col2 = st.columns([2, 1])
125
 
126
  with col1:
 
127
  if not rotate_3d:
128
  fig = plt.figure(figsize=(4.5, 4))
129
  ax = fig.add_subplot(111, projection="3d")
130
 
131
  idx = np.random.choice(len(Z.ravel()), min(350, len(Z.ravel())), replace=False)
132
+ ax.scatter(X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
133
+ color="orange", alpha=0.25, s=8)
 
 
134
 
135
  ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
136
  ax.set_title("3D Linear Regression")
137
  st.pyplot(fig, clear_figure=True)
 
138
  else:
139
  placeholder = st.empty()
 
140
  for angle in range(0, 360, 5):
141
  fig = plt.figure(figsize=(4.5, 4))
142
  ax = fig.add_subplot(111, projection="3d")
 
160
  c = model.intercept_
161
 
162
  equation3d = rf"z = {a:.3f}x_1 + {b:.3f}x_2 + {c:.3f}"
 
 
163
  st.markdown(f"<div class='eq-box'>${equation3d}$</div>", unsafe_allow_html=True)
164
 
165
  else: