shubham680 commited on
Commit
6883bcd
·
verified ·
1 Parent(s): 3123a33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -31
app.py CHANGED
@@ -1,21 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import sympy as sp
4
  import plotly.graph_objs as go
5
 
6
- st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
-
8
-
9
- # Custom CSS for styling
10
  st.markdown("""
11
  <style>
12
- /* Page background */
13
  .stApp {
14
  background-color: #f9f9f9;
15
  font-family: 'Segoe UI', sans-serif;
16
  }
17
-
18
- /* Title */
19
  h1 {
20
  text-align: center;
21
  color: #2C3E50;
@@ -23,15 +152,11 @@ st.markdown("""
23
  font-weight: bold;
24
  margin-bottom: 20px;
25
  }
26
-
27
- /* Input boxes */
28
  .stTextInput > div > div > input {
29
  border: 2px solid #3498DB;
30
  border-radius: 8px;
31
  padding: 8px;
32
  }
33
-
34
- /* Buttons */
35
  div.stButton > button {
36
  background-color: #3498DB;
37
  color: white;
@@ -45,13 +170,9 @@ st.markdown("""
45
  background-color: #2980B9;
46
  transform: scale(1.05);
47
  }
48
-
49
- /* Success / Error messages */
50
  .stAlert {
51
  border-radius: 8px;
52
  }
53
-
54
- /* Chart section */
55
  .block-container {
56
  padding-top: 2rem;
57
  padding-bottom: 2rem;
@@ -59,14 +180,14 @@ st.markdown("""
59
  </style>
60
  """, unsafe_allow_html=True)
61
 
62
-
63
  st.title("Gradient Descent Visualizer")
64
 
65
  x = sp.Symbol("x")
66
- # User input function, starting point, and learning rate
67
  func_input = st.text_input("Enter Function", "x^2")
68
  start_point = float(st.text_input("Starting Point", "2"))
69
  learning_rate = float(st.text_input("Learning Rate", "0.01"))
 
70
 
71
  if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
72
  try:
@@ -85,28 +206,44 @@ if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.s
85
  except Exception as e:
86
  st.error(f"Error setting up function: {e}")
87
 
88
- # Gradient Descent Iteration button
89
  if 'func' in st.session_state and 'gradient_func' in st.session_state:
90
  if st.button("Next Iteration"):
91
  try:
92
- # Get the current point and gradient value
93
  x_old = float(st.session_state.points[-1])
94
  grad_val = st.session_state.gradient_func(x_old)
95
  x_new = x_old - learning_rate * grad_val
96
-
97
- # Append the new point to the list of points
98
- st.session_state.points.append(x_new)
99
- st.session_state.step += 1
100
-
101
- st.success(f"Iteration {st.session_state.step} Complete!")
102
 
 
 
 
 
 
 
 
103
  except Exception as e:
104
  st.error(f"Error in iteration: {e}")
105
 
106
- # Creating the plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if 'func' in st.session_state and len(st.session_state.points) > 0:
108
  try:
109
- # Create x-values for plotting the function
110
  x_val = np.linspace(-10, 10, 500)
111
  y_val = st.session_state.func(x_val)
112
 
@@ -114,7 +251,8 @@ if 'func' in st.session_state and len(st.session_state.points) > 0:
114
  iter_y = st.session_state.func(iter_points)
115
 
116
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
117
- trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
 
118
  trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
119
  marker=dict(color='green', size=15),
120
  text=[f"{iter_points[-1]:.6f}"], textposition="top center",
@@ -129,8 +267,5 @@ if 'func' in st.session_state and len(st.session_state.points) > 0:
129
  fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
130
  st.plotly_chart(fig, use_container_width=True)
131
  st.success(f"Current Point = {iter_points[-1]}")
132
-
133
  except Exception as e:
134
  st.error(f"Plot error: {e}")
135
-
136
-
 
1
+ # import streamlit as st
2
+ # import numpy as np
3
+ # import sympy as sp
4
+ # import plotly.graph_objs as go
5
+
6
+ # st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
+
8
+
9
+ # # Custom CSS for styling
10
+ # st.markdown("""
11
+ # <style>
12
+ # /* Page background */
13
+ # .stApp {
14
+ # background-color: #f9f9f9;
15
+ # font-family: 'Segoe UI', sans-serif;
16
+ # }
17
+
18
+ # /* Title */
19
+ # h1 {
20
+ # text-align: center;
21
+ # color: #2C3E50;
22
+ # font-size: 38px !important;
23
+ # font-weight: bold;
24
+ # margin-bottom: 20px;
25
+ # }
26
+
27
+ # /* Input boxes */
28
+ # .stTextInput > div > div > input {
29
+ # border: 2px solid #3498DB;
30
+ # border-radius: 8px;
31
+ # padding: 8px;
32
+ # }
33
+
34
+ # /* Buttons */
35
+ # div.stButton > button {
36
+ # background-color: #3498DB;
37
+ # color: white;
38
+ # border-radius: 10px;
39
+ # padding: 10px 24px;
40
+ # font-size: 16px;
41
+ # border: none;
42
+ # transition: 0.3s;
43
+ # }
44
+ # div.stButton > button:hover {
45
+ # background-color: #2980B9;
46
+ # transform: scale(1.05);
47
+ # }
48
+
49
+ # /* Success / Error messages */
50
+ # .stAlert {
51
+ # border-radius: 8px;
52
+ # }
53
+
54
+ # /* Chart section */
55
+ # .block-container {
56
+ # padding-top: 2rem;
57
+ # padding-bottom: 2rem;
58
+ # }
59
+ # </style>
60
+ # """, unsafe_allow_html=True)
61
+
62
+
63
+ # st.title("Gradient Descent Visualizer")
64
+
65
+ # x = sp.Symbol("x")
66
+ # # User input function, starting point, and learning rate
67
+ # func_input = st.text_input("Enter Function", "x^2")
68
+ # start_point = float(st.text_input("Starting Point", "2"))
69
+ # learning_rate = float(st.text_input("Learning Rate", "0.01"))
70
+
71
+ # if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
72
+ # try:
73
+ # expr = func_input.replace("^", "**")
74
+ # expr_final = sp.sympify(expr)
75
+ # func = sp.lambdify(x, expr_final, "numpy")
76
+ # grad = sp.diff(expr_final, x)
77
+ # gradient_func = sp.lambdify(x, grad, "numpy")
78
+
79
+ # st.session_state.func = func
80
+ # st.session_state.gradient_func = gradient_func
81
+ # st.session_state.points = [start_point]
82
+ # st.session_state.step = 0
83
+ # st.success("Function and Gradient Set Up Successfully!")
84
+
85
+ # except Exception as e:
86
+ # st.error(f"Error setting up function: {e}")
87
+
88
+ # # Gradient Descent Iteration button
89
+ # if 'func' in st.session_state and 'gradient_func' in st.session_state:
90
+ # if st.button("Next Iteration"):
91
+ # try:
92
+ # # Get the current point and gradient value
93
+ # x_old = float(st.session_state.points[-1])
94
+ # grad_val = st.session_state.gradient_func(x_old)
95
+ # x_new = x_old - learning_rate * grad_val
96
+
97
+ # # Append the new point to the list of points
98
+ # st.session_state.points.append(x_new)
99
+ # st.session_state.step += 1
100
+
101
+ # st.success(f"Iteration {st.session_state.step} Complete!")
102
+
103
+ # except Exception as e:
104
+ # st.error(f"Error in iteration: {e}")
105
+
106
+ # # Creating the plot
107
+ # if 'func' in st.session_state and len(st.session_state.points) > 0:
108
+ # try:
109
+ # # Create x-values for plotting the function
110
+ # x_val = np.linspace(-10, 10, 500)
111
+ # y_val = st.session_state.func(x_val)
112
+
113
+ # iter_points = np.array(st.session_state.points)
114
+ # iter_y = st.session_state.func(iter_points)
115
+
116
+ # trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
117
+ # trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
118
+ # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
119
+ # marker=dict(color='green', size=15),
120
+ # text=[f"{iter_points[-1]:.6f}"], textposition="top center",
121
+ # name="Current Point")
122
+
123
+ # layout = go.Layout(
124
+ # title=f"Iteration {st.session_state.step}",
125
+ # xaxis=dict(title="x - axis"),
126
+ # yaxis=dict(title="y - axis")
127
+ # )
128
+
129
+ # fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
130
+ # st.plotly_chart(fig, use_container_width=True)
131
+ # st.success(f"Current Point = {iter_points[-1]}")
132
+
133
+ # except Exception as e:
134
+ # st.error(f"Plot error: {e}")
135
+
136
+
137
  import streamlit as st
138
  import numpy as np
139
  import sympy as sp
140
  import plotly.graph_objs as go
141
 
 
 
 
 
142
  st.markdown("""
143
  <style>
 
144
  .stApp {
145
  background-color: #f9f9f9;
146
  font-family: 'Segoe UI', sans-serif;
147
  }
 
 
148
  h1 {
149
  text-align: center;
150
  color: #2C3E50;
 
152
  font-weight: bold;
153
  margin-bottom: 20px;
154
  }
 
 
155
  .stTextInput > div > div > input {
156
  border: 2px solid #3498DB;
157
  border-radius: 8px;
158
  padding: 8px;
159
  }
 
 
160
  div.stButton > button {
161
  background-color: #3498DB;
162
  color: white;
 
170
  background-color: #2980B9;
171
  transform: scale(1.05);
172
  }
 
 
173
  .stAlert {
174
  border-radius: 8px;
175
  }
 
 
176
  .block-container {
177
  padding-top: 2rem;
178
  padding-bottom: 2rem;
 
180
  </style>
181
  """, unsafe_allow_html=True)
182
 
 
183
  st.title("Gradient Descent Visualizer")
184
 
185
  x = sp.Symbol("x")
186
+
187
  func_input = st.text_input("Enter Function", "x^2")
188
  start_point = float(st.text_input("Starting Point", "2"))
189
  learning_rate = float(st.text_input("Learning Rate", "0.01"))
190
+ num_iterations = int(st.text_input("Number of Iterations", "10"))
191
 
192
  if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
193
  try:
 
206
  except Exception as e:
207
  st.error(f"Error setting up function: {e}")
208
 
 
209
  if 'func' in st.session_state and 'gradient_func' in st.session_state:
210
  if st.button("Next Iteration"):
211
  try:
 
212
  x_old = float(st.session_state.points[-1])
213
  grad_val = st.session_state.gradient_func(x_old)
214
  x_new = x_old - learning_rate * grad_val
 
 
 
 
 
 
215
 
216
+ tolerance = 0.001
217
+ if abs(x_new - x_old) < tolerance or abs(grad_val) < tolerance:
218
+ st.success(f"Reached Minima at x = {x_new:.6f}")
219
+ else:
220
+ st.session_state.points.append(x_new)
221
+ st.session_state.step += 1
222
+ st.success(f"Iteration {st.session_state.step} Complete!")
223
  except Exception as e:
224
  st.error(f"Error in iteration: {e}")
225
 
226
+ if st.button("Run Iterations"):
227
+ try:
228
+ tolerance = 0.001
229
+ for i in range(num_iterations):
230
+ x_old = float(st.session_state.points[-1])
231
+ grad_val = st.session_state.gradient_func(x_old)
232
+ x_new = x_old - learning_rate * grad_val
233
+
234
+ if abs(x_new - x_old) < tolerance or abs(grad_val) < tolerance:
235
+ st.success(f"Reached Minima early at x = {x_new:.6f} (after {i+1} steps)")
236
+ break
237
+
238
+ st.session_state.points.append(x_new)
239
+ st.session_state.step += 1
240
+
241
+ st.success(f"Ran {st.session_state.step} Iterations in total")
242
+ except Exception as e:
243
+ st.error(f"Error in multiple iterations: {e}")
244
+
245
  if 'func' in st.session_state and len(st.session_state.points) > 0:
246
  try:
 
247
  x_val = np.linspace(-10, 10, 500)
248
  y_val = st.session_state.func(x_val)
249
 
 
251
  iter_y = st.session_state.func(iter_points)
252
 
253
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
254
+ trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines",
255
+ name="Gradient Descent Path", marker=dict(color="red"))
256
  trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
257
  marker=dict(color='green', size=15),
258
  text=[f"{iter_points[-1]:.6f}"], textposition="top center",
 
267
  fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
268
  st.plotly_chart(fig, use_container_width=True)
269
  st.success(f"Current Point = {iter_points[-1]}")
 
270
  except Exception as e:
271
  st.error(f"Plot error: {e}")