trohith89's picture
Update app.py
d90f498 verified
raw
history blame
5.61 kB
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Safe function evaluation
def safe_eval(func_str, x_val):
allowed_names = {"x": x_val, "np": np}
try:
return eval(func_str, {"__builtins__": None}, allowed_names)
except Exception as e:
raise ValueError(f"Error evaluating the function: {e}")
# Function derivative using finite difference method
def derivative(func_str, x_val, h=1e-5):
return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
# Tangent line equation
def tangent_line(func_str, x_val, x_range):
y_val = safe_eval(func_str, x_val)
slope = derivative(func_str, x_val)
return slope * (x_range - x_val) + y_val
# Reset session state
def reset_state():
st.session_state.x = st.session_state.starting_point
st.session_state.iteration = 0
st.session_state.x_vals = [st.session_state.starting_point]
st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
# Initialize session state
if "func_input" not in st.session_state:
st.session_state.func_input = "x**2 + x"
if "x" not in st.session_state:
st.session_state.x = 4.0
st.session_state.iteration = 0
st.session_state.x_vals = [4.0]
st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
st.set_page_config(layout="wide")
# CSS for borders and font
st.markdown(
"""
<style>
* {
font-family: Cambria, Arial, sans-serif !important;
}
.stPlotlyChart {
border: 5px solid #001A6E; /* Plot border */
border-radius: 10px;
padding: 5px;
}
</style>
""",
unsafe_allow_html=True,
)
st.title("🌟 Gradient Descent Interactive Tool 🌟")
col1, col2 = st.columns([1, 2])
# Left Section
with col1:
st.subheader("πŸ”§ Define Your Function")
func_input = st.text_input(
"Enter a function of x (e.g., x**2 + x):",
key="func_input",
on_change=reset_state
)
starting_point = st.number_input(
"Starting Point (Xβ‚€):",
value=4.0,
step=0.1,
key="starting_point",
on_change=reset_state
)
learning_rate = st.number_input(
"Learning Rate (Ε‹):",
value=0.25,
step=0.01,
key="learning_rate"
)
if st.button("Reset"):
reset_state()
if st.button("Next Iteration"):
try:
grad = derivative(st.session_state.func_input, st.session_state.x)
st.session_state.x = st.session_state.x - learning_rate * grad
st.session_state.iteration += 1
st.session_state.x_vals.append(st.session_state.x)
st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
except Exception as e:
st.error(f"⚠️ Error: {str(e)}")
# Right Section - Visualization
with col2:
st.subheader("πŸ“Š Visualization")
try:
x_plot = np.linspace(-10, 10, 400)
y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
fig = go.Figure()
# Function plot
fig.add_trace(go.Scatter(
x=x_plot,
y=y_plot,
mode="lines",
line=dict(color="blue", width=2),
name="Function"
))
# Gradient descent points
fig.add_trace(go.Scatter(
x=st.session_state.x_vals,
y=st.session_state.y_vals,
mode="markers",
marker=dict(color="red", size=10),
name="Gradient Descent Points"
))
# Tangent line
current_x = st.session_state.x
tangent_x = np.linspace(-10, 10, 200)
tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
fig.add_trace(go.Scatter(
x=tangent_x,
y=tangent_y,
mode="lines",
line=dict(color="orange", width=3),
name="Tangent Line"
))
# Plot layout
fig.update_layout(
xaxis=dict(
title="x-axis",
zeroline=True,
zerolinecolor="white",
zerolinewidth=2,
showgrid=True,
gridcolor="lightgray",
color="white"
),
yaxis=dict(
title="y-axis",
zeroline=True,
zerolinecolor="white",
zerolinewidth=2,
showgrid=True,
gridcolor="lightgray",
range=[0, max(y_plot) + 10], # Show non-negative y-axis only
color="white"
),
plot_bgcolor="black",
paper_bgcolor="black",
font=dict(color="white"),
legend=dict(
x=0.6, # Legend slightly left for border visibility
y=1.0,
bgcolor="black",
bordercolor="#001A6E",
borderwidth=2
),
margin=dict(l=10, r=80, t=10, b=10), # Expand right border
width=800,
height=400,
showlegend=True
)
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.error(f"⚠️ Error in visualization: {str(e)}")
# Display iteration and current point info
col5, col6, col7 = st.columns(3)
col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
col6.success(f"βœ… Current x: {st.session_state.x:.4f}")
col7.warning(f"πŸ“ Current Point: ({st.session_state.x:.4f}, {st.session_state.y_vals[-1]:.4f})")