shubham680 commited on
Commit
bba7bc6
·
verified ·
1 Parent(s): 63cea96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -186
app.py CHANGED
@@ -1,162 +1,34 @@
1
- # import streamlit as st
2
- # import numpy as np
3
- # import sympy as sp
4
- # import plotly.graph_objs as go
5
-
6
- # # Custom CSS for styling
7
- # st.markdown("""
8
- # <style>
9
- # /* Page background */
10
- # .stApp {
11
- # background-color: #f9f9f9;
12
- # font-family: 'Segoe UI', sans-serif;
13
- # }
14
-
15
- # /* Title */
16
- # h1 {
17
- # text-align: center;
18
- # color: #2C3E50;
19
- # font-size: 38px !important;
20
- # font-weight: bold;
21
- # margin-bottom: 20px;
22
- # }
23
-
24
- # /* Input boxes */
25
- # .stTextInput > div > div > input {
26
- # border: 2px solid #3498DB;
27
- # border-radius: 8px;
28
- # padding: 8px;
29
- # }
30
-
31
- # /* Buttons */
32
- # div.stButton > button {
33
- # background-color: #3498DB;
34
- # color: white;
35
- # border-radius: 10px;
36
- # padding: 10px 24px;
37
- # font-size: 16px;
38
- # border: none;
39
- # transition: 0.3s;
40
- # }
41
- # div.stButton > button:hover {
42
- # background-color: #2980B9;
43
- # transform: scale(1.05);
44
- # }
45
-
46
- # /* Success / Error messages */
47
- # .stAlert {
48
- # border-radius: 8px;
49
- # }
50
-
51
- # /* Chart section */
52
- # .block-container {
53
- # padding-top: 2rem;
54
- # padding-bottom: 2rem;
55
- # }
56
- # </style>
57
- # """, unsafe_allow_html=True)
58
-
59
-
60
- # st.title("Gradient Descent Visualizer")
61
-
62
- # x = sp.Symbol("x")
63
- # # User input function, starting point, and learning rate
64
- # func_input = st.text_input("Enter Function", "x^2")
65
- # start_point = float(st.text_input("Starting Point", "2"))
66
- # learning_rate = float(st.text_input("Learning Rate", "0.01"))
67
-
68
- # if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
69
- # try:
70
- # expr = func_input.replace("^", "**")
71
- # expr_final = sp.sympify(expr)
72
- # func = sp.lambdify(x, expr_final, "numpy")
73
- # grad = sp.diff(expr_final, x)
74
- # gradient_func = sp.lambdify(x, grad, "numpy")
75
-
76
- # st.session_state.func = func
77
- # st.session_state.gradient_func = gradient_func
78
- # st.session_state.points = [start_point]
79
- # st.session_state.step = 0
80
- # st.success("Function and Gradient Set Up Successfully!")
81
-
82
- # except Exception as e:
83
- # st.error(f"Error setting up function: {e}")
84
-
85
- # # Gradient Descent Iteration button
86
- # if 'func' in st.session_state and 'gradient_func' in st.session_state:
87
- # if st.button("Next Iteration"):
88
- # try:
89
- # # Get the current point and gradient value
90
- # x_old = float(st.session_state.points[-1])
91
- # grad_val = st.session_state.gradient_func(x_old)
92
- # x_new = x_old - learning_rate * grad_val
93
-
94
- # # Append the new point to the list of points
95
- # st.session_state.points.append(x_new)
96
- # st.session_state.step += 1
97
-
98
- # st.success(f"Iteration {st.session_state.step} Complete!")
99
-
100
- # except Exception as e:
101
- # st.error(f"Error in iteration: {e}")
102
-
103
- # # Creating the plot
104
- # if 'func' in st.session_state and len(st.session_state.points) > 0:
105
- # try:
106
- # # Create x-values for plotting the function
107
- # x_val = np.linspace(-10, 10, 500)
108
- # y_val = st.session_state.func(x_val)
109
-
110
- # iter_points = np.array(st.session_state.points)
111
- # iter_y = st.session_state.func(iter_points)
112
-
113
- # trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
114
- # trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
115
- # trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
116
- # marker=dict(color='green', size=15),
117
- # text=[f"{iter_points[-1]:.6f}"], textposition="top center",
118
- # name="Current Point")
119
-
120
- # layout = go.Layout(
121
- # title=f"Iteration {st.session_state.step}",
122
- # xaxis=dict(title="x - axis"),
123
- # yaxis=dict(title="y - axis")
124
- # )
125
-
126
- # fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
127
- # st.plotly_chart(fig, use_container_width=True)
128
- # st.success(f"Current Point = {iter_points[-1]}")
129
-
130
- # except Exception as e:
131
- # st.error(f"Plot error: {e}")
132
-
133
-
134
  import streamlit as st
