import streamlit as st import numpy as np import sympy as sp import plotly.graph_objs as go st.markdown(""" """, unsafe_allow_html=True) st.title("Gradient Descent Visualizer") x = sp.Symbol("x") func_input = st.text_input("Enter Function", "x^2") start_point = float(st.text_input("Starting Point", "2")) learning_rate = float(st.text_input("Learning Rate", "0.01")) num_iterations = int(st.text_input("Number of Iterations", "10")) if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state: try: expr = func_input.replace("^", "**") expr_final = sp.sympify(expr) func = sp.lambdify(x, expr_final, "numpy") grad = sp.diff(expr_final, x) gradient_func = sp.lambdify(x, grad, "numpy") st.session_state.func = func st.session_state.gradient_func = gradient_func st.session_state.points = [start_point] st.session_state.step = 0 st.success("Function and Gradient Set Up Successfully!") except Exception as e: st.error(f"Error setting up function: {e}") if 'func' in st.session_state and 'gradient_func' in st.session_state: if st.button("Next Iteration"): try: x_old = float(st.session_state.points[-1]) grad_val = st.session_state.gradient_func(x_old) x_new = x_old - learning_rate * grad_val st.session_state.points.append(x_new) st.session_state.step += 1 st.success(f"Iteration {st.session_state.step} Complete!") except Exception as e: st.error(f"Error in iteration: {e}") if st.button("Run Iterations"): try: for i in range(num_iterations): x_old = float(st.session_state.points[-1]) grad_val = st.session_state.gradient_func(x_old) x_new = x_old - learning_rate * grad_val st.session_state.points.append(x_new) st.session_state.step += 1 st.success(f"Ran {st.session_state.step} Iterations in total") except Exception as e: st.error(f"Error in multiple iterations: {e}") if 'func' in st.session_state and len(st.session_state.points) > 0: try: x_val = np.linspace(-10, 10, 500) y_val = st.session_state.func(x_val) iter_points = np.array(st.session_state.points) iter_y = st.session_state.func(iter_points) trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue")) trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red")) trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', marker=dict(color='green', size=15), text=[f"{iter_points[-1]:.6f}"], textposition="top center", name="Current Point") layout = go.Layout( title=f"Iteration {st.session_state.step}", xaxis=dict(title="x - axis"), yaxis=dict(title="y - axis"), width=1000, height=600 ) fig = go.Figure(data=[trace1, trace2, trace3], layout=layout) st.plotly_chart(fig, use_container_width=True) st.success(f"Current Point = {iter_points[-1]}") except Exception as e: st.error(f"Plot error: {e}")