Spaces:
Sleeping
Sleeping
File size: 4,575 Bytes
e5aa730 182f1c9 e5aa730 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import streamlit as st
import torch
import numpy as np
import plotly.graph_objs as go
st.markdown("""
<style>
.stApp {
background-color: #f9f9f9;
font-family: 'Segoe UI', sans-serif;
}
h1 {
text-align: center;
color: #2C3E50;
font-size: 38px !important;
font-weight: bold;
margin-bottom: 20px;
}
.stTextInput > div > div > input {
border: 2px solid #3498DB;
border-radius: 8px;
padding: 8px;
}
div.stButton > button {
background-color: #3498DB;
color: white;
border-radius: 10px;
padding: 10px 24px;
font-size: 16px;
border: none;
transition: 0.3s;
}
div.stButton > button:hover {
background-color: #2980B9;
transform: scale(1.05);
}
.stAlert {
border-radius: 8px;
}
.block-container {
padding-top: 2rem;
padding-bottom: 2rem;
max-width: 1200px;
}
</style>
""", unsafe_allow_html=True)
st.title("Gradient Descent Visualizer")
func_input = st.text_input("Enter Function of x", "x**2")
start_point = float(st.text_input("Starting Point", "2"))
learning_rate = float(st.text_input("Learning Rate", "0.01"))
num_iterations = int(st.text_input("Number of Iterations", "10"))
def make_function(expr: str):
"""Dynamically create a function in torch"""
def func(x):
return eval(expr, {"x": x, "torch": torch})
return func
if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
try:
func = make_function(func_input)
st.session_state.func = func
st.session_state.points = [start_point]
st.session_state.step = 0
st.success("Function Set Up Successfully with PyTorch!")
except Exception as e:
st.error(f"Error setting up function: {e}")
def gradient_step(x_val, func, lr):
x = torch.tensor([x_val], dtype=torch.float32, requires_grad=True)
y = func(x)
y.backward()
grad = x.grad.item()
new_x = x_val - lr * grad
return new_x, grad
if 'func' in st.session_state:
if st.button("Next Iteration"):
try:
x_old = float(st.session_state.points[-1])
x_new, grad_val = gradient_step(x_old, st.session_state.func, learning_rate)
st.session_state.points.append(x_new)
st.session_state.step += 1
st.success(f"Iteration {st.session_state.step} Complete! (grad={grad_val:.6f})")
except Exception as e:
st.error(f"Error in iteration: {e}")
if st.button("Run Iterations"):
try:
for i in range(num_iterations):
x_old = float(st.session_state.points[-1])
x_new, grad_val = gradient_step(x_old, st.session_state.func, learning_rate)
st.session_state.points.append(x_new)
st.session_state.step += 1
st.success(f"Ran {st.session_state.step} Iterations in total")
except Exception as e:
st.error(f"Error in multiple iterations: {e}")
if 'func' in st.session_state and len(st.session_state.points) > 0:
try:
x_val = np.linspace(-10, 10, 500)
x_torch = torch.tensor(x_val, dtype=torch.float32)
y_val = st.session_state.func(x_torch).detach().numpy()
iter_points = np.array(st.session_state.points)
iter_torch = torch.tensor(iter_points, dtype=torch.float32)
iter_y = st.session_state.func(iter_torch).detach().numpy()
trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines",
name="Gradient Descent Path", marker=dict(color="red"))
trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text',
marker=dict(color='green', size=15),
text=[f"{iter_points[-1]:.6f}"], textposition="top center",
name="Current Point")
layout = go.Layout(
title=f"Iteration {st.session_state.step}",
xaxis=dict(title="x - axis"),
yaxis=dict(title="y - axis"),
width=1000,
height=600
)
fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
st.plotly_chart(fig, use_container_width=True)
st.success(f"Current Point = {iter_points[-1]}")
except Exception as e:
st.error(f"Plot error: {e}")
|