shubham680 commited on
Commit
1d0d871
·
verified ·
1 Parent(s): f3f4d32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -124,6 +124,10 @@ if st.button("Set Up"):
124
  grad = sp.diff(expr_final, x)
125
  gradient_func = sp.lambdify(x, grad, "numpy")
126
 
 
 
 
 
127
  # Initialize session state for points and steps
128
  if 'points' not in st.session_state:
129
  st.session_state.points = [start_point]
@@ -136,12 +140,12 @@ if st.button("Set Up"):
136
  st.error(f"Error setting up function: {e}")
137
 
138
  # Gradient Descent Iteration button
139
- if 'points' in st.session_state and 'step' in st.session_state:
140
  if st.button("Next Iteration"):
141
  try:
142
  # Get the current point and gradient value
143
  x_old = float(st.session_state.points[-1])
144
- grad_val = gradient_func(x_old)
145
  x_new = x_old - learning_rate * grad_val
146
 
147
  # Append the new point to the list of points
@@ -154,15 +158,15 @@ if 'points' in st.session_state and 'step' in st.session_state:
154
  st.error(f"Error in iteration: {e}")
155
 
156
  # Creating the plot
157
- if 'points' in st.session_state and len(st.session_state.points) > 0:
158
  try:
159
  # Create x-values for plotting the function
160
  x_val = np.linspace(-6, 6, 500)
161
- y_val = func(x_val)
162
 
163
  # Get the points visited by gradient descent
164
  iter_points = np.array(st.session_state.points)
165
- iter_y = func(iter_points)
166
 
167
  # Plot the function and the gradient descent path
168
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
@@ -186,4 +190,3 @@ if 'points' in st.session_state and len(st.session_state.points) > 0:
186
 
187
  except Exception as e:
188
  st.error(f"Plot error: {e}")
189
-
 
124
  grad = sp.diff(expr_final, x)
125
  gradient_func = sp.lambdify(x, grad, "numpy")
126
 
127
+ # Store the function and gradient in session state
128
+ st.session_state.func = func
129
+ st.session_state.gradient_func = gradient_func
130
+
131
  # Initialize session state for points and steps
132
  if 'points' not in st.session_state:
133
  st.session_state.points = [start_point]
 
140
  st.error(f"Error setting up function: {e}")
141
 
142
  # Gradient Descent Iteration button
143
+ if 'func' in st.session_state and 'gradient_func' in st.session_state:
144
  if st.button("Next Iteration"):
145
  try:
146
  # Get the current point and gradient value
147
  x_old = float(st.session_state.points[-1])
148
+ grad_val = st.session_state.gradient_func(x_old)
149
  x_new = x_old - learning_rate * grad_val
150
 
151
  # Append the new point to the list of points
 
158
  st.error(f"Error in iteration: {e}")
159
 
160
  # Creating the plot
161
+ if 'func' in st.session_state and len(st.session_state.points) > 0:
162
  try:
163
  # Create x-values for plotting the function
164
  x_val = np.linspace(-6, 6, 500)
165
+ y_val = st.session_state.func(x_val)
166
 
167
  # Get the points visited by gradient descent
168
  iter_points = np.array(st.session_state.points)
169
+ iter_y = st.session_state.func(iter_points)
170
 
171
  # Plot the function and the gradient descent path
172
  trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
 
190
 
191
  except Exception as e:
192
  st.error(f"Plot error: {e}")