Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import streamlit as st
|
| 4 |
+
import sympy as sp
|
| 5 |
+
|
| 6 |
+
st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide", page_icon="π")
|
| 7 |
+
st.markdown(
|
| 8 |
+
"""
|
| 9 |
+
<style>
|
| 10 |
+
.css-18e3th9 { padding-top: 1rem; padding-bottom: 1rem; }
|
| 11 |
+
.stButton>button { border-radius: 8px; font-weight: bold; }
|
| 12 |
+
.stNumberInput>div>div>input { border-radius: 8px; }
|
| 13 |
+
.stTextInput>div>div>input { border-radius: 8px; }
|
| 14 |
+
.st-sidebar { background-color: #f9f9f9; border-right: 2px solid #eee; }
|
| 15 |
+
h1, h2, h3 { color: #4a4e69; font-family: 'Arial'; }
|
| 16 |
+
</style>
|
| 17 |
+
""",
|
| 18 |
+
unsafe_allow_html=True,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def parse_function(user_input):
|
| 22 |
+
x = sp.symbols('x')
|
| 23 |
+
try:
|
| 24 |
+
func_expr = sp.sympify(user_input)
|
| 25 |
+
func = sp.lambdify(x, func_expr, 'numpy')
|
| 26 |
+
gradient_expr = sp.diff(func_expr, x)
|
| 27 |
+
gradient = sp.lambdify(x, gradient_expr, 'numpy')
|
| 28 |
+
return func, gradient, str(gradient_expr)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
st.error(f"Invalid function input: {e}")
|
| 31 |
+
return None, None, None
|
| 32 |
+
|
| 33 |
+
def gradient_descent_step(x, learning_rate, func, gradient):
|
| 34 |
+
grad = gradient(x)
|
| 35 |
+
x_new = x - learning_rate * grad
|
| 36 |
+
loss_new = func(x_new)
|
| 37 |
+
return x_new, loss_new, grad
|
| 38 |
+
|
| 39 |
+
st.title("π Interactive Gradient Descent Visualizer")
|
| 40 |
+
with st.sidebar:
|
| 41 |
+
st.header("π§ Parameters")
|
| 42 |
+
user_function = st.text_input("Enter a Function (e.g., 2*x**2 + 3*x)", value="x**2")
|
| 43 |
+
start_point = st.number_input("Starting Point", value=5.0, step=0.00001, format="%.5f")
|
| 44 |
+
learning_rate = st.number_input("Learning Rate", value=0.1, step=0.00001, format="%.5f")
|
| 45 |
+
|
| 46 |
+
def setup_gradient_descent():
|
| 47 |
+
st.session_state.iteration = 0
|
| 48 |
+
st.session_state.final_minimum = False
|
| 49 |
+
st.session_state.x_values = [start_point]
|
| 50 |
+
st.session_state.loss_values = [func(start_point)]
|
| 51 |
+
st.session_state.gradients = []
|
| 52 |
+
|
| 53 |
+
st.button("π Setup", on_click=setup_gradient_descent)
|
| 54 |
+
|
| 55 |
+
func, gradient, gradient_str = parse_function(user_function)
|
| 56 |
+
|
| 57 |
+
if func and gradient:
|
| 58 |
+
if 'iteration' not in st.session_state:
|
| 59 |
+
st.session_state.iteration = 0
|
| 60 |
+
st.session_state.final_minimum = False
|
| 61 |
+
st.session_state.x_values = [start_point]
|
| 62 |
+
st.session_state.loss_values = [func(start_point)]
|
| 63 |
+
st.session_state.gradients = []
|
| 64 |
+
|
| 65 |
+
col1, col2 = st.columns([2, 1])
|
| 66 |
+
|
| 67 |
+
with col1:
|
| 68 |
+
st.subheader(f"π Gradient Descent Visualization for: $f(x) = {user_function}$")
|
| 69 |
+
st.markdown(f"**Gradient (f'(x)):** $f'(x) = {gradient_str}$")
|
| 70 |
+
|
| 71 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 72 |
+
x_range = np.linspace(-10, 10, 500)
|
| 73 |
+
y_range = func(x_range)
|
| 74 |
+
|
| 75 |
+
ax.plot(x_range, y_range, label="Function: f(x)", color='#0077b6', linewidth=2)
|
| 76 |
+
|
| 77 |
+
for i in range(st.session_state.iteration + 1):
|
| 78 |
+
x = st.session_state.x_values[i]
|
| 79 |
+
y = func(x)
|
| 80 |
+
ax.scatter(x, y, color="#d90429", s=60, zorder=5, label="Descent Point" if i == 0 else "")
|
| 81 |
+
|
| 82 |
+
if st.session_state.iteration < len(st.session_state.x_values):
|
| 83 |
+
x = st.session_state.x_values[st.session_state.iteration]
|
| 84 |
+
y = func(x)
|
| 85 |
+
grad = gradient(x)
|
| 86 |
+
tangent = grad * (x_range - x) + y
|
| 87 |
+
ax.plot(x_range, tangent, color="#ffb703", linestyle="--", linewidth=1.5, alpha=0.8, label="Tangent Line")
|
| 88 |
+
|
| 89 |
+
ax.set_xlim([-10, 10])
|
| 90 |
+
ax.set_ylim([min(func(x_range)) - 1, max(func(x_range)) + 1])
|
| 91 |
+
ax.set_xlabel("x", fontsize=12, labelpad=10)
|
| 92 |
+
ax.set_ylabel("f(x)", fontsize=12, labelpad=10)
|
| 93 |
+
ax.set_title("Gradient Descent with Tangent Lines", fontsize=14, fontweight="bold", color="#4a4e69")
|
| 94 |
+
ax.legend(loc="upper right", frameon=True, fontsize=10)
|
| 95 |
+
ax.grid(alpha=0.3)
|
| 96 |
+
st.pyplot(fig)
|
| 97 |
+
|
| 98 |
+
with col2:
|
| 99 |
+
st.subheader("π Progress")
|
| 100 |
+
for i, (x, loss) in enumerate(zip(st.session_state.x_values, st.session_state.loss_values)):
|
| 101 |
+
st.write(f"Iteration {i}: x = {x:.4f}, f(x) = {loss:.4f}")
|
| 102 |
+
|
| 103 |
+
current_gradient = gradient(st.session_state.x_values[-1])
|
| 104 |
+
if abs(current_gradient) < 1e-6 and not st.session_state.final_minimum:
|
| 105 |
+
st.success(f"π― Final Minimum Reached: $x = {st.session_state.x_values[-1]:.4f}, f(x) = {st.session_state.loss_values[-1]:.4f}$")
|
| 106 |
+
st.session_state.final_minimum = True
|
| 107 |
+
|
| 108 |
+
def next_iteration():
|
| 109 |
+
if not st.session_state.final_minimum:
|
| 110 |
+
new_x, new_loss, grad = gradient_descent_step(st.session_state.x_values[-1], learning_rate, func, gradient)
|
| 111 |
+
st.session_state.x_values.append(new_x)
|
| 112 |
+
st.session_state.loss_values.append(new_loss)
|
| 113 |
+
st.session_state.gradients.append(grad)
|
| 114 |
+
st.session_state.iteration += 1
|
| 115 |
+
|
| 116 |
+
st.sidebar.button("β Next Iteration", on_click=next_iteration)
|
| 117 |
+
|
| 118 |
+
if st.session_state.final_minimum:
|
| 119 |
+
st.info("The gradient is close to zero. Further iterations may not improve the result.")
|