Nikhithapotnuru commited on
Commit
af00ec9
·
verified ·
1 Parent(s): 0562900

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -2
app.py CHANGED
@@ -1,4 +1,62 @@
1
  import streamlit as st
 
 
 
2
 
3
- st.title("👋 Hey There")
4
- st.header("Let's play with gradient descent.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import sympy as sp
5
 
6
+ # --- Streamlit App ---
7
+ st.title("🔽 Gradient Descent Visualizer")
8
+
9
+ # Function input
10
+ st.subheader("Define Function")
11
+ func_str = st.text_input("Enter a function in terms of x:", "x**2 + x")
12
+
13
+ # Sympy setup
14
+ x = sp.symbols("x")
15
+ try:
16
+ func = sp.sympify(func_str)
17
+ derivative = sp.diff(func, x)
18
+ except Exception as e:
19
+ st.error(f"Invalid function: {e}")
20
+ st.stop()
21
+
22
+ # Parameters
23
+ st.subheader("Parameters")
24
+ start_point = st.number_input("Starting Point", value=5.0)
25
+ learning_rate = st.number_input("Learning Rate", value=0.25, step=0.01)
26
+ iterations = st.slider("Number of Iterations", 1, 50, 16)
27
+
28
+ # Convert sympy to numpy function
29
+ f_np = sp.lambdify(x, func, "numpy")
30
+ fprime_np = sp.lambdify(x, derivative, "numpy")
31
+
32
+ # Gradient descent iterations
33
+ points = [start_point]
34
+ for i in range(iterations):
35
+ grad = fprime_np(points[-1])
36
+ new_point = points[-1] - learning_rate * grad
37
+ points.append(new_point)
38
+
39
+ # Final point
40
+ current_point = points[-1]
41
+
42
+ # Plot function and descent path
43
+ st.subheader(f"Iteration {iterations}")
44
+ x_vals = np.linspace(-6, 6, 400)
45
+ y_vals = f_np(x_vals)
46
+
47
+ fig, ax = plt.subplots()
48
+ ax.plot(x_vals, y_vals, label=str(func))
49
+ ax.axhline(0, color="brown", linewidth=1)
50
+ ax.axvline(0, color="gray", linewidth=1)
51
+
52
+ # Plot descent points
53
+ y_points = f_np(np.array(points))
54
+ ax.plot(points, y_points, "ro-")
55
+
56
+ ax.set_xlabel("x - axis")
57
+ ax.set_ylabel("y - axis")
58
+ ax.legend()
59
+ st.pyplot(fig)
60
+
61
+ # Show current point
62
+ st.success(f"📍 Current Point: {current_point}")