File size: 4,957 Bytes
06bfef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
import sympy as sp

st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide", page_icon="πŸ“‰")
st.markdown(
    """
    <style>
    .css-18e3th9 { padding-top: 1rem; padding-bottom: 1rem; }
    .stButton>button { border-radius: 8px; font-weight: bold; }
    .stNumberInput>div>div>input { border-radius: 8px; }
    .stTextInput>div>div>input { border-radius: 8px; }
    .st-sidebar { background-color: #f9f9f9; border-right: 2px solid #eee; }
    h1, h2, h3 { color: #4a4e69; font-family: 'Arial'; }
    </style>
    """,
    unsafe_allow_html=True,
)

def parse_function(user_input):
    x = sp.symbols('x')
    try:
        func_expr = sp.sympify(user_input)
        func = sp.lambdify(x, func_expr, 'numpy')
        gradient_expr = sp.diff(func_expr, x)
        gradient = sp.lambdify(x, gradient_expr, 'numpy')
        return func, gradient, str(gradient_expr)
    except Exception as e:
        st.error(f"Invalid function input: {e}")
        return None, None, None

def gradient_descent_step(x, learning_rate, func, gradient):
    grad = gradient(x)
    x_new = x - learning_rate * grad
    loss_new = func(x_new)
    return x_new, loss_new, grad

st.title("πŸ“‰ Interactive Gradient Descent Visualizer")
with st.sidebar:
    st.header("πŸ”§ Parameters")
    user_function = st.text_input("Enter a Function (e.g., 2*x**2 + 3*x)", value="x**2")
    start_point = st.number_input("Starting Point", value=5.0, step=0.00001, format="%.5f")
    learning_rate = st.number_input("Learning Rate", value=0.1, step=0.00001, format="%.5f")

    def setup_gradient_descent():
        st.session_state.iteration = 0
        st.session_state.final_minimum = False
        st.session_state.x_values = [start_point]
        st.session_state.loss_values = [func(start_point)]
        st.session_state.gradients = []

    st.button("πŸš€ Setup", on_click=setup_gradient_descent)

func, gradient, gradient_str = parse_function(user_function)

if func and gradient:
    if 'iteration' not in st.session_state:
        st.session_state.iteration = 0
        st.session_state.final_minimum = False
        st.session_state.x_values = [start_point]
        st.session_state.loss_values = [func(start_point)]
        st.session_state.gradients = []

    col1, col2 = st.columns([2, 1])

    with col1:
        st.subheader(f"πŸ” Gradient Descent Visualization for: $f(x) = {user_function}$")
        st.markdown(f"**Gradient (f'(x)):**  $f'(x) = {gradient_str}$")

        fig, ax = plt.subplots(figsize=(10, 6))
        x_range = np.linspace(-10, 10, 500)
        y_range = func(x_range)

        ax.plot(x_range, y_range, label="Function: f(x)", color='#0077b6', linewidth=2)

        for i in range(st.session_state.iteration + 1):
            x = st.session_state.x_values[i]
            y = func(x)
            ax.scatter(x, y, color="#d90429", s=60, zorder=5, label="Descent Point" if i == 0 else "")

        if st.session_state.iteration < len(st.session_state.x_values):
            x = st.session_state.x_values[st.session_state.iteration]
            y = func(x)
            grad = gradient(x)
            tangent = grad * (x_range - x) + y
            ax.plot(x_range, tangent, color="#ffb703", linestyle="--", linewidth=1.5, alpha=0.8, label="Tangent Line")

        ax.set_xlim([-10, 10])
        ax.set_ylim([min(func(x_range)) - 1, max(func(x_range)) + 1])
        ax.set_xlabel("x", fontsize=12, labelpad=10)
        ax.set_ylabel("f(x)", fontsize=12, labelpad=10)
        ax.set_title("Gradient Descent with Tangent Lines", fontsize=14, fontweight="bold", color="#4a4e69")
        ax.legend(loc="upper right", frameon=True, fontsize=10)
        ax.grid(alpha=0.3)
        st.pyplot(fig)

    with col2:
        st.subheader("πŸ“Š Progress")
        for i, (x, loss) in enumerate(zip(st.session_state.x_values, st.session_state.loss_values)):
            st.write(f"Iteration {i}: x = {x:.4f}, f(x) = {loss:.4f}")

        current_gradient = gradient(st.session_state.x_values[-1])
        if abs(current_gradient) < 1e-6 and not st.session_state.final_minimum:
            st.success(f"🎯 Final Minimum Reached: $x = {st.session_state.x_values[-1]:.4f}, f(x) = {st.session_state.loss_values[-1]:.4f}$")
            st.session_state.final_minimum = True

    def next_iteration():
        if not st.session_state.final_minimum:
            new_x, new_loss, grad = gradient_descent_step(st.session_state.x_values[-1], learning_rate, func, gradient)
            st.session_state.x_values.append(new_x)
            st.session_state.loss_values.append(new_loss)
            st.session_state.gradients.append(grad)
            st.session_state.iteration += 1

    st.sidebar.button("⏭ Next Iteration", on_click=next_iteration)

    if st.session_state.final_minimum:
        st.info("The gradient is close to zero. Further iterations may not improve the result.")