Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| # Title of the app | |
| st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide") | |
| st.markdown("<h1 style='text-align: center; color: #FFD700;'> π Gradient Descent Visualizer</h1>", unsafe_allow_html=True) | |
| # Custom CSS for background and button color | |
| st.markdown(""" | |
| <style> | |
| body { | |
| background-color: #121212; /* Dark gray background for modern look */ | |
| color: white; /* White text for contrast */ | |
| } | |
| .stButton>button { | |
| background: linear-gradient(45deg, #FF7F50, #FF4500); /* Coral to OrangeRed gradient */ | |
| color: white; /* White button text */ | |
| border: none; | |
| border-radius: 8px; | |
| padding: 10px 20px; | |
| font-size: 16px; | |
| font-weight: bold; | |
| transition: transform 0.2s ease, box-shadow 0.3s ease, filter 0.3s ease; /* Smooth hover effects */ | |
| } | |
| .stButton>button:hover { | |
| transform: scale(1.1); /* Slight zoom effect on hover */ | |
| box-shadow: 0 0 20px 10px rgba(255, 69, 0, 0.8); /* Glowing shadow effect */ | |
| background: linear-gradient(45deg, #FF4500, #FF7F50); /* Reverse gradient */ | |
| filter: brightness(1.2); /* Slightly brightens the button */ | |
| } | |
| h1, h2, h3 { | |
| color: #00FFFF; /* Aqua for headings */ | |
| } | |
| .custom-text { | |
| color: #FFD700; /* Gold for highlighted text */ | |
| font-weight: bold; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Safe function evaluation | |
| def evaluate_function(expression, x_value): | |
| """Safely evaluates the mathematical function.""" | |
| allowed_names = {"x": x_value, "np": np} # Allow only x and numpy | |
| return eval(expression, {"_builtins_": None}, allowed_names) | |
| # Compute derivative using finite difference | |
| def compute_derivative(expression, x_value, h=1e-5): | |
| """Numerically calculates the derivative at a given point.""" | |
| return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h) | |
| # Tangent line calculation | |
| def calculate_tangent(expression, x_value, x_range): | |
| """Generates the tangent line for a given point.""" | |
| y_value = evaluate_function(expression, x_value) | |
| slope = compute_derivative(expression, x_value) | |
| return slope * (x_range - x_value) + y_value | |
| # Reset state | |
| def reset_session_state(): | |
| """Resets the session state for a fresh start.""" | |
| st.session_state.x_current = st.session_state.initial_point | |
| st.session_state.iter_count = 0 | |
| st.session_state.history = [ | |
| (st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point)) | |
| ] | |
| st.session_state.current_index = 0 | |
| # Initialize session state variables | |
| if "x_current" not in st.session_state: | |
| st.session_state.x_current = 0.0 # Default starting point | |
| if "iter_count" not in st.session_state: | |
| st.session_state.iter_count = 0 | |
| if "history" not in st.session_state: | |
| st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example | |
| if "current_index" not in st.session_state: | |
| st.session_state.current_index = 0 | |
| if "learning_rate" not in st.session_state: | |
| st.session_state.learning_rate = 0.1 | |
| # Create two-column grid layout for the left side (more space for the right graph) | |
| left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion | |
| # Left side content (Function Input and Gradient Descent Parameters) | |
| with left_col: | |
| st.markdown("<h3 style='color: #7FFF00;'>Input Your Function</h3>", unsafe_allow_html=True) | |
| function_input = st.text_input( | |
| "Enter Function:`Ex:'x**2`,`np.sin(x)`", | |
| "x**2 + x", | |
| key="math_function", | |
| on_change=reset_session_state | |
| ) | |
| st.markdown("<h3 style='color: #FF69B4;'>Set Parameters</h3>", unsafe_allow_html=True) | |
| initial_point = st.number_input( | |
| "Initial Value of x", | |
| value=4.0, | |
| step=0.1, | |
| format="%.2f", | |
| key="initial_point", | |
| on_change=reset_session_state | |
| ) | |
| st.number_input( | |
| "Learning Rate", | |
| value=st.session_state.learning_rate, | |
| step=0.01, | |
| format="%.2f", | |
| key="learning_rate" | |
| ) # Updates session state directly without reset | |
| st.markdown("<h3 style='color: #1E90FF;'>Controls</h3>", unsafe_allow_html=True) | |
| if st.button("π Run Descent Step", type="primary"): | |
| try: | |
| gradient = compute_derivative(function_input, st.session_state.x_current) | |
| st.session_state.x_current -= st.session_state.learning_rate * gradient | |
| st.session_state.iter_count += 1 | |
| st.session_state.history.append( | |
| (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current)) | |
| ) | |
| st.session_state.current_index = st.session_state.iter_count | |
| except Exception as e: | |
| st.error(f"Error: {str(e)}") | |
| if st.button("π Reset"): | |
| reset_session_state() | |
| # Right side content (Visualization and Iteration Details) | |
| with right_col: | |
| st.markdown("<h3 style='color: #FF6347;'>Gradient Descent Visualization</h3>", unsafe_allow_html=True) | |
| # Display iteration details using buttons | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| if st.button("β¬ οΈ Previous Iteration") and st.session_state.current_index > 0: | |
| st.session_state.current_index -= 1 | |
| with col2: | |
| st.markdown(f"**Iteration:** {st.session_state.current_index}", unsafe_allow_html=True) | |
| with col3: | |
| if st.button("β‘οΈ Next Iteration") and st.session_state.current_index < st.session_state.iter_count: | |
| st.session_state.current_index += 1 | |
| try: | |
| selected_x, selected_y = st.session_state.history[st.session_state.current_index] | |
| st.markdown(f"x Value: <span style='color: #FFD700;'>{selected_x:.4f}</span>", unsafe_allow_html=True) | |
| st.markdown(f"f(x): <span style='color: #FFD700;'>{selected_y:.4f}</span>", unsafe_allow_html=True) | |
| except IndexError: | |
| st.warning("No iteration data available. Please run a descent step first.") | |
| # Prepare data for visualization | |
| x_range = np.linspace(-10, 10, 500) # Define range for x | |
| y_range = [evaluate_function(function_input, x) for x in x_range] | |
| # Plot function curve with orange color | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=y_range, | |
| mode='lines', | |
| name='Function', | |
| line=dict(color='orange') # Curve color set to orange | |
| )) | |
| # Add current point | |
| x_current, y_current = st.session_state.history[st.session_state.current_index] | |
| fig.add_trace(go.Scatter( | |
| x=[x_current], | |
| y=[y_current], | |
| mode='markers', | |
| name='Current Point', | |
| marker=dict(size=10, color='red') | |
| )) | |
| # Add tangent line | |
| tangent_y = calculate_tangent(function_input, x_current, x_range) | |
| fig.add_trace(go.Scatter( | |
| x=x_range, | |
| y=tangent_y, | |
| mode='lines', | |
| name='Tangent Line', | |
| line=dict(dash='dash', color='blue') # Tangent line in blue | |
| )) | |
| # Layout adjustments | |
| fig.update_layout( | |
| title="Gradient Descent Progress", | |
| xaxis_title="x", | |
| yaxis_title="f(x)", | |
| template="plotly_white", | |
| height=600 | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |