Adityaganesh's picture
Create app.py
06bfef5 verified
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
import sympy as sp
st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide", page_icon="πŸ“‰")
st.markdown(
"""
<style>
.css-18e3th9 { padding-top: 1rem; padding-bottom: 1rem; }
.stButton>button { border-radius: 8px; font-weight: bold; }
.stNumberInput>div>div>input { border-radius: 8px; }
.stTextInput>div>div>input { border-radius: 8px; }
.st-sidebar { background-color: #f9f9f9; border-right: 2px solid #eee; }
h1, h2, h3 { color: #4a4e69; font-family: 'Arial'; }
</style>
""",
unsafe_allow_html=True,
)
def parse_function(user_input):
x = sp.symbols('x')
try:
func_expr = sp.sympify(user_input)
func = sp.lambdify(x, func_expr, 'numpy')
gradient_expr = sp.diff(func_expr, x)
gradient = sp.lambdify(x, gradient_expr, 'numpy')
return func, gradient, str(gradient_expr)
except Exception as e:
st.error(f"Invalid function input: {e}")
return None, None, None
def gradient_descent_step(x, learning_rate, func, gradient):
grad = gradient(x)
x_new = x - learning_rate * grad
loss_new = func(x_new)
return x_new, loss_new, grad
st.title("πŸ“‰ Interactive Gradient Descent Visualizer")
with st.sidebar:
st.header("πŸ”§ Parameters")
user_function = st.text_input("Enter a Function (e.g., 2*x**2 + 3*x)", value="x**2")
start_point = st.number_input("Starting Point", value=5.0, step=0.00001, format="%.5f")
learning_rate = st.number_input("Learning Rate", value=0.1, step=0.00001, format="%.5f")
def setup_gradient_descent():
st.session_state.iteration = 0
st.session_state.final_minimum = False
st.session_state.x_values = [start_point]
st.session_state.loss_values = [func(start_point)]
st.session_state.gradients = []
st.button("πŸš€ Setup", on_click=setup_gradient_descent)
func, gradient, gradient_str = parse_function(user_function)
if func and gradient:
if 'iteration' not in st.session_state:
st.session_state.iteration = 0
st.session_state.final_minimum = False
st.session_state.x_values = [start_point]
st.session_state.loss_values = [func(start_point)]
st.session_state.gradients = []
col1, col2 = st.columns([2, 1])
with col1:
st.subheader(f"πŸ” Gradient Descent Visualization for: $f(x) = {user_function}$")
st.markdown(f"**Gradient (f'(x)):** $f'(x) = {gradient_str}$")
fig, ax = plt.subplots(figsize=(10, 6))
x_range = np.linspace(-10, 10, 500)
y_range = func(x_range)
ax.plot(x_range, y_range, label="Function: f(x)", color='#0077b6', linewidth=2)
for i in range(st.session_state.iteration + 1):
x = st.session_state.x_values[i]
y = func(x)
ax.scatter(x, y, color="#d90429", s=60, zorder=5, label="Descent Point" if i == 0 else "")
if st.session_state.iteration < len(st.session_state.x_values):
x = st.session_state.x_values[st.session_state.iteration]
y = func(x)
grad = gradient(x)
tangent = grad * (x_range - x) + y
ax.plot(x_range, tangent, color="#ffb703", linestyle="--", linewidth=1.5, alpha=0.8, label="Tangent Line")
ax.set_xlim([-10, 10])
ax.set_ylim([min(func(x_range)) - 1, max(func(x_range)) + 1])
ax.set_xlabel("x", fontsize=12, labelpad=10)
ax.set_ylabel("f(x)", fontsize=12, labelpad=10)
ax.set_title("Gradient Descent with Tangent Lines", fontsize=14, fontweight="bold", color="#4a4e69")
ax.legend(loc="upper right", frameon=True, fontsize=10)
ax.grid(alpha=0.3)
st.pyplot(fig)
with col2:
st.subheader("πŸ“Š Progress")
for i, (x, loss) in enumerate(zip(st.session_state.x_values, st.session_state.loss_values)):
st.write(f"Iteration {i}: x = {x:.4f}, f(x) = {loss:.4f}")
current_gradient = gradient(st.session_state.x_values[-1])
if abs(current_gradient) < 1e-6 and not st.session_state.final_minimum:
st.success(f"🎯 Final Minimum Reached: $x = {st.session_state.x_values[-1]:.4f}, f(x) = {st.session_state.loss_values[-1]:.4f}$")
st.session_state.final_minimum = True
def next_iteration():
if not st.session_state.final_minimum:
new_x, new_loss, grad = gradient_descent_step(st.session_state.x_values[-1], learning_rate, func, gradient)
st.session_state.x_values.append(new_x)
st.session_state.loss_values.append(new_loss)
st.session_state.gradients.append(grad)
st.session_state.iteration += 1
st.sidebar.button("⏭ Next Iteration", on_click=next_iteration)
if st.session_state.final_minimum:
st.info("The gradient is close to zero. Further iterations may not improve the result.")