Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,21 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import numpy as np
|
| 3 |
import sympy as sp
|
| 4 |
import plotly.graph_objs as go
|
| 5 |
|
| 6 |
-
st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# Custom CSS for styling
|
| 10 |
st.markdown("""
|
| 11 |
<style>
|
| 12 |
-
/* Page background */
|
| 13 |
.stApp {
|
| 14 |
background-color: #f9f9f9;
|
| 15 |
font-family: 'Segoe UI', sans-serif;
|
| 16 |
}
|
| 17 |
-
|
| 18 |
-
/* Title */
|
| 19 |
h1 {
|
| 20 |
text-align: center;
|
| 21 |
color: #2C3E50;
|
|
@@ -23,15 +152,11 @@ st.markdown("""
|
|
| 23 |
font-weight: bold;
|
| 24 |
margin-bottom: 20px;
|
| 25 |
}
|
| 26 |
-
|
| 27 |
-
/* Input boxes */
|
| 28 |
.stTextInput > div > div > input {
|
| 29 |
border: 2px solid #3498DB;
|
| 30 |
border-radius: 8px;
|
| 31 |
padding: 8px;
|
| 32 |
}
|
| 33 |
-
|
| 34 |
-
/* Buttons */
|
| 35 |
div.stButton > button {
|
| 36 |
background-color: #3498DB;
|
| 37 |
color: white;
|
|
@@ -45,13 +170,9 @@ st.markdown("""
|
|
| 45 |
background-color: #2980B9;
|
| 46 |
transform: scale(1.05);
|
| 47 |
}
|
| 48 |
-
|
| 49 |
-
/* Success / Error messages */
|
| 50 |
.stAlert {
|
| 51 |
border-radius: 8px;
|
| 52 |
}
|
| 53 |
-
|
| 54 |
-
/* Chart section */
|
| 55 |
.block-container {
|
| 56 |
padding-top: 2rem;
|
| 57 |
padding-bottom: 2rem;
|
|
@@ -59,14 +180,14 @@ st.markdown("""
|
|
| 59 |
</style>
|
| 60 |
""", unsafe_allow_html=True)
|
| 61 |
|
| 62 |
-
|
| 63 |
st.title("Gradient Descent Visualizer")
|
| 64 |
|
| 65 |
x = sp.Symbol("x")
|
| 66 |
-
|
| 67 |
func_input = st.text_input("Enter Function", "x^2")
|
| 68 |
start_point = float(st.text_input("Starting Point", "2"))
|
| 69 |
learning_rate = float(st.text_input("Learning Rate", "0.01"))
|
|
|
|
| 70 |
|
| 71 |
if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
|
| 72 |
try:
|
|
@@ -85,28 +206,44 @@ if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.s
|
|
| 85 |
except Exception as e:
|
| 86 |
st.error(f"Error setting up function: {e}")
|
| 87 |
|
| 88 |
-
# Gradient Descent Iteration button
|
| 89 |
if 'func' in st.session_state and 'gradient_func' in st.session_state:
|
| 90 |
if st.button("Next Iteration"):
|
| 91 |
try:
|
| 92 |
-
# Get the current point and gradient value
|
| 93 |
x_old = float(st.session_state.points[-1])
|
| 94 |
grad_val = st.session_state.gradient_func(x_old)
|
| 95 |
x_new = x_old - learning_rate * grad_val
|
| 96 |
-
|
| 97 |
-
# Append the new point to the list of points
|
| 98 |
-
st.session_state.points.append(x_new)
|
| 99 |
-
st.session_state.step += 1
|
| 100 |
-
|
| 101 |
-
st.success(f"Iteration {st.session_state.step} Complete!")
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
except Exception as e:
|
| 104 |
st.error(f"Error in iteration: {e}")
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
if 'func' in st.session_state and len(st.session_state.points) > 0:
|
| 108 |
try:
|
| 109 |
-
# Create x-values for plotting the function
|
| 110 |
x_val = np.linspace(-10, 10, 500)
|
| 111 |
y_val = st.session_state.func(x_val)
|
| 112 |
|
|
@@ -114,7 +251,8 @@ if 'func' in st.session_state and len(st.session_state.points) > 0:
|
|
| 114 |
iter_y = st.session_state.func(iter_points)
|
| 115 |
|
| 116 |
trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
|
| 117 |
-
trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines",
|
|
|
|
| 118 |
trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
|
| 119 |
marker=dict(color='green', size=15),
|
| 120 |
text=[f"{iter_points[-1]:.6f}"], textposition="top center",
|
|
@@ -129,8 +267,5 @@ if 'func' in st.session_state and len(st.session_state.points) > 0:
|
|
| 129 |
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
|
| 130 |
st.plotly_chart(fig, use_container_width=True)
|
| 131 |
st.success(f"Current Point = {iter_points[-1]}")
|
| 132 |
-
|
| 133 |
except Exception as e:
|
| 134 |
st.error(f"Plot error: {e}")
|
| 135 |
-
|
| 136 |
-
|
|
|
|
| 1 |
+
# import streamlit as st
|
| 2 |
+
# import numpy as np
|
| 3 |
+
# import sympy as sp
|
| 4 |
+
# import plotly.graph_objs as go
|
| 5 |
+
|
| 6 |
+
# st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# # Custom CSS for styling
|
| 10 |
+
# st.markdown("""
|
| 11 |
+
# <style>
|
| 12 |
+
# /* Page background */
|
| 13 |
+
# .stApp {
|
| 14 |
+
# background-color: #f9f9f9;
|
| 15 |
+
# font-family: 'Segoe UI', sans-serif;
|
| 16 |
+
# }
|
| 17 |
+
|
| 18 |
+
# /* Title */
|
| 19 |
+
# h1 {
|
| 20 |
+
# text-align: center;
|
| 21 |
+
# color: #2C3E50;
|
| 22 |
+
# font-size: 38px !important;
|
| 23 |
+
# font-weight: bold;
|
| 24 |
+
# margin-bottom: 20px;
|
| 25 |
+
# }
|
| 26 |
+
|
| 27 |
+
# /* Input boxes */
|
| 28 |
+
# .stTextInput > div > div > input {
|
| 29 |
+
# border: 2px solid #3498DB;
|
| 30 |
+
# border-radius: 8px;
|
| 31 |
+
# padding: 8px;
|
| 32 |
+
# }
|
| 33 |
+
|
| 34 |
+
# /* Buttons */
|
| 35 |
+
# div.stButton > button {
|
| 36 |
+
# background-color: #3498DB;
|
| 37 |
+
# color: white;
|
| 38 |
+
# border-radius: 10px;
|
| 39 |
+
# padding: 10px 24px;
|
| 40 |
+
# font-size: 16px;
|
| 41 |
+
# border: none;
|
| 42 |
+
# transition: 0.3s;
|
| 43 |
+
# }
|
| 44 |
+
# div.stButton > button:hover {
|
| 45 |
+
# background-color: #2980B9;
|
| 46 |
+
# transform: scale(1.05);
|
| 47 |
+
# }
|
| 48 |
+
|
| 49 |
+
# /* Success / Error messages */
|
| 50 |
+
# .stAlert {
|
| 51 |
+
# border-radius: 8px;
|
| 52 |
+
# }
|
| 53 |
+
|
| 54 |
+
# /* Chart section */
|
| 55 |
+
# .block-container {
|
| 56 |
+
# padding-top: 2rem;
|
| 57 |
+
# padding-bottom: 2rem;
|
| 58 |
+
# }
|
| 59 |
+
# </style>
|
| 60 |
+
# """, unsafe_allow_html=True)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# st.title("Gradient Descent Visualizer")
|
| 64 |
+
|
| 65 |
+
# x = sp.Symbol("x")
|
| 66 |
+
# # User input function, starting point, and learning rate
|
| 67 |
+
# func_input = st.text_input("Enter Function", "x^2")
|
| 68 |
+
# start_point = float(st.text_input("Starting Point", "2"))
|
| 69 |
+
# learning_rate = float(st.text_input("Learning Rate", "0.01"))
|
| 70 |
+
|
| 71 |
+
# if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
|
| 72 |
+
# try:
|
| 73 |
+
# expr = func_input.replace("^", "**")
|
| 74 |
+
# expr_final = sp.sympify(expr)
|
| 75 |
+
# func = sp.lambdify(x, expr_final, "numpy")
|
| 76 |
+
# grad = sp.diff(expr_final, x)
|
| 77 |
+
# gradient_func = sp.lambdify(x, grad, "numpy")
|
| 78 |
+
|
| 79 |
+
# st.session_state.func = func
|
| 80 |
+
# st.session_state.gradient_func = gradient_func
|
| 81 |
+
# st.session_state.points = [start_point]
|
| 82 |
+
# st.session_state.step = 0
|
| 83 |
+
# st.success("Function and Gradient Set Up Successfully!")
|
| 84 |
+
|
| 85 |
+
# except Exception as e:
|
| 86 |
+
# st.error(f"Error setting up function: {e}")
|
| 87 |
+
|
| 88 |
+
# # Gradient Descent Iteration button
|
| 89 |
+
# if 'func' in st.session_state and 'gradient_func' in st.session_state:
|
| 90 |
+
# if st.button("Next Iteration"):
|
| 91 |
+
# try:
|
| 92 |
+
# # Get the current point and gradient value
|
| 93 |
+
# x_old = float(st.session_state.points[-1])
|
| 94 |
+
# grad_val = st.session_state.gradient_func(x_old)
|
| 95 |
+
# x_new = x_old - learning_rate * grad_val
|
| 96 |
+
|
| 97 |
+
# # Append the new point to the list of points
|
| 98 |
+
# st.session_state.points.append(x_new)
|
| 99 |
+
# st.session_state.step += 1
|
| 100 |
+
|
| 101 |
+
# st.success(f"Iteration {st.session_state.step} Complete!")
|
| 102 |
+
|
| 103 |
+
# except Exception as e:
|
| 104 |
+
# st.error(f"Error in iteration: {e}")
|
| 105 |
+
|
| 106 |
+
# # Creating the plot
|
| 107 |
+
# if 'func' in st.session_state and len(st.session_state.points) > 0:
|
| 108 |
+
# try:
|
| 109 |
+
# # Create x-values for plotting the function
|
| 110 |
+
# x_val = np.linspace(-10, 10, 500)
|
| 111 |
+
# y_val = st.session_state.func(x_val)
|
| 112 |
+
|
| 113 |
+
# iter_points = np.array(st.session_state.points)
|
| 114 |
+
# iter_y = st.session_state.func(iter_points)
|
| 115 |
+
|
| 116 |
+
# trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
|
| 117 |
+
# trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", name="Gradient Descent Path", marker=dict(color="red"))
|
| 118 |
+
# trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
|
| 119 |
+
# marker=dict(color='green', size=15),
|
| 120 |
+
# text=[f"{iter_points[-1]:.6f}"], textposition="top center",
|
| 121 |
+
# name="Current Point")
|
| 122 |
+
|
| 123 |
+
# layout = go.Layout(
|
| 124 |
+
# title=f"Iteration {st.session_state.step}",
|
| 125 |
+
# xaxis=dict(title="x - axis"),
|
| 126 |
+
# yaxis=dict(title="y - axis")
|
| 127 |
+
# )
|
| 128 |
+
|
| 129 |
+
# fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
|
| 130 |
+
# st.plotly_chart(fig, use_container_width=True)
|
| 131 |
+
# st.success(f"Current Point = {iter_points[-1]}")
|
| 132 |
+
|
| 133 |
+
# except Exception as e:
|
| 134 |
+
# st.error(f"Plot error: {e}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
import streamlit as st
|
| 138 |
import numpy as np
|
| 139 |
import sympy as sp
|
| 140 |
import plotly.graph_objs as go
|
| 141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
st.markdown("""
|
| 143 |
<style>
|
|
|
|
| 144 |
.stApp {
|
| 145 |
background-color: #f9f9f9;
|
| 146 |
font-family: 'Segoe UI', sans-serif;
|
| 147 |
}
|
|
|
|
|
|
|
| 148 |
h1 {
|
| 149 |
text-align: center;
|
| 150 |
color: #2C3E50;
|
|
|
|
| 152 |
font-weight: bold;
|
| 153 |
margin-bottom: 20px;
|
| 154 |
}
|
|
|
|
|
|
|
| 155 |
.stTextInput > div > div > input {
|
| 156 |
border: 2px solid #3498DB;
|
| 157 |
border-radius: 8px;
|
| 158 |
padding: 8px;
|
| 159 |
}
|
|
|
|
|
|
|
| 160 |
div.stButton > button {
|
| 161 |
background-color: #3498DB;
|
| 162 |
color: white;
|
|
|
|
| 170 |
background-color: #2980B9;
|
| 171 |
transform: scale(1.05);
|
| 172 |
}
|
|
|
|
|
|
|
| 173 |
.stAlert {
|
| 174 |
border-radius: 8px;
|
| 175 |
}
|
|
|
|
|
|
|
| 176 |
.block-container {
|
| 177 |
padding-top: 2rem;
|
| 178 |
padding-bottom: 2rem;
|
|
|
|
| 180 |
</style>
|
| 181 |
""", unsafe_allow_html=True)
|
| 182 |
|
|
|
|
| 183 |
st.title("Gradient Descent Visualizer")
|
| 184 |
|
| 185 |
x = sp.Symbol("x")
|
| 186 |
+
|
| 187 |
func_input = st.text_input("Enter Function", "x^2")
|
| 188 |
start_point = float(st.text_input("Starting Point", "2"))
|
| 189 |
learning_rate = float(st.text_input("Learning Rate", "0.01"))
|
| 190 |
+
num_iterations = int(st.text_input("Number of Iterations", "10"))
|
| 191 |
|
| 192 |
if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
|
| 193 |
try:
|
|
|
|
| 206 |
except Exception as e:
|
| 207 |
st.error(f"Error setting up function: {e}")
|
| 208 |
|
|
|
|
| 209 |
if 'func' in st.session_state and 'gradient_func' in st.session_state:
|
| 210 |
if st.button("Next Iteration"):
|
| 211 |
try:
|
|
|
|
| 212 |
x_old = float(st.session_state.points[-1])
|
| 213 |
grad_val = st.session_state.gradient_func(x_old)
|
| 214 |
x_new = x_old - learning_rate * grad_val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
+
tolerance = 0.001
|
| 217 |
+
if abs(x_new - x_old) < tolerance or abs(grad_val) < tolerance:
|
| 218 |
+
st.success(f"Reached Minima at x = {x_new:.6f}")
|
| 219 |
+
else:
|
| 220 |
+
st.session_state.points.append(x_new)
|
| 221 |
+
st.session_state.step += 1
|
| 222 |
+
st.success(f"Iteration {st.session_state.step} Complete!")
|
| 223 |
except Exception as e:
|
| 224 |
st.error(f"Error in iteration: {e}")
|
| 225 |
|
| 226 |
+
if st.button("Run Iterations"):
|
| 227 |
+
try:
|
| 228 |
+
tolerance = 0.001
|
| 229 |
+
for i in range(num_iterations):
|
| 230 |
+
x_old = float(st.session_state.points[-1])
|
| 231 |
+
grad_val = st.session_state.gradient_func(x_old)
|
| 232 |
+
x_new = x_old - learning_rate * grad_val
|
| 233 |
+
|
| 234 |
+
if abs(x_new - x_old) < tolerance or abs(grad_val) < tolerance:
|
| 235 |
+
st.success(f"Reached Minima early at x = {x_new:.6f} (after {i+1} steps)")
|
| 236 |
+
break
|
| 237 |
+
|
| 238 |
+
st.session_state.points.append(x_new)
|
| 239 |
+
st.session_state.step += 1
|
| 240 |
+
|
| 241 |
+
st.success(f"Ran {st.session_state.step} Iterations in total")
|
| 242 |
+
except Exception as e:
|
| 243 |
+
st.error(f"Error in multiple iterations: {e}")
|
| 244 |
+
|
| 245 |
if 'func' in st.session_state and len(st.session_state.points) > 0:
|
| 246 |
try:
|
|
|
|
| 247 |
x_val = np.linspace(-10, 10, 500)
|
| 248 |
y_val = st.session_state.func(x_val)
|
| 249 |
|
|
|
|
| 251 |
iter_y = st.session_state.func(iter_points)
|
| 252 |
|
| 253 |
trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
|
| 254 |
+
trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines",
|
| 255 |
+
name="Gradient Descent Path", marker=dict(color="red"))
|
| 256 |
trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
|
| 257 |
marker=dict(color='green', size=15),
|
| 258 |
text=[f"{iter_points[-1]:.6f}"], textposition="top center",
|
|
|
|
| 267 |
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
|
| 268 |
st.plotly_chart(fig, use_container_width=True)
|
| 269 |
st.success(f"Current Point = {iter_points[-1]}")
|
|
|
|
| 270 |
except Exception as e:
|
| 271 |
st.error(f"Plot error: {e}")
|
|
|
|
|
|