shubham680 commited on
Commit
f9eefe7
·
verified ·
1 Parent(s): 8b8fb42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -112
app.py CHANGED
@@ -1,98 +1,3 @@
1
- # import streamlit as st
2
- # import numpy as np
3
- # import sympy as sp
4
- # import plotly.graph_objs as go
5
-
6
-
7
-
8
- # st.title("Gradient Descent Visualizer")
9
-
10
- # x = sp.Symbol("x")
11
-
12
- # func_input = st.text_input("Enter Function","x^2")
13
-
14
- # start_point = float(st.text_input("Starting Point", "2"))
15
-
16
- # #setup = st.button("Set Up")
17
-
18
- # # if setup:
19
- # # try:
20
- # # expr = func.replace("^", "**")
21
- # # expr_final = sp.sympify(expr)
22
- # # func = sp.lambdify(x, expr_final, "numpy")
23
- # # grad = sp.diff(expr_final,x)
24
- # # gradient_func = sp.lambdify(x,grad,"numpy")
25
- # # except Exception as e:
26
- # # st.error(f"Error parsing function: {e}")
27
- # expr = func_input.replace("^", "**")
28
- # expr_final = sp.sympify(expr)
29
- # func = sp.lambdify(x, expr_final, "numpy")
30
- # grad = sp.diff(expr_final,x)
31
- # gradient_func = sp.lambdify(x,grad,"numpy")
32
-
33
- # learning_rate = float(st.text_input("Learning Rate","0.01"))
34
-
35
-
36
- # # Intializing intial session state
37
- # if 'points' not in st.session_state:
38
- # st.session_state.points = [start_point]
39
-
40
- # if 'step' not in st.session_state:
41
- # st.session_state.step=0
42
-
43
-
44
- # if st.button("Next Iteration"):
45
- # lr = learning_rate
46
- # x_old = float(st.session_state.points[-1])
47
- # grad_val = gradient_func(x_old)
48
- # x_new = x_old - lr * grad_val
49
- # st.session_state.points.append(x_new)
50
- # st.session_state.step +=1
51
-
52
- # # Creating the plot
53
- # if len(st.session_state.points) > 0:
54
- # try:
55
- # x_val = np.linspace(-6,6,500)
56
- # y_val = func(x_val)
57
- # iter_points = np.array(st.session_state.points)
58
- # iter_y = func(iter_points)
59
-
60
- # trace1 = go.Scatter(x=x_val,y=y_val,mode="lines",name="Function", line=dict(color="blue"))
61
- # trace2 = go.Scatter(x=iter_points,y=iter_y,mode="markers+lines",name="Gradient Descent Path",marker=dict(color="red"))
62
- # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',marker=dict(color='green', size=15),text=[f"{iter_points[-1]:.6f}"],textposition="top center",name="Current Point")
63
- # layout = go.Layout(title=f"Iteration {st.session_state.step}",xaxis=dict(title="x - axis"),yaxis=dict(title="y - axis"))
64
- # fig = go.Figure(data=[trace1,trace2,trace3],layout=layout)
65
- # st.plotly_chart(fig,use_container_width=True)
66
- # st.success(f"Current Point = {iter_points[-1]}")
67
-
68
- # except Exception as e:
69
- # st.error(f"Plot error: {e}")
70
-
71
- # # if len(st.session_state.points) > 0:
72
- # # try:
73
- # # x_vals = np.linspace(-10, 10, 300)
74
- # # y_vals = func(x_vals)
75
- # # iter_points = np.array(st.session_state.points)
76
- # # iter_y = func(iter_points)
77
-
78
- # # trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='blue'))
79
- # # trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red'))
80
- # # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
81
- # # marker=dict(color='green', size=10),
82
- # # text=[f"{iter_points[-1]:.6f}"],
83
- # # textposition="top center",
84
- # # name="Current Point")
85
-
86
- # # layout = go.Layout(title=f"Iteration {st.session_state.step}",
87
- # # xaxis=dict(title="x - axis"),
88
- # # yaxis=dict(title="y - axis"))
89
-
90
- # # fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
91
- # # st.plotly_chart(fig, use_container_width=True)
92
- # # st.success(f"Current Point = {iter_points[-1]}")
93
- # # except Exception as e:
94
- # # st.error(f"Plot error: {e}")
95
-
96
  import streamlit as st
