Anshini commited on
Commit
d21f501
Β·
verified Β·
1 Parent(s): 23931f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -24
app.py CHANGED
@@ -8,9 +8,31 @@ import tensorflow as tf
8
  from tensorflow.keras import layers, models
9
 
10
  # -------------------------------
11
- # Helper Functions
 
 
 
 
 
12
  # -------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
14
  def generate_data(dataset, test_size):
15
  if dataset == "moons":
16
  X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
@@ -18,9 +40,9 @@ def generate_data(dataset, test_size):
18
  X, y = make_circles(n_samples=1000, noise=0.2, factor=0.5, random_state=42)
19
  else:
20
  X, y = make_blobs(n_samples=1000, centers=2, cluster_std=1.5, random_state=42)
21
-
22
  X = StandardScaler().fit_transform(X)
23
- return train_test_split(X, y, test_size=1-test_size, random_state=42)
24
 
25
  def build_model(activation, learning_rate):
26
  model = models.Sequential([
@@ -38,8 +60,7 @@ def plot_decision_boundary(model, X, y):
38
  xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
39
  np.linspace(y_min, y_max, 200))
40
  grid = np.c_[xx.ravel(), yy.ravel()]
41
- preds = model.predict(grid)
42
- preds = preds.reshape(xx.shape)
43
 
44
  plt.contourf(xx, yy, preds, cmap="RdBu", alpha=0.6)
45
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap="RdBu", edgecolors='white')
@@ -57,28 +78,50 @@ def plot_loss(history):
57
  st.pyplot(plt.gcf())
58
  plt.clf()
59
 
60
- def train_and_visualize(dataset, lr, act, split, batch):
61
- X_train, X_test, y_train, y_test = generate_data(dataset, split)
62
- model = build_model(act, lr)
63
- history = model.fit(X_train, y_train, epochs=50, batch_size=batch,
64
- validation_data=(X_test, y_test), verbose=0)
65
- X_combined = np.vstack((X_train, X_test))
66
- y_combined = np.concatenate((y_train, y_test))
67
- plot_decision_boundary(model, X_combined, y_combined)
68
- plot_loss(history)
69
-
70
  # -------------------------------
71
- # Streamlit UI
72
  # -------------------------------
 
 
 
 
 
 
 
 
73
 
 
 
 
74
  st.title("🧠 Neural Network Playground")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- dataset = st.selectbox("Choose Dataset", ["moons", "circles", "blobs"])
77
- learning_rate = st.number_input("Learning Rate", value=0.01, format="%.4f")
78
- activation = st.selectbox("Activation Function", ["relu", "sigmoid", "tanh"])
79
- split_ratio = st.slider("Train-Test Split Ratio", 0.5, 0.9, 0.7)
80
- batch_size = st.number_input("Batch Size", value=32, step=16)
81
 
82
- if st.button("Train Model"):
83
- with st.spinner("Training in progress..."):
84
- train_and_visualize(dataset, learning_rate, activation, split_ratio, batch_size)
 
8
  from tensorflow.keras import layers, models
9
 
10
  # -------------------------------
11
+ # Page Config
12
+ # -------------------------------
13
+ st.set_page_config(page_title="Neural Network Playground", layout="wide")
14
+
15
+ # -------------------------------
16
+ # Styling
17
  # -------------------------------
18
+ st.markdown(
19
+ """
20
+ <style>
21
+ .stApp {
22
+ background-color: #f5f7fa;
23
+ }
24
+ h1, h2 {
25
+ color: #333333;
26
+ font-family: 'Segoe UI', sans-serif;
27
+ }
28
+ </style>
29
+ """,
30
+ unsafe_allow_html=True
31
+ )
32
 
33
+ # -------------------------------
34
+ # Helper Functions
35
+ # -------------------------------
36
  def generate_data(dataset, test_size):
37
  if dataset == "moons":
38
  X, y = make_moons(n_samples=1000, noise=0.2, random_state=42)
 
40
  X, y = make_circles(n_samples=1000, noise=0.2, factor=0.5, random_state=42)
41
  else:
42
  X, y = make_blobs(n_samples=1000, centers=2, cluster_std=1.5, random_state=42)
43
+
44
  X = StandardScaler().fit_transform(X)
45
+ return train_test_split(X, y, test_size=1 - test_size, random_state=42)
46
 
47
  def build_model(activation, learning_rate):
48
  model = models.Sequential([
 
60
  xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
61
  np.linspace(y_min, y_max, 200))
62
  grid = np.c_[xx.ravel(), yy.ravel()]
63
+ preds = model.predict(grid, verbose=0).reshape(xx.shape)
 
64
 
65
  plt.contourf(xx, yy, preds, cmap="RdBu", alpha=0.6)
66
  plt.scatter(X[:, 0], X[:, 1], c=y, cmap="RdBu", edgecolors='white')
 
78
  st.pyplot(plt.gcf())
79
  plt.clf()
80
 
 
 
 
 
 
 
 
 
 
 
81
  # -------------------------------
82
+ # Sidebar Inputs
83
  # -------------------------------
84
+ with st.sidebar:
85
+ st.header("πŸ”§ Hyperparameters")
86
+ dataset = st.selectbox("Select Dataset", ["moons", "circles", "blobs"])
87
+ learning_rate = st.number_input("Learning Rate", value=0.01, format="%.4f")
88
+ activation = st.selectbox("Activation Function", ["relu", "sigmoid", "tanh"])
89
+ split_ratio = st.slider("Train-Test Split", 0.5, 0.9, 0.7)
90
+ batch_size = st.number_input("Batch Size", value=32, step=16)
91
+ train_button = st.button("πŸš€ Train Model")
92
 
93
+ # -------------------------------
94
+ # Main App
95
+ # -------------------------------
96
  st.title("🧠 Neural Network Playground")
97
+ st.write("Interactively explore how neural networks learn decision boundaries with different hyperparameters and synthetic datasets.")
98
+
99
+ if train_button:
100
+ with st.spinner("Training the neural network..."):
101
+ # Generate data
102
+ X_train, X_test, y_train, y_test = generate_data(dataset, split_ratio)
103
+
104
+ # Build and train model
105
+ model = build_model(activation, learning_rate)
106
+ history = model.fit(X_train, y_train, epochs=50, batch_size=batch_size,
107
+ validation_data=(X_test, y_test), verbose=0)
108
+
109
+ # Evaluation
110
+ loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
111
+
112
+ # Display accuracy
113
+ st.metric("πŸ“Š Test Accuracy", f"{accuracy * 100:.2f}%")
114
+
115
+ # Tabs for output
116
+ tab1, tab2 = st.tabs(["🧭 Decision Boundary", "πŸ“‰ Training vs Testing Loss"])
117
+ with tab1:
118
+ X_all = np.vstack((X_train, X_test))
119
+ y_all = np.concatenate((y_train, y_test))
120
+ plot_decision_boundary(model, X_all, y_all)
121
 
122
+ with tab2:
123
+ plot_loss(history)
 
 
 
124
 
125
+ # Expandable section for model summary
126
+ with st.expander("πŸ“œ View Model Summary"):
127
+ model.summary(print_fn=lambda x: st.text(x))