Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| # Safe function evaluation | |
| def safe_eval(func_str, x_val): | |
| """ Safely evaluates the function at a given x value. """ | |
| allowed_names = {"x": x_val, "np": np} | |
| try: | |
| return eval(func_str, {"__builtins__": None}, allowed_names) | |
| except Exception as e: | |
| raise ValueError(f"Error evaluating the function: {e}") | |
| # Function derivative using finite difference method | |
| def derivative(func_str, x_val, h=1e-5): | |
| """ Numerically compute the derivative of the function at x using finite differences. """ | |
| return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h) | |
| # Tangent line equation | |
| def tangent_line(func_str, x_val, x_range): | |
| """ Compute the tangent line at a given x value. """ | |
| y_val = safe_eval(func_str, x_val) | |
| slope = derivative(func_str, x_val) | |
| return slope * (x_range - x_val) + y_val | |
| # Callback to reset session state | |
| def reset_state(): | |
| st.session_state.x = st.session_state.starting_point | |
| st.session_state.iteration = 0 | |
| st.session_state.x_vals = [st.session_state.starting_point] | |
| st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)] | |
| # Initialize session state variables | |
| if "func_input" not in st.session_state: | |
| st.session_state.func_input = "x**2 + x" | |
| if "x" not in st.session_state: | |
| st.session_state.x = 4.0 | |
| st.session_state.iteration = 0 | |
| st.session_state.x_vals = [4.0] | |
| st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)] | |
| # Full-width layout | |
| st.set_page_config(layout="wide") | |
| # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color | |
| st.markdown( | |
| """ | |
| <style> | |
| * { | |
| font-family: Cambria, Arial, sans-serif !important; | |
| } | |
| h1, h2, h3, h4, h5 { | |
| text-align: center; | |
| margin-top: 0; | |
| } | |
| input, .stButton button, .stDownloadButton button { | |
| border: 2px solid #ea445a; | |
| border-radius: 5px; | |
| padding: 10px; | |
| } | |
| .stInfo, .stSuccess { | |
| border: 2px solid #ea445a; | |
| border-radius: 5px; | |
| padding: 10px; | |
| } | |
| .stButton { | |
| margin-top: 10px; | |
| } | |
| /* Reduced Padding at the top */ | |
| .css-1d391kg { | |
| padding-top: 0.5rem; | |
| } | |
| /* Centering the legend in the plot */ | |
| .stPlotlyChart { | |
| display: block; | |
| margin: 0 auto; | |
| } | |
| /* Adjusting for full width without scrolling */ | |
| .css-1lcbvhc { | |
| padding-left: 0; | |
| padding-right: 0; | |
| } | |
| /* Custom borders for input fields */ | |
| .stTextInput input, .stNumberInput input { | |
| border: 2px solid #001A6E; | |
| border-radius: 5px; | |
| padding: 10px; | |
| } | |
| /* Tooltip styling */ | |
| .tooltip { | |
| position: relative; | |
| display: inline-block; | |
| cursor: pointer; | |
| } | |
| .tooltip .tooltiptext { | |
| visibility: hidden; | |
| opacity: 0; | |
| width: 300px; | |
| background-color: #001A6E; | |
| color: #fff; | |
| text-align: center; | |
| border-radius: 5px; | |
| padding: 5px; | |
| position: absolute; | |
| z-index: 1; | |
| bottom: 125%; /* Position the tooltip above */ | |
| left: 50%; | |
| margin-left: -150px; | |
| transition: opacity 0.3s; | |
| } | |
| .tooltip:hover .tooltiptext { | |
| visibility: visible; | |
| opacity: 1; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Page Layout | |
| st.title("π Gradient Descent Interactive Tool π") | |
| col1, col2 = st.columns([1, 2]) | |
| # Left Section: User Input | |
| with col1: | |
| st.subheader("π§ Define Your Function") | |
| # Tooltip with instructions when hovering over the function input label | |
| st.markdown( | |
| """ | |
| <div class="tooltip"> | |
| <label for="func_input">Enter a function of 'x':</label> | |
| <span class="tooltiptext"> | |
| **How to input your function:** | |
| - Please give the inputs as mentioned below | |
| - x^n as x**n, | |
| - sin(x) as np.sin(x) | |
| - log(x) as np.log(x), | |
| - e^x or exp(x) as np.exp(x). | |
| </span> | |
| </div> | |
| """, | |
| unsafe_allow_html=True | |
| ) | |
| # Use text input for the user to define a function, but no `value` argument | |
| func_input = st.text_input( | |
| "π", | |
| key="func_input", | |
| on_change=reset_state | |
| ) | |
| st.subheader("βοΈ Gradient Descent Parameters") | |
| starting_point = st.number_input( | |
| "Starting Point (Xβ)", | |
| value=4.0, | |
| step=0.1, | |
| format="%.2f", | |
| key="starting_point", | |
| on_change=reset_state | |
| ) | |
| learning_rate = st.number_input( | |
| "Learning Rate (Ε)", | |
| value=0.25, | |
| step=0.01, | |
| format="%.2f", | |
| key="learning_rate", | |
| on_change=reset_state | |
| ) | |
| col3, col4 = st.columns(2) | |
| with col3: | |
| if st.button("π Set Up Function"): | |
| reset_state() | |
| with col4: | |
| if st.button("βΆοΈ Next Iteration"): | |
| try: | |
| grad = derivative(st.session_state.func_input, st.session_state.x) | |
| st.session_state.x = st.session_state.x - learning_rate * grad | |
| st.session_state.iteration += 1 | |
| st.session_state.x_vals.append(st.session_state.x) | |
| st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x)) | |
| except Exception as e: | |
| st.error(f"β οΈ Error: {str(e)}") | |
| # Right Section: Visualization | |
| with col2: | |
| st.subheader("π Gradient Descent Visualization") | |
| try: | |
| # Plot the function and all current and previous gradient descent points | |
| x_plot = np.linspace(-10, 10, 400) | |
| y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot] | |
| fig = go.Figure() | |
| # Function curve | |
| fig.add_trace(go.Scatter( | |
| x=x_plot, | |
| y=y_plot, | |
| mode="lines+markers", | |
| line=dict(color="blue", width=2), | |
| marker=dict(size=4, color="blue", symbol="circle"), | |
| name="Function" | |
| )) | |
| # All gradient descent points (red points without coordinates) | |
| fig.add_trace(go.Scatter( | |
| x=st.session_state.x_vals, | |
| y=st.session_state.y_vals, | |
| mode="markers", | |
| marker=dict(color="red", size=10), | |
| name="Gradient Descent Points" | |
| )) | |
| # Tangent line at the current gradient descent point | |
| current_x = st.session_state.x | |
| tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot | |
| tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x) | |
| fig.add_trace(go.Scatter( | |
| x=tangent_x, | |
| y=tangent_y, | |
| mode="lines", | |
| line=dict(color="orange", width=3), | |
| name="Tangent Line" | |
| )) | |
| # Dynamic zoom for better visibility | |
| fig.update_layout( | |
| xaxis=dict( | |
| title="x-axis", | |
| range=[-10, 10], | |
| showline=True, | |
| linecolor="white", | |
| tickcolor="white", | |
| tickfont=dict(color="white"), | |
| ticks="outside", | |
| ), | |
| yaxis=dict( | |
| title="y-axis", | |
| range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000 | |
| showline=True, | |
| linecolor="white", | |
| tickcolor="white", | |
| tickfont=dict(color="white"), | |
| ticks="outside", | |
| ), | |
| plot_bgcolor="black", | |
| paper_bgcolor="black", | |
| title="", | |
| margin=dict(l=10, r=10, t=10, b=10), | |
| width=800, | |
| height=400, | |
| showlegend=True, | |
| legend=dict( | |
| x=1.1, | |
| y=0.5, | |
| xanchor="left", | |
| yanchor="middle", | |
| orientation="v", | |
| font=dict(size=12, color="white"), | |
| bgcolor="black", | |
| bordercolor="white", | |
| borderwidth=2, | |
| ) | |
| ) | |
| # Axis lines for quadrants | |
| fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis | |
| fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"β οΈ Error in visualization: {str(e)}") | |
| # Iteration stats and download | |
| col5, col6 = st.columns(2) | |
| col5.info(f"π§βπ» Iteration: {st.session_state.iteration}") | |
| col6.success(f"β Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}") | |