Spaces:
Sleeping
Sleeping
File size: 6,604 Bytes
742dc62 816fe9c 647d9b5 209a2c7 647d9b5 15f83dc 209a2c7 1aeffd1 209a2c7 1aeffd1 209a2c7 15f83dc 647d9b5 15f83dc 1aeffd1 209a2c7 15f83dc 1aeffd1 209a2c7 1aeffd1 209a2c7 1aeffd1 559203b 647d9b5 1aeffd1 559203b 1aeffd1 559203b 647d9b5 559203b 1aeffd1 647d9b5 1aeffd1 647d9b5 559203b 647d9b5 559203b 1aeffd1 647d9b5 1aeffd1 647d9b5 1aeffd1 559203b 209a2c7 1aeffd1 15f83dc 1aeffd1 647d9b5 1aeffd1 647d9b5 1aeffd1 647d9b5 1aeffd1 647d9b5 1aeffd1 647d9b5 15f83dc 1aeffd1 559203b 15f83dc 1aeffd1 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 15f83dc 647d9b5 209a2c7 647d9b5 15f83dc 647d9b5 209a2c7 647d9b5 15f83dc 647d9b5 209a2c7 647d9b5 15f83dc 647d9b5 209a2c7 15f83dc 209a2c7 647d9b5 15f83dc 647d9b5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Title of the app
st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
st.markdown("<h1 style='text-align: center; color: #00FA9A;'>โจ Gradient Descent Visualizer โจ</h1>", unsafe_allow_html=True)
# Custom CSS for enhanced UI
st.markdown("""
<style>
body {
background: linear-gradient(to right, #141E30, #243B55);
color: #E0FFFF;
}
.stButton>button {
background: linear-gradient(to right, #00C6FF, #0072FF);
color: white;
border: none;
border-radius: 10px;
padding: 10px 15px;
font-size: 16px;
font-weight: bold;
}
.stButton>button:hover {
background: linear-gradient(to right, #0072FF, #00C6FF);
}
.iteration-controls button {
width: 100%;
margin: 5px 0;
}
.block-container {
padding: 0;
}
</style>
""", unsafe_allow_html=True)
# Safe function evaluation
def evaluate_function(expression, x_value):
"""Safely evaluates the mathematical function."""
allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
return eval(expression, {"_builtins_": None}, allowed_names)
# Compute derivative using finite difference
def compute_derivative(expression, x_value, h=1e-5):
"""Numerically calculates the derivative at a given point."""
return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
# Tangent line calculation
def calculate_tangent(expression, x_value, x_range):
"""Generates the tangent line for a given point."""
y_value = evaluate_function(expression, x_value)
slope = compute_derivative(expression, x_value)
return slope * (x_range - x_value) + y_value
# Reset state
def reset_session_state():
"""Resets the session state for a fresh start."""
st.session_state.x_current = st.session_state.initial_point
st.session_state.iter_count = 0
st.session_state.history = [
(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))
]
st.session_state.current_index = 0
# Initialize session state variables
if "x_current" not in st.session_state:
st.session_state.x_current = 0.0 # Default starting point
if "iter_count" not in st.session_state:
st.session_state.iter_count = 0
if "history" not in st.session_state:
st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example
if "current_index" not in st.session_state:
st.session_state.current_index = 0
if "learning_rate" not in st.session_state:
st.session_state.learning_rate = 0.1
# Create a two-column layout with equal widths
left_col, right_col = st.columns(2)
# Left side content
with left_col:
st.markdown("### ๐งฎ Input Your Function")
function_input = st.text_input(
"Enter Function: Example: `x**2`, `np.sin(x)`",
"x**2 + x",
key="math_function",
on_change=reset_session_state
)
st.markdown("### โ๏ธ Set Parameters")
initial_point = st.number_input(
"๐ข Initial Value of x",
value=4.0,
step=0.1,
format="%.2f",
key="initial_point",
on_change=reset_session_state
)
st.number_input(
"๐ Learning Rate",
value=st.session_state.learning_rate,
step=0.01,
format="%.2f",
key="learning_rate"
) # Updates session state directly without reset
st.markdown("### ๐๏ธ Controls")
if st.button("๐ Run Descent Step"):
try:
gradient = compute_derivative(function_input, st.session_state.x_current)
st.session_state.x_current -= st.session_state.learning_rate * gradient
st.session_state.iter_count += 1
st.session_state.history.append(
(st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
)
st.session_state.current_index = st.session_state.iter_count
except Exception as e:
st.error(f"Error: {str(e)}")
if st.button("๐ Reset"):
reset_session_state()
# Right side content
with right_col:
st.markdown("### ๐ Gradient Descent Visualization")
# Iteration control buttons
col1, col2, col3 = st.columns([1, 1, 1])
with col1:
if st.button("โฎ๏ธ Previous") and st.session_state.current_index > 0:
st.session_state.current_index -= 1
with col2:
st.markdown(f"<p style='text-align: center;'>Iteration: <strong>{st.session_state.current_index}</strong></p>", unsafe_allow_html=True)
with col3:
if st.button("โญ๏ธ Next") and st.session_state.current_index < st.session_state.iter_count:
st.session_state.current_index += 1
try:
selected_x, selected_y = st.session_state.history[st.session_state.current_index]
st.markdown(f"๐งพ **x Value:** `{selected_x:.4f}`")
st.markdown(f"๐ **f(x):** `{selected_y:.4f}`")
except IndexError:
st.warning("No iteration data available. Please run a descent step first.")
# Prepare data for visualization
x_range = np.linspace(-10, 10, 500)
y_range = [evaluate_function(function_input, x) for x in x_range]
# Plot function curve
fig = go.Figure()
fig.add_trace(go.Scatter(
x=x_range,
y=y_range,
mode='lines',
name='Function',
line=dict(color='blue') # Blue color for curve
))
# Add current point
x_current, y_current = st.session_state.history[st.session_state.current_index]
fig.add_trace(go.Scatter(
x=[x_current],
y=[y_current],
mode='markers',
name='Current Point',
marker=dict(size=12, color='red') # Red for current point
))
# Add tangent line
tangent_y = calculate_tangent(function_input, x_current, x_range)
fig.add_trace(go.Scatter(
x=x_range,
y=tangent_y,
mode='lines',
name='Tangent Line',
line=dict(dash='dash', color='yellow') # Yellow dashed line for tangent
))
# Layout adjustments
fig.update_layout(
title="Gradient Descent Progress ๐",
xaxis_title="x",
yaxis_title="f(x)",
template="plotly_dark",
height=600
)
st.plotly_chart(fig, use_container_width=True)
|