trohith89's picture
Update app.py
0824e67 verified
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Safe function evaluation
def safe_eval(func_str, x_val):
""" Safely evaluates the function at a given x value. """
allowed_names = {"x": x_val, "np": np}
try:
return eval(func_str, {"__builtins__": None}, allowed_names)
except Exception as e:
raise ValueError(f"Error evaluating the function: {e}")
# Function derivative using finite difference method
def derivative(func_str, x_val, h=1e-5):
""" Numerically compute the derivative of the function at x using finite differences. """
return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
# Tangent line equation
def tangent_line(func_str, x_val, x_range):
""" Compute the tangent line at a given x value. """
y_val = safe_eval(func_str, x_val)
slope = derivative(func_str, x_val)
return slope * (x_range - x_val) + y_val
# Callback to reset session state
def reset_state():
st.session_state.x = st.session_state.starting_point
st.session_state.iteration = 0
st.session_state.x_vals = [st.session_state.starting_point]
st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
# Initialize session state variables
if "func_input" not in st.session_state:
st.session_state.func_input = "x**2 + x"
if "x" not in st.session_state:
st.session_state.x = 4.0
st.session_state.iteration = 0
st.session_state.x_vals = [4.0]
st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
# Full-width layout
st.set_page_config(layout="wide")
# CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
st.markdown(
"""
<style>
* {
font-family: Cambria, Arial, sans-serif !important;
}
h1, h2, h3, h4, h5 {
text-align: center;
margin-top: 0;
}
input, .stButton button, .stDownloadButton button {
border: 2px solid #ea445a;
border-radius: 5px;
padding: 10px;
}
.stInfo, .stSuccess {
border: 2px solid #ea445a;
border-radius: 5px;
padding: 10px;
}
.stButton {
margin-top: 10px;
}
/* Reduced Padding at the top */
.css-1d391kg {
padding-top: 0.5rem;
}
/* Centering the legend in the plot */
.stPlotlyChart {
display: block;
margin: 0 auto;
}
/* Adjusting for full width without scrolling */
.css-1lcbvhc {
padding-left: 0;
padding-right: 0;
}
/* Custom borders for input fields */
.stTextInput input, .stNumberInput input {
border: 2px solid #001A6E;
border-radius: 5px;
padding: 10px;
}
/* Tooltip styling */
.tooltip {
position: relative;
display: inline-block;
cursor: pointer;
}
.tooltip .tooltiptext {
visibility: hidden;
opacity: 0;
width: 300px;
background-color: #001A6E;
color: #fff;
text-align: center;
border-radius: 5px;
padding: 5px;
position: absolute;
z-index: 1;
bottom: 125%; /* Position the tooltip above */
left: 50%;
margin-left: -150px;
transition: opacity 0.3s;
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
</style>
""",
unsafe_allow_html=True,
)
# Page Layout
st.title("🌟 Gradient Descent Visualization Tool 🌟")
col1, col2 = st.columns([1, 2])
# Left Section: User Input
with col1:
st.subheader("πŸ”§ Define Your Function")
# Tooltip with instructions when hovering over the function input label
st.markdown(
"""
<div class="tooltip">
<label for="func_input">Enter a function of 'x':</label>
<span class="tooltiptext">
**How to input your function:**
- Please give the inputs as mentioned below
- x^n as x**n,
- sin(x) as np.sin(x)
- log(x) as np.log(x),
- e^x or exp(x) as np.exp(x).
</span>
</div>
""",
unsafe_allow_html=True
)
# Use text input for the user to define a function, but no value argument
func_input = st.text_input(
"πŸ‘‡",
key="func_input",
on_change=reset_state
)
st.subheader("βš™οΈ Gradient Descent Parameters")
starting_point = st.number_input(
"Starting Point (Xβ‚€)",
value=4.0,
step=0.1,
format="%.2f",
key="starting_point",
on_change=reset_state
)
learning_rate = st.number_input(
"Learning Rate (Ε‹)",
value=0.25,
step=0.01,
format="%.2f",
key="learning_rate",
on_change=reset_state
)
col3, col4 = st.columns(2)
with col3:
if st.button("πŸ”„ Set Up Function"):
reset_state()
with col4:
if st.button("▢️ Next Iteration"):
try:
grad = derivative(st.session_state.func_input, st.session_state.x)
st.session_state.x = st.session_state.x - learning_rate * grad
st.session_state.iteration += 1
st.session_state.x_vals.append(st.session_state.x)
st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
except Exception as e:
st.error(f"⚠️ Error: {str(e)}")
# Right Section: Visualization
with col2:
st.subheader("πŸ“Š Gradient Descent Visualization")
try:
# Plot the function and all current and previous gradient descent points
x_plot = np.linspace(-10, 10, 400)
y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
fig = go.Figure()
# Function curve
fig.add_trace(go.Scatter(
x=x_plot,
y=y_plot,
mode="lines+markers",
line=dict(color="blue", width=2),
marker=dict(size=4, color="blue", symbol="circle"),
name="Function"
))
# All gradient descent points (red points without coordinates)
fig.add_trace(go.Scatter(
x=st.session_state.x_vals,
y=st.session_state.y_vals,
mode="markers",
marker=dict(color="red", size=10),
name="Gradient Descent Points"
))
# Tangent line at the current gradient descent point
current_x = st.session_state.x
tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot
tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
fig.add_trace(go.Scatter(
x=tangent_x,
y=tangent_y,
mode="lines",
line=dict(color="orange", width=3),
name="Tangent Line"
))
# Dynamic zoom for better visibility
fig.update_layout(
xaxis=dict(
title="x-axis",
range=[-10, 10],
showline=True,
linecolor="white",
tickcolor="white",
tickfont=dict(color="white"),
ticks="outside",
),
yaxis=dict(
title="y-axis",
range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000
showline=True,
linecolor="white",
tickcolor="white",
tickfont=dict(color="white"),
ticks="outside",
),
plot_bgcolor="black",
paper_bgcolor="black",
title="",
margin=dict(l=10, r=10, t=10, b=10),
width=800,
height=400,
showlegend=True,
legend=dict(
x=1.1,
y=0.5,
xanchor="left",
yanchor="middle",
orientation="v",
font=dict(size=12, color="white"),
bgcolor="black",
bordercolor="white",
borderwidth=2,
)
)
# Axis lines for quadrants
fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
st.plotly_chart(fig, use_container_width=True)
except Exception as e:
st.error(f"⚠️ Error in visualization: {str(e)}")
# Iteration stats and download
col5, col6 = st.columns(2)
col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")