sree4411 commited on
Commit
efd1d7c
Β·
verified Β·
1 Parent(s): 374a992

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -146
app.py CHANGED
@@ -11,155 +11,108 @@ from keras.layers import Dense
11
  from keras.optimizers import SGD
12
  from mlxtend.plotting import plot_decision_regions
13
 
14
- # -------------------------------
15
- # PAGE CONFIGURATION
16
- # -------------------------------
17
- st.set_page_config(page_title="πŸ§ͺ Neural Network Playground", layout="centered")
18
- st.title("🧬 Neural Network Playground")
19
-
20
- # -------------------------------
21
- # SESSION INITIALIZATION
22
- # -------------------------------
23
- def init_session():
24
- defaults = {
25
- 'X': None, 'y': None, 'model': None,
26
- 'X_train': None, 'y_train': None,
27
- 'history': None
28
- }
29
- for key, value in defaults.items():
30
- if key not in st.session_state:
31
- st.session_state[key] = value
32
-
33
- init_session()
34
-
35
- # -------------------------------
36
- # DATA GENERATION FUNCTION
37
- # -------------------------------
38
- def generate_data(dataset, samples, noise, factor):
39
- if dataset == "make_circles":
40
- return make_circles(n_samples=samples, noise=noise, factor=factor, random_state=42)
41
- elif dataset == "make_moons":
42
- return make_moons(n_samples=samples, noise=noise, random_state=42)
43
- else:
44
- return make_classification(n_samples=samples, n_features=2, n_informative=2,
45
- n_redundant=0, n_clusters_per_class=1,
46
- flip_y=noise, random_state=42)
47
-
48
- # -------------------------------
49
- # MODEL TRAINING FUNCTION
50
- # -------------------------------
51
- def train_model(X, y, test_size, learning_rate, batch_size, epochs):
52
- X_train, _, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=1)
53
- scaler = StandardScaler()
54
- X_train_scaled = scaler.fit_transform(X_train)
55
-
56
- model = Sequential([
57
- Dense(8, activation='relu', input_shape=(2,)),
58
- Dense(4, activation='relu'),
59
- Dense(1, activation='sigmoid')
60
- ])
61
- model.compile(optimizer=SGD(learning_rate=learning_rate),
62
- loss='binary_crossentropy',
63
- metrics=['accuracy'])
64
-
65
- history = model.fit(X_train_scaled, y_train,
66
- validation_split=0.2,
67
- epochs=epochs,
68
- batch_size=batch_size,
69
- verbose=0)
70
-
71
- return model, X_train_scaled, y_train, history
72
-
73
- # -------------------------------
74
- # PLOT FUNCTIONS
75
- # -------------------------------
76
- def plot_dataset(X, y):
77
- df = pd.DataFrame(X, columns=['x1', 'x2'])
78
- df['label'] = y
79
- fig, ax = plt.subplots()
80
- sns.scatterplot(data=df, x='x1', y='x2', hue='label', palette='viridis', ax=ax)
81
- st.pyplot(fig)
82
-
83
- def plot_decision_boundary(model, X, y):
84
- fig, ax = plt.subplots()
85
- plot_decision_regions(X, y, clf=model, legend=2, ax=ax)
86
- st.pyplot(fig)
87
-
88
- def plot_loss_curve(history):
89
- fig, ax = plt.subplots()
90
- ax.plot(history.history['loss'], label='Training Loss')
91
- ax.plot(history.history['val_loss'], label='Validation Loss')
92
- ax.set_xlabel("Epochs")
93
- ax.set_ylabel("Loss")
94
- ax.set_title("Loss Curve")
95
- ax.legend()
96
- st.pyplot(fig)
97
-
98
- # -------------------------------
99
- # STREAMLIT TABS
100
- # -------------------------------
101
- tab1, tab2, tab3 = st.tabs([
102
- "πŸ”Ή Step 1: Data Generator",
103
- "πŸ”Ή Step 2: Train Neural Net",
104
- "πŸ”Ή Step 3: Visualize Results"
105
- ])
106
-
107
- # -------------------------------
108
- # TAB 1: DATA GENERATOR
109
- # -------------------------------
110
- with tab1:
111
- st.header("🎲 Generate Dataset")
112
- col1, col2 = st.columns(2)
113
- dataset = col1.selectbox("Dataset Type", ["make_classification", "make_moons", "make_circles"])
114
- samples = col2.slider("Samples", 100, 5000, 1000, 100)
115
- noise = st.slider("Noise", 0.0, 1.0, 0.2)
116
- factor = st.slider("Factor (only for Circles)", 0.1, 1.0, 0.5)
117
-
118
- if st.button("Generate"):
119
- X, y = generate_data(dataset, samples, noise, factor)
120
- st.session_state.X = X
121
- st.session_state.y = y
122
- st.success("βœ… Data generated!")
123
- plot_dataset(X, y)
124
-
125
- # -------------------------------
126
- # TAB 2: TRAINING
127
- # -------------------------------
128
- with tab2:
129
- st.header("🧠 Train Model")
130
-
131
  if st.session_state.X is None:
132
- st.warning("⚠️ Generate data first in Step 1.")
133
  else:
134
- test_size = st.slider("Test Size (%)", 10, 90, 20) / 100
135
- lr = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.1])
136
- batch_size = st.slider("Batch Size", 1, 512, 64)
 
 
137
  epochs = st.slider("Epochs", 10, 500, 100)
138
 
139
- if st.button("Train Model"):
140
- with st.spinner("Training in progress..."):
141
- model, X_train, y_train, history = train_model(
142
- st.session_state.X, st.session_state.y,
143
- test_size, lr, batch_size, epochs
144
- )
145
- st.session_state.model = model
146
- st.session_state.X_train = X_train
147
- st.session_state.y_train = y_train
148
- st.session_state.history = history
149
- st.success("βœ… Training Complete!")
150
-
151
- # -------------------------------
152
- # TAB 3: VISUALIZATION
153
- # -------------------------------
154
- with tab3:
155
- st.header("πŸ“Š Model Visualization")
 
 
 
 
 
 
 
 
 
 
156
  if st.session_state.model is None:
157
- st.warning("⚠️ Train the model in Step 2 to visualize results.")
158
  else:
159
- st.subheader("🌐 Decision Boundary")
160
- plot_decision_boundary(st.session_state.model,
161
- st.session_state.X_train,
162
- st.session_state.y_train)
163
-
164
- st.subheader("πŸ“‰ Training Loss Curve")
165
- plot_loss_curve(st.session_state.history)
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from keras.optimizers import SGD
12
  from mlxtend.plotting import plot_decision_regions
13
 
