trohith89 commited on
Commit
b65b429
·
verified ·
1 Parent(s): f517a8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -1
app.py CHANGED
@@ -165,4 +165,116 @@ with col1:
165
  value=0.25,
166
  step=0.01,
167
  format="%.2f",
168
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  value=0.25,
166
  step=0.01,
167
  format="%.2f",
168
+ key="learning_rate",
169
+ on_change=reset_state
170
+ )
171
+
172
+ col3, col4 = st.columns(2)
173
+ with col3:
174
+ if st.button("🔄 Set Up Function"):
175
+ reset_state()
176
+ with col4:
177
+ if st.button("▶️ Next Iteration"):
178
+ try:
179
+ grad = derivative(st.session_state.func_input, st.session_state.x)
180
+ st.session_state.x = st.session_state.x - learning_rate * grad
181
+ st.session_state.iteration += 1
182
+ st.session_state.x_vals.append(st.session_state.x)
183
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
184
+ except Exception as e:
185
+ st.error(f"⚠️ Error: {str(e)}")
186
+
187
+ # Right Section: Visualization
188
+ with col2:
189
+ st.subheader("📊 Gradient Descent Visualization")
190
+ try:
191
+ # Plot the function and all current and previous gradient descent points
192
+ x_plot = np.linspace(-10, 10, 400)
193
+ y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
194
+
195
+ fig = go.Figure()
196
+
197
+ # Function curve
198
+ fig.add_trace(go.Scatter(
199
+ x=x_plot,
200
+ y=y_plot,
201
+ mode="lines+markers",
202
+ line=dict(color="blue", width=2),
203
+ marker=dict(size=4, color="blue", symbol="circle"),
204
+ name="Function"
205
+ ))
206
+
207
+ # All gradient descent points (red points without coordinates)
208
+ fig.add_trace(go.Scatter(
209
+ x=st.session_state.x_vals,
210
+ y=st.session_state.y_vals,
211
+ mode="markers",
212
+ marker=dict(color="red", size=10),
213
+ name="Gradient Descent Points"
214
+ ))
215
+
216
+ # Tangent line at the current gradient descent point
217
+ current_x = st.session_state.x
218
+ tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range for tangent line
219
+ tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
+ fig.add_trace(go.Scatter(
221
+ x=tangent_x,
222
+ y=tangent_y,
223
+ mode="lines",
224
+ line=dict(color="orange", width=3),
225
+ name="Tangent Line"
226
+ ))
227
+
228
+ # Dynamic zoom for better visibility
229
+ fig.update_layout(
230
+ xaxis=dict(
231
+ title="x-axis",
232
+ range=[-10, 10],
233
+ showline=True,
234
+ linecolor="white",
235
+ tickcolor="white",
236
+ tickfont=dict(color="white"),
237
+ ticks="outside",
238
+ ),
239
+ yaxis=dict(
240
+ title="y-axis",
241
+ range=[min(y_plot) - 5, max(y_plot) + 5],
242
+ showline=True,
243
+ linecolor="white",
244
+ tickcolor="white",
245
+ tickfont=dict(color="white"),
246
+ ticks="outside",
247
+ ),
248
+ plot_bgcolor="black",
249
+ paper_bgcolor="black",
250
+ title="",
251
+ margin=dict(l=10, r=10, t=10, b=10),
252
+ width=800,
253
+ height=400,
254
+ showlegend=True,
255
+ legend=dict(
256
+ x=1.1,
257
+ y=0.5,
258
+ xanchor="left",
259
+ yanchor="middle",
260
+ orientation="v",
261
+ font=dict(size=12, color="white"),
262
+ bgcolor="black",
263
+ bordercolor="white",
264
+ borderwidth=2,
265
+ )
266
+ )
267
+
268
+ # Axis lines for quadrants
269
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
271
+
272
+ st.plotly_chart(fig, use_container_width=True)
273
+
274
+ except Exception as e:
275
+ st.error(f"⚠️ Error in visualization: {str(e)}")
276
+
277
+ # Iteration stats and download
278
+ col5, col6 = st.columns(2)
279
+ col5.info(f"🧑‍💻 Iteration: {st.session_state.iteration}")
280
+ col6.success(f"✅ Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")