Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import time | |
| from sklearn.datasets import make_moons, make_circles, make_classification, make_regression | |
| # Set Streamlit page style | |
| st.set_page_config(page_title="🔬 Neural Net Playground", layout="wide") | |
| st.markdown("<style>.block-container {padding-top: 1rem;}</style>", unsafe_allow_html=True) | |
| # ========== Initialize Session State ========== | |
| if "epoch" not in st.session_state: st.session_state.epoch = 0 | |
| if "running" not in st.session_state: st.session_state.running = False | |
| if "loss_history" not in st.session_state: st.session_state.loss_history = [] | |
| # ========== Title ========== | |
| st.title("🧠 Neural Network Trainer") | |
| st.markdown("Interactive trainer for basic neural network concepts.") | |
| # ========== 3-COLUMN LAYOUT ========== | |
| left, mid, right = st.columns([2, 3, 2]) | |
| # ========= Left: Dataset & Feature Controls ========= | |
| with left: | |
| st.header("📊 Dataset & Features") | |
| data_type = st.radio("Data Type", ["Classification", "Regression"]) | |
| noise = st.slider("Noise", 0.0, 1.0, 0.2, 0.05) | |
| samples = st.slider("Samples", 100, 1000, 500, 50) | |
| feature_dict = { | |
| "X₁": st.checkbox("X₁", value=True), | |
| "X₂": st.checkbox("X₂", value=True), | |
| "X₁²": st.checkbox("X₁²"), | |
| "X₂²": st.checkbox("X₂²"), | |
| "X₁X₂": st.checkbox("X₁X₂"), | |
| "sin(X₁)": st.checkbox("sin(X₁)"), | |
| "sin(X₂)": st.checkbox("sin(X₂)") | |
| } | |
| selected_features = [f for f, v in feature_dict.items() if v] | |
| # ========= Middle: Training Controls ========= | |
| with mid: | |
| st.header("⚙️ Model Settings") | |
| c1, c2, c3 = st.columns(3) | |
| with c1: | |
| activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh"]) | |
| with c2: | |
| regularization = st.selectbox("Regularization", ["None", "L1", "L2"]) | |
| with c3: | |
| learning_rate = st.select_slider("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1], value=0.01) | |
| reg_rate = st.slider("Reg. Rate", 0.0001, 0.1, 0.01) if regularization != "None" else 0 | |
| hidden_layers = st.slider("Hidden Layers", 1, 5, 2) | |
| neurons = [st.slider(f"Neurons in Layer {i+1}", 2, 20, 4) for i in range(hidden_layers)] | |
| st.subheader("Training Controls") | |
| col_a, col_b, col_c = st.columns(3) | |
| with col_a: | |
| if st.button("🔄 Reset"): | |
| st.session_state.epoch = 0 | |
| st.session_state.running = False | |
| st.session_state.loss_history = [] | |
| with col_b: | |
| if st.button("▶️ Train"): | |
| st.session_state.running = True | |
| with col_c: | |
| if st.button("⏸️ Pause"): | |
| st.session_state.running = False | |
| # ========= Right: Metrics & Plot ========= | |
| with right: | |
| st.header("📈 Live Metrics") | |
| if st.session_state.loss_history: | |
| st.metric("Epoch", st.session_state.epoch) | |
| st.metric("Current Loss", f"{st.session_state.loss_history[-1]:.4f}") | |
| else: | |
| st.info("No training yet.") | |
| st.subheader("Training Loss") | |
| fig, ax = plt.subplots(figsize=(4, 2)) | |
| ax.plot(st.session_state.loss_history, color="royalblue", marker="o") | |
| ax.set_xlabel("Epoch") | |
| ax.set_ylabel("Loss") | |
| ax.grid(True, linestyle="--", linewidth=0.5) | |
| st.pyplot(fig) | |
| # ========== Dataset Generation ========== | |
| def get_data(): | |
| if data_type == "Classification": | |
| X, y = make_moons(n_samples=samples, noise=noise) | |
| else: | |
| X, y = make_regression(n_samples=samples, n_features=1, noise=noise*10) | |
| return X, y | |
| X, y = get_data() | |
| # ========== Training Loop Simulation ========== | |
| if st.session_state.running: | |
| progress = st.progress(0, text="Training in progress...") | |
| for i in range(10): | |
| time.sleep(0.1) | |
| st.session_state.epoch += 1 | |
| loss = np.exp(-0.05 * st.session_state.epoch) + np.random.normal(0, 0.02) | |
| st.session_state.loss_history.append(loss) | |
| progress.progress((i+1)/10, text=f"Training... Epoch {st.session_state.epoch}") | |
| progress.empty() | |
| # ========== Dataset Plot ========== | |
| st.subheader("🧪 Dataset Visualization") | |
| fig, ax = plt.subplots() | |
| if data_type == "Classification": | |
| scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap="coolwarm", edgecolor="k") | |
| else: | |
| ax.scatter(X[:, 0], y, c=y, cmap="plasma", edgecolor="k") | |
| sns.kdeplot(x=X[:, 0], y=y, fill=True, cmap="plasma", ax=ax, alpha=0.3) | |
| ax.set_title(f"{data_type} Dataset") | |
| ax.grid(True) | |
| st.pyplot(fig) | |
| # import streamlit as st | |
| # import numpy as np | |
| # import matplotlib.pyplot as plt | |
| # import seaborn as sns | |
| # import graphviz | |
| # import time | |
| # from sklearn.datasets import make_moons, make_circles, make_classification | |
| # from sklearn.datasets import make_regression | |
| # # Set Streamlit page title | |
| # st.set_page_config(page_title="Neural Network Trainer", layout="wide") | |
| # # ================= Session State for Training Controls ================= | |
| # if "epoch" not in st.session_state: | |
| # st.session_state.epoch = 0 | |
| # if "running" not in st.session_state: | |
| # st.session_state.running = False | |
| # # ================= TRAINING CONTROL PANEL (Top) ================= | |
| # st.markdown("### Training Controls") | |
| # col1, col2, col3, col4, col5, col6, col7, col8, col9 = st.columns(9) | |
| # with col1: | |
| # if st.button("↩️ Reset"): | |
| # st.session_state.epoch = 0 | |
| # st.session_state.running = False | |
| # with col2: | |
| # if st.button("▶️ Train"): | |
| # st.session_state.running = True | |
| # with col3: | |
| # if st.button("⏸️ Pause"): | |
| # st.session_state.running = False | |
| # with col4: | |
| # activation = st.selectbox("Activation", ["ReLU", "Sigmoid", "Tanh", "LeakyReLU"]) | |
| # with col5: | |
| # regularization = st.selectbox("Regularization", ["None", "L1", "L2"]) | |
| # with col6: | |
| # reg_rate = st.selectbox("Regularization Rate", [0.0001, 0.001, 0.01, 0.1]) if regularization in ["L1", "L2"] else 0 | |
| # with col7: | |
| # problem_type = st.selectbox("Problem Type", ["Classification", "Regression"]) | |
| # with col8: | |
| # learning_rate = st.selectbox("Learning Rate", [0.0001, 0.001, 0.01, 0.03, 0.1]) | |
| # with col9: | |
| # st.write(f"Epoch: {st.session_state.epoch}") | |
| # # 🚀 Fix: Run training loop without breaking Streamlit | |
| # if st.session_state.running: | |
| # time.sleep(1) # Simulating training | |
| # st.session_state.epoch += 1 | |
| # # ================= MAIN LAYOUT ================= | |
| # col_features, col_hidden, col_output = st.columns([2, 2, 2]) | |
| # # ========== FEATURE SELECTION MOVED TO MIDDLE ========== | |
| # with col_features: | |
| # st.header("FEATURE SELECTION") | |
| # feature_dict = { | |
| # "X₁": st.checkbox("X₁", value=True), | |
| # "X₂": st.checkbox("X₂", value=True), | |
| # "X₁²": st.checkbox("X₁²"), | |
| # "X₂²": st.checkbox("X₂²"), | |
| # "X₁X₂": st.checkbox("X₁X₂"), | |
| # "sin(X₁)": st.checkbox("sin(X₁)"), | |
| # "sin(X₂)": st.checkbox("sin(X₂)"), | |
| # } | |
| # selected_features = [f for f, v in feature_dict.items() if v] | |
| # # ========== HIDDEN LAYERS PANEL (Middle) ========== # | |
| # with col_hidden: | |
| # st.header("HIDDEN LAYERS") | |
| # hidden_layers = st.slider("Number of Hidden Layers", 1, 7, 2) | |
| # neurons = [] | |
| # for i in range(hidden_layers): | |
| # neurons.append(st.slider(f"Neurons in Layer {i+1}", 1, 20, 4)) | |
| # # ========== OUTPUT PANEL (Right) ========== # | |
| # with col_output: | |
| # st.header("OUTPUT") | |
| # st.write("Test Loss: 0.501") | |
| # st.write("Training Loss: 0.507") | |
| # # Spiral Plot with Updated Color Palette | |
| # x = np.linspace(-6, 6, 300) | |
| # y = np.sin(x) + np.random.normal(0, 0.1, x.shape) | |
| # fig, ax = plt.subplots() | |
| # sns.scatterplot(x=x, y=y, hue=x, palette="plasma", ax=ax) | |
| # st.pyplot(fig) | |
| # show_test_data = st.checkbox("Show test data") | |
| # discretize_output = st.checkbox("Discretize output") | |
| # # Sidebar for dataset selection | |
| # st.sidebar.header("Dataset Selection") | |
| # data_type = st.sidebar.radio("Choose Data Type", ["Classification", "Regression"]) | |
| # # Generate classification data | |
| # def generate_classification_data(): | |
| # st.sidebar.subheader("Classification Settings") | |
| # dataset_type = st.sidebar.selectbox("Dataset Type", ["Moons", "Circles", "Classification"]) | |
| # noise = st.sidebar.slider("Noise Level", 0.0, 1.0, 0.2, step=0.05) | |
| # samples = st.sidebar.slider("Number of Samples", 100, 1000, 500, step=50) | |
| # if dataset_type == "Moons": | |
| # X, y = make_moons(n_samples=samples, noise=noise) | |
| # elif dataset_type == "Circles": | |
| # X, y = make_circles(n_samples=samples, noise=noise, factor=0.5) | |
| # else: | |
| # X, y = make_classification(n_samples=samples, n_features=2, n_classes=2, n_clusters_per_class=1, flip_y=noise) | |
| # return X, y | |
| # # Generate regression data | |
| # def generate_regression_data(): | |
| # st.sidebar.subheader("Regression Settings") | |
| # samples = st.sidebar.slider("Number of Samples", 100, 1000, 500, step=50) | |
| # noise = st.sidebar.slider("Noise Level", 0.0, 10.0, 2.0, step=0.5) | |
| # X, y = make_regression(n_samples=samples, n_features=1, noise=noise) | |
| # return X, y | |
| # # Select dataset type | |
| # if data_type == "Classification": | |
| # X, y = generate_classification_data() | |
| # cmap = "coolwarm" | |
| # title = "Classification Data" | |
| # is_classification = True | |
| # else: | |
| # X, y = generate_regression_data() | |
| # cmap = "plasma" | |
| # title = "Regression Data" | |
| # is_classification = False | |
| # # 🎯 Reduced Size of the Plot | |
| # fig, ax = plt.subplots(figsize=(4, 2)) # Smaller size (width=4, height=2) | |
| # if is_classification: | |
| # scatter = ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap, edgecolors="white", alpha=0.8) | |
| # ax.set_xlabel("Feature 1", fontsize=8) | |
| # ax.set_ylabel("Feature 2", fontsize=8) | |
| # else: | |
| # scatter = ax.scatter(X[:, 0], y, c=y, cmap=cmap, edgecolors="white", alpha=0.8) | |
| # sns.kdeplot(x=X[:, 0], y=y, fill=True, cmap=cmap, alpha=0.3, ax=ax) | |
| # ax.set_xlabel("Feature 1", fontsize=8) | |
| # ax.set_ylabel("Target", fontsize=8) | |
| # ax.set_title(title, fontsize=10) | |
| # ax.tick_params(axis='both', labelsize=7) | |
| # ax.grid(True, linewidth=0.5) | |
| # # Display in Streamlit | |
| # st.pyplot(fig) | |
| # # ================= NEURAL NETWORK VISUALIZATION ================= | |
| # def draw_neural_network(): | |
| # graph = graphviz.Digraph(engine="dot") | |
| # # Input Layer (Features) | |
| # input_nodes = [] | |
| # for feature in selected_features: | |
| # graph.node(feature, feature, shape="circle", style="filled", fillcolor="lightblue", width="0.6", height="0.6") | |
| # input_nodes.append(feature) | |
| # # Hidden Layers | |
| # prev_layer = input_nodes | |
| # hidden_layers_nodes = [] | |
| # for i, num_neurons in enumerate(neurons): | |
| # layer_nodes = [f"H{i+1}_{j+1}" for j in range(num_neurons)] | |
| # hidden_layers_nodes.append(layer_nodes) | |
| # for node in layer_nodes: | |
| # graph.node(node, node, shape="circle", style="filled", fillcolor="lightyellow", width="0.6", height="0.6") | |
| # # Connect previous layer to this hidden layer | |
| # for prev in prev_layer: | |
| # for curr in layer_nodes: | |
| # graph.edge(prev, curr) | |
| # prev_layer = layer_nodes # Update previous layer for next iteration | |
| # # Output Layer | |
| # graph.node("Output", "Output", shape="circle", style="filled", fillcolor="lightgreen", width="0.6", height="0.6") | |
| # # Connect last hidden layer to output | |
| # for last_hidden in prev_layer: | |
| # graph.edge(last_hidden, "Output") | |
| # graph.attr(rankdir="LR") # Make it horizontal (Left to Right) | |
| # return graph | |
| # # =================== DISPLAY NEURAL NETWORK =================== | |
| # st.graphviz_chart(draw_neural_network()) | |
| # # =================== DISPLAY DATA PLOT =================== | |
| # st.sidebar.subheader("Dataset Visualization") | |
| # fig, ax = plt.subplots() | |
| # ax.scatter(X[:, 0], X[:, 1], c=y, cmap="plasma", edgecolors="k") | |
| # st.sidebar.pyplot(fig) | |
| # import streamlit as st | |
| # import numpy as np | |
| # import matplotlib.pyplot as plt | |
| # import time | |
| # # Initialize session state | |
| # if "epoch" not in st.session_state: | |
| # st.session_state.epoch = 0 | |
| # if "running" not in st.session_state: | |
| # st.session_state.running = False | |
| # if "loss_history" not in st.session_state: | |
| # st.session_state.loss_history = [] | |
| # # Training controls | |
| # col1, col2, col3 = st.columns(3) | |
| # with col1: | |
| # if st.button("Reset"): | |
| # st.session_state.epoch = 0 | |
| # st.session_state.running = False | |
| # st.session_state.loss_history = [] | |
| # with col2: | |
| # if st.button("Train"): | |
| # st.session_state.running = True | |
| # with col3: | |
| # if st.button("Pause"): | |
| # st.session_state.running = False | |
| # # Training loop simulation | |
| # if st.session_state.running: | |
| # for _ in range(10): | |
| # time.sleep(0.5) | |
| # st.session_state.epoch += 1 | |
| # simulated_loss = np.exp(-0.1 * st.session_state.epoch) + np.random.normal(0, 0.02) | |
| # st.session_state.loss_history.append(simulated_loss) | |
| # # Epoch vs Training Loss Plot (Smaller Size) | |
| # st.header("Epoch vs Training Loss") | |
| # fig, ax = plt.subplots(figsize=(4, 2)) # Reduce plot size (width=4, height=2) | |
| # ax.plot(range(1, len(st.session_state.loss_history) + 1), st.session_state.loss_history, marker="o", linestyle="-", color="blue") | |
| # ax.set_xlabel("Epoch") | |
| # ax.set_ylabel("Training Loss") | |
| # ax.set_title("Training Loss Over Epochs", fontsize=10) | |
| # ax.tick_params(axis='both', labelsize=8) | |
| # ax.grid(True, linewidth=0.5) | |
| # st.pyplot(fig) | |
| # # Display current epoch and training loss below the plot | |
| # if st.session_state.loss_history: | |
| # st.write(f"Epoch: {st.session_state.epoch}") | |
| # st.write(f"Training Loss: {st.session_state.loss_history[-1]:.4f}") | |
| # # Display current epoch and training loss below the plot | |
| # if st.session_state.loss_history: | |
| # st.write(f"Epoch: {st.session_state.epoch}") | |
| # st.write(f"Training Loss: {st.session_state.loss_history[-1]:.4f}") | |
| # # =================== TRAINING STATUS =================== | |
| # if st.session_state.running: | |
| # st.write("🚀 Training started...") | |
| # elif not st.session_state.running and st.session_state.epoch > 0: | |
| # st.write("⏸️ Training paused.") |