Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # Functions and their derivatives | |
| functions = { | |
| "sin(x)": (np.sin, np.cos), | |
| "x^2": (lambda x: x**2, lambda x: 2*x), | |
| "x": (lambda x: x, lambda x: np.ones_like(x)), | |
| "x^3": (lambda x: np.power(x, 3), lambda x: 3 * np.power(x, 2)), | |
| "e^x": (np.exp, np.exp) | |
| } | |
| # Gradient Descent Simulation | |
| def gradient_descent_step(func, derivative, x, learning_rate): | |
| grad = derivative(x) | |
| x = x - learning_rate * grad | |
| return x, func(x) | |
| # Function to plot the function and gradient descent points | |
| def plot_gradient_descent(func, derivative, points, learning_rate): | |
| x_vals = np.linspace(-10, 10, 400) | |
| y_vals = func(x_vals) | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(x_vals, y_vals, label=f"f(x)") | |
| for i, (x, y) in enumerate(points): | |
| # Plot the point | |
| plt.scatter(x, y, color='red') | |
| # Plot the tangent line | |
| slope = derivative(x) | |
| tangent_line = slope * (x_vals - x) + y | |
| plt.plot(x_vals, tangent_line, '--', color='gray', alpha=0.5, label=f"Tangent at iteration {i}" if i == 0 else "") | |
| plt.title(f"Gradient Descent with Learning Rate {learning_rate}") | |
| plt.xlabel("x") | |
| plt.ylabel("f(x)") | |
| plt.axhline(0, color='black', linewidth=0.5) | |
| plt.axvline(0, color='black', linewidth=0.5) | |
| plt.grid(True) | |
| plt.legend() | |
| st.pyplot(plt) | |
| # Streamlit app | |
| st.title("Learning Rate Optimization") | |
| st.sidebar.image('Innomatics-Logo1.png', use_column_width=True) | |
| # Initialize session state variables if not already initialized | |
| if 'current_iteration' not in st.session_state: | |
| st.session_state.current_iteration = 0 | |
| if 'points' not in st.session_state: | |
| st.session_state.points = [] | |
| if 'x' not in st.session_state: | |
| st.session_state.x = None | |
| # User input for selecting the function | |
| function_name = st.sidebar.selectbox("Select a function to plot", list(functions.keys())) | |
| # Generate starting points including 0.99 | |
| starting_points = np.round(np.linspace(-10, 10, 21), 2).tolist() # 21 points between -10 and 10 with 2 decimal precision | |
| if 0.99 not in starting_points: | |
| starting_points.append(0.99) | |
| starting_points = sorted(starting_points) # Ensure the list is sorted | |
| starting_point = st.sidebar.selectbox("Select Starting Point", starting_points, index=starting_points.index(5.0)) | |
| # Generate learning rates including 0.44 | |
| learning_rates = np.round(np.linspace(0.001, 1.0, 100), 3).tolist() # 100 points between 0.001 and 1.0 | |
| if 0.44 not in learning_rates: | |
| learning_rates.append(0.44) | |
| learning_rates = sorted(learning_rates) # Ensure the list is sorted | |
| default_learning_rate = 0.1 | |
| closest_index = int(np.argmin(np.abs(np.array(learning_rates) - default_learning_rate))) # Convert to int | |
| learning_rate = st.sidebar.selectbox("Select Learning Rate", learning_rates, index=closest_index) | |
| # Selectbox for number of iterations | |
| iterations_list = list(range(1, 51)) # Generates numbers from 1 to 50 | |
| iterations = st.sidebar.selectbox("Select Number of Iterations", iterations_list, index=iterations_list.index(10)) | |
| # Submit button | |
| submit_button = st.sidebar.button(label='Submit') | |
| # Handle form submission | |
| if submit_button: | |
| st.session_state.current_iteration = 0 | |
| st.session_state.points = [(starting_point, functions[function_name][0](starting_point))] | |
| st.session_state.x = starting_point | |
| # Next iteration button | |
| if st.sidebar.button("Next Iteration"): | |
| if st.session_state.current_iteration < iterations: | |
| st.session_state.current_iteration += 1 | |
| x, y = gradient_descent_step(functions[function_name][0], functions[function_name][1], st.session_state.x, learning_rate) | |
| st.session_state.x = x | |
| st.session_state.points.append((x, y)) | |
| # Plot the result after every iteration | |
| if st.session_state.points: | |
| plot_gradient_descent(functions[function_name][0], functions[function_name][1], st.session_state.points, learning_rate) | |