trohith89's picture
Update app.py
417b36c verified
raw
history blame
4.26 kB
import streamlit as st
# Title
st.title("Machine Learning Project")
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Title of the app
st.title("Gradient Descent Visualizer with Tangent Lines")
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Title of the app
st.title("Gradient Descent Visualizer with Tangent Lines")
# Safe function evaluation
def safe_eval(func_str, x_val):
""" Safely evaluates the function at a given x value. """
allowed_names = {"x": x_val, "np": np} # Only allow x and numpy
return eval(func_str, {"__builtins__": None}, allowed_names)
# Function derivative using finite difference method
def derivative(func_str, x_val, h=1e-5):
""" Numerically compute the derivative of the function at x using finite differences. """
return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
# Tangent Line Equation
def tangent_line(func_str, x_val, x_range):
""" Compute the tangent line at a given x value. """
y_val = safe_eval(func_str, x_val)
slope = derivative(func_str, x_val)
return slope * (x_range - x_val) + y_val
# Callback to reset session state
def reset_state():
st.session_state.x = st.session_state.starting_point
st.session_state.iteration = 0
st.session_state.x_vals = [st.session_state.starting_point]
st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
# Function input
st.header("Define Your Function")
func_input = st.text_input("Enter a function of 'x' (e.g., x**2 + x, sin(x), x**3 - 3*x + 2):", "x**2 + x", key="func_input", on_change=reset_state)
# Starting Point and Learning Rate
st.header("Gradient Descent Parameters")
starting_point = st.number_input("Starting Point", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state)
learning_rate = st.number_input("Learning Rate", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state)
# Initialize session state variables if they don't exist
if "x" not in st.session_state:
st.session_state.x = starting_point
st.session_state.iteration = 0
st.session_state.x_vals = [starting_point]
st.session_state.y_vals = [safe_eval(func_input, starting_point)]
# "Next Iteration" button logic
if st.button("Next Iteration"):
try:
# Perform one iteration of gradient descent
grad = derivative(func_input, st.session_state.x)
st.session_state.x = st.session_state.x - learning_rate * grad
st.session_state.iteration += 1
# Save the new x and y values
st.session_state.x_vals.append(st.session_state.x)
st.session_state.y_vals.append(safe_eval(func_input, st.session_state.x))
except Exception as e:
st.error(f"Error: {str(e)}")
# Display iteration results
st.subheader("Gradient Descent Progress")
st.write(f"Iteration: {st.session_state.iteration}")
st.write(f"Current x: {st.session_state.x:.4f}")
st.write(f"Current f(x): {st.session_state.y_vals[-1]:.4f}")
# Plot the function, gradient descent points, and tangent line
x_plot = np.linspace(-10, 10, 400)
y_plot = [safe_eval(func_input, x) for x in x_plot]
fig = go.Figure()
# Add the function curve
fig.add_trace(go.Scatter(x=x_plot, y=y_plot, mode="lines", name="Function"))
# Add gradient descent points in red
fig.add_trace(go.Scatter(
x=st.session_state.x_vals,
y=st.session_state.y_vals,
mode="markers",
marker=dict(color="red", size=8),
name="Gradient Descent Points"
))
# Add the tangent line at the current point
current_x = st.session_state.x
current_y = safe_eval(func_input, current_x)
slope = derivative(func_input, current_x)
# Generate tangent line range
tangent_x = np.linspace(current_x - 2, current_x + 2, 100)
tangent_y = tangent_line(func_input, current_x, tangent_x)
# Plot the tangent line as a straight solid line
fig.add_trace(go.Scatter(
x=tangent_x,
y=tangent_y,
mode="lines",
line=dict(color="orange", width=3),
name="Tangent Line"
))
# Update layout
fig.update_layout(
xaxis_title="x",
yaxis_title="f(x)",
title="Gradient Descent Visualization with Tangent Line"
)
# Display the plot
st.plotly_chart(fig)