14
+ # --- Config ---
15
+ st.set_page_config(page_title="Neural Net Lab", layout="wide")
16
+ st.title("πŸ”¬ Interactive Neural Network Lab")
17
+
18
+ # --- Session State Setup ---
19
+ if 'X' not in st.session_state: st.session_state.X = None
20
+ if 'y' not in st.session_state: st.session_state.y = None
21
+ if 'model' not in st.session_state: st.session_state.model = None
22
+ if 'history' not in st.session_state: st.session_state.history = None
23
+ if 'X_train' not in st.session_state: st.session_state.X_train = None
24
+ if 'y_train' not in st.session_state: st.session_state.y_train = None
25
+
26
+ # --- Step 1: Dataset Generator ---
27
+ with st.expander("πŸ“Œ STEP 1: Generate Dataset", expanded=True):
28
+ st.markdown("Start by creating a synthetic 2D classification dataset.")
29
+
30
+ col1, col2, col3 = st.columns([2, 1, 1])
31
+ dataset_type = col1.selectbox("Choose a dataset type", ["make_classification", "make_moons", "make_circles"])
32
+ n_samples = col2.slider("Number of Samples", 100, 5000, 1000, step=100)
33
+ noise_level = col3.slider("Noise", 0.0, 1.0, 0.2)
34
+ circle_factor = st.slider("Factor (Circles only)", 0.1, 1.0, 0.5)
35
+
36
+ if st.button("πŸš€ Generate Dataset"):
37
+ if dataset_type == "make_moons":
38
+ X, y = make_moons(n_samples=n_samples, noise=noise_level, random_state=42)
39
+ elif dataset_type == "make_circles":
40
+ X, y = make_circles(n_samples=n_samples, noise=noise_level, factor=circle_factor, random_state=42)
41
+ else:
42
+ X, y = make_classification(n_samples=n_samples, n_features=2, n_informative=2,
43
+ n_redundant=0, n_clusters_per_class=1, flip_y=noise_level, random_state=42)
44
+ st.session_state.X, st.session_state.y = X, y
45
+ st.success("Dataset generated successfully! πŸŽ‰")
46
+
47
+ if st.session_state.X is not None:
48
+ df = pd.DataFrame(st.session_state.X, columns=["x1", "x2"])
49
+ df["label"] = st.session_state.y
50
+ st.markdown("### πŸ” Dataset Preview")
51
+ st.dataframe(df.head())
52
+
53
+ st.markdown("### πŸ“Š Visualize")
54
+ fig, ax = plt.subplots()
55
+ sns.scatterplot(data=df, x="x1", y="x2", hue="label", palette="coolwarm", ax=ax)
56
+ st.pyplot(fig)
57
+
58
+ # --- Step 2: Train Neural Network ---
59
+ with st.expander("πŸ€– STEP 2: Train Neural Network"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if st.session_state.X is None:
61
+ st.warning("⚠️ Please generate the dataset in Step 1.")
62
  else:
63
+ st.subheader("πŸ› οΈ Model Configuration")
64
+ c1, c2, c3 = st.columns(3)
65
+ test_size = c1.slider("Test Split %", 10, 90, 20) / 100
66
+ learning_rate = c2.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.1])
67
+ batch_size = c3.slider("Batch Size", 1, 512, 64)
68
  epochs = st.slider("Epochs", 10, 500, 100)
69
 
70
+ if st.button("🧠 Train the Model"):
71
+ st.info("πŸ“‘ Preprocessing and Training in progress...")
72
+ X_train, _, y_train, _ = train_test_split(st.session_state.X, st.session_state.y, test_size=test_size, random_state=1)
73
+ scaler = StandardScaler()
74
+ X_scaled = scaler.fit_transform(X_train)
75
+
76
+ model = Sequential([
77
+ Dense(8, activation='relu', input_shape=(2,)),
78
+ Dense(4, activation='relu'),
79
+ Dense(1, activation='sigmoid')
80
+ ])
81
+ model.compile(optimizer=SGD(learning_rate=learning_rate), loss='binary_crossentropy', metrics=['accuracy'])
82
+
83
+ history = model.fit(X_scaled, y_train, validation_split=0.2,
84
+ batch_size=batch_size, epochs=epochs, verbose=0)
85
+
86
+ st.session_state.model = model
87
+ st.session_state.history = history
88
+ st.session_state.X_train = X_scaled
89
+ st.session_state.y_train = y_train
90
+ st.success("βœ… Model training complete!")
91
+
92
+ st.metric("Final Accuracy", f"{history.history['val_loss'][-1]:.4f}")
93
+ st.progress(100)
94
+
95
+ # --- Step 3: Visualize Model Output ---
96
+ with st.expander("πŸ“ˆ STEP 3: Visualize Model Output"):
97
  if st.session_state.model is None:
98
+ st.warning("⚠️ Train the model first in Step 2.")
99
  else:
100
+ col1, col2 = st.columns(2)
101
+
102
+ with col1:
103
+ st.subheader("🌐 Decision Boundary")
104
+ fig1, ax1 = plt.subplots()
105
+ plot_decision_regions(st.session_state.X_train, st.session_state.y_train,
106
+ clf=st.session_state.model, ax=ax1, legend=2)
107
+ st.pyplot(fig1)
108
+
109
+ with col2:
110
+ st.subheader("πŸ“‰ Training Loss Curve")
111
+ fig2, ax2 = plt.subplots()
112
+ ax2.plot(st.session_state.history.history['loss'], label='Train Loss')
113
+ ax2.plot(st.session_state.history.history['val_loss'], label='Val Loss')
114
+ ax2.set_title("Loss over Epochs")
115
+ ax2.legend()
116
+ st.pyplot(fig2)
117
+
118
+ st.success("πŸ§ͺ Visualization ready!")