Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -186,11 +186,11 @@ if problem_type == "Classification":
|
|
| 186 |
# Main layout
|
| 187 |
col_left, col_center, col_right = st.columns([1, 2, 1])
|
| 188 |
|
| 189 |
-
# Left panel: Dataset with Seaborn
|
| 190 |
with col_left:
|
| 191 |
st.markdown('<div class="panel">', unsafe_allow_html=True)
|
| 192 |
st.subheader("Data")
|
| 193 |
-
fig, ax = plt.subplots(figsize=(3, 3))
|
| 194 |
if problem_type == "Classification":
|
| 195 |
sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
|
| 196 |
else:
|
|
@@ -291,7 +291,7 @@ with col_center:
|
|
| 291 |
st.button("Remove Layer", on_click=remove_layer)
|
| 292 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 293 |
|
| 294 |
-
# Right panel: Output and Training
|
| 295 |
with col_right:
|
| 296 |
st.markdown('<div class="panel">', unsafe_allow_html=True)
|
| 297 |
st.subheader("Output")
|
|
@@ -321,7 +321,6 @@ with col_right:
|
|
| 321 |
self.X, self.y = X, y
|
| 322 |
self.losses = {"Epoch": [], "Train Loss": [], "Val Loss": []}
|
| 323 |
self.placeholder = st.empty()
|
| 324 |
-
self.model = None
|
| 325 |
|
| 326 |
def on_train_begin(self, logs=None):
|
| 327 |
self.model = self.model # Use the model passed implicitly by Keras
|
|
@@ -332,47 +331,49 @@ with col_right:
|
|
| 332 |
self.losses["Train Loss"].append(logs["loss"])
|
| 333 |
self.losses["Val Loss"].append(logs.get("val_loss", logs["loss"]))
|
| 334 |
with self.placeholder.container():
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
plt.
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
| 376 |
except Exception as e:
|
| 377 |
st.error(f"Error in epoch end: {e}")
|
| 378 |
|
|
|
|
| 186 |
# Main layout
|
| 187 |
col_left, col_center, col_right = st.columns([1, 2, 1])
|
| 188 |
|
| 189 |
+
# Left panel: Dataset with Seaborn (3x3 size)
|
| 190 |
with col_left:
|
| 191 |
st.markdown('<div class="panel">', unsafe_allow_html=True)
|
| 192 |
st.subheader("Data")
|
| 193 |
+
fig, ax = plt.subplots(figsize=(3, 3)) # Fixed size for consistency
|
| 194 |
if problem_type == "Classification":
|
| 195 |
sns.scatterplot(x=fv[:, 0], y=fv[:, 1], hue=cv, palette="coolwarm", edgecolor="k", alpha=0.7, ax=ax, legend=False)
|
| 196 |
else:
|
|
|
|
| 291 |
st.button("Remove Layer", on_click=remove_layer)
|
| 292 |
st.markdown('</div>', unsafe_allow_html=True)
|
| 293 |
|
| 294 |
+
# Right panel: Output and Training (decision region and loss plots stacked vertically, same size as dataset scatterplot)
|
| 295 |
with col_right:
|
| 296 |
st.markdown('<div class="panel">', unsafe_allow_html=True)
|
| 297 |
st.subheader("Output")
|
|
|
|
| 321 |
self.X, self.y = X, y
|
| 322 |
self.losses = {"Epoch": [], "Train Loss": [], "Val Loss": []}
|
| 323 |
self.placeholder = st.empty()
|
|
|
|
| 324 |
|
| 325 |
def on_train_begin(self, logs=None):
|
| 326 |
self.model = self.model # Use the model passed implicitly by Keras
|
|
|
|
| 331 |
self.losses["Train Loss"].append(logs["loss"])
|
| 332 |
self.losses["Val Loss"].append(logs.get("val_loss", logs["loss"]))
|
| 333 |
with self.placeholder.container():
|
| 334 |
+
# Single column for vertical stacking
|
| 335 |
+
st.subheader("Decision Region & Loss")
|
| 336 |
+
# Decision region plot (3x3 size)
|
| 337 |
+
fig1, ax1 = plt.subplots(figsize=(3, 3)) # Match dataset scatterplot size
|
| 338 |
+
if problem_type == "Classification":
|
| 339 |
+
X_2d = self.X[:, :2] # Use only first two features for 2D
|
| 340 |
+
y_pred_proba = self.model.predict(X_2d, verbose=0) if self.model else np.zeros((len(X_2d), 1))
|
| 341 |
+
y_pred = (y_pred_proba > 0.5).astype(int).ravel()
|
| 342 |
+
try:
|
| 343 |
+
plot_decision_regions(X_2d, self.y, clf=self.model, legend=2, colors='blue,red')
|
| 344 |
+
plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap='coolwarm', edgecolors='k', alpha=0.7)
|
| 345 |
+
xx, yy = np.meshgrid(np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), 100),
|
| 346 |
+
np.linspace(X_2d[:, 1].min(), X_2d[:, 1].max(), 100))
|
| 347 |
+
grid = np.c_[xx.ravel(), yy.ravel()]
|
| 348 |
+
Z = self.model.predict(grid, verbose=0) if self.model else np.zeros((len(grid), 1))
|
| 349 |
+
Z = (Z > 0.5).astype(int).reshape(xx.shape)
|
| 350 |
+
plt.contour(xx, yy, Z, levels=[0.5], colors='black', linewidths=2)
|
| 351 |
+
except Exception as e:
|
| 352 |
+
st.warning(f"Decision region plot failed: {e}")
|
| 353 |
+
xx, yy = np.meshgrid(np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), 100),
|
| 354 |
+
np.linspace(X_2d[:, 1].min(), X_2d[:, 1].max(), 100))
|
| 355 |
+
grid = np.c_[xx.ravel(), yy.ravel()]
|
| 356 |
+
Z = self.model.predict(grid, verbose=0) if self.model else np.zeros((len(grid), 1))
|
| 357 |
+
Z = (Z > 0.5).astype(int).reshape(xx.shape)
|
| 358 |
+
plt.contour(xx, yy, Z, levels=[0.5], colors='black', linewidths=2)
|
| 359 |
+
plt.contourf(xx, yy, Z, alpha=0.3, cmap="coolwarm")
|
| 360 |
+
plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap="coolwarm", edgecolors="k", alpha=0.7)
|
| 361 |
+
else:
|
| 362 |
+
y_pred = self.model.predict(self.X, verbose=0) if self.model else np.zeros_like(self.X[:, 0])
|
| 363 |
+
plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
|
| 364 |
+
plt.plot(self.X[:, 0], y_pred, "r-", linewidths=2)
|
| 365 |
+
ax1.set_facecolor("#333")
|
| 366 |
+
ax1.set_xticks([])
|
| 367 |
+
ax1.set_yticks([])
|
| 368 |
+
st.pyplot(fig1)
|
| 369 |
+
|
| 370 |
+
# Train-Val-Loss plot (3x3 size)
|
| 371 |
+
fig2, ax2 = plt.subplots(figsize=(3, 3)) # Match dataset scatterplot size
|
| 372 |
+
ax2.plot(self.losses["Epoch"], self.losses["Train Loss"], "b-", label="Train")
|
| 373 |
+
ax2.plot(self.losses["Epoch"], self.losses["Val Loss"], "r--", label="Val")
|
| 374 |
+
ax2.legend()
|
| 375 |
+
ax2.set_facecolor("#333")
|
| 376 |
+
st.pyplot(fig2)
|
| 377 |
except Exception as e:
|
| 378 |
st.error(f"Error in epoch end: {e}")
|
| 379 |
|