Update src/streamlit_app.py
Browse files- src/streamlit_app.py +26 -25
src/streamlit_app.py
CHANGED
|
@@ -55,21 +55,18 @@ if train_btn:
|
|
| 55 |
st.session_state.data = (X, y, y_pred, mse, model)
|
| 56 |
|
| 57 |
else:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
x1 = np.linspace(0, 10, grid_n)
|
| 62 |
-
x2 = np.linspace(0, 10, grid_n)
|
| 63 |
X1, X2 = np.meshgrid(x1, x2)
|
| 64 |
|
| 65 |
-
noise = np.random.randn(
|
| 66 |
Z = 3 * X1 + 2 * X2 + 10 + noise
|
| 67 |
|
| 68 |
X_flat = np.column_stack((X1.ravel(), X2.ravel()))
|
| 69 |
Z_flat = Z.ravel()
|
| 70 |
|
| 71 |
model = LinearRegression().fit(X_flat, Z_flat)
|
| 72 |
-
Z_pred = model.predict(X_flat).reshape(
|
| 73 |
mse = mean_squared_error(Z_flat, Z_pred.ravel())
|
| 74 |
|
| 75 |
st.session_state.data = (X1, X2, Z, Z_pred, mse, model)
|
|
@@ -84,14 +81,14 @@ if st.session_state.trained:
|
|
| 84 |
|
| 85 |
st.success("π Model trained successfully!")
|
| 86 |
|
| 87 |
-
# ----------------- 2D -----------------
|
| 88 |
if mode == "2D Regression":
|
| 89 |
X, y, y_pred, mse, model = st.session_state.data
|
| 90 |
|
| 91 |
col1, col2 = st.columns([2, 1])
|
| 92 |
|
| 93 |
with col1:
|
| 94 |
-
fig, ax = plt.subplots(figsize=(4,
|
| 95 |
ax.scatter(X, y, color="orange", label="Data", s=18)
|
| 96 |
ax.plot(X, y_pred, color="blue", linewidth=2, label="Regression Line")
|
| 97 |
ax.set_title("2D Linear Regression")
|
|
@@ -102,7 +99,7 @@ if st.session_state.trained:
|
|
| 102 |
st.metric("MSE", f"{mse:.4f}")
|
| 103 |
st.code(f"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}")
|
| 104 |
|
| 105 |
-
# ----------------- 3D -----------------
|
| 106 |
else:
|
| 107 |
X1, X2, Z, Z_pred, mse, model = st.session_state.data
|
| 108 |
|
|
@@ -112,36 +109,40 @@ if st.session_state.trained:
|
|
| 112 |
|
| 113 |
# Static 3D plot
|
| 114 |
if not rotate_3d:
|
| 115 |
-
fig = plt.figure(figsize=(4,
|
| 116 |
ax = fig.add_subplot(111, projection="3d")
|
| 117 |
|
| 118 |
-
idx = np.random.choice(len(Z.ravel()), min(
|
| 119 |
-
ax.scatter(
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
|
| 122 |
-
ax.plot_surface(X1, X2, Z_pred, alpha=0.
|
| 123 |
ax.set_title("3D Linear Regression")
|
| 124 |
st.pyplot(fig, clear_figure=True)
|
| 125 |
|
| 126 |
-
#
|
| 127 |
else:
|
| 128 |
placeholder = st.empty()
|
| 129 |
|
| 130 |
-
for angle in range(0, 360,
|
| 131 |
-
fig = plt.figure(figsize=(4,
|
| 132 |
ax = fig.add_subplot(111, projection="3d")
|
| 133 |
|
| 134 |
-
idx = np.random.choice(len(Z.ravel()), min(
|
| 135 |
-
ax.scatter(
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
|
|
|
|
| 140 |
ax.view_init(elev=25, azim=angle)
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
placeholder.pyplot(fig, clear_figure=True)
|
| 144 |
-
time.sleep(0.
|
| 145 |
|
| 146 |
with col2:
|
| 147 |
st.metric("MSE", f"{mse:.4f}")
|
|
|
|
| 55 |
st.session_state.data = (X, y, y_pred, mse, model)
|
| 56 |
|
| 57 |
else:
|
| 58 |
+
x1 = np.linspace(0, 10, num_points)
|
| 59 |
+
x2 = np.linspace(0, 10, num_points)
|
|
|
|
|
|
|
|
|
|
| 60 |
X1, X2 = np.meshgrid(x1, x2)
|
| 61 |
|
| 62 |
+
noise = np.random.randn(num_points, num_points) * noise_level
|
| 63 |
Z = 3 * X1 + 2 * X2 + 10 + noise
|
| 64 |
|
| 65 |
X_flat = np.column_stack((X1.ravel(), X2.ravel()))
|
| 66 |
Z_flat = Z.ravel()
|
| 67 |
|
| 68 |
model = LinearRegression().fit(X_flat, Z_flat)
|
| 69 |
+
Z_pred = model.predict(X_flat).reshape(num_points, num_points)
|
| 70 |
mse = mean_squared_error(Z_flat, Z_pred.ravel())
|
| 71 |
|
| 72 |
st.session_state.data = (X1, X2, Z, Z_pred, mse, model)
|
|
|
|
| 81 |
|
| 82 |
st.success("π Model trained successfully!")
|
| 83 |
|
| 84 |
+
# ----------------- 2D Regression -----------------
|
| 85 |
if mode == "2D Regression":
|
| 86 |
X, y, y_pred, mse, model = st.session_state.data
|
| 87 |
|
| 88 |
col1, col2 = st.columns([2, 1])
|
| 89 |
|
| 90 |
with col1:
|
| 91 |
+
fig, ax = plt.subplots(figsize=(4.5, 4))
|
| 92 |
ax.scatter(X, y, color="orange", label="Data", s=18)
|
| 93 |
ax.plot(X, y_pred, color="blue", linewidth=2, label="Regression Line")
|
| 94 |
ax.set_title("2D Linear Regression")
|
|
|
|
| 99 |
st.metric("MSE", f"{mse:.4f}")
|
| 100 |
st.code(f"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}")
|
| 101 |
|
| 102 |
+
# ----------------- 3D Regression -----------------
|
| 103 |
else:
|
| 104 |
X1, X2, Z, Z_pred, mse, model = st.session_state.data
|
| 105 |
|
|
|
|
| 109 |
|
| 110 |
# Static 3D plot
|
| 111 |
if not rotate_3d:
|
| 112 |
+
fig = plt.figure(figsize=(4.5, 4))
|
| 113 |
ax = fig.add_subplot(111, projection="3d")
|
| 114 |
|
| 115 |
+
idx = np.random.choice(len(Z.ravel()), min(350, len(Z.ravel())), replace=False)
|
| 116 |
+
ax.scatter(
|
| 117 |
+
X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
|
| 118 |
+
color="orange", alpha=0.25, s=8
|
| 119 |
+
)
|
| 120 |
|
| 121 |
+
ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
|
| 122 |
ax.set_title("3D Linear Regression")
|
| 123 |
st.pyplot(fig, clear_figure=True)
|
| 124 |
|
| 125 |
+
# Rotating 3D animation (HuggingFace-friendly)
|
| 126 |
else:
|
| 127 |
placeholder = st.empty()
|
| 128 |
|
| 129 |
+
for angle in range(0, 360, 5):
|
| 130 |
+
fig = plt.figure(figsize=(4.5, 4))
|
| 131 |
ax = fig.add_subplot(111, projection="3d")
|
| 132 |
|
| 133 |
+
idx = np.random.choice(len(Z.ravel()), min(300, len(Z.ravel())), replace=False)
|
| 134 |
+
ax.scatter(
|
| 135 |
+
X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
|
| 136 |
+
alpha=0.2, color="orange", s=6
|
| 137 |
+
)
|
| 138 |
|
| 139 |
+
ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
|
| 140 |
ax.view_init(elev=25, azim=angle)
|
| 141 |
+
|
| 142 |
+
ax.set_title("π Rotating 3D Regression Model")
|
| 143 |
|
| 144 |
placeholder.pyplot(fig, clear_figure=True)
|
| 145 |
+
time.sleep(0.07)
|
| 146 |
|
| 147 |
with col2:
|
| 148 |
st.metric("MSE", f"{mse:.4f}")
|