sree4411 commited on
Commit
ff402b5
Β·
verified Β·
1 Parent(s): 4145bbe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ from sklearn.datasets import make_classification, make_moons, make_circles
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.preprocessing import StandardScaler
9
+ from keras.models import Sequential
10
+ from keras.layers import Dense
11
+ from keras.optimizers import SGD
12
+ from mlxtend.plotting import plot_decision_regions
13
+
14
+ # -------------------------------
15
+ # PAGE CONFIGURATION
16
+ # -------------------------------
17
+ st.set_page_config(page_title="πŸ§ͺ Neural Network Playground", layout="centered")
18
+ st.title("🧬 Neural Network Playground")
19
+
20
+ # -------------------------------
21
+ # SESSION INITIALIZATION
22
+ # -------------------------------
23
+ def init_session():
24
+ defaults = {
25
+ 'X': None, 'y': None, 'model': None,
26
+ 'X_train': None, 'y_train': None,
27
+ 'history': None
28
+ }
29
+ for key, value in defaults.items():
30
+ if key not in st.session_state:
31
+ st.session_state[key] = value
32
+
33
+ init_session()
34
+
35
+ # -------------------------------
36
+ # DATA GENERATION FUNCTION
37
+ # -------------------------------
38
+ def generate_data(dataset, samples, noise, factor):
39
+ if dataset == "make_circles":
40
+ return make_circles(n_samples=samples, noise=noise, factor=factor, random_state=42)
41
+ elif dataset == "make_moons":
42
+ return make_moons(n_samples=samples, noise=noise, random_state=42)
43
+ else:
44
+ return make_classification(n_samples=samples, n_features=2, n_informative=2,
45
+ n_redundant=0, n_clusters_per_class=1,
46
+ flip_y=noise, random_state=42)
47
+
48
+ # -------------------------------
49
+ # MODEL TRAINING FUNCTION
50
+ # -------------------------------
51
+ def train_model(X, y, test_size, learning_rate, batch_size, epochs):
52
+ X_train, _, y_train, _ = train_test_split(X, y, test_size=test_size, random_state=1)
53
+ scaler = StandardScaler()
54
+ X_train_scaled = scaler.fit_transform(X_train)
55
+
56
+ model = Sequential([
57
+ Dense(8, activation='relu', input_shape=(2,)),
58
+ Dense(4, activation='relu'),
59
+ Dense(1, activation='sigmoid')
60
+ ])
61
+ model.compile(optimizer=SGD(learning_rate=learning_rate),
62
+ loss='binary_crossentropy',
63
+ metrics=['accuracy'])
64
+
65
+ history = model.fit(X_train_scaled, y_train,
66
+ validation_split=0.2,
67
+ epochs=epochs,
68
+ batch_size=batch_size,
69
+ verbose=0)
70
+
71
+ return model, X_train_scaled, y_train, history
72
+
73
+ # -------------------------------
74
+ # PLOT FUNCTIONS
75
+ # -------------------------------
76
+ def plot_dataset(X, y):
77
+ df = pd.DataFrame(X, columns=['x1', 'x2'])
78
+ df['label'] = y
79
+ fig, ax = plt.subplots()
80
+ sns.scatterplot(data=df, x='x1', y='x2', hue='label', palette='viridis', ax=ax)
81
+ st.pyplot(fig)
82
+
83
+ def plot_decision_boundary(model, X, y):
84
+ fig, ax = plt.subplots()
85
+ plot_decision_regions(X, y, clf=model, legend=2, ax=ax)
86
+ st.pyplot(fig)
87
+
88
+ def plot_loss_curve(history):
89
+ fig, ax = plt.subplots()
90
+ ax.plot(history.history['loss'], label='Training Loss')
91
+ ax.plot(history.history['val_loss'], label='Validation Loss')
92
+ ax.set_xlabel("Epochs")
93
+ ax.set_ylabel("Loss")
94
+ ax.set_title("Loss Curve")
95
+ ax.legend()
96
+ st.pyplot(fig)
97
+
98
+ # -------------------------------
99
+ # STREAMLIT TABS
100
+ # -------------------------------
101
+ tab1, tab2, tab3 = st.tabs([
102
+ "πŸ”Ή Step 1: Data Generator",
103
+ "πŸ”Ή Step 2: Train Neural Net",
104
+ "πŸ”Ή Step 3: Visualize Results"
105
+ ])
106
+
107
+ # -------------------------------
108
+ # TAB 1: DATA GENERATOR
109
+ # -------------------------------
110
+ with tab1:
111
+ st.header("🎲 Generate Dataset")
112
+ col1, col2 = st.columns(2)
113
+ dataset = col1.selectbox("Dataset Type", ["make_classification", "make_moons", "make_circles"])
114
+ samples = col2.slider("Samples", 100, 5000, 1000, 100)
115
+ noise = st.slider("Noise", 0.0, 1.0, 0.2)
116
+ factor = st.slider("Factor (only for Circles)", 0.1, 1.0, 0.5)
117
+
118
+ if st.button("Generate"):
119
+ X, y = generate_data(dataset, samples, noise, factor)
120
+ st.session_state.X = X
121
+ st.session_state.y = y
122
+ st.success("βœ… Data generated!")
123
+ plot_dataset(X, y)
124
+
125
+ # -------------------------------
126
+ # TAB 2: TRAINING
127
+ # -------------------------------
128
+ with tab2:
129
+ st.header("🧠 Train Model")
130
+
131
+ if st.session_state.X is None:
132
+ st.warning("⚠️ Generate data first in Step 1.")
133
+ else:
134
+ test_size = st.slider("Test Size (%)", 10, 90, 20) / 100
135
+ lr = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.1])
136
+ batch_size = st.slider("Batch Size", 1, 512, 64)
137
+ epochs = st.slider("Epochs", 10, 500, 100)
138
+
139
+ if st.button("Train Model"):
140
+ with st.spinner("Training in progress..."):
141
+ model, X_train, y_train, history = train_model(
142
+ st.session_state.X, st.session_state.y,
143
+ test_size, lr, batch_size, epochs
144
+ )
145
+ st.session_state.model = model
146
+ st.session_state.X_train = X_train
147
+ st.session_state.y_train = y_train
148
+ st.session_state.history = history
149
+ st.success("βœ… Training Complete!")
150
+
151
+ # -------------------------------
152
+ # TAB 3: VISUALIZATION
153
+ # -------------------------------
154
+ with tab3:
155
+ st.header("πŸ“Š Model Visualization")
156
+ if st.session_state.model is None:
157
+ st.warning("⚠️ Train the model in Step 2 to visualize results.")
158
+ else:
159
+ st.subheader("🌐 Decision Boundary")
160
+ plot_decision_boundary(st.session_state.model,
161
+ st.session_state.X_train,
162
+ st.session_state.y_train)
163
+
164
+ st.subheader("πŸ“‰ Training Loss Curve")
165
+ plot_loss_curve(st.session_state.history)