Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import plotly.graph_objs as go | |
| import sympy as sp | |
| st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| html, body, [class*="css"] { | |
| font-family: 'Segoe UI', sans-serif; | |
| background-color: #ffffff; | |
| color: #222; | |
| } | |
| h1 { | |
| font-size: 36px; | |
| font-weight: 700; | |
| margin-bottom: 0.5em; | |
| } | |
| .stTextInput > div > input { | |
| border: 2px solid #00d0b6; | |
| border-radius: 8px; | |
| padding: 0.5em; | |
| font-size: 16px; | |
| } | |
| .stButton > button { | |
| background-color: #00d0b6; | |
| color: white; | |
| font-weight: 600; | |
| border-radius: 8px; | |
| padding: 0.6em 1.2em; | |
| font-size: 16px; | |
| } | |
| .stButton > button:hover { | |
| background-color: #00baa5; | |
| transition: 0.3s; | |
| } | |
| .stMarkdown { | |
| font-size: 18px; | |
| font-weight: 500; | |
| } | |
| .element-container:has(.stButton) { | |
| margin-top: 1em; | |
| margin-bottom: 1em; | |
| } | |
| .stColumns { | |
| gap: 0.5rem !important; | |
| } | |
| .st-c1 { | |
| font-weight: bold; | |
| } | |
| .stSuccess { | |
| font-size: 18px; | |
| font-weight: 600; | |
| color: white; | |
| background-color: #00d0b6; | |
| border-radius: 6px; | |
| padding: 0.4em 0.8em; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("Interactive Gradient Descent Visualizer (1D)") | |
| x = sp.Symbol('x') | |
| defaults = { | |
| "step": 0, | |
| "points": [], | |
| "gradient_func": None, | |
| "func": None, | |
| "parsed_func": None, | |
| "func_input": "x^2+x" | |
| } | |
| for key, val in defaults.items(): | |
| if key not in st.session_state: | |
| st.session_state[key] = val | |
| st.session_state.func_input = st.text_input("Function", st.session_state.func_input) | |
| st.markdown("Try these functions:") | |
| c1, c2, c3, c4, c5 = st.columns(5) | |
| if c1.button("x²"): | |
| st.session_state.func_input = "x^2" | |
| if c2.button("x³"): | |
| st.session_state.func_input = "x^3" | |
| if c3.button("sin(x)"): | |
| st.session_state.func_input = "sin(x)" | |
| if c4.button("1/x"): | |
| st.session_state.func_input = "1/x" | |
| if c5.button("Polynomial"): | |
| st.session_state.func_input = "x**4 - 3*x**3 + 2" | |
| start_point = st.text_input("Starting Point", "5") | |
| setup = st.button("Set Up") | |
| if setup: | |
| try: | |
| st.session_state.step = 0 | |
| st.session_state.points = [] | |
| expr = st.session_state.func_input.replace("^", "**") | |
| parsed = sp.sympify(expr) | |
| st.session_state.parsed_func = parsed | |
| st.session_state.func = sp.lambdify(x, parsed, "numpy") | |
| grad = sp.diff(parsed, x) | |
| st.session_state.gradient_func = sp.lambdify(x, grad, "numpy") | |
| st.session_state.points.append(float(start_point)) | |
| except Exception as e: | |
| st.error(f"Error parsing function: {e}") | |
| learning_rate = st.text_input("Learning Rate", "0.01") | |
| if st.button("Next Iteration"): | |
| if st.session_state.func is None or st.session_state.gradient_func is None or len(st.session_state.points) == 0: | |
| st.warning("Please set up the function first.") | |
| else: | |
| try: | |
| lr = float(learning_rate) | |
| x_curr = st.session_state.points[-1] | |
| grad_val = st.session_state.gradient_func(x_curr) | |
| x_next = x_curr - lr * grad_val | |
| st.session_state.points.append(x_next) | |
| st.session_state.step += 1 | |
| except Exception as e: | |
| st.error(f"Iteration error: {e}") | |
| if st.session_state.func is not None and len(st.session_state.points) > 0: | |
| try: | |
| x_vals = np.linspace(-6, 6, 400) | |
| y_vals = st.session_state.func(x_vals) | |
| iter_points = np.array(st.session_state.points) | |
| iter_y = st.session_state.func(iter_points) | |
| trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='blue')) | |
| trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red')) | |
| trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', | |
| marker=dict(color='green', size=12), | |
| text=[f"{iter_points[-1]:.6f}"], | |
| textposition="top center", | |
| name="Current Point") | |
| layout = go.Layout(title=f"Iteration {st.session_state.step}", | |
| xaxis=dict(title="x - axis"), | |
| yaxis=dict(title="y - axis")) | |
| fig = go.Figure(data=[trace1, trace2, trace3], layout=layout) | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.success(f"Current Point = {iter_points[-1]}") | |
| except Exception as e: | |
| st.error(f"Plot error: {e}") |