import streamlit as st import numpy as np import plotly.graph_objects as go # Safe function evaluation def safe_eval(func_str, x_val): """ Safely evaluates the function at a given x value. """ allowed_names = {"x": x_val, "np": np} try: return eval(func_str, {"__builtins__": None}, allowed_names) except Exception as e: raise ValueError(f"Error evaluating the function: {e}") # Function derivative using finite difference method def derivative(func_str, x_val, h=1e-5): """ Numerically compute the derivative of the function at x using finite differences. """ return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h) # Tangent line equation def tangent_line(func_str, x_val, x_range): """ Compute the tangent line at a given x value. """ y_val = safe_eval(func_str, x_val) slope = derivative(func_str, x_val) return slope * (x_range - x_val) + y_val # Callback to reset session state def reset_state(): st.session_state.x = st.session_state.starting_point st.session_state.iteration = 0 st.session_state.x_vals = [st.session_state.starting_point] st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)] # Initialize session state variables if "func_input" not in st.session_state: st.session_state.func_input = "x**2 + x" if "x" not in st.session_state: st.session_state.x = 4.0 st.session_state.iteration = 0 st.session_state.x_vals = [4.0] st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)] # Full-width layout st.set_page_config(layout="wide") # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color st.markdown( """ """, unsafe_allow_html=True, ) # Page Layout st.title("🌟 Gradient Descent Visualization Tool 🌟") col1, col2 = st.columns([1, 2]) # Left Section: User Input with col1: st.subheader("🔧 Define Your Function") # Tooltip with instructions when hovering over the function input label st.markdown( """
**How to input your function:** - Please give the inputs as mentioned below - x^n as x**n, - sin(x) as np.sin(x) - log(x) as np.log(x), - e^x or exp(x) as np.exp(x).
""", unsafe_allow_html=True ) # Use text input for the user to define a function, but no value argument func_input = st.text_input( "👇", key="func_input", on_change=reset_state ) st.subheader("⚙️ Gradient Descent Parameters") starting_point = st.number_input( "Starting Point (X₀)", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state ) learning_rate = st.number_input( "Learning Rate (ŋ)", value=0.25, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state ) col3, col4 = st.columns(2) with col3: if st.button("🔄 Set Up Function"): reset_state() with col4: if st.button("▶️ Next Iteration"): try: grad = derivative(st.session_state.func_input, st.session_state.x) st.session_state.x = st.session_state.x - learning_rate * grad st.session_state.iteration += 1 st.session_state.x_vals.append(st.session_state.x) st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x)) except Exception as e: st.error(f"⚠️ Error: {str(e)}") # Right Section: Visualization with col2: st.subheader("📊 Gradient Descent Visualization") try: # Plot the function and all current and previous gradient descent points x_plot = np.linspace(-10, 10, 400) y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot] fig = go.Figure() # Function curve fig.add_trace(go.Scatter( x=x_plot, y=y_plot, mode="lines+markers", line=dict(color="blue", width=2), marker=dict(size=4, color="blue", symbol="circle"), name="Function" )) # All gradient descent points (red points without coordinates) fig.add_trace(go.Scatter( x=st.session_state.x_vals, y=st.session_state.y_vals, mode="markers", marker=dict(color="red", size=10), name="Gradient Descent Points" )) # Tangent line at the current gradient descent point current_x = st.session_state.x tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x) fig.add_trace(go.Scatter( x=tangent_x, y=tangent_y, mode="lines", line=dict(color="orange", width=3), name="Tangent Line" )) # Dynamic zoom for better visibility fig.update_layout( xaxis=dict( title="x-axis", range=[-10, 10], showline=True, linecolor="white", tickcolor="white", tickfont=dict(color="white"), ticks="outside", ), yaxis=dict( title="y-axis", range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000 showline=True, linecolor="white", tickcolor="white", tickfont=dict(color="white"), ticks="outside", ), plot_bgcolor="black", paper_bgcolor="black", title="", margin=dict(l=10, r=10, t=10, b=10), width=800, height=400, showlegend=True, legend=dict( x=1.1, y=0.5, xanchor="left", yanchor="middle", orientation="v", font=dict(size=12, color="white"), bgcolor="black", bordercolor="white", borderwidth=2, ) ) # Axis lines for quadrants fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis st.plotly_chart(fig, use_container_width=True) except Exception as e: st.error(f"⚠️ Error in visualization: {str(e)}") # Iteration stats and download col5, col6 = st.columns(2) col5.info(f"🧑‍💻 Iteration: {st.session_state.iteration}") col6.success(f"✅ Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")