Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -124,6 +124,10 @@ if st.button("Set Up"):
|
|
| 124 |
grad = sp.diff(expr_final, x)
|
| 125 |
gradient_func = sp.lambdify(x, grad, "numpy")
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
# Initialize session state for points and steps
|
| 128 |
if 'points' not in st.session_state:
|
| 129 |
st.session_state.points = [start_point]
|
|
@@ -136,12 +140,12 @@ if st.button("Set Up"):
|
|
| 136 |
st.error(f"Error setting up function: {e}")
|
| 137 |
|
| 138 |
# Gradient Descent Iteration button
|
| 139 |
-
if '
|
| 140 |
if st.button("Next Iteration"):
|
| 141 |
try:
|
| 142 |
# Get the current point and gradient value
|
| 143 |
x_old = float(st.session_state.points[-1])
|
| 144 |
-
grad_val = gradient_func(x_old)
|
| 145 |
x_new = x_old - learning_rate * grad_val
|
| 146 |
|
| 147 |
# Append the new point to the list of points
|
|
@@ -154,15 +158,15 @@ if 'points' in st.session_state and 'step' in st.session_state:
|
|
| 154 |
st.error(f"Error in iteration: {e}")
|
| 155 |
|
| 156 |
# Creating the plot
|
| 157 |
-
if '
|
| 158 |
try:
|
| 159 |
# Create x-values for plotting the function
|
| 160 |
x_val = np.linspace(-6, 6, 500)
|
| 161 |
-
y_val = func(x_val)
|
| 162 |
|
| 163 |
# Get the points visited by gradient descent
|
| 164 |
iter_points = np.array(st.session_state.points)
|
| 165 |
-
iter_y = func(iter_points)
|
| 166 |
|
| 167 |
# Plot the function and the gradient descent path
|
| 168 |
trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
|
|
@@ -186,4 +190,3 @@ if 'points' in st.session_state and len(st.session_state.points) > 0:
|
|
| 186 |
|
| 187 |
except Exception as e:
|
| 188 |
st.error(f"Plot error: {e}")
|
| 189 |
-
|
|
|
|
| 124 |
grad = sp.diff(expr_final, x)
|
| 125 |
gradient_func = sp.lambdify(x, grad, "numpy")
|
| 126 |
|
| 127 |
+
# Store the function and gradient in session state
|
| 128 |
+
st.session_state.func = func
|
| 129 |
+
st.session_state.gradient_func = gradient_func
|
| 130 |
+
|
| 131 |
# Initialize session state for points and steps
|
| 132 |
if 'points' not in st.session_state:
|
| 133 |
st.session_state.points = [start_point]
|
|
|
|
| 140 |
st.error(f"Error setting up function: {e}")
|
| 141 |
|
| 142 |
# Gradient Descent Iteration button
|
| 143 |
+
if 'func' in st.session_state and 'gradient_func' in st.session_state:
|
| 144 |
if st.button("Next Iteration"):
|
| 145 |
try:
|
| 146 |
# Get the current point and gradient value
|
| 147 |
x_old = float(st.session_state.points[-1])
|
| 148 |
+
grad_val = st.session_state.gradient_func(x_old)
|
| 149 |
x_new = x_old - learning_rate * grad_val
|
| 150 |
|
| 151 |
# Append the new point to the list of points
|
|
|
|
| 158 |
st.error(f"Error in iteration: {e}")
|
| 159 |
|
| 160 |
# Creating the plot
|
| 161 |
+
if 'func' in st.session_state and len(st.session_state.points) > 0:
|
| 162 |
try:
|
| 163 |
# Create x-values for plotting the function
|
| 164 |
x_val = np.linspace(-6, 6, 500)
|
| 165 |
+
y_val = st.session_state.func(x_val)
|
| 166 |
|
| 167 |
# Get the points visited by gradient descent
|
| 168 |
iter_points = np.array(st.session_state.points)
|
| 169 |
+
iter_y = st.session_state.func(iter_points)
|
| 170 |
|
| 171 |
# Plot the function and the gradient descent path
|
| 172 |
trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
|
|
|
|
| 190 |
|
| 191 |
except Exception as e:
|
| 192 |
st.error(f"Plot error: {e}")
|
|
|