Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| # Title of the app | |
| st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide") | |
| st.markdown("<h1 style='text-align: center; color: #00FA9A;'>โจ Gradient Descent Visualizer โจ</h1>", unsafe_allow_html=True) | |
| # Custom CSS for enhanced UI | |
| st.markdown(""" | |
| <style> | |
| body { | |
| background: linear-gradient(to right, #141E30, #243B55); | |
| color: #E0FFFF; | |
| } | |
| .stButton>button { | |
| background: linear-gradient(to right, #00C6FF, #0072FF); | |
| color: white; | |
| border: none; | |
| border-radius: 10px; | |
| padding: 10px 15px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| } | |
| .stButton>button:hover { | |
| background: linear-gradient(to right, #0072FF, #00C6FF); | |
| } | |
| .iteration-controls button { | |
| width: 100%; | |
| margin: 5px 0; | |
| } | |
| .block-container { | |
| padding: 0; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Safe function evaluation | |
| def evaluate_function(expression, x_value): | |
| """Safely evaluates the mathematical function.""" | |
| allowed_names = {"x": x_value, "np": np} # Allow only x and numpy | |
| return eval(expression, {"_builtins_": None}, allowed_names) | |
| # Compute derivative using finite difference | |
| def compute_derivative(expression, x_value, h=1e-5): | |
| """Numerically calculates the derivative at a given point.""" | |
| return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h) | |
| # Tangent line calculation | |
| def calculate_tangent(expression, x_value, x_range): | |
| """Generates the tangent line for a given point.""" | |
| y_value = evaluate_function(expression, x_value) | |
| slope = compute_derivative(expression, x_value) | |
| return slope * (x_range - x_value) + y_value | |
| # Reset state | |
| def reset_session_state(): | |
| """Resets the session state for a fresh start.""" | |
| st.session_state.x_current = st.session_state.initial_point | |
| st.session_state.iter_count = 0 | |
| st.session_state.history = [ | |
| (st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point)) | |
| ] | |
| st.session_state.current_index = 0 | |
| # Initialize session state variables | |
| if "x_current" not in st.session_state: | |
| st.session_state.x_current = 0.0 # Default starting point | |
| if "iter_count" not in st.session_state: | |
| st.session_state.iter_count = 0 | |
| if "history" not in st.session_state: | |
| st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example | |
| if "current_index" not in st.session_state: | |
| st.session_state.current_index = 0 | |
| if "learning_rate" not in st.session_state: | |
| st.session_state.learning_rate = 0.1 | |
| # Create a two-column layout with equal widths | |
| left_col, right_col = st.columns(2) | |
| # Left side content | |
| with left_col: | |
| st.markdown("### ๐งฎ Input Your Function") | |
| function_input = st.text_input( | |
| "Enter Function: Example: `x**2`, `np.sin(x)`", | |
| "x**2 + x", | |
| key="math_function", | |
| on_change=reset_session_state | |
| ) | |
| st.markdown("### โ๏ธ Set Parameters") | |
| initial_point = st.number_input( | |
| "๐ข Initial Value of x", | |
| value=4.0, | |
| step=0.1, | |
| format="%.2f", | |
| key="initial_point", | |
| on_change=reset_session_state | |
| ) | |
| st.number_input( | |
| "๐ Learning Rate", | |
| value=st.session_state.learning_rate, | |
| step=0.01, | |
| format="%.2f", | |
| key="learning_rate" | |
| ) # Updates session state directly without reset | |
| st.markdown("### ๐๏ธ Controls") | |
| if st.button("๐ Run Descent Step"): | |
| try: | |
| gradient = compute_derivative(function_input, st.session_state.x_current) | |
| st.session_state.x_current -= st.session_state.learning_rate * gradient | |
| st.session_state.iter_count += 1 | |
| st.session_state.history.append( | |
| (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current)) | |
| ) | |
| st.session_state.current_index = st.session_state.iter_count | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| if st.button("๐ Reset"): | |
| reset_session_state() | |
| # Right side content | |
| with right_col: | |
| st.markdown("### ๐ Gradient Descent Visualization") | |
| # Iteration control buttons | |
| col1, col2, col3 = st.columns([1, 1, 1]) | |
| with col1: | |
| if st.button("โฎ๏ธ Previous") and st.session_state.current_index > 0: | |
| st.session_state.current_index -= 1 | |
| with col2: | |
| st.markdown(f"<p style='text-align: center;'>Iteration: <strong>{st.session_state.current_index}</strong></p>", unsafe_allow_html=True) | |
| with col3: | |
| if st.button("โญ๏ธ Next") and st.session_state.current_index < st.session_state.iter_count: | |
| st.session_state.current_index += 1 | |
| try: | |
| selected_x, selected_y = st.session_state.history[st.session_state.current_index] | |
| st.markdown(f"๐งพ **x Value:** `{selected_x:.4f}`") | |
| st.markdown(f"๐ **f(x):** `{selected_y:.4f}`") | |
| except IndexError: | |
| st.warning("No iteration data available. Please run a descent step first.") | |
| # Prepare data for visualization | |
| x_range = np.linspace(-10, 10, 500) | |
| y_range = [evaluate_function(function_input, x) for x in x_range] | |
| # Plot function curve | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=y_range, | |
| mode='lines', | |
| name='Function', | |
| line=dict(color='blue') # Blue color for curve | |
| )) | |
| # Add current point | |
| x_current, y_current = st.session_state.history[st.session_state.current_index] | |
| fig.add_trace(go.Scatter( | |
| x=[x_current], | |
| y=[y_current], | |
| mode='markers', | |
| name='Current Point', | |
| marker=dict(size=12, color='red') # Red for current point | |
| )) | |
| # Add tangent line | |
| tangent_y = calculate_tangent(function_input, x_current, x_range) | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=tangent_y, | |
| mode='lines', | |
| name='Tangent Line', | |
| line=dict(dash='dash', color='yellow') # Yellow dashed line for tangent | |
| )) | |
| # Layout adjustments | |
| fig.update_layout( | |
| title="Gradient Descent Progress ๐", | |
| xaxis_title="x", | |
| yaxis_title="f(x)", | |
| template="plotly_dark", | |
| height=600 | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |