selva1909 commited on
Commit
057dd64
Β·
verified Β·
1 Parent(s): c4b59e4

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -25
src/streamlit_app.py CHANGED
@@ -55,21 +55,18 @@ if train_btn:
55
  st.session_state.data = (X, y, y_pred, mse, model)
56
 
57
  else:
58
- # Lower resolution grid β†’ MUCH faster rendering
59
- grid_n = int(num_points / 3)
60
-
61
- x1 = np.linspace(0, 10, grid_n)
62
- x2 = np.linspace(0, 10, grid_n)
63
  X1, X2 = np.meshgrid(x1, x2)
64
 
65
- noise = np.random.randn(grid_n, grid_n) * noise_level
66
  Z = 3 * X1 + 2 * X2 + 10 + noise
67
 
68
  X_flat = np.column_stack((X1.ravel(), X2.ravel()))
69
  Z_flat = Z.ravel()
70
 
71
  model = LinearRegression().fit(X_flat, Z_flat)
72
- Z_pred = model.predict(X_flat).reshape(grid_n, grid_n)
73
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
74
 
75
  st.session_state.data = (X1, X2, Z, Z_pred, mse, model)
@@ -84,14 +81,14 @@ if st.session_state.trained:
84
 
85
  st.success("πŸŽ‰ Model trained successfully!")
86
 
87
- # ----------------- 2D -----------------
88
  if mode == "2D Regression":
89
  X, y, y_pred, mse, model = st.session_state.data
90
 
91
  col1, col2 = st.columns([2, 1])
92
 
93
  with col1:
94
- fig, ax = plt.subplots(figsize=(4, 3.5))
95
  ax.scatter(X, y, color="orange", label="Data", s=18)
96
  ax.plot(X, y_pred, color="blue", linewidth=2, label="Regression Line")
97
  ax.set_title("2D Linear Regression")
@@ -102,7 +99,7 @@ if st.session_state.trained:
102
  st.metric("MSE", f"{mse:.4f}")
103
  st.code(f"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}")
104
 
105
- # ----------------- 3D -----------------
106
  else:
107
  X1, X2, Z, Z_pred, mse, model = st.session_state.data
108
 
@@ -112,36 +109,40 @@ if st.session_state.trained:
112
 
113
  # Static 3D plot
114
  if not rotate_3d:
115
- fig = plt.figure(figsize=(4, 3.5))
116
  ax = fig.add_subplot(111, projection="3d")
117
 
118
- idx = np.random.choice(len(Z.ravel()), min(200, len(Z.ravel())), replace=False)
119
- ax.scatter(X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
120
- color="orange", alpha=0.22, s=6)
 
 
121
 
122
- ax.plot_surface(X1, X2, Z_pred, alpha=0.7, color="blue")
123
  ax.set_title("3D Linear Regression")
124
  st.pyplot(fig, clear_figure=True)
125
 
126
- # Smooth rotation animation (PC + HF optimized)
127
  else:
128
  placeholder = st.empty()
129
 
130
- for angle in range(0, 360, 6): # fewer steps = smoother on PC
131
- fig = plt.figure(figsize=(4, 3.5))
132
  ax = fig.add_subplot(111, projection="3d")
133
 
134
- idx = np.random.choice(len(Z.ravel()), min(150, len(Z.ravel())), replace=False)
135
- ax.scatter(X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
136
- alpha=0.2, color="orange", s=5)
137
-
138
- ax.plot_surface(X1, X2, Z_pred, alpha=0.70, color="blue")
139
 
 
140
  ax.view_init(elev=25, azim=angle)
141
- ax.set_title("Rotating 3D Regression Model")
 
142
 
143
  placeholder.pyplot(fig, clear_figure=True)
144
- time.sleep(0.05) # PC-friendly speed
145
 
146
  with col2:
147
  st.metric("MSE", f"{mse:.4f}")
 
55
  st.session_state.data = (X, y, y_pred, mse, model)
56
 
57
  else:
58
+ x1 = np.linspace(0, 10, num_points)
59
+ x2 = np.linspace(0, 10, num_points)
 
 
 
60
  X1, X2 = np.meshgrid(x1, x2)
61
 
62
+ noise = np.random.randn(num_points, num_points) * noise_level
63
  Z = 3 * X1 + 2 * X2 + 10 + noise
64
 
65
  X_flat = np.column_stack((X1.ravel(), X2.ravel()))
66
  Z_flat = Z.ravel()
67
 
68
  model = LinearRegression().fit(X_flat, Z_flat)
69
+ Z_pred = model.predict(X_flat).reshape(num_points, num_points)
70
  mse = mean_squared_error(Z_flat, Z_pred.ravel())
71
 
72
  st.session_state.data = (X1, X2, Z, Z_pred, mse, model)
 
81
 
82
  st.success("πŸŽ‰ Model trained successfully!")
83
 
84
+ # ----------------- 2D Regression -----------------
85
  if mode == "2D Regression":
86
  X, y, y_pred, mse, model = st.session_state.data
87
 
88
  col1, col2 = st.columns([2, 1])
89
 
90
  with col1:
91
+ fig, ax = plt.subplots(figsize=(4.5, 4))
92
  ax.scatter(X, y, color="orange", label="Data", s=18)
93
  ax.plot(X, y_pred, color="blue", linewidth=2, label="Regression Line")
94
  ax.set_title("2D Linear Regression")
 
99
  st.metric("MSE", f"{mse:.4f}")
100
  st.code(f"y = {model.coef_[0]:.3f}x + {model.intercept_:.3f}")
101
 
102
+ # ----------------- 3D Regression -----------------
103
  else:
104
  X1, X2, Z, Z_pred, mse, model = st.session_state.data
105
 
 
109
 
110
  # Static 3D plot
111
  if not rotate_3d:
112
+ fig = plt.figure(figsize=(4.5, 4))
113
  ax = fig.add_subplot(111, projection="3d")
114
 
115
+ idx = np.random.choice(len(Z.ravel()), min(350, len(Z.ravel())), replace=False)
116
+ ax.scatter(
117
+ X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
118
+ color="orange", alpha=0.25, s=8
119
+ )
120
 
121
+ ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
122
  ax.set_title("3D Linear Regression")
123
  st.pyplot(fig, clear_figure=True)
124
 
125
+ # Rotating 3D animation (HuggingFace-friendly)
126
  else:
127
  placeholder = st.empty()
128
 
129
+ for angle in range(0, 360, 5):
130
+ fig = plt.figure(figsize=(4.5, 4))
131
  ax = fig.add_subplot(111, projection="3d")
132
 
133
+ idx = np.random.choice(len(Z.ravel()), min(300, len(Z.ravel())), replace=False)
134
+ ax.scatter(
135
+ X1.ravel()[idx], X2.ravel()[idx], Z.ravel()[idx],
136
+ alpha=0.2, color="orange", s=6
137
+ )
138
 
139
+ ax.plot_surface(X1, X2, Z_pred, alpha=0.75, color="blue")
140
  ax.view_init(elev=25, azim=angle)
141
+
142
+ ax.set_title("πŸ”„ Rotating 3D Regression Model")
143
 
144
  placeholder.pyplot(fig, clear_figure=True)
145
+ time.sleep(0.07)
146
 
147
  with col2:
148
  st.metric("MSE", f"{mse:.4f}")