135
  import numpy as np
136
  import sympy as sp
137
  import plotly.graph_objs as go
138
- import time
139
 
140
- # -------------------------------
141
- # Custom CSS
142
- # -------------------------------
143
  st.markdown("""
144
  <style>
 
145
  .stApp {
146
  background-color: #f9f9f9;
147
  font-family: 'Segoe UI', sans-serif;
148
  }
 
 
149
  h1 {
150
  text-align: center;
151
  color: #2C3E50;
152
  font-size: 38px !important;
153
  font-weight: bold;
 
154
  }
 
 
155
  .stTextInput > div > div > input {
156
  border: 2px solid #3498DB;
157
  border-radius: 8px;
158
  padding: 8px;
159
  }
 
 
160
  div.stButton > button {
161
  background-color: #3498DB;
162
  color: white;
@@ -170,32 +42,30 @@ st.markdown("""
170
  background-color: #2980B9;
171
  transform: scale(1.05);
172
  }
 
 
173
  .stAlert {
174
  border-radius: 8px;
175
  }
 
 
 
 
 
 
176
  </style>
177
  """, unsafe_allow_html=True)
178
 
179
- # -------------------------------
180
- # Original Logic (unchanged)
181
- # -------------------------------
182
- st.title("📉 Gradient Descent Visualizer")
183
 
184
- x = sp.Symbol("x")
185
-
186
- # Sidebar Inputs
187
- with st.sidebar:
188
- st.header("⚙️ Controls")
189
- func_input = st.text_input("Enter Function", "x^2")
190
- start_point = float(st.text_input("Starting Point", "2"))
191
- learning_rate = float(st.text_input("Learning Rate", "0.01"))
192
 
193
- col1, col2 = st.columns(2)
194
- setup = col1.button("Set Up")
195
- iterate = col2.button("Next Iteration")
 
 
196
 
197
- # Function Setup
198
- if setup or 'func' not in st.session_state or 'points' not in st.session_state:
199
  try:
200
  expr = func_input.replace("^", "**")
201
  expr_final = sp.sympify(expr)
@@ -207,38 +77,36 @@ if setup or 'func' not in st.session_state or 'points' not in st.session_state:
207
  st.session_state.gradient_func = gradient_func
208
  st.session_state.points = [start_point]
209
  st.session_state.step = 0
210
- st.sidebar.success("Function and Gradient Set Up!")
211
-
212
  except Exception as e:
213
- st.sidebar.error(f"Error: {e}")
214
 
215
- # Iteration
216
  if 'func' in st.session_state and 'gradient_func' in st.session_state:
217
- if iterate:
218
  try:
 
219
  x_old = float(st.session_state.points[-1])
220
  grad_val = st.session_state.gradient_func(x_old)
221
  x_new = x_old - learning_rate * grad_val
 
 
222
  st.session_state.points.append(x_new)
223
  st.session_state.step += 1
224
 
225
- # Progress bar for interactivity
226
- progress = st.progress(0)
227
- for i in range(100):
228
- time.sleep(0.01)
229
- progress.progress(i+1)
230
-
231
- st.success(f"✨ Iteration {st.session_state.step} Complete!")
232
 
233
  except Exception as e:
234
- st.sidebar.error(f"Error: {e}")
235
 
236
- # Visualization
237
  if 'func' in st.session_state and len(st.session_state.points) > 0:
238
  try:
 
239
  x_val = np.linspace(-10, 10, 500)
240
  y_val = st.session_state.func(x_val)
241
-
242
  iter_points = np.array(st.session_state.points)
243
  iter_y = st.session_state.func(iter_points)
244
 
@@ -248,25 +116,18 @@ if 'func' in st.session_state and len(st.session_state.points) > 0:
248
  marker=dict(color='green', size=15),
249
  text=[f"{iter_points[-1]:.6f}"], textposition="top center",
250
  name="Current Point")
251
-
252
- layout = go.Layout(title=f"Iteration {st.session_state.step}",
253
- xaxis=dict(title="x-axis"),
254
- yaxis=dict(title="y-axis"))
255
-
 
 
256
  fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
257
-
258
- st.markdown("### 📊 Gradient Descent Path")
259
  st.plotly_chart(fig, use_container_width=True)
260
-
261
- # Live Metrics
262
- col1, col2 = st.columns(2)
263
- col1.metric("🔢 Iteration", st.session_state.step)
264
- col2.metric("📍 Current x", f"{iter_points[-1]:.6f}")
265
-
266
- # Expander for extra info
267
- with st.expander("📘 Function & Gradient Details"):
268
- st.write("Function:", expr_final)
269
- st.write("Gradient:", grad)
270
 
271
  except Exception as e:
272
  st.error(f"Plot error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import numpy as np
3
  import sympy as sp
4
  import plotly.graph_objs as go
 
5
 
6
+ # Custom CSS for styling
 
 
7
  st.markdown("""
