Mpavan45 commited on
Commit
c9d0458
Β·
verified Β·
1 Parent(s): a35cd4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -2,6 +2,16 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
 
 
 
 
 
 
 
 
 
 
5
  # Title of the app
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
@@ -29,8 +39,8 @@ def calculate_tangent(expression, x_value, x_range):
29
  def reset_session_state():
30
  st.session_state.x_current = st.session_state.initial_point
31
  st.session_state.iter_count = 0
32
- st.session_state.x_points = [st.session_state.initial_point]
33
- st.session_state.y_points = [evaluate_function(st.session_state.math_function, st.session_state.initial_point)]
34
 
35
  # Create two-column grid layout for the left side (more space for the right graph)
36
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
@@ -59,19 +69,14 @@ with left_col:
59
  # Right side content (Gradient Descent Updates and Progress)
60
  with right_col:
61
  st.header("Gradient Descent Updates")
62
- if "x_current" not in st.session_state:
63
- st.session_state.x_current = initial_point
64
- st.session_state.iter_count = 0
65
- st.session_state.x_points = [initial_point]
66
- st.session_state.y_points = [evaluate_function(function_input, initial_point)]
67
-
68
  if st.button("πŸ”„ Run Descent Step", type="primary"):
69
  try:
70
  gradient = compute_derivative(function_input, st.session_state.x_current)
71
  st.session_state.x_current -= learning_rate * gradient
72
  st.session_state.iter_count += 1
73
- st.session_state.x_points.append(st.session_state.x_current)
74
- st.session_state.y_points.append(evaluate_function(function_input, st.session_state.x_current))
 
75
  except Exception as e:
76
  st.error(f"Error: {str(e)}")
77
 
@@ -105,10 +110,11 @@ with right_col:
105
  )
106
 
107
  # Add gradient descent points
 
108
  plot.add_trace(
109
  go.Scatter(
110
- x=st.session_state.x_points,
111
- y=st.session_state.y_points,
112
  mode="markers",
113
  marker=dict(color="red", size=10, symbol="diamond"),
114
  name="Descent Steps",
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Initialize session state variables if not already present
6
+ if "x_current" not in st.session_state:
7
+ st.session_state.x_current = 0.0 # Default starting point
8
+ if "iter_count" not in st.session_state:
9
+ st.session_state.iter_count = 0
10
+ if "history" not in st.session_state:
11
+ st.session_state.history = [] # Store (x, f(x)) for each iteration
12
+ if "current_index" not in st.session_state:
13
+ st.session_state.current_index = 0
14
+
15
  # Title of the app
16
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
17
  st.title("🌟 Gradient Descent Visualizer")
 
39
  def reset_session_state():
40
  st.session_state.x_current = st.session_state.initial_point
41
  st.session_state.iter_count = 0
42
+ st.session_state.history = [(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))]
43
+ st.session_state.current_index = 0
44
 
45
  # Create two-column grid layout for the left side (more space for the right graph)
46
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
 
69
  # Right side content (Gradient Descent Updates and Progress)
70
  with right_col:
71
  st.header("Gradient Descent Updates")
 
 
 
 
 
 
72
  if st.button("πŸ”„ Run Descent Step", type="primary"):
73
  try:
74
  gradient = compute_derivative(function_input, st.session_state.x_current)
75
  st.session_state.x_current -= learning_rate * gradient
76
  st.session_state.iter_count += 1
77
+ current_y = evaluate_function(function_input, st.session_state.x_current)
78
+ st.session_state.history.append((st.session_state.x_current, current_y))
79
+ st.session_state.current_index = st.session_state.iter_count
80
  except Exception as e:
81
  st.error(f"Error: {str(e)}")
82
 
 
110
  )
111
 
112
  # Add gradient descent points
113
+ x_points, y_points = zip(*st.session_state.history)
114
  plot.add_trace(
115
  go.Scatter(
116
+ x=x_points,
117
+ y=y_points,
118
  mode="markers",
119
  marker=dict(color="red", size=10, symbol="diamond"),
120
  name="Descent Steps",