|
|
import streamlit as st |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import base64 |
|
|
import plotly.graph_objects as go |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
|
|
|
|
|
|
|
def parse_function(func_str, x): |
|
|
try: |
|
|
return eval(func_str) |
|
|
except Exception as e: |
|
|
st.error(f"Error evaluating function: {e}") |
|
|
return np.zeros_like(x) |
|
|
|
|
|
|
|
|
def compute_gradient(func_str, x): |
|
|
delta = 1e-8 |
|
|
grad = (parse_function(func_str, x + delta) - parse_function(func_str, x)) / delta |
|
|
return grad |
|
|
|
|
|
|
|
|
def encode_image(image_path): |
|
|
with open(image_path, "rb") as image_file: |
|
|
return base64.b64encode(image_file.read()).decode() |
|
|
|
|
|
def add_bg_from_local(image_file): |
|
|
encoded_string = encode_image(image_file) |
|
|
st.markdown( |
|
|
f""" |
|
|
<style> |
|
|
.stApp {{ |
|
|
background-image: url(data:image/{"png"};base64,{encoded_string}); |
|
|
background-size: cover; |
|
|
background-repeat: no-repeat; |
|
|
background-attachment: fixed; |
|
|
}} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
add_bg_from_local("Icon/rm183-kul-21.jpg") |
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
.reportview-container { |
|
|
background: "white" |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
body { |
|
|
font-family: 'Roboto', sans-serif; |
|
|
} |
|
|
.stButton>button { |
|
|
color: white; |
|
|
border-radius: 8px; |
|
|
padding: 10px 20px; |
|
|
font-weight: bold; |
|
|
transition: background-color 0.3s ease; |
|
|
} |
|
|
.stButton>button:hover { |
|
|
background-color: Black; |
|
|
border-color: white; |
|
|
color: white; |
|
|
} |
|
|
.sidebar .sidebar-content { |
|
|
padding: 2rem; |
|
|
} |
|
|
.stApp { |
|
|
font-family: 'Roboto', sans-serif; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
file_ = open("Icon/wave-chart-ezgif.com-gif-maker.gif", "rb").read() |
|
|
base64_gif = base64.b64encode(file_).decode("utf-8") |
|
|
|
|
|
st.markdown( |
|
|
f""" |
|
|
<h1 style='text-align: center; color: Black; margin-top: -50px; padding-top: 0px;'> |
|
|
Interactive Gradient Descent Visualizer |
|
|
<img src="data:image/gif;base64,{base64_gif}" alt="Icon" style="width: 85px; margin-right: 10px;"> |
|
|
</h1> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
st.markdown(""" |
|
|
<p style="color: black;"> |
|
|
Explore how gradient descent works visually and interactively. |
|
|
Adjust parameters and watch as the algorithm converges towards the minimum of a function. |
|
|
</p> |
|
|
""", |
|
|
unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.header("Input Parameters") |
|
|
|
|
|
|
|
|
function_options = ["x**2", "x**3", "np.sin(x)", "1/x", "Custom Polynomial"] |
|
|
selected_function = st.sidebar.selectbox("Choose a function:", function_options) |
|
|
|
|
|
if selected_function == "Custom Polynomial": |
|
|
func_str = st.sidebar.text_input("Enter custom polynomial in terms of x:", value="x**2 - 4*x + 4") |
|
|
else: |
|
|
func_str = st.sidebar.text_input(f"Modify the selected function ({selected_function}):", value=selected_function) |
|
|
|
|
|
|
|
|
if "x_vals" not in st.session_state: |
|
|
st.session_state.x_vals = [] |
|
|
if "y_vals" not in st.session_state: |
|
|
st.session_state.y_vals = [] |
|
|
if "current_step" not in st.session_state: |
|
|
st.session_state.current_step = 0 |
|
|
|
|
|
|
|
|
initial_x = st.sidebar.number_input("Initial Point (x):", value=0.00) |
|
|
learning_rate = st.sidebar.number_input("Learning Rate:", value=0.1, step=0.01, format="%.2f") |
|
|
|
|
|
|
|
|
if st.session_state.current_step == 0: |
|
|
st.session_state.x_vals = [initial_x] |
|
|
st.session_state.y_vals = [parse_function(func_str, initial_x)] |
|
|
|
|
|
col1, col2 = st.sidebar.columns(2) |
|
|
|
|
|
if col1.button("Reset Graph"): |
|
|
st.session_state.x_vals = [initial_x] |
|
|
st.session_state.y_vals = [parse_function(func_str, initial_x)] |
|
|
st.session_state.current_step = 0 |
|
|
|
|
|
if col2.button("Next Iteration"): |
|
|
current_x = st.session_state.x_vals[-1] |
|
|
grad = compute_gradient(func_str, current_x) |
|
|
next_x = current_x - learning_rate * grad |
|
|
st.session_state.x_vals.append(next_x) |
|
|
st.session_state.y_vals.append(parse_function(func_str, next_x)) |
|
|
st.session_state.current_step += 1 |
|
|
|
|
|
x_vals = np.linspace(-20, 30, 1000) |
|
|
y_vals = parse_function(func_str, x_vals) |
|
|
|
|
|
fig = go.Figure() |
|
|
|
|
|
|
|
|
fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Function Curve', line=dict(color='teal'))) |
|
|
|
|
|
|
|
|
if st.session_state.current_step > 0: |
|
|
fig.add_trace(go.Scatter( |
|
|
x=st.session_state.x_vals, y=st.session_state.y_vals, |
|
|
mode='markers+lines', name='Gradient Descent Steps', |
|
|
marker=dict(color='red', size=10), line=dict(dash='dash', width=1.5) |
|
|
)) |
|
|
|
|
|
|
|
|
def draw_tangent(fig, func_str, x_point): |
|
|
y_point = parse_function(func_str, x_point) |
|
|
grad = compute_gradient(func_str, x_point) |
|
|
tangent_x = np.linspace(-20, 30, 1000) |
|
|
tangent_y = grad * (tangent_x - x_point) + y_point |
|
|
|
|
|
fig.add_trace(go.Scatter( |
|
|
x=tangent_x, y=tangent_y, mode='lines', name=f'Tangent at x={x_point:.2f}', |
|
|
line=dict(dash='dot', color='green', width=2) |
|
|
)) |
|
|
fig.add_trace(go.Scatter( |
|
|
x=[x_point], y=[y_point], mode='markers', name='Tangent Point', |
|
|
marker=dict(color='blue', size=12, symbol='circle') |
|
|
)) |
|
|
|
|
|
|
|
|
if len(st.session_state.x_vals) > 0: |
|
|
draw_tangent(fig, func_str, st.session_state.x_vals[-1]) |
|
|
|
|
|
fig.update_layout( |
|
|
shapes=[ |
|
|
dict(type="line", x0=-20, y0=0, x1=30, y1=0, line=dict(color="black", width=2)), |
|
|
dict(type="line", x0=0, y0=-110, x1=0, y1=120, line=dict(color="black", width=2)) |
|
|
], |
|
|
xaxis=dict( |
|
|
title='x', |
|
|
range=[-20, 30], |
|
|
showline=True, |
|
|
linecolor='black', |
|
|
linewidth=2, |
|
|
mirror=True, |
|
|
ticks='inside', |
|
|
tickfont=dict(color='black'), |
|
|
titlefont=dict(color='black'), |
|
|
), |
|
|
yaxis=dict( |
|
|
title='y', |
|
|
range=[-110, 120], |
|
|
showline=True, |
|
|
linecolor='Black', |
|
|
linewidth=2, |
|
|
mirror=True, |
|
|
ticks='inside', |
|
|
tickfont=dict(color='black'), |
|
|
titlefont=dict(color='black'), |
|
|
), |
|
|
|
|
|
plot_bgcolor= 'rgba(0, 0, 0, 0)', |
|
|
paper_bgcolor= 'rgba(0, 0, 0, 0)', |
|
|
font=dict(color='black'), |
|
|
legend=dict( |
|
|
font=dict(color='black'), |
|
|
x=1.05, |
|
|
xanchor='left', |
|
|
y=1, |
|
|
yanchor='top' |
|
|
), |
|
|
|
|
|
width=800, height=500, |
|
|
template="plotly_white", |
|
|
title="Gradient Descent on the Selected Function", |
|
|
titlefont=dict(color='black'), |
|
|
margin=dict(l=50, r=50, t=50, b=50), |
|
|
) |
|
|
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
if st.session_state.current_step > 0: |
|
|
iteration_data = { |
|
|
"Iteration": list(range(st.session_state.current_step + 1)), |
|
|
"x Value": [f"{x_val:.5f}" for x_val in st.session_state.x_vals], |
|
|
"y Value": [f"{y_val:.5f}" for y_val in st.session_state.y_vals] |
|
|
} |
|
|
|
|
|
iteration_df = pd.DataFrame(iteration_data) |
|
|
|
|
|
st.markdown("<h3 style='color: black;'>Iteration Details</h3>", unsafe_allow_html=True) |
|
|
st.markdown( |
|
|
iteration_df.to_html(index=False, escape=False), |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
.dataframe { |
|
|
color: black; |
|
|
font-size: 14px; |
|
|
border-collapse: collapse; |
|
|
width: 100%; |
|
|
} |
|
|
.dataframe th, .dataframe td { |
|
|
padding: 8px; |
|
|
text-align: center; |
|
|
border: 1px solid black; |
|
|
} |
|
|
.dataframe th { |
|
|
background-color: #f2f2f2; |
|
|
border: 2px solid black; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.subheader("Current Status") |
|
|
st.sidebar.write(f"Iteration: {st.session_state.current_step}") |
|
|
st.sidebar.write(f"Current x: {st.session_state.x_vals[-1]:.5f}") |
|
|
st.sidebar.write(f"Current y: {st.session_state.y_vals[-1]:.5f}") |
|
|
|