Adityaganesh commited on
Commit
06bfef5
Β·
verified Β·
1 Parent(s): fe51709

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import streamlit as st
4
+ import sympy as sp
5
+
6
+ st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide", page_icon="πŸ“‰")
7
+ st.markdown(
8
+ """
9
+ <style>
10
+ .css-18e3th9 { padding-top: 1rem; padding-bottom: 1rem; }
11
+ .stButton>button { border-radius: 8px; font-weight: bold; }
12
+ .stNumberInput>div>div>input { border-radius: 8px; }
13
+ .stTextInput>div>div>input { border-radius: 8px; }
14
+ .st-sidebar { background-color: #f9f9f9; border-right: 2px solid #eee; }
15
+ h1, h2, h3 { color: #4a4e69; font-family: 'Arial'; }
16
+ </style>
17
+ """,
18
+ unsafe_allow_html=True,
19
+ )
20
+
21
+ def parse_function(user_input):
22
+ x = sp.symbols('x')
23
+ try:
24
+ func_expr = sp.sympify(user_input)
25
+ func = sp.lambdify(x, func_expr, 'numpy')
26
+ gradient_expr = sp.diff(func_expr, x)
27
+ gradient = sp.lambdify(x, gradient_expr, 'numpy')
28
+ return func, gradient, str(gradient_expr)
29
+ except Exception as e:
30
+ st.error(f"Invalid function input: {e}")
31
+ return None, None, None
32
+
33
+ def gradient_descent_step(x, learning_rate, func, gradient):
34
+ grad = gradient(x)
35
+ x_new = x - learning_rate * grad
36
+ loss_new = func(x_new)
37
+ return x_new, loss_new, grad
38
+
39
+ st.title("πŸ“‰ Interactive Gradient Descent Visualizer")
40
+ with st.sidebar:
41
+ st.header("πŸ”§ Parameters")
42
+ user_function = st.text_input("Enter a Function (e.g., 2*x**2 + 3*x)", value="x**2")
43
+ start_point = st.number_input("Starting Point", value=5.0, step=0.00001, format="%.5f")
44
+ learning_rate = st.number_input("Learning Rate", value=0.1, step=0.00001, format="%.5f")
45
+
46
+ def setup_gradient_descent():
47
+ st.session_state.iteration = 0
48
+ st.session_state.final_minimum = False
49
+ st.session_state.x_values = [start_point]
50
+ st.session_state.loss_values = [func(start_point)]
51
+ st.session_state.gradients = []
52
+
53
+ st.button("πŸš€ Setup", on_click=setup_gradient_descent)
54
+
55
+ func, gradient, gradient_str = parse_function(user_function)
56
+
57
+ if func and gradient:
58
+ if 'iteration' not in st.session_state:
59
+ st.session_state.iteration = 0
60
+ st.session_state.final_minimum = False
61
+ st.session_state.x_values = [start_point]
62
+ st.session_state.loss_values = [func(start_point)]
63
+ st.session_state.gradients = []
64
+
65
+ col1, col2 = st.columns([2, 1])
66
+
67
+ with col1:
68
+ st.subheader(f"πŸ” Gradient Descent Visualization for: $f(x) = {user_function}$")
69
+ st.markdown(f"**Gradient (f'(x)):** $f'(x) = {gradient_str}$")
70
+
71
+ fig, ax = plt.subplots(figsize=(10, 6))
72
+ x_range = np.linspace(-10, 10, 500)
73
+ y_range = func(x_range)
74
+
75
+ ax.plot(x_range, y_range, label="Function: f(x)", color='#0077b6', linewidth=2)
76
+
77
+ for i in range(st.session_state.iteration + 1):
78
+ x = st.session_state.x_values[i]
79
+ y = func(x)
80
+ ax.scatter(x, y, color="#d90429", s=60, zorder=5, label="Descent Point" if i == 0 else "")
81
+
82
+ if st.session_state.iteration < len(st.session_state.x_values):
83
+ x = st.session_state.x_values[st.session_state.iteration]
84
+ y = func(x)
85
+ grad = gradient(x)
86
+ tangent = grad * (x_range - x) + y
87
+ ax.plot(x_range, tangent, color="#ffb703", linestyle="--", linewidth=1.5, alpha=0.8, label="Tangent Line")
88
+
89
+ ax.set_xlim([-10, 10])
90
+ ax.set_ylim([min(func(x_range)) - 1, max(func(x_range)) + 1])
91
+ ax.set_xlabel("x", fontsize=12, labelpad=10)
92
+ ax.set_ylabel("f(x)", fontsize=12, labelpad=10)
93
+ ax.set_title("Gradient Descent with Tangent Lines", fontsize=14, fontweight="bold", color="#4a4e69")
94
+ ax.legend(loc="upper right", frameon=True, fontsize=10)
95
+ ax.grid(alpha=0.3)
96
+ st.pyplot(fig)
97
+
98
+ with col2:
99
+ st.subheader("πŸ“Š Progress")
100
+ for i, (x, loss) in enumerate(zip(st.session_state.x_values, st.session_state.loss_values)):
101
+ st.write(f"Iteration {i}: x = {x:.4f}, f(x) = {loss:.4f}")
102
+
103
+ current_gradient = gradient(st.session_state.x_values[-1])
104
+ if abs(current_gradient) < 1e-6 and not st.session_state.final_minimum:
105
+ st.success(f"🎯 Final Minimum Reached: $x = {st.session_state.x_values[-1]:.4f}, f(x) = {st.session_state.loss_values[-1]:.4f}$")
106
+ st.session_state.final_minimum = True
107
+
108
+ def next_iteration():
109
+ if not st.session_state.final_minimum:
110
+ new_x, new_loss, grad = gradient_descent_step(st.session_state.x_values[-1], learning_rate, func, gradient)
111
+ st.session_state.x_values.append(new_x)
112
+ st.session_state.loss_values.append(new_loss)
113
+ st.session_state.gradients.append(grad)
114
+ st.session_state.iteration += 1
115
+
116
+ st.sidebar.button("⏭ Next Iteration", on_click=next_iteration)
117
+
118
+ if st.session_state.final_minimum:
119
+ st.info("The gradient is close to zero. Further iterations may not improve the result.")