Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| # Safe function evaluation | |
| def safe_eval(func_str, x_val): | |
| allowed_names = {"x": x_val, "np": np} | |
| try: | |
| return eval(func_str, {"__builtins__": None}, allowed_names) | |
| except Exception as e: | |
| raise ValueError(f"Error evaluating the function: {e}") | |
| # Function derivative using finite difference method | |
| def derivative(func_str, x_val, h=1e-5): | |
| return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h) | |
| # Tangent line equation | |
| def tangent_line(func_str, x_val, x_range): | |
| y_val = safe_eval(func_str, x_val) | |
| slope = derivative(func_str, x_val) | |
| return slope * (x_range - x_val) + y_val | |
| # Reset session state | |
| def reset_state(): | |
| st.session_state.x = st.session_state.starting_point | |
| st.session_state.iteration = 0 | |
| st.session_state.x_vals = [st.session_state.starting_point] | |
| st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)] | |
| # Initialize session state | |
| if "func_input" not in st.session_state: | |
| st.session_state.func_input = "x**2 + x" | |
| if "x" not in st.session_state: | |
| st.session_state.x = 4.0 | |
| st.session_state.iteration = 0 | |
| st.session_state.x_vals = [4.0] | |
| st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)] | |
| st.set_page_config(layout="wide") | |
| # CSS for borders and font | |
| st.markdown( | |
| """ | |
| <style> | |
| * { | |
| font-family: Cambria, Arial, sans-serif !important; | |
| } | |
| .stPlotlyChart { | |
| border: 5px solid #001A6E; /* Plot border */ | |
| border-radius: 10px; | |
| padding: 5px; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.title("π Gradient Descent Interactive Tool π") | |
| col1, col2 = st.columns([1, 2]) | |
| # Left Section | |
| with col1: | |
| st.subheader("π§ Define Your Function") | |
| func_input = st.text_input( | |
| "Enter a function of x (e.g., x**2 + x):", | |
| key="func_input", | |
| on_change=reset_state | |
| ) | |
| starting_point = st.number_input( | |
| "Starting Point (Xβ):", | |
| value=4.0, | |
| step=0.1, | |
| key="starting_point", | |
| on_change=reset_state | |
| ) | |
| learning_rate = st.number_input( | |
| "Learning Rate (Ε):", | |
| value=0.25, | |
| step=0.01, | |
| key="learning_rate" | |
| ) | |
| if st.button("Reset"): | |
| reset_state() | |
| if st.button("Next Iteration"): | |
| try: | |
| grad = derivative(st.session_state.func_input, st.session_state.x) | |
| st.session_state.x = st.session_state.x - learning_rate * grad | |
| st.session_state.iteration += 1 | |
| st.session_state.x_vals.append(st.session_state.x) | |
| st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x)) | |
| except Exception as e: | |
| st.error(f"β οΈ Error: {str(e)}") | |
| # Right Section - Visualization | |
| with col2: | |
| st.subheader("π Visualization") | |
| try: | |
| x_plot = np.linspace(-10, 10, 400) | |
| y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot] | |
| fig = go.Figure() | |
| # Function plot | |
| fig.add_trace(go.Scatter( | |
| x=x_plot, | |
| y=y_plot, | |
| mode="lines", | |
| line=dict(color="blue", width=2), | |
| name="Function" | |
| )) | |
| # Gradient descent points | |
| fig.add_trace(go.Scatter( | |
| x=st.session_state.x_vals, | |
| y=st.session_state.y_vals, | |
| mode="markers", | |
| marker=dict(color="red", size=10), | |
| name="Gradient Descent Points" | |
| )) | |
| # Tangent line | |
| current_x = st.session_state.x | |
| tangent_x = np.linspace(-10, 10, 200) | |
| tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x) | |
| fig.add_trace(go.Scatter( | |
| x=tangent_x, | |
| y=tangent_y, | |
| mode="lines", | |
| line=dict(color="orange", width=3), | |
| name="Tangent Line" | |
| )) | |
| # Plot layout | |
| fig.update_layout( | |
| xaxis=dict( | |
| title="x-axis", | |
| zeroline=True, | |
| zerolinecolor="white", | |
| zerolinewidth=2, | |
| showgrid=True, | |
| gridcolor="lightgray", | |
| color="white" | |
| ), | |
| yaxis=dict( | |
| title="y-axis", | |
| zeroline=True, | |
| zerolinecolor="white", | |
| zerolinewidth=2, | |
| showgrid=True, | |
| gridcolor="lightgray", | |
| range=[0, max(y_plot) + 10], # Show non-negative y-axis only | |
| color="white" | |
| ), | |
| plot_bgcolor="black", | |
| paper_bgcolor="black", | |
| font=dict(color="white"), | |
| legend=dict( | |
| x=0.6, # Legend slightly left for border visibility | |
| y=1.0, | |
| bgcolor="black", | |
| bordercolor="#001A6E", | |
| borderwidth=2 | |
| ), | |
| margin=dict(l=10, r=80, t=10, b=10), # Expand right border | |
| width=800, | |
| height=400, | |
| showlegend=True | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| except Exception as e: | |
| st.error(f"β οΈ Error in visualization: {str(e)}") | |
| # Display iteration and current point info | |
| col5, col6, col7 = st.columns(3) | |
| col5.info(f"π§βπ» Iteration: {st.session_state.iteration}") | |
| col6.success(f"β Current x: {st.session_state.x:.4f}") | |
| col7.warning(f"π Current Point: ({st.session_state.x:.4f}, {st.session_state.y_vals[-1]:.4f})") | |