ANN_PLAYGROUND / src /streamlit_app.py
satya11's picture
Update src/streamlit_app.py
8dbf941 verified
# Import packages
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import make_moons, make_circles, make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import numpy as np
import matplotlib.pyplot as plt
import io
import time
# Streamlit config
st.set_page_config(page_title="ANN Visualizer", layout="wide")
# ---------- UI Components ----------
st.title("🧠 Interactive ANN Visualizer")
with st.sidebar:
st.header("βš™οΈ Model Configuration")
dataset_type = st.selectbox("Dataset", ["moons", "circles", "classification"])
n_samples = st.slider("Samples", 100, 5000, 500, 100)
noise = st.slider("Noise", 0.0, 1.0, 0.2, 0.05)
epochs = st.slider("Epochs", 100, 5000, 500, 100)
lr = st.number_input("Learning Rate", 0.0001, 1.0, 0.01, format="%f")
early_stop = st.checkbox("Early Stopping", value=True)
patience = st.slider("Patience", 1, 20, 5) if early_stop else None
min_delta = st.number_input("Min Delta", 0.0001, 0.1, 0.001, format="%f") if early_stop else None
st.subheader("🧱 Hidden Layers")
num_hidden = st.number_input("Number of Layers", 1, 10, 2)
layer_configs = []
activation_map = {"ReLU": nn.ReLU, "Tanh": nn.Tanh, "Sigmoid": nn.Sigmoid}
for i in range(num_hidden):
st.markdown(f"**Layer {i + 1}**")
units = st.number_input(f"Units", 1, 512, 8, key=f"units_{i}")
act = st.selectbox("Activation", list(activation_map.keys()), key=f"act_{i}")
dropout = st.slider("Dropout", 0.0, 0.9, 0.0, 0.05, key=f"drop_{i}")
reg_type = st.selectbox("Regularization", ["None", "L1", "L2", "L1_L2"], key=f"reg_{i}")
reg_strength = st.number_input("Reg Strength", 0.0, 1.0, 0.001, format="%f", key=f"reg_strength_{i}") if reg_type != "None" else 0.0
layer_configs.append((units, act, dropout, reg_type, reg_strength))
start_training = st.button("πŸš€ Train Model")
# ---------- Data Generation ----------
def generate_data():
if dataset_type == "moons":
return make_moons(n_samples=n_samples, noise=noise, random_state=42)
elif dataset_type == "circles":
return make_circles(n_samples=n_samples, noise=noise, factor=0.5, random_state=42)
return make_classification(n_samples=n_samples, n_features=2, n_informative=2, n_redundant=0, n_clusters_per_class=1)
# ---------- Model ----------
class CustomLayer(nn.Module):
def __init__(self, in_f, out_f, activation, dropout, reg_type, reg_strength):
super().__init__()
self.linear = nn.Linear(in_f, out_f)
self.activation = activation_map[activation]()
self.dropout = nn.Dropout(dropout)
self.reg_type = reg_type
self.reg_strength = reg_strength
def forward(self, x):
return self.dropout(self.activation(self.linear(x)))
def reg_loss(self):
if self.reg_type == "L1":
return self.reg_strength * torch.sum(torch.abs(self.linear.weight))
elif self.reg_type == "L2":
return self.reg_strength * torch.sum(self.linear.weight ** 2)
elif self.reg_type == "L1_L2":
return self.reg_strength * (torch.sum(torch.abs(self.linear.weight)) + torch.sum(self.linear.weight ** 2))
return 0.0
class ANN(nn.Module):
def __init__(self, input_dim, output_dim, configs):
super().__init__()
self.layers = nn.ModuleList()
prev = input_dim
self.reg_layers = []
for units, act, drop, reg, reg_strength in configs:
layer = CustomLayer(prev, units, act, drop, reg, reg_strength)
self.layers.append(layer)
self.reg_layers.append(layer)
prev = units
self.output = nn.Linear(prev, output_dim)
def forward(self, x):
for l in self.layers:
x = l(x)
return self.output(x)
def regularization_loss(self):
return sum(l.reg_loss() for l in self.reg_layers)
# ---------- Training Logic ----------
if start_training:
X, y = generate_data()
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
model = ANN(2, 2, layer_configs)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
best_loss = float("inf")
best_weights = None
patience_counter = 0
train_losses, test_losses = [], []
grid_x, grid_y = np.meshgrid(np.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, 400),
np.linspace(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5, 400))
grid_tensor = torch.tensor(np.c_[grid_x.ravel(), grid_y.ravel()], dtype=torch.float32)
progress = st.progress(0)
for epoch in range(1, epochs + 1):
model.train()
optimizer.zero_grad()
out = model(X_train_tensor)
loss = criterion(out, y_train_tensor) + model.regularization_loss()
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
val_out = model(X_test_tensor)
val_loss = criterion(val_out, y_test_tensor) + model.regularization_loss()
train_losses.append(loss.item())
test_losses.append(val_loss.item())
if early_stop:
if val_loss.item() < best_loss - min_delta:
best_loss = val_loss.item()
best_weights = model.state_dict()
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
st.warning(f"Stopped early at epoch {epoch}")
break
if epoch % (epochs // 10) == 0 or epoch == epochs:
with torch.no_grad():
preds = model(grid_tensor).argmax(dim=1).numpy().reshape(grid_x.shape)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(grid_x, grid_y, preds, cmap='Spectral', alpha=0.8)
ax.scatter(X[:, 0], X[:, 1], c=y, cmap='Spectral', edgecolor='k', s=15)
ax.set_title(f"Decision Boundary at Epoch {epoch}")
ax.axis("off")
st.pyplot(fig)
progress.progress(epoch / epochs)
if best_weights:
model.load_state_dict(best_weights)
# Final results
st.success("βœ… Training complete!")
st.subheader("πŸ“ˆ Loss Curve")
fig1, ax1 = plt.subplots()
ax1.plot(train_losses, label="Train", color="navy")
ax1.plot(test_losses, label="Test", color="orange")
ax1.legend(); ax1.grid(True); st.pyplot(fig1)
buf1 = io.BytesIO(); fig1.savefig(buf1, format="png")
st.download_button("Download Loss Plot", buf1.getvalue(), "loss.png", "image/png")
st.subheader("🧭 Final Decision Boundary")
with torch.no_grad():
final_preds = model(grid_tensor).argmax(dim=1).numpy().reshape(grid_x.shape)
fig2, ax2 = plt.subplots(figsize=(5, 5))
ax2.contourf(grid_x, grid_y, final_preds, cmap='Spectral', alpha=0.8)
ax2.scatter(X[:, 0], X[:, 1], c=y, cmap='Spectral', edgecolor='k', s=15)
ax2.set_title("Final Decision Boundary"); ax2.axis("off")
st.pyplot(fig2)
buf2 = io.BytesIO(); fig2.savefig(buf2, format="png")
st.download_button("Download Decision Boundary", buf2.getvalue(), "boundary.png", "image/png")
y_train_pred = model(X_train_tensor).argmax(dim=1).numpy()
y_test_pred = model(X_test_tensor).argmax(dim=1).numpy()
st.metric("Train Accuracy", f"{accuracy_score(y_train, y_train_pred):.2%}")
st.metric("Test Accuracy", f"{accuracy_score(y_test, y_test_pred):.2%}")