8
  <style>
9
+ /* Page background */
10
  .stApp {
11
  background-color: #f9f9f9;
12
  font-family: 'Segoe UI', sans-serif;
13
  }
14
+
15
+ /* Title */
16
  h1 {
17
  text-align: center;
18
  color: #2C3E50;
19
  font-size: 38px !important;
20
  font-weight: bold;
21
+ margin-bottom: 20px;
22
  }
23
+
24
+ /* Input boxes */
25
  .stTextInput > div > div > input {
26
  border: 2px solid #3498DB;
27
  border-radius: 8px;
28
  padding: 8px;
29
  }
30
+
31
+ /* Buttons */
32
  div.stButton > button {
33
  background-color: #3498DB;
34
  color: white;
 
42
  background-color: #2980B9;
43
  transform: scale(1.05);
44
  }
45
+
46
+ /* Success / Error messages */
47
  .stAlert {
48
  border-radius: 8px;
49
  }
50
+
51
+ /* Chart section */
52
+ .block-container {
53
+ padding-top: 2rem;
54
+ padding-bottom: 2rem;
55
+ }
56
  </style>
57
  """, unsafe_allow_html=True)
58
 
 
 
 
 
59
 
60
+ st.title("Gradient Descent Visualizer")
 
 
 
 
 
 
 
61
 
62
+ x = sp.Symbol("x")
63
+ # User input function, starting point, and learning rate
64
+ func_input = st.text_input("Enter Function", "x^2")
65
+ start_point = float(st.text_input("Starting Point", "2"))
66
+ learning_rate = float(st.text_input("Learning Rate", "0.01"))
67
 
68
+ if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
 
69
  try:
70
  expr = func_input.replace("^", "**")
71
  expr_final = sp.sympify(expr)
 
77
  st.session_state.gradient_func = gradient_func
78
  st.session_state.points = [start_point]
79
  st.session_state.step = 0
80
+ st.success("Function and Gradient Set Up Successfully!")
81
+
82
  except Exception as e:
83
+ st.error(f"Error setting up function: {e}")
84
 
85
+ # Gradient Descent Iteration button
86
  if 'func' in st.session_state and 'gradient_func' in st.session_state:
87
+ if st.button("Next Iteration"):
88
  try:
89
+ # Get the current point and gradient value
90
  x_old = float(st.session_state.points[-1])
91
  grad_val = st.session_state.gradient_func(x_old)
92
  x_new = x_old - learning_rate * grad_val
93
+
94
+ # Append the new point to the list of points
95
  st.session_state.points.append(x_new)
96
  st.session_state.step += 1
97
 
98
+ st.success(f"Iteration {st.session_state.step} Complete!")
 
 
 
 
 
 
99
 
100
  except Exception as e:
101
+ st.error(f"Error in iteration: {e}")
102
 
103
+ # Creating the plot
104
  if 'func' in st.session_state and len(st.session_state.points) > 0:
105
  try:
106
+ # Create x-values for plotting the function
107
  x_val = np.linspace(-10, 10, 500)
108
  y_val = st.session_state.func(x_val)
109
+
110
  iter_points = np.array(st.session_state.points)
111
  iter_y = st.session_state.func(iter_points)
112
 
 
116
  marker=dict(color='green', size=15),
117
  text=[f"{iter_points[-1]:.6f}"], textposition="top center",
118
  name="Current Point")
119
+
120
+ layout = go.Layout(
121
+ title=f"Iteration {st.session_state.step}",
122
+ xaxis=dict(title="x - axis"),
123
+ yaxis=dict(title="y - axis")
124
+ )
125
+
126
  fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
 
 
127
  st.plotly_chart(fig, use_container_width=True)
128
+ st.success(f"Current Point = {iter_points[-1]}")
 
 
 
 
 
 
 
 
 
129
 
130
  except Exception as e:
131
  st.error(f"Plot error: {e}")
132
+
133
+