import streamlit as st import torch import numpy as np import plotly.graph_objs as go st.markdown(""" """, unsafe_allow_html=True) st.title("Gradient Descent Visualizer") func_input = st.text_input("Enter Function of x", "x**2") start_point = float(st.text_input("Starting Point", "2")) learning_rate = float(st.text_input("Learning Rate", "0.01")) num_iterations = int(st.text_input("Number of Iterations", "10")) def make_function(expr: str): """Dynamically create a function in torch""" def func(x): return eval(expr, {"x": x, "torch": torch}) return func if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state: try: func = make_function(func_input) st.session_state.func = func st.session_state.points = [start_point] st.session_state.step = 0 st.success("Function Set Up Successfully with PyTorch!") except Exception as e: st.error(f"Error setting up function: {e}") def gradient_step(x_val, func, lr): x = torch.tensor([x_val], dtype=torch.float32, requires_grad=True) y = func(x) y.backward() grad = x.grad.item() new_x = x_val - lr * grad return new_x, grad if 'func' in st.session_state: if st.button("Next Iteration"): try: x_old = float(st.session_state.points[-1]) x_new, grad_val = gradient_step(x_old, st.session_state.func, learning_rate) st.session_state.points.append(x_new) st.session_state.step += 1 st.success(f"Iteration {st.session_state.step} Complete! (grad={grad_val:.6f})") except Exception as e: st.error(f"Error in iteration: {e}") if st.button("Run Iterations"): try: for i in range(num_iterations): x_old = float(st.session_state.points[-1]) x_new, grad_val = gradient_step(x_old, st.session_state.func, learning_rate) st.session_state.points.append(x_new) st.session_state.step += 1 st.success(f"Ran {st.session_state.step} Iterations in total") except Exception as e: st.error(f"Error in multiple iterations: {e}") if 'func' in st.session_state and len(st.session_state.points) > 0: try: x_val = np.linspace(-10, 10, 500) x_torch = torch.tensor(x_val, dtype=torch.float32) y_val = st.session_state.func(x_torch).detach().numpy() iter_points = np.array(st.session_state.points) iter_torch = torch.tensor(iter_points, dtype=torch.float32) iter_y = st.session_state.func(iter_torch).detach().numpy() trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue")) trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red")) trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', marker=dict(color='green', size=15), text=[f"{iter_points[-1]:.6f}"], textposition="top center", name="Current Point") layout = go.Layout( title=f"Iteration {st.session_state.step}", xaxis=dict(title="x - axis"), yaxis=dict(title="y - axis"), width=1000, height=600 ) fig = go.Figure(data=[trace1, trace2, trace3], layout=layout) st.plotly_chart(fig, use_container_width=True) st.success(f"Current Point = {iter_points[-1]}") except Exception as e: st.error(f"Plot error: {e}")