trohith89 commited on
Commit
607d095
verified
1 Parent(s): 0f8cbff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -6,10 +6,12 @@ import graphviz
6
  import time
7
  import tensorflow as tf
8
  from tensorflow import keras
9
- from tensorflow.keras import layers
 
10
  from sklearn.model_selection import train_test_split
11
  from sklearn.datasets import make_moons, make_circles, make_classification, make_blobs
12
  from mlxtend.plotting import plot_decision_regions
 
13
 
14
  # Set Streamlit page title
15
  st.set_page_config(page_title="Neural Network Trainer", layout="wide")
@@ -41,7 +43,7 @@ with col3:
41
  if st.button("鈴革笍 Pause"):
42
  st.session_state.running = False
43
  with col4:
44
- activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"])
45
  with col5:
46
  problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
47
  with col6:
@@ -87,19 +89,29 @@ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_
87
 
88
  # ========== TRAINING ANN ========== #
89
  def build_ann():
90
- model = keras.Sequential()
91
  model.add(layers.Input(shape=(X.shape[1],)))
92
  for units in neurons:
93
- model.add(layers.Dense(units, activation=activation.lower()))
94
- model.add(layers.Dense(1, activation="sigmoid" if problem_type == "Classification" else "linear"))
95
  model.compile(optimizer=keras.optimizers.Adam(learning_rate), loss="binary_crossentropy" if problem_type == "Classification" else "mse")
96
  return model
97
 
 
 
 
 
 
 
 
 
 
98
  if st.session_state.running:
99
  model = build_ann()
100
  history = model.fit(X_train, y_train, epochs=num_epochs, batch_size=batch_size, validation_data=(X_test, y_test), verbose=0)
101
  st.session_state.train_loss_history = history.history["loss"]
102
  st.session_state.test_loss_history = history.history["val_loss"]
 
103
 
104
  # ========== LOSS PLOT ========== #
105
  with col_plot:
@@ -116,7 +128,8 @@ with col_plot:
116
  # =================== DECISION REGION =================== #
117
  if problem_type == "Classification":
118
  fig, ax = plt.subplots()
119
- plot_decision_regions(X_train, y_train, clf=model, ax=ax)
 
120
  ax.set_title("Decision Region")
121
  st.pyplot(fig)
122
 
 
6
  import time
7
  import tensorflow as tf
8
  from tensorflow import keras
9
+ from tensorflow.keras import layers, Sequential
10
+ from tensorflow.keras.regularizers import l2
11
  from sklearn.model_selection import train_test_split
12
  from sklearn.datasets import make_moons, make_circles, make_classification, make_blobs
13
  from mlxtend.plotting import plot_decision_regions
14
+ from sklearn.base import BaseEstimator, ClassifierMixin
15
 
16
  # Set Streamlit page title
17
  st.set_page_config(page_title="Neural Network Trainer", layout="wide")
 
43
  if st.button("鈴革笍 Pause"):
44
  st.session_state.running = False
45
  with col4:
46
+ activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"], index=2)
47
  with col5:
48
  problem_type = st.selectbox("Problem Type", ["Classification", "Regression"])
49
  with col6:
 
89
 
90
  # ========== TRAINING ANN ========== #
91
  def build_ann():
92
+ model = Sequential()
93
  model.add(layers.Input(shape=(X.shape[1],)))
94
  for units in neurons:
95
+ model.add(layers.Dense(units=units, activation='tanh'))
96
+ model.add(layers.Dense(units=1, activation='sigmoid', kernel_regularizer=l2(0.1)))
97
  model.compile(optimizer=keras.optimizers.Adam(learning_rate), loss="binary_crossentropy" if problem_type == "Classification" else "mse")
98
  return model
99
 
100
+ class KerasClassifierWrapper(BaseEstimator, ClassifierMixin):
101
+ def __init__(self, model):
102
+ self.model = model
103
+ def fit(self, X, y):
104
+ self.model.fit(X, y, epochs=num_epochs, batch_size=batch_size, verbose=0)
105
+ return self
106
+ def predict(self, X):
107
+ return (self.model.predict(X) > 0.5).astype(int).flatten()
108
+
109
  if st.session_state.running:
110
  model = build_ann()
111
  history = model.fit(X_train, y_train, epochs=num_epochs, batch_size=batch_size, validation_data=(X_test, y_test), verbose=0)
112
  st.session_state.train_loss_history = history.history["loss"]
113
  st.session_state.test_loss_history = history.history["val_loss"]
114
+ wrapper_model = KerasClassifierWrapper(model)
115
 
116
  # ========== LOSS PLOT ========== #
117
  with col_plot:
 
128
  # =================== DECISION REGION =================== #
129
  if problem_type == "Classification":
130
  fig, ax = plt.subplots()
131
+ plot_decision_regions(X_train, y_train, clf=wrapper_model, ax=ax)
132
+ ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolor='k')
133
  ax.set_title("Decision Region")
134
  st.pyplot(fig)
135