shubham680 commited on
Commit
f3f4d32
·
verified ·
1 Parent(s): 5c16461

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -66
app.py CHANGED
@@ -1,94 +1,189 @@
1
- import streamlit as st
2
- import numpy as np
3
- import sympy as sp
4
- import plotly.graph_objs as go
5
 
6
 
7
 
8
- st.title("Gradient Descent Visualizer")
9
 
10
- x = sp.Symbol("x")
11
 
12
- func_input = st.text_input("Enter Function","x^2")
13
 
14
- start_point = float(st.text_input("Starting Point", "2"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- #setup = st.button("Set Up")
17
 
18
- # if setup:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # try:
20
- # expr = func.replace("^", "**")
21
- # expr_final = sp.sympify(expr)
22
- # func = sp.lambdify(x, expr_final, "numpy")
23
- # grad = sp.diff(expr_final,x)
24
- # gradient_func = sp.lambdify(x,grad,"numpy")
 
 
 
 
 
 
 
 
25
  # except Exception as e:
26
- # st.error(f"Error parsing function: {e}")
27
- expr = func_input.replace("^", "**")
28
- expr_final = sp.sympify(expr)
29
- func = sp.lambdify(x, expr_final, "numpy")
30
- grad = sp.diff(expr_final,x)
31
- gradient_func = sp.lambdify(x,grad,"numpy")
32
 
33
- learning_rate = float(st.text_input("Learning Rate","0.01"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
35
 
36
- # Intializing intial session state
37
- if 'points' not in st.session_state:
38
- st.session_state.points = [start_point]
39
 
40
- if 'step' not in st.session_state:
41
- st.session_state.step=0
42
 
 
 
 
 
43
 
44
- if st.button("Next Iteration"):
45
- lr = learning_rate
46
- x_old = float(st.session_state.points[-1])
47
- grad_val = gradient_func(x_old)
48
- x_new = x_old - lr * grad_val
49
- st.session_state.points.append(x_new)
50
- st.session_state.step +=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  # Creating the plot
53
- if len(st.session_state.points) > 0:
54
  try:
55
- x_val = np.linspace(-6,6,500)
 
56
  y_val = func(x_val)
 
 
57
  iter_points = np.array(st.session_state.points)
58
  iter_y = func(iter_points)
59
 
60
- trace1 = go.Scatter(x=x_val,y=y_val,mode="lines",name="Function", line=dict(color="blue"))
61
- trace2 = go.Scatter(x=iter_points,y=iter_y,mode="markers+lines",name="Gradient Descent Path",marker=dict(color="red"))
62
- trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',marker=dict(color='green', size=15),text=[f"{iter_points[-1]:.6f}"],textposition="top center",name="Current Point")
63
- layout = go.Layout(title=f"Iteration {st.session_state.step}",xaxis=dict(title="x - axis"),yaxis=dict(title="y - axis"))
64
- fig = go.Figure(data=[trace1,trace2,trace3],layout=layout)
65
- st.plotly_chart(fig,use_container_width=True)
 
 
 
 
 
 
 
 
 
 
 
 
66
  st.success(f"Current Point = {iter_points[-1]}")
67
 
68
  except Exception as e:
69
  st.error(f"Plot error: {e}")
70
 
71
- # if len(st.session_state.points) > 0:
72
- # try:
73
- # x_vals = np.linspace(-10, 10, 300)
74
- # y_vals = func(x_vals)
75
- # iter_points = np.array(st.session_state.points)
76
- # iter_y = func(iter_points)
77
-
78
- # trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='blue'))
79
- # trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red'))
80
- # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
81
- # marker=dict(color='green', size=10),
82
- # text=[f"{iter_points[-1]:.6f}"],
83
- # textposition="top center",
84
- # name="Current Point")
85
-
86
- # layout = go.Layout(title=f"Iteration {st.session_state.step}",
87
- # xaxis=dict(title="x - axis"),
88
- # yaxis=dict(title="y - axis"))
89
-
90
- # fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
91
- # st.plotly_chart(fig, use_container_width=True)
92
- # st.success(f"Current Point = {iter_points[-1]}")
93
- # except Exception as e:
94
- # st.error(f"Plot error: {e}")
 
1
+ # import streamlit as st
2
+ # import numpy as np
3
+ # import sympy as sp
4
+ # import plotly.graph_objs as go
5
 
6
 
7
 
8
+ # st.title("Gradient Descent Visualizer")
9
 
10
+ # x = sp.Symbol("x")
11
 
12
+ # func_input = st.text_input("Enter Function","x^2")
13
 
14
+ # start_point = float(st.text_input("Starting Point", "2"))
15
+
16
+ # #setup = st.button("Set Up")
17
+
18
+ # # if setup:
19
+ # # try:
20
+ # # expr = func.replace("^", "**")
21
+ # # expr_final = sp.sympify(expr)
22
+ # # func = sp.lambdify(x, expr_final, "numpy")
23
+ # # grad = sp.diff(expr_final,x)
24
+ # # gradient_func = sp.lambdify(x,grad,"numpy")
25
+ # # except Exception as e:
26
+ # # st.error(f"Error parsing function: {e}")
27
+ # expr = func_input.replace("^", "**")
28
+ # expr_final = sp.sympify(expr)
29
+ # func = sp.lambdify(x, expr_final, "numpy")
30
+ # grad = sp.diff(expr_final,x)
31
+ # gradient_func = sp.lambdify(x,grad,"numpy")
32
+
33
+ # learning_rate = float(st.text_input("Learning Rate","0.01"))
34
 
 
35
 
36
+ # # Intializing intial session state
37
+ # if 'points' not in st.session_state:
38
+ # st.session_state.points = [start_point]
39
+
40
+ # if 'step' not in st.session_state:
41
+ # st.session_state.step=0
42
+
43
+
44
+ # if st.button("Next Iteration"):
45
+ # lr = learning_rate
46
+ # x_old = float(st.session_state.points[-1])
47
+ # grad_val = gradient_func(x_old)
48
+ # x_new = x_old - lr * grad_val
49
+ # st.session_state.points.append(x_new)
50
+ # st.session_state.step +=1
51
+
52
+ # # Creating the plot
53
+ # if len(st.session_state.points) > 0:
54
  # try:
55
+ # x_val = np.linspace(-6,6,500)
56
+ # y_val = func(x_val)
57
+ # iter_points = np.array(st.session_state.points)
58
+ # iter_y = func(iter_points)
59
+
60
+ # trace1 = go.Scatter(x=x_val,y=y_val,mode="lines",name="Function", line=dict(color="blue"))
61
+ # trace2 = go.Scatter(x=iter_points,y=iter_y,mode="markers+lines",name="Gradient Descent Path",marker=dict(color="red"))
62
+ # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',marker=dict(color='green', size=15),text=[f"{iter_points[-1]:.6f}"],textposition="top center",name="Current Point")
63
+ # layout = go.Layout(title=f"Iteration {st.session_state.step}",xaxis=dict(title="x - axis"),yaxis=dict(title="y - axis"))
64
+ # fig = go.Figure(data=[trace1,trace2,trace3],layout=layout)
65
+ # st.plotly_chart(fig,use_container_width=True)
66
+ # st.success(f"Current Point = {iter_points[-1]}")
67
+
68
  # except Exception as e:
69
+ # st.error(f"Plot error: {e}")
 
 
 
 
 
70
 
71
+ # # if len(st.session_state.points) > 0:
72
+ # # try:
73
+ # # x_vals = np.linspace(-10, 10, 300)
74
+ # # y_vals = func(x_vals)
75
+ # # iter_points = np.array(st.session_state.points)
76
+ # # iter_y = func(iter_points)
77
+
78
+ # # trace1 = go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function', line=dict(color='blue'))
79
+ # # trace2 = go.Scatter(x=iter_points, y=iter_y, mode='markers+lines', name='Gradient Descent Path', marker=dict(color='red'))
80
+ # # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
81
+ # # marker=dict(color='green', size=10),
82
+ # # text=[f"{iter_points[-1]:.6f}"],
83
+ # # textposition="top center",
84
+ # # name="Current Point")
85
+
86
+ # # layout = go.Layout(title=f"Iteration {st.session_state.step}",
87
+ # # xaxis=dict(title="x - axis"),
88
+ # # yaxis=dict(title="y - axis"))
89
+
90
+ # # fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
91
+ # # st.plotly_chart(fig, use_container_width=True)
92
+ # # st.success(f"Current Point = {iter_points[-1]}")
93
+ # # except Exception as e:
94
+ # # st.error(f"Plot error: {e}")
95
 
96
+ import streamlit as st
97
+ import numpy as np
98
+ import sympy as sp
99
+ import plotly.graph_objs as go
100
 
101
+ st.title("Gradient Descent Visualizer")
 
 
102
 
103
+ # Define the symbol 'x' for the function
104
+ x = sp.Symbol("x")
105
 
106
+ # User input for function, starting point, and learning rate
107
+ func_input = st.text_input("Enter Function", "x^2")
108
+ start_point = float(st.text_input("Starting Point", "2"))
109
+ learning_rate = float(st.text_input("Learning Rate", "0.01"))
110
 
111
+ # Set up button to initialize the function and gradient
112
+ if st.button("Set Up"):
113
+ try:
114
+ # Replace ^ with ** for exponentiation
115
+ expr = func_input.replace("^", "**")
116
+
117
+ # Parse the function
118
+ expr_final = sp.sympify(expr)
119
+
120
+ # Lambdify to create callable functions
121
+ func = sp.lambdify(x, expr_final, "numpy")
122
+
123
+ # Calculate the gradient
124
+ grad = sp.diff(expr_final, x)
125
+ gradient_func = sp.lambdify(x, grad, "numpy")
126
+
127
+ # Initialize session state for points and steps
128
+ if 'points' not in st.session_state:
129
+ st.session_state.points = [start_point]
130
+ if 'step' not in st.session_state:
131
+ st.session_state.step = 0
132
+
133
+ st.success("Function and Gradient Set Up Successfully!")
134
+
135
+ except Exception as e:
136
+ st.error(f"Error setting up function: {e}")
137
+
138
+ # Gradient Descent Iteration button
139
+ if 'points' in st.session_state and 'step' in st.session_state:
140
+ if st.button("Next Iteration"):
141
+ try:
142
+ # Get the current point and gradient value
143
+ x_old = float(st.session_state.points[-1])
144
+ grad_val = gradient_func(x_old)
145
+ x_new = x_old - learning_rate * grad_val
146
+
147
+ # Append the new point to the list of points
148
+ st.session_state.points.append(x_new)
149
+ st.session_state.step += 1
150
+
151
+ st.success(f"Iteration {st.session_state.step} Complete!")
152
+
153
+ except Exception as e:
154
+ st.error(f"Error in iteration: {e}")
155
 
156
  # Creating the plot
157
+ if 'points' in st.session_state and len(st.session_state.points) > 0:
158
  try:
159
+ # Create x-values for plotting the function
160
+ x_val = np.linspace(-6, 6, 500)
161
  y_val = func(x_val)
162
+
163
+ # Get the points visited by gradient descent
164
  iter_points = np.array(st.session_state.points)
165
  iter_y = func(iter_points)
166
 
167
+ # Plot the function and the gradient descent path
168
+ trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
169
+ trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
170
+ trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
171
+ marker=dict(color='green', size=15),
172
+ text=[f"{iter_points[-1]:.6f}"], textposition="top center",
173
+ name="Current Point")
174
+
175
+ # Layout for the plot
176
+ layout = go.Layout(
177
+ title=f"Iteration {st.session_state.step}",
178
+ xaxis=dict(title="x - axis"),
179
+ yaxis=dict(title="y - axis")
180
+ )
181
+
182
+ # Create the figure and display it
183
+ fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
184
+ st.plotly_chart(fig, use_container_width=True)
185
  st.success(f"Current Point = {iter_points[-1]}")
186
 
187
  except Exception as e:
188
  st.error(f"Plot error: {e}")
189