File size: 4,400 Bytes
f3f4d32
 
 
 
6d52f3e
d5172eb
 
 
 
 
 
 
 
 
 
 
bba7bc6
d5172eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bba7bc6
 
 
b91e62f
bba7bc6
d5172eb
 
 
bba7bc6
63cea96
bba7bc6
6883bcd
bba7bc6
 
 
6883bcd
63cea96
bba7bc6
f3f4d32
 
 
 
 
 
 
1d0d871
 
8b8fb42
 
bba7bc6
 
f3f4d32
bba7bc6
f3f4d32
1d0d871
bba7bc6
f3f4d32
 
1d0d871
f3f4d32
 
b91e62f
 
 
f3f4d32
bba7bc6
6d52f3e
6883bcd
 
 
 
 
 
 
 
 
 
 
 
 
 
1d0d871
6d52f3e
f9eefe7
1d0d871
bba7bc6
6d52f3e
1d0d871
6d52f3e
f3f4d32
6883bcd
 
f3f4d32
 
 
 
bba7bc6
 
 
 
4bfefe0
 
 
bba7bc6
 
f3f4d32
9592fc9
bba7bc6
9592fc9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import numpy as np 
import sympy as sp
import plotly.graph_objs as go

st.markdown("""
    <style>
    .stApp {
        background-color: #f9f9f9;
        font-family: 'Segoe UI', sans-serif;
    }
    h1 {
        text-align: center;
        color: #2C3E50;
        font-size: 38px !important;
        font-weight: bold;
        margin-bottom: 20px;
    }
    .stTextInput > div > div > input {
        border: 2px solid #3498DB;
        border-radius: 8px;
        padding: 8px;
    }
    div.stButton > button {
        background-color: #3498DB;
        color: white;
        border-radius: 10px;
        padding: 10px 24px;
        font-size: 16px;
        border: none;
        transition: 0.3s;
    }
    div.stButton > button:hover {
        background-color: #2980B9;
        transform: scale(1.05);
    }
    .stAlert {
        border-radius: 8px;
    }
    .block-container {
        padding-top: 2rem;
        padding-bottom: 2rem;
        max-width: 1200px;
    }
    </style>
""", unsafe_allow_html=True)

st.title("Gradient Descent Visualizer")

x = sp.Symbol("x")

func_input = st.text_input("Enter Function", "x^2")
start_point = float(st.text_input("Starting Point", "2"))
learning_rate = float(st.text_input("Learning Rate", "0.01"))
num_iterations = int(st.text_input("Number of Iterations", "10"))

if st.button("Set Up") or 'func' not in st.session_state or 'points' not in st.session_state:
    try:
        expr = func_input.replace("^", "**")
        expr_final = sp.sympify(expr)
        func = sp.lambdify(x, expr_final, "numpy")
        grad = sp.diff(expr_final, x)
        gradient_func = sp.lambdify(x, grad, "numpy")
        
        st.session_state.func = func
        st.session_state.gradient_func = gradient_func
        st.session_state.points = [start_point]
        st.session_state.step = 0
        st.success("Function and Gradient Set Up Successfully!")
    
    except Exception as e:
        st.error(f"Error setting up function: {e}")

if 'func' in st.session_state and 'gradient_func' in st.session_state:
    if st.button("Next Iteration"):
        try:
            x_old = float(st.session_state.points[-1])
            grad_val = st.session_state.gradient_func(x_old)
            x_new = x_old - learning_rate * grad_val

            st.session_state.points.append(x_new)
            st.session_state.step += 1
            st.success(f"Iteration {st.session_state.step} Complete!")
        except Exception as e:
            st.error(f"Error in iteration: {e}")

    if st.button("Run Iterations"):
        try:
            for i in range(num_iterations):
                x_old = float(st.session_state.points[-1])
                grad_val = st.session_state.gradient_func(x_old)
                x_new = x_old - learning_rate * grad_val

                st.session_state.points.append(x_new)
                st.session_state.step += 1

            st.success(f"Ran {st.session_state.step} Iterations in total")
        except Exception as e:
            st.error(f"Error in multiple iterations: {e}")

if 'func' in st.session_state and len(st.session_state.points) > 0:
    try:
        x_val = np.linspace(-10, 10, 500)
        y_val = st.session_state.func(x_val)
        
        iter_points = np.array(st.session_state.points)
        iter_y = st.session_state.func(iter_points)

        trace1 = go.Scatter(x=x_val, y=y_val, mode="lines", name="Function", line=dict(color="blue"))
        trace2 = go.Scatter(x=iter_points, y=iter_y, mode="markers+lines", 
                            name="Gradient Descent Path", marker=dict(color="red"))
        trace3 = go.Scatter(x=[iter_points[-1]], y=[iter_y[-1]], mode='markers+text', 
                            marker=dict(color='green', size=15), 
                            text=[f"{iter_points[-1]:.6f}"], textposition="top center", 
                            name="Current Point")
        
        layout = go.Layout(
            title=f"Iteration {st.session_state.step}",
            xaxis=dict(title="x - axis"),
            yaxis=dict(title="y - axis"),
            width=1000,
            height=600
        )
        
        fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
        st.plotly_chart(fig, use_container_width=True)
        st.success(f"Current Point = {iter_points[-1]}")
    except Exception as e:
        st.error(f"Plot error: {e}")