Spaces:
Sleeping
Sleeping
File size: 7,605 Bytes
bcde049 f3ec8f2 f216914 aa96b75 c932968 90c2389 b260d74 86e0690 b260d74 c932968 b260d74 86e0690 c932968 86e0690 e8d0245 c932968 86e0690 c932968 86e0690 ce2a035 aa96b75 7f37c44 aa96b75 7f37c44 c9d0458 aa96b75 a43c583 c8541e0 a43c583 f216914 6acb70a aa96b75 c932968 f216914 e8d0245 f216914 c932968 aa96b75 f216914 aa96b75 7f37c44 f216914 7f37c44 f216914 6acb70a c932968 c0093b4 f216914 aa96b75 7f37c44 aa96b75 a43c583 c9d0458 aa96b75 c0093b4 f216914 c932968 62e27a7 f216914 a35cd4f f216914 a35cd4f f216914 a35cd4f 62e27a7 a43c583 c932968 a43c583 5be05c5 62e27a7 5be05c5 f216914 62e27a7 f216914 5be05c5 62e27a7 f216914 5be05c5 62e27a7 f216914 5be05c5 62e27a7 5be05c5 62e27a7 5be05c5 62e27a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import streamlit as st
import numpy as np
import plotly.graph_objects as go
# Title of the app
st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
st.markdown("<h1 style='text-align: center; color: #FFD700;'> π Gradient Descent Visualizer</h1>", unsafe_allow_html=True)
# Custom CSS for background and button color
st.markdown("""
<style>
body {
background-color: #121212; /* Dark gray background for modern look */
color: white; /* White text for contrast */
}
.stButton>button {
background: linear-gradient(45deg, #FF7F50, #FF4500); /* Coral to OrangeRed gradient */
color: white; /* White button text */
border: none;
border-radius: 8px;
padding: 10px 20px;
font-size: 16px;
font-weight: bold;
transition: transform 0.2s ease, box-shadow 0.3s ease, filter 0.3s ease; /* Smooth hover effects */
}
.stButton>button:hover {
transform: scale(1.1); /* Slight zoom effect on hover */
box-shadow: 0 0 20px 10px rgba(255, 69, 0, 0.8); /* Glowing shadow effect */
background: linear-gradient(45deg, #FF4500, #FF7F50); /* Reverse gradient */
filter: brightness(1.2); /* Slightly brightens the button */
}
h1, h2, h3 {
color: #00FFFF; /* Aqua for headings */
}
.custom-text {
color: #FFD700; /* Gold for highlighted text */
font-weight: bold;
}
</style>
""", unsafe_allow_html=True)
# Safe function evaluation
def evaluate_function(expression, x_value):
"""Safely evaluates the mathematical function."""
allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
return eval(expression, {"_builtins_": None}, allowed_names)
# Compute derivative using finite difference
def compute_derivative(expression, x_value, h=1e-5):
"""Numerically calculates the derivative at a given point."""
return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
# Tangent line calculation
def calculate_tangent(expression, x_value, x_range):
"""Generates the tangent line for a given point."""
y_value = evaluate_function(expression, x_value)
slope = compute_derivative(expression, x_value)
return slope * (x_range - x_value) + y_value
# Reset state
def reset_session_state():
"""Resets the session state for a fresh start."""
st.session_state.x_current = st.session_state.initial_point
st.session_state.iter_count = 0
st.session_state.history = [
(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))
]
st.session_state.current_index = 0
# Initialize session state variables
if "x_current" not in st.session_state:
st.session_state.x_current = 0.0 # Default starting point
if "iter_count" not in st.session_state:
st.session_state.iter_count = 0
if "history" not in st.session_state:
st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example
if "current_index" not in st.session_state:
st.session_state.current_index = 0
if "learning_rate" not in st.session_state:
st.session_state.learning_rate = 0.1
# Create two-column grid layout for the left side (more space for the right graph)
left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
# Left side content (Function Input and Gradient Descent Parameters)
with left_col:
st.markdown("<h3 style='color: #7FFF00;'>Input Your Function</h3>", unsafe_allow_html=True)
function_input = st.text_input(
"Enter Function:`Ex:'x**2`,`np.sin(x)`",
"x**2 + x",
key="math_function",
on_change=reset_session_state
)
st.markdown("<h3 style='color: #FF69B4;'>Set Parameters</h3>", unsafe_allow_html=True)
initial_point = st.number_input(
"Initial Value of x",
value=4.0,
step=0.1,
format="%.2f",
key="initial_point",
on_change=reset_session_state
)
st.number_input(
"Learning Rate",
value=st.session_state.learning_rate,
step=0.01,
format="%.2f",
key="learning_rate"
) # Updates session state directly without reset
st.markdown("<h3 style='color: #1E90FF;'>Controls</h3>", unsafe_allow_html=True)
if st.button("π Run Descent Step", type="primary"):
try:
gradient = compute_derivative(function_input, st.session_state.x_current)
st.session_state.x_current -= st.session_state.learning_rate * gradient
st.session_state.iter_count += 1
st.session_state.history.append(
(st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
)
st.session_state.current_index = st.session_state.iter_count
except Exception as e:
st.error(f"Error: {str(e)}")
if st.button("π Reset"):
reset_session_state()
# Right side content (Visualization and Iteration Details)
with right_col:
st.markdown("<h3 style='color: #FF6347;'>Gradient Descent Visualization</h3>", unsafe_allow_html=True)
# Display iteration details using buttons
col1, col2, col3 = st.columns(3)
with col1:
if st.button("β¬
οΈ Previous Iteration") and st.session_state.current_index > 0:
st.session_state.current_index -= 1
with col2:
st.markdown(f"**Iteration:** {st.session_state.current_index}", unsafe_allow_html=True)
with col3:
if st.button("β‘οΈ Next Iteration") and st.session_state.current_index < st.session_state.iter_count:
st.session_state.current_index += 1
try:
selected_x, selected_y = st.session_state.history[st.session_state.current_index]
st.markdown(f"x Value: <span style='color: #FFD700;'>{selected_x:.4f}</span>", unsafe_allow_html=True)
st.markdown(f"f(x): <span style='color: #FFD700;'>{selected_y:.4f}</span>", unsafe_allow_html=True)
except IndexError:
st.warning("No iteration data available. Please run a descent step first.")
# Prepare data for visualization
x_range = np.linspace(-10, 10, 500) # Define range for x
y_range = [evaluate_function(function_input, x) for x in x_range]
# Plot function curve with orange color
fig = go.Figure()
fig.add_trace(go.Scatter(
x=x_range,
y=y_range,
mode='lines',
name='Function',
line=dict(color='orange') # Curve color set to orange
))
# Add current point
x_current, y_current = st.session_state.history[st.session_state.current_index]
fig.add_trace(go.Scatter(
x=[x_current],
y=[y_current],
mode='markers',
name='Current Point',
marker=dict(size=10, color='red')
))
# Add tangent line
tangent_y = calculate_tangent(function_input, x_current, x_range)
fig.add_trace(go.Scatter(
x=x_range,
y=tangent_y,
mode='lines',
name='Tangent Line',
line=dict(dash='dash', color='blue') # Tangent line in blue
))
# Layout adjustments
fig.update_layout(
title="Gradient Descent Progress",
xaxis_title="x",
yaxis_title="f(x)",
template="plotly_white",
height=600
)
st.plotly_chart(fig, use_container_width=True)
|