Spaces:
Sleeping
Sleeping
File size: 6,183 Bytes
742dc62 816fe9c 2a2c734 6a14f39 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 1aeffd1 26eec67 559203b 2a2c734 559203b 26eec67 559203b 26eec67 1aeffd1 26eec67 559203b 26eec67 1aeffd1 26eec67 1aeffd1 26eec67 1aeffd1 559203b 2a2c734 8bc07a1 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 8bc07a1 2a2c734 6a14f39 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 2a2c734 fa7e3f8 6952440 2a2c734 fa7e3f8 6952440 2a2c734 fa7e3f8 2a2c734 6a14f39 2a2c734 5284381 2a2c734 6952440 2a2c734 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Page Configurations
st.set_page_config(page_title="Gradient Descent Visualization", layout="wide")
st.markdown("<h1 style='text-align: center; color: #FFD700;'>β‘ Gradient Descent Visualization β‘</h1>", unsafe_allow_html=True)
# Custom CSS for Background and Buttons
st.markdown("""
<style>
body {
background-color: #1E1E1E; /* Dark grey background */
color: white; /* White text for contrast */
}
.stButton>button {
background: linear-gradient(to right, #4CAF50, #2E7D32); /* Green gradient */
color: white; /* White text */
border-radius: 8px;
padding: 8px 16px;
font-size: 16px;
}
.stButton>button:hover {
background: linear-gradient(to right, #2E7D32, #4CAF50); /* Reverse gradient on hover */
}
.stMarkdown h3 {
color: #03A9F4; /* Blue color for section titles */
}
</style>
""", unsafe_allow_html=True)
# Safe Function Evaluation
def evaluate_function(expression, x_value):
allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
return eval(expression, {"_builtins_": None}, allowed_names)
# Compute Derivative
def compute_derivative(expression, x_value, h=1e-5):
return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
# Tangent Line Calculation
def calculate_tangent(expression, x_value, x_range):
y_value = evaluate_function(expression, x_value)
slope = compute_derivative(expression, x_value)
return slope * (x_range - x_value) + y_value
# Reset Session State
def reset_session_state():
st.session_state.x_current = st.session_state.initial_point
st.session_state.iter_count = 0
st.session_state.history = [
(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))
]
st.session_state.current_index = 0
# Initialize Session State
if "x_current" not in st.session_state:
st.session_state.x_current = 0.0
if "iter_count" not in st.session_state:
st.session_state.iter_count = 0
if "history" not in st.session_state:
st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))]
if "current_index" not in st.session_state:
st.session_state.current_index = 0
if "learning_rate" not in st.session_state:
st.session_state.learning_rate = 0.1
# Layout
left_col, right_col = st.columns([1, 2])
# Left Column: Inputs
with left_col:
st.markdown("### π Define Your Equation")
function_input = st.text_input(
"Input Equation (e.g., `x**2`, `np.sin(x)`):",
"x**2 + x",
key="math_function",
on_change=reset_session_state
)
st.markdown("### π§ Configure Settings")
initial_point = st.number_input(
"Starting Value of x:",
value=4.0,
step=0.1,
format="%.2f",
key="initial_point",
on_change=reset_session_state
)
st.number_input(
"Step Size (Learning Rate):",
value=st.session_state.learning_rate,
step=0.01,
format="%.2f",
key="learning_rate"
)
st.markdown("### πΉοΈ Actions")
col1, col2 = st.columns(2)
with col1:
if st.button("π Compute Next Step"):
try:
gradient = compute_derivative(function_input, st.session_state.x_current)
st.session_state.x_current -= st.session_state.learning_rate * gradient
st.session_state.iter_count += 1
st.session_state.history.append(
(st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
)
st.session_state.current_index = st.session_state.iter_count
except Exception as e:
st.error(f"Error: {str(e)}")
with col2:
if st.button("π Restart"):
reset_session_state()
# Right Column: Visualization
with right_col:
st.markdown("### π Gradient Descent Steps")
# Navigation Buttons
col1, col2, col3 = st.columns(3)
with col1:
if st.button("β¬
οΈ Previous Step") and st.session_state.current_index > 0:
st.session_state.current_index -= 1
with col2:
st.markdown(f"<p style='text-align: center;'>Step Count: <strong>{st.session_state.current_index}</strong></p>", unsafe_allow_html=True)
with col3:
if st.button("β‘οΈ Next Step") and st.session_state.current_index < st.session_state.iter_count:
st.session_state.current_index += 1
try:
selected_x, selected_y = st.session_state.history[st.session_state.current_index]
st.markdown(f"π **Current x:** `{selected_x:.4f}`")
st.markdown(f"π **f(x) at Current Step:** `{selected_y:.4f}`")
except IndexError:
st.warning("No data to display. Perform a computation first.")
# Visualization
x_range = np.linspace(-10, 10, 500)
y_range = [evaluate_function(function_input, x) for x in x_range]
fig = go.Figure()
fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Equation', line=dict(color='#FFD700')))
# Get current position from history
x_current, y_current = st.session_state.history[st.session_state.current_index]
fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Position', marker=dict(size=12, color='#FF4500')))
# Calculate and plot the updated tangent line at the current position
tangent_y = calculate_tangent(function_input, x_current, x_range)
fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent', line=dict(dash='dash', color='#00FFFF')))
fig.update_layout(
title="Gradient Descent Progress",
xaxis_title="x",
yaxis_title="f(x)",
template="plotly_dark",
height=500,
width=900,
margin=dict(l=20, r=20, t=50, b=20),
)
st.plotly_chart(fig, use_container_width=True)
|