File size: 6,183 Bytes
742dc62
816fe9c
 
 
2a2c734
6a14f39
 
2a2c734
 
 
 
 
 
 
 
 
fa7e3f8
2a2c734
 
 
 
 
 
fa7e3f8
2a2c734
 
fa7e3f8
2a2c734
 
 
1aeffd1
26eec67
559203b
2a2c734
 
559203b
26eec67
559203b
 
 
26eec67
1aeffd1
 
 
 
 
26eec67
559203b
 
 
 
 
 
 
 
26eec67
1aeffd1
26eec67
1aeffd1
 
 
26eec67
1aeffd1
 
 
 
559203b
2a2c734
 
8bc07a1
2a2c734
 
fa7e3f8
2a2c734
fa7e3f8
2a2c734
 
 
 
fa7e3f8
2a2c734
fa7e3f8
2a2c734
 
 
 
 
 
 
fa7e3f8
2a2c734
 
 
 
 
 
fa7e3f8
2a2c734
 
fa7e3f8
2a2c734
 
 
 
 
 
 
 
 
 
 
fa7e3f8
2a2c734
8bc07a1
2a2c734
 
6a14f39
2a2c734
 
 
 
fa7e3f8
2a2c734
 
fa7e3f8
2a2c734
fa7e3f8
2a2c734
 
 
 
fa7e3f8
 
2a2c734
fa7e3f8
2a2c734
 
 
 
 
 
fa7e3f8
6952440
 
2a2c734
fa7e3f8
6952440
 
2a2c734
fa7e3f8
2a2c734
 
6a14f39
2a2c734
 
 
5284381
2a2c734
 
 
6952440
2a2c734
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import streamlit as st
import numpy as np
import plotly.graph_objects as go

# Page Configurations
st.set_page_config(page_title="Gradient Descent Visualization", layout="wide")
st.markdown("<h1 style='text-align: center; color: #FFD700;'>⚑ Gradient Descent Visualization ⚑</h1>", unsafe_allow_html=True)

# Custom CSS for Background and Buttons
st.markdown("""
    <style>
        body {
            background-color: #1E1E1E;  /* Dark grey background */
            color: white;  /* White text for contrast */
        }
        .stButton>button {
            background: linear-gradient(to right, #4CAF50, #2E7D32);  /* Green gradient */
            color: white;  /* White text */
            border-radius: 8px;
            padding: 8px 16px;
            font-size: 16px;
        }
        .stButton>button:hover {
            background: linear-gradient(to right, #2E7D32, #4CAF50);  /* Reverse gradient on hover */
        }
        .stMarkdown h3 {
            color: #03A9F4;  /* Blue color for section titles */
        }
    </style>
""", unsafe_allow_html=True)

# Safe Function Evaluation
def evaluate_function(expression, x_value):
    allowed_names = {"x": x_value, "np": np}  # Allow only x and numpy
    return eval(expression, {"_builtins_": None}, allowed_names)

# Compute Derivative
def compute_derivative(expression, x_value, h=1e-5):
    return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)

# Tangent Line Calculation
def calculate_tangent(expression, x_value, x_range):
    y_value = evaluate_function(expression, x_value)
    slope = compute_derivative(expression, x_value)
    return slope * (x_range - x_value) + y_value

# Reset Session State
def reset_session_state():
    st.session_state.x_current = st.session_state.initial_point
    st.session_state.iter_count = 0
    st.session_state.history = [
        (st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))
    ]
    st.session_state.current_index = 0

# Initialize Session State
if "x_current" not in st.session_state:
    st.session_state.x_current = 0.0
if "iter_count" not in st.session_state:
    st.session_state.iter_count = 0
if "history" not in st.session_state:
    st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))]
if "current_index" not in st.session_state:
    st.session_state.current_index = 0
if "learning_rate" not in st.session_state:
    st.session_state.learning_rate = 0.1

# Layout
left_col, right_col = st.columns([1, 2])

# Left Column: Inputs
with left_col:
    st.markdown("### πŸ“ Define Your Equation")
    function_input = st.text_input(
        "Input Equation (e.g., `x**2`, `np.sin(x)`):", 
        "x**2 + x", 
        key="math_function", 
        on_change=reset_session_state
    )
    st.markdown("### πŸ”§ Configure Settings")
    initial_point = st.number_input(
        "Starting Value of x:", 
        value=4.0, 
        step=0.1, 
        format="%.2f", 
        key="initial_point", 
        on_change=reset_session_state
    )
    st.number_input(
        "Step Size (Learning Rate):", 
        value=st.session_state.learning_rate, 
        step=0.01, 
        format="%.2f", 
        key="learning_rate"
    )
    
    st.markdown("### πŸ•ΉοΈ Actions")
    col1, col2 = st.columns(2)
    with col1:
        if st.button("πŸŒ€ Compute Next Step"):
            try:
                gradient = compute_derivative(function_input, st.session_state.x_current)
                st.session_state.x_current -= st.session_state.learning_rate * gradient
                st.session_state.iter_count += 1
                st.session_state.history.append(
                    (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
                )
                st.session_state.current_index = st.session_state.iter_count
            except Exception as e:
                st.error(f"Error: {str(e)}")
    with col2:
        if st.button("πŸ”ƒ Restart"):
            reset_session_state()

# Right Column: Visualization
with right_col:
    st.markdown("### πŸ“ˆ Gradient Descent Steps")
    
    # Navigation Buttons
    col1, col2, col3 = st.columns(3)
    with col1:
        if st.button("⬅️ Previous Step") and st.session_state.current_index > 0:
            st.session_state.current_index -= 1
    with col2:
        st.markdown(f"<p style='text-align: center;'>Step Count: <strong>{st.session_state.current_index}</strong></p>", unsafe_allow_html=True)
    with col3:
        if st.button("➑️ Next Step") and st.session_state.current_index < st.session_state.iter_count:
            st.session_state.current_index += 1
    
    try:
        selected_x, selected_y = st.session_state.history[st.session_state.current_index]
        st.markdown(f"πŸ“ **Current x:** `{selected_x:.4f}`")
        st.markdown(f"πŸ“ˆ **f(x) at Current Step:** `{selected_y:.4f}`")
    except IndexError:
        st.warning("No data to display. Perform a computation first.")
    
    # Visualization
    x_range = np.linspace(-10, 10, 500)
    y_range = [evaluate_function(function_input, x) for x in x_range]
    
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Equation', line=dict(color='#FFD700')))

    # Get current position from history
    x_current, y_current = st.session_state.history[st.session_state.current_index]
    fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Position', marker=dict(size=12, color='#FF4500')))
    
    # Calculate and plot the updated tangent line at the current position
    tangent_y = calculate_tangent(function_input, x_current, x_range)
    fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent', line=dict(dash='dash', color='#00FFFF')))
    
    fig.update_layout(
        title="Gradient Descent Progress",
        xaxis_title="x",
        yaxis_title="f(x)",
        template="plotly_dark",
        height=500,
        width=900,
        margin=dict(l=20, r=20, t=50, b=20),
    )
    
    st.plotly_chart(fig, use_container_width=True)