selva1909's picture
Update src/streamlit_app.py
aaf1da0 verified
import streamlit as st
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
st.set_page_config(page_title="Linear Regression Playground", layout="centered")
# === FIX: fully visible equation with dark box ===
st.markdown("""
<style>
.eq-box {
border: 2px solid #444;
border-radius: 8px;
background: #222; /* DARK background */
padding: 14px;
width: 100%;
font-size: 22px;
color: white !important; /* WHITE text */
text-align: center;
margin-top: 14px;
}
.mathjax-chtml, .MathJax {
color: white !important; /* Force formula text white */
}
</style>
""", unsafe_allow_html=True)
st.title("πŸ“‰ Linear Regression Playground (2D & 3D)")
st.write("Experiment with regression, noise, slope, intercept β€” and visualize the model!")
# ------------------------------------
# Sidebar Controls
# ------------------------------------
st.sidebar.header("βš™οΈ Controls")
mode = st.sidebar.radio("Choose Mode", ["2D Regression", "3D Regression"])
num_points = st.sidebar.slider("Number of Data Points", 20, 500, 100)
noise_level = st.sidebar.slider("Noise Level", 0.0, 5.0, 1.0)
rotate_3d = False
if mode == "3D Regression":
rotate_3d = st.sidebar.toggle("πŸ”„ Rotate 3D Model", value=False)
train_btn = st.sidebar.button("Generate & Train Model")
if "trained" not in st.session_state:
st.session_state.trained = False
st.session_state.current_mode = None
if mode != st.session_state.current_mode:
st.session_state.trained = False
st.session_state.current_mode = mode
# ------------------------------------
# Generate dataset
# ------------------------------------
if train_btn:
with st.spinner("⏳ Training model..."):
time.sleep(0.5)
if mode == "2D Regression":
X = np.linspace(0, 10, num_points).reshape(-1, 1)
y = 2.5 * X.flatten() + 5 + np.random.randn(num_points) * noise_level
model = LinearRegression().fit(X, y)
y_pred = model.predict(X)
mse = mean_squared_error(y, y_pred)
st.session_state.data = (X, y, y_pred, mse, model)
else:
x1 = np.linspace(0, 10, num_points)
x2 = np.linspace(0, 10, num_points)
X1, X2 = np.meshgrid(x1, x2)
noise = np.random.randn(num_points, num_points) * noise_level
Z = 3 * X1 + 2 * X2 + 10 + noise
X_flat = np.column_stack((X1.ravel(), X2.ravel()))
Z_flat = Z.ravel()
model = LinearRegression().fit(X_flat, Z_flat)
Z_pred = model.predict(X_flat).reshape(num_points, num_points)
mse = mean_squared_error(Z_flat, Z_pred.ravel())
st.session_state.data = (X1, X2, Z, Z_pred, mse, model)
st.session_state.trained = True
# ------------------------------------
# Visualization
# ------------------------------------
if st.session_state.trained:
st.success("πŸŽ‰ Model trained successfully!")
# ----------------- 2D Regression -----------------
if mode == "2D Regression":
X, y, y_pred, mse, model = st.session_state.data
col1, col2 = st.columns([2, 1])
with col1:
fig, ax = plt.subplots(figsize=(4.5, 4))
ax.scatter(X, y, color="orange", label="Data", s=18)
ax.plot(X, y_pred, color="blue", linewidth=2, label="Regression Line")
ax.set_title("2D Linear Regression")
ax.legend()
st.pyplot(fig, clear_figure=True)
with col2:
st.metric("MSE", f"{mse:.4f}")
equation = rf"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}"
st.markdown(f"<div class='eq-box'>${equation}$</div>", unsafe_allow_html=True)
# ----------------- 3D Regression -----------------
else:
X1, X2, Z, Z_pred, mse, model = st.session_state.data
col1, col2 = st.columns([2, 1])
with col1:
if not rotate_3d:
fig = plt.figure(figsize=(4.5, 4))
ax = fig.add_subplot(111, projection="3d")
idx = np.random.choice(len(Z.ravel()), min(350, len(Z.ravel())), replace=False)
ax.scatter(X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
color="orange", alpha=0.25, s=8)
ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
ax.set_title("3D Linear Regression")
st.pyplot(fig, clear_figure=True)
else:
placeholder = st.empty()
for angle in range(0, 360, 5):
fig = plt.figure(figsize=(4.5, 4))
ax = fig.add_subplot(111, projection="3d")
idx = np.random.choice(len(Z.ravel()), min(300, len(Z.ravel())), replace=False)
ax.scatter(X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
alpha=0.2, color="orange", s=6)
ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
ax.view_init(elev=25, azim=angle)
ax.set_title("πŸ”„ Rotating 3D Regression Model")
placeholder.pyplot(fig, clear_figure=True)
time.sleep(0.07)
with col2:
st.metric("MSE", f"{mse:.4f}")
a = model.coef_[0]
b = model.coef_[1]
c = model.intercept_
equation3d = rf"z = {a:.3f}x_1 + {b:.3f}x_2 + {c:.3f}"
st.markdown(f"<div class='eq-box'>${equation3d}$</div>", unsafe_allow_html=True)
else:
st.info("Click **Generate & Train Model** to begin.")