trohith89 commited on
Commit
4beaaa3
·
verified ·
1 Parent(s): f8cfb29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -246,7 +246,7 @@ with col_center:
246
  st.session_state.hidden_layer_neurons.pop()
247
 
248
  def increase_neurons(i):
249
- if st.session_state.hidden_layer_neurons[i] < 8:
250
  st.session_state.hidden_layer_neurons[i] += 1
251
 
252
  def decrease_neurons(i):
@@ -258,7 +258,7 @@ with col_center:
258
  with col1:
259
  st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,))
260
  with col2:
261
- st.write(f"Layer {i+1}: {st.session_state.hidden_layer_neurons[i]}")
262
  with col3:
263
  st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
264
  col_btn1, col_btn2 = st.columns(2)
@@ -298,6 +298,10 @@ with col_right:
298
  self.X, self.y = X, y
299
  self.losses = {"Epoch": [], "Train Loss": [], "Val Loss": []}
300
  self.placeholder = st.empty()
 
 
 
 
301
 
302
  def on_epoch_end(self, epoch, logs=None):
303
  self.losses["Epoch"].append(epoch + 1)
@@ -308,13 +312,14 @@ with col_right:
308
  with col1:
309
  plt.figure(figsize=(3, 3))
310
  if problem_type == "Classification":
311
- # Ensure we use only the first two features for 2D plotting
312
- X_2d = self.X[:, :2] # Use only X1 and X2 for decision boundary
313
  y_pred_proba = self.model.predict(X_2d, verbose=0)
314
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
315
  try:
316
  plot_decision_regions(X_2d, self.y, clf=self.model, legend=2, colors='blue,red')
317
  plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap='coolwarm', edgecolors='k', alpha=0.7)
 
318
  except Exception as e:
319
  st.warning(f"Decision region plot failed: {e}")
320
  xx, yy = np.meshgrid(np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), 100),
@@ -328,7 +333,7 @@ with col_right:
328
  else:
329
  y_pred = self.model.predict(self.X, verbose=0)
330
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
331
- plt.plot(self.X[:, 0], y_pred, "r-", linewidth=2)
332
  plt.gca().set_facecolor("#333")
333
  plt.xticks([])
334
  plt.yticks([])
 
246
  st.session_state.hidden_layer_neurons.pop()
247
 
248
  def increase_neurons(i):
249
+ if st.session_state.num_hidden_layers[i] < 8:
250
  st.session_state.hidden_layer_neurons[i] += 1
251
 
252
  def decrease_neurons(i):
 
258
  with col1:
259
  st.button("−", key=f"dec_{i}", on_click=decrease_neurons, args=(i,))
260
  with col2:
261
+ st.write(f"Layer {i+1}: {st.session_state.hidden_layer_neurons[i]} neurons")
262
  with col3:
263
  st.button("+", key=f"inc_{i}", on_click=increase_neurons, args=(i,))
264
  col_btn1, col_btn2 = st.columns(2)
 
298
  self.X, self.y = X, y
299
  self.losses = {"Epoch": [], "Train Loss": [], "Val Loss": []}
300
  self.placeholder = st.empty()
301
+ self.model = None # Store the model instance
302
+
303
+ def on_train_begin(self, logs=None):
304
+ self.model = self.model # Ensure model is stored
305
 
306
  def on_epoch_end(self, epoch, logs=None):
307
  self.losses["Epoch"].append(epoch + 1)
 
312
  with col1:
313
  plt.figure(figsize=(3, 3))
314
  if problem_type == "Classification":
315
+ # Use only the first two features for 2D plotting
316
+ X_2d = self.X[:, :2] # Ensure 2D input
317
  y_pred_proba = self.model.predict(X_2d, verbose=0)
318
  y_pred = (y_pred_proba > 0.5).astype(int).ravel()
319
  try:
320
  plot_decision_regions(X_2d, self.y, clf=self.model, legend=2, colors='blue,red')
321
  plt.scatter(X_2d[:, 0], X_2d[:, 1], c=self.y, cmap='coolwarm', edgecolors='k', alpha=0.7)
322
+ plt.contour(X_2d[:, 0], X_2d[:, 1], y_pred.reshape(X_2d.shape[0], X_2d.shape[1]), levels=[0.5], colors='black', linewidths=2)
323
  except Exception as e:
324
  st.warning(f"Decision region plot failed: {e}")
325
  xx, yy = np.meshgrid(np.linspace(X_2d[:, 0].min(), X_2d[:, 0].max(), 100),
 
333
  else:
334
  y_pred = self.model.predict(self.X, verbose=0)
335
  plt.scatter(self.X[:, 0], self.y, c="blue", alpha=0.5)
336
+ plt.plot(self.X[:, 0], y_pred, "r-", linewidths=2)
337
  plt.gca().set_facecolor("#333")
338
  plt.xticks([])
339
  plt.yticks([])