97
  import numpy as np
98
  import sympy as sp
@@ -100,38 +5,26 @@ import plotly.graph_objs as go
100
 
101
  st.title("Gradient Descent Visualizer")
102
 
103
- # Define the symbol 'x' for the function
104
  x = sp.Symbol("x")
105
 
106
- # User input for function, starting point, and learning rate
107
  func_input = st.text_input("Enter Function", "x^2")
108
  start_point = float(st.text_input("Starting Point", "2"))
109
  learning_rate = float(st.text_input("Learning Rate", "0.01"))
110
 
111
- # Set up button to initialize the function and gradient
112
  if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
113
  try:
114
- # Replace ^ with ** for exponentiation
115
  expr = func_input.replace("^", "**")
116
-
117
- # Parse the function
118
  expr_final = sp.sympify(expr)
119
-
120
- # Lambdify to create callable functions
121
  func = sp.lambdify(x, expr_final, "numpy")
122
-
123
- # Calculate the gradient
124
  grad = sp.diff(expr_final, x)
125
  gradient_func = sp.lambdify(x, grad, "numpy")
126
 
127
- # Store the function and gradient in session state
128
  st.session_state.func = func
129
  st.session_state.gradient_func = gradient_func
130
-
131
- # Initialize session state for points and steps (resetting points on setup)
132
  st.session_state.points = [start_point]
133
  st.session_state.step = 0
134
-
135
  st.success("Function and Gradient Set Up Successfully!")
136
 
137
  except Exception as e:
@@ -159,14 +52,12 @@ if 'func' in st.session_state and 'gradient_func' in st.session_state:
159
  if 'func' in st.session_state and len(st.session_state.points) > 0:
160
  try:
161
  # Create x-values for plotting the function
162
- x_val = np.linspace(-6, 6, 500)
163
  y_val = st.session_state.func(x_val)
164
 
165
- # Get the points visited by gradient descent
166
  iter_points = np.array(st.session_state.points)
167
  iter_y = st.session_state.func(iter_points)
168
 
169
- # Plot the function and the gradient descent path
170
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
171
  trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
172
  trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import sympy as sp
 
5
 
6
  st.title("Gradient Descent Visualizer")
7
 
8
+
9
  x = sp.Symbol("x")
10
 
11
+ # User input function, starting point, and learning rate
12
  func_input = st.text_input("Enter Function", "x^2")
13
  start_point = float(st.text_input("Starting Point", "2"))
14
  learning_rate = float(st.text_input("Learning Rate", "0.01"))
15
 
 
16
  if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
17
  try:
 
18
  expr = func_input.replace("^", "**")
 
 
19
  expr_final = sp.sympify(expr)
 
 
20
  func = sp.lambdify(x, expr_final, "numpy")
 
 
21
  grad = sp.diff(expr_final, x)
22
  gradient_func = sp.lambdify(x, grad, "numpy")
23
 
 
24
  st.session_state.func = func
25
  st.session_state.gradient_func = gradient_func
 
 
26
  st.session_state.points = [start_point]
27
  st.session_state.step = 0
 
28
  st.success("Function and Gradient Set Up Successfully!")
29
 
30
  except Exception as e:
 
52
  if 'func' in st.session_state and len(st.session_state.points) > 0:
53
  try:
54
  # Create x-values for plotting the function
55
+ x_val = np.linspace(-10, 10, 500)
56
  y_val = st.session_state.func(x_val)
57
 
 
58
  iter_points = np.array(st.session_state.points)
59
  iter_y = st.session_state.func(iter_points)
60
 
 
61
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
62
  trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
63
  trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',