santosh7's picture
Update app.py
b9e61ac verified
Raw
History Blame Contribute Delete
3.95 kB
import streamlit as st
import numpy as np
import matplotlib.pyplot as plt
# Functions and their derivatives
functions = {
"sin(x)": (np.sin, np.cos),
"x^2": (lambda x: x**2, lambda x: 2*x),
"x": (lambda x: x, lambda x: np.ones_like(x)),
"x^3": (lambda x: np.power(x, 3), lambda x: 3 * np.power(x, 2)),
"e^x": (np.exp, np.exp)
}
# Gradient Descent Simulation
def gradient_descent_step(func, derivative, x, learning_rate):
grad = derivative(x)
x = x - learning_rate * grad
return x, func(x)
# Function to plot the function and gradient descent points
def plot_gradient_descent(func, derivative, points, learning_rate):
x_vals = np.linspace(-10, 10, 400)
y_vals = func(x_vals)
plt.figure(figsize=(10, 6))
plt.plot(x_vals, y_vals, label=f"f(x)")
for i, (x, y) in enumerate(points):
# Plot the point
plt.scatter(x, y, color='red')
# Plot the tangent line
slope = derivative(x)
tangent_line = slope * (x_vals - x) + y
plt.plot(x_vals, tangent_line, '--', color='gray', alpha=0.5, label=f"Tangent at iteration {i}" if i == 0 else "")
plt.title(f"Gradient Descent with Learning Rate {learning_rate}")
plt.xlabel("x")
plt.ylabel("f(x)")
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.grid(True)
plt.legend()
st.pyplot(plt)
# Streamlit app
st.title("Learning Rate Optimization")
st.sidebar.image('Innomatics-Logo1.png', use_column_width=True)
# Initialize session state variables if not already initialized
if 'current_iteration' not in st.session_state:
st.session_state.current_iteration = 0
if 'points' not in st.session_state:
st.session_state.points = []
if 'x' not in st.session_state:
st.session_state.x = None
# User input for selecting the function
function_name = st.sidebar.selectbox("Select a function to plot", list(functions.keys()))
# Generate starting points including 0.99
starting_points = np.round(np.linspace(-10, 10, 21), 2).tolist() # 21 points between -10 and 10 with 2 decimal precision
if 0.99 not in starting_points:
starting_points.append(0.99)
starting_points = sorted(starting_points) # Ensure the list is sorted
starting_point = st.sidebar.selectbox("Select Starting Point", starting_points, index=starting_points.index(5.0))
# Generate learning rates including 0.44
learning_rates = np.round(np.linspace(0.001, 1.0, 100), 3).tolist() # 100 points between 0.001 and 1.0
if 0.44 not in learning_rates:
learning_rates.append(0.44)
learning_rates = sorted(learning_rates) # Ensure the list is sorted
default_learning_rate = 0.1
closest_index = int(np.argmin(np.abs(np.array(learning_rates) - default_learning_rate))) # Convert to int
learning_rate = st.sidebar.selectbox("Select Learning Rate", learning_rates, index=closest_index)
# Selectbox for number of iterations
iterations_list = list(range(1, 51)) # Generates numbers from 1 to 50
iterations = st.sidebar.selectbox("Select Number of Iterations", iterations_list, index=iterations_list.index(10))
# Submit button
submit_button = st.sidebar.button(label='Submit')
# Handle form submission
if submit_button:
st.session_state.current_iteration = 0
st.session_state.points = [(starting_point, functions[function_name][0](starting_point))]
st.session_state.x = starting_point
# Next iteration button
if st.sidebar.button("Next Iteration"):
if st.session_state.current_iteration < iterations:
st.session_state.current_iteration += 1
x, y = gradient_descent_step(functions[function_name][0], functions[function_name][1], st.session_state.x, learning_rate)
st.session_state.x = x
st.session_state.points.append((x, y))
# Plot the result after every iteration
if st.session_state.points:
plot_gradient_descent(functions[function_name][0], functions[function_name][1], st.session_state.points, learning_rate)