|
|
import streamlit as st
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import base64
|
|
|
import plotly.graph_objects as go
|
|
|
|
|
|
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
|
|
def parse_function(func_str, x):
|
|
|
try:
|
|
|
return eval(func_str)
|
|
|
except Exception as e:
|
|
|
st.error(f"Error evaluating function: {e}")
|
|
|
return np.zeros_like(x)
|
|
|
|
|
|
|
|
|
def compute_gradient(func_str, x):
|
|
|
delta = 1e-8
|
|
|
grad = (parse_function(func_str, x + delta) - parse_function(func_str, x)) / delta
|
|
|
return grad
|
|
|
|
|
|
|
|
|
def encode_image(image_path):
|
|
|
with open(image_path, "rb") as image_file:
|
|
|
return base64.b64encode(image_file.read()).decode()
|
|
|
|
|
|
def add_bg_from_local(image_file):
|
|
|
encoded_string = encode_image(image_file)
|
|
|
st.markdown(
|
|
|
f"""
|
|
|
<style>
|
|
|
.stApp {{
|
|
|
background-image: url(data:image/{"png"};base64,{encoded_string});
|
|
|
background-size: cover;
|
|
|
background-repeat: no-repeat;
|
|
|
background-attachment: fixed;
|
|
|
}}
|
|
|
</style>
|
|
|
""",
|
|
|
unsafe_allow_html=True
|
|
|
)
|
|
|
|
|
|
add_bg_from_local("Icons/rm183-kul-21.jpg")
|
|
|
|
|
|
st.markdown(
|
|
|
"""
|
|
|
<style>
|
|
|
.reportview-container {
|
|
|
background: "white"
|
|
|
}
|
|
|
</style>
|
|
|
""",
|
|
|
unsafe_allow_html=True
|
|
|
)
|
|
|
|
|
|
st.markdown("""
|
|
|
<style>
|
|
|
body {
|
|
|
font-family: 'Roboto', sans-serif;
|
|
|
}
|
|
|
.stButton>button {
|
|
|
color: white;
|
|
|
border-radius: 8px;
|
|
|
padding: 10px 20px;
|
|
|
font-weight: bold;
|
|
|
transition: background-color 0.3s ease;
|
|
|
}
|
|
|
.stButton>button:hover {
|
|
|
background-color: Black;
|
|
|
border-color: white;
|
|
|
color: white;
|
|
|
}
|
|
|
.sidebar .sidebar-content {
|
|
|
padding: 2rem;
|
|
|
}
|
|
|
.stApp {
|
|
|
font-family: 'Roboto', sans-serif;
|
|
|
}
|
|
|
</style>
|
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
file_ = open("Icons/wave-chart-ezgif.com-gif-maker.gif", "rb").read()
|
|
|
base64_gif = base64.b64encode(file_).decode("utf-8")
|
|
|
|
|
|
st.markdown(
|
|
|
f"""
|
|
|
<h1 style='text-align: center; color: Black; margin-top: -50px; padding-top: 0px;'>
|
|
|
Interactive Gradient Descent Visualizer
|
|
|
<img src="data:image/gif;base64,{base64_gif}" alt="Icon" style="width: 85px; margin-right: 10px;">
|
|
|
</h1>
|
|
|
""",
|
|
|
unsafe_allow_html=True
|
|
|
)
|
|
|
st.markdown("""
|
|
|
<p style="color: black;">
|
|
|
Explore how gradient descent works visually and interactively.
|
|
|
Adjust parameters and watch as the algorithm converges towards the minimum of a function.
|
|
|
</p>
|
|
|
""",
|
|
|
unsafe_allow_html=True)
|
|
|
|
|
|
st.sidebar.header("Input Parameters")
|
|
|
|
|
|
|
|
|
function_options = ["x**2", "x**3", "np.sin(x)", "1/x", "Custom Polynomial"]
|
|
|
selected_function = st.sidebar.selectbox("Choose a function:", function_options)
|
|
|
|
|
|
if selected_function == "Custom Polynomial":
|
|
|
func_str = st.sidebar.text_input("Enter custom polynomial in terms of x:", value="x**2 - 4*x + 4")
|
|
|
else:
|
|
|
func_str = st.sidebar.text_input(f"Modify the selected function ({selected_function}):", value=selected_function)
|
|
|
|
|
|
|
|
|
if "x_vals" not in st.session_state:
|
|
|
st.session_state.x_vals = []
|
|
|
if "y_vals" not in st.session_state:
|
|
|
st.session_state.y_vals = []
|
|
|
if "current_step" not in st.session_state:
|
|
|
st.session_state.current_step = 0
|
|
|
|
|
|
|
|
|
initial_x = st.sidebar.number_input("Initial Point (x):", value=0.00)
|
|
|
learning_rate = st.sidebar.number_input("Learning Rate:", value=0.1, step=0.01, format="%.2f")
|
|
|
|
|
|
|
|
|
if st.session_state.current_step == 0:
|
|
|
st.session_state.x_vals = [initial_x]
|
|
|
st.session_state.y_vals = [parse_function(func_str, initial_x)]
|
|
|
|
|
|
col1, col2 = st.sidebar.columns(2)
|
|
|
|
|
|
if col1.button("Reset Graph"):
|
|
|
st.session_state.x_vals = [initial_x]
|
|
|
st.session_state.y_vals = [parse_function(func_str, initial_x)]
|
|
|
st.session_state.current_step = 0
|
|
|
|
|
|
if col2.button("Next Iteration"):
|
|
|
current_x = st.session_state.x_vals[-1]
|
|
|
grad = compute_gradient(func_str, current_x)
|
|
|
next_x = current_x - learning_rate * grad
|
|
|
st.session_state.x_vals.append(next_x)
|
|
|
st.session_state.y_vals.append(parse_function(func_str, next_x))
|
|
|
st.session_state.current_step += 1
|
|
|
|
|
|
x_vals = np.linspace(-20, 30, 1000)
|
|
|
y_vals = parse_function(func_str, x_vals)
|
|
|
|
|
|
fig = go.Figure()
|
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function Curve', line=dict(color='teal')))
|
|
|
|
|
|
|
|
|
if st.session_state.current_step > 0:
|
|
|
fig.add_trace(go.Scatter(
|
|
|
x=st.session_state.x_vals, y=st.session_state.y_vals,
|
|
|
mode='markers+lines', name='Gradient Descent Steps',
|
|
|
marker=dict(color='red', size=10), line=dict(dash='dash', width=1.5)
|
|
|
))
|
|
|
|
|
|
|
|
|
def draw_tangent(fig, func_str, x_point):
|
|
|
y_point = parse_function(func_str, x_point)
|
|
|
grad = compute_gradient(func_str, x_point)
|
|
|
tangent_x = np.linspace(-20, 30, 1000)
|
|
|
tangent_y = grad * (tangent_x - x_point) + y_point
|
|
|
|
|
|
fig.add_trace(go.Scatter(
|
|
|
x=tangent_x, y=tangent_y, mode='lines', name=f'Tangent at x={x_point:.2f}',
|
|
|
line=dict(dash='dot', color='green', width=2)
|
|
|
))
|
|
|
fig.add_trace(go.Scatter(
|
|
|
x=[x_point], y=[y_point], mode='markers', name='Tangent Point',
|
|
|
marker=dict(color='blue', size=12, symbol='circle')
|
|
|
))
|
|
|
|
|
|
|
|
|
if len(st.session_state.x_vals) > 0:
|
|
|
draw_tangent(fig, func_str, st.session_state.x_vals[-1])
|
|
|
|
|
|
fig.update_layout(
|
|
|
shapes=[
|
|
|
dict(type="line", x0=-20, y0=0, x1=30, y1=0, line=dict(color="black", width=2)),
|
|
|
dict(type="line", x0=0, y0=-110, x1=0, y1=120, line=dict(color="black", width=2))
|
|
|
],
|
|
|
xaxis=dict(
|
|
|
title='x',
|
|
|
range=[-20, 30],
|
|
|
showline=True,
|
|
|
linecolor='black',
|
|
|
linewidth=2,
|
|
|
mirror=True,
|
|
|
ticks='inside',
|
|
|
tickfont=dict(color='black'),
|
|
|
titlefont=dict(color='black'),
|
|
|
),
|
|
|
yaxis=dict(
|
|
|
title='y',
|
|
|
range=[-110, 120],
|
|
|
showline=True,
|
|
|
linecolor='Black',
|
|
|
linewidth=2,
|
|
|
mirror=True,
|
|
|
ticks='inside',
|
|
|
tickfont=dict(color='black'),
|
|
|
titlefont=dict(color='black'),
|
|
|
),
|
|
|
|
|
|
plot_bgcolor= 'rgba(0, 0, 0, 0)',
|
|
|
paper_bgcolor= 'rgba(0, 0, 0, 0)',
|
|
|
font=dict(color='black'),
|
|
|
legend=dict(
|
|
|
font=dict(color='black'),
|
|
|
x=1.05,
|
|
|
xanchor='left',
|
|
|
y=1,
|
|
|
yanchor='top'
|
|
|
),
|
|
|
|
|
|
width=800, height=500,
|
|
|
template="plotly_white",
|
|
|
title="Gradient Descent on the Selected Function",
|
|
|
titlefont=dict(color='black'),
|
|
|
margin=dict(l=50, r=50, t=50, b=50),
|
|
|
)
|
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True)
|
|
|
|
|
|
if st.session_state.current_step > 0:
|
|
|
iteration_data = {
|
|
|
"Iteration": list(range(st.session_state.current_step + 1)),
|
|
|
"x Value": [f"{x_val:.5f}" for x_val in st.session_state.x_vals],
|
|
|
"y Value": [f"{y_val:.5f}" for y_val in st.session_state.y_vals]
|
|
|
}
|
|
|
|
|
|
iteration_df = pd.DataFrame(iteration_data)
|
|
|
|
|
|
st.markdown("<h3 style='color: black;'>Iteration Details</h3>", unsafe_allow_html=True)
|
|
|
st.markdown(
|
|
|
iteration_df.to_html(index=False, escape=False),
|
|
|
unsafe_allow_html=True
|
|
|
)
|
|
|
|
|
|
|
|
|
st.markdown("""
|
|
|
<style>
|
|
|
.dataframe {
|
|
|
color: black;
|
|
|
font-size: 14px;
|
|
|
border-collapse: collapse;
|
|
|
width: 100%;
|
|
|
}
|
|
|
.dataframe th, .dataframe td {
|
|
|
padding: 8px;
|
|
|
text-align: center;
|
|
|
border: 1px solid black;
|
|
|
}
|
|
|
.dataframe th {
|
|
|
background-color: #f2f2f2;
|
|
|
border: 2px solid black;
|
|
|
}
|
|
|
</style>
|
|
|
""", unsafe_allow_html=True)
|
|
|
|
|
|
st.sidebar.subheader("Current Status")
|
|
|
st.sidebar.write(f"Iteration: {st.session_state.current_step}")
|
|
|
st.sidebar.write(f"Current x: {st.session_state.x_vals[-1]:.5f}")
|
|
|
st.sidebar.write(f"Current y: {st.session_state.y_vals[-1]:.5f}")
|
|
|
|