Mpavan45 commited on
Commit
aa96b75
Β·
verified Β·
1 Parent(s): f8ec4d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -111
app.py CHANGED
@@ -2,124 +2,139 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Set up the page title and layout
6
- st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
8
- st.markdown("## Understand Gradient Descent with Visualizations")
9
- st.markdown("---")
10
 
11
  # Safe function evaluation
12
- def safe_eval(func_str, x_val):
13
- """ Safely evaluates the function at a given x value. """
14
- allowed_names = {"x": x_val, "np": np}
15
- return eval(func_str, {"_builtins_": None}, allowed_names)
16
-
17
- # Derivative using finite difference method
18
- def derivative(func_str, x_val, h=1e-5):
19
- """ Calculates the derivative of the function at a point x using numerical methods. """
20
- return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
21
-
22
- # Compute tangent line
23
- def tangent_line(func_str, x_val, x_range):
24
- """ Computes the tangent line at a given x value over a specified x range. """
25
- y_val = safe_eval(func_str, x_val)
26
- slope = derivative(func_str, x_val)
27
- return slope * (x_range - x_val) + y_val
28
-
29
- # Reset state on input changes
30
- def reset_state():
31
- st.session_state.x = st.session_state.starting_point
32
- st.session_state.iteration = 0
33
- st.session_state.x_vals = [st.session_state.starting_point]
34
- st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
35
-
36
- # Sidebar for user input
37
- st.sidebar.header("πŸ”§ Function and Parameters")
38
- st.sidebar.markdown("Enter a mathematical function for gradient descent:")
39
-
40
- # Function input
41
- func_input = st.sidebar.text_input(
42
- "Function of x (e.g., x*2 + x):", "x*2 + x", key="func_input", on_change=reset_state
43
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- # Gradient Descent parameters
46
- st.sidebar.markdown("Set the starting point and learning rate:")
47
- starting_point = st.sidebar.number_input(
48
- "Starting Point", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state
49
  )
50
- learning_rate = st.sidebar.number_input(
51
- "Learning Rate", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state
 
 
 
 
 
 
 
 
52
  )
53
 
54
- # Initialize session state variables
55
- if "x" not in st.session_state:
56
- st.session_state.x = starting_point
57
- st.session_state.iteration = 0
58
- st.session_state.x_vals = [starting_point]
59
- st.session_state.y_vals = [safe_eval(func_input, starting_point)]
60
-
61
- # Function to handle iteration movement
62
- def update_iteration(step):
63
- """Moves to next or previous iteration when the buttons are clicked."""
64
- if 0 <= st.session_state.iteration + step < len(st.session_state.x_vals):
65
- st.session_state.iteration += step
66
- st.session_state.x = st.session_state.x_vals[st.session_state.iteration]
67
- st.session_state.y_vals = [safe_eval(func_input, st.session_state.x)]
68
-
69
-
70
-
71
- # Buttons to run descent step and reset
72
- col1, col2 = st.columns([1, 1])
73
- with col1:
74
- if st.button("πŸ”„ Run Descent Step"):
75
- grad = derivative(func_input, st.session_state.x)
76
- st.session_state.x -= learning_rate * grad
77
- st.session_state.iteration += 1
78
- st.session_state.x_vals.append(st.session_state.x)
79
- st.session_state.y_vals.append(safe_eval(func_input, st.session_state.x))
80
-
81
- with col2:
82
- if st.button("πŸ”„ Reset", on_click=reset_state):
83
- pass
84
-
85
- # Plotting
86
- x_range = np.linspace(-2, 6, 400) # Zoomed-in x-axis range for better visualization
87
- y_range = [safe_eval(func_input, x) for x in x_range]
88
 
89
- # Create the plot
90
- fig = go.Figure()
91
-
92
- # Plot the function
93
- fig.add_trace(go.Scatter(x=x_range, y=y_range, mode="lines", line=dict(color="royalblue"), name="Function"))
94
-
95
- # Plot gradient descent points
96
- fig.add_trace(go.Scatter(
97
- x=st.session_state.x_vals[:st.session_state.iteration+1], y=st.session_state.y_vals[:st.session_state.iteration+1],
98
- mode="markers", marker=dict(color="red", size=8), name="Gradient Descent Points"
99
- ))
100
-
101
- # Plot tangent line at current point
102
- current_x = st.session_state.x
103
- current_y = safe_eval(func_input, current_x)
104
- slope = derivative(func_input, current_x)
105
-
106
- tangent_x = np.linspace(current_x - 1, current_x + 1, 100) # Smaller range for tangent line
107
- tangent_y = tangent_line(func_input, current_x, tangent_x)
108
-
109
- fig.add_trace(go.Scatter(
110
- x=tangent_x, y=tangent_y, mode="lines",
111
- line=dict(color="orange", dash="dash"), name="Tangent Line"
112
- ))
113
-
114
- # Customize the layout for clear visibility
115
- fig.update_layout(
116
- title="πŸ“‰ Gradient Descent Visualization",
117
- xaxis=dict(title="x", range=[-2, 6]), # Zoomed-in range for better visualization
118
- yaxis=dict(title="f(x)"),
119
- template="plotly_white", # Light background for better contrast
120
- legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
121
- margin=dict(l=50, r=50, t=50, b=50) # Adjust margins for better padding
122
  )
123
 
124
- # Display the plot
125
- st.plotly_chart(fig, use_container_width=True)
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Title of the app
6
+ st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
8
+ st.markdown("---") # Horizontal separator for cleaner layout
 
9
 
10
  # Safe function evaluation
11
+ def evaluate_function(expression, x_value):
12
+ """Safely evaluates the mathematical function."""
13
+ allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
14
+ return eval(expression, {"_builtins_": None}, allowed_names)
15
+
16
+ # Compute derivative using finite difference
17
+ def compute_derivative(expression, x_value, h=1e-5):
18
+ """Numerically calculates the derivative at a given point."""
19
+ return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
20
+
21
+ # Tangent line calculation
22
+ def calculate_tangent(expression, x_value, x_range):
23
+ """Generates the tangent line for a given point."""
24
+ y_value = evaluate_function(expression, x_value)
25
+ slope = compute_derivative(expression, x_value)
26
+ return slope * (x_range - x_value) + y_value
27
+
28
+ # Reset state
29
+ def reset_session_state():
30
+ st.session_state.x_current = st.session_state.initial_point
31
+ st.session_state.iter_count = 0
32
+ st.session_state.x_points = [st.session_state.initial_point]
33
+ st.session_state.y_points = [evaluate_function(st.session_state.math_function, st.session_state.initial_point)]
34
+
35
+ # Create two-column grid layout for the left side
36
+ left_col, right_col = st.columns([2, 1]) # 2 for left, 1 for right grid proportion
37
+
38
+ # Left side content (Function Input and Gradient Descent Parameters)
39
+ with left_col:
40
+ st.header("Input Your Function")
41
+ st.markdown("Define a mathematical function (e.g., `x**2`, `np.sin(x)`, `x**3 - 3*x + 2`):")
42
+ function_input = st.text_input("Enter Function:", "x**2 + x", key="math_function", on_change=reset_session_state)
43
+ st.markdown("---")
44
+
45
+ st.header("Set Parameters for Gradient Descent")
46
+ st.markdown("Configure the starting point and learning rate:")
47
+ initial_point = st.number_input(
48
+ "Initial Value of x", value=4.0, step=0.1, format="%.2f", key="initial_point", on_change=reset_session_state
49
+ )
50
+ learning_rate = st.number_input(
51
+ "Learning Rate", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_session_state
52
+ )
53
+ st.markdown("---")
54
+
55
+ # Right side content (Gradient Descent Updates and Progress)
56
+ with right_col:
57
+ st.header("Gradient Descent Updates")
58
+ if "x_current" not in st.session_state:
59
+ st.session_state.x_current = initial_point
60
+ st.session_state.iter_count = 0
61
+ st.session_state.x_points = [initial_point]
62
+ st.session_state.y_points = [evaluate_function(function_input, initial_point)]
63
+
64
+ if st.button("πŸ”„ Run Descent Step", type="primary"):
65
+ try:
66
+ gradient = compute_derivative(function_input, st.session_state.x_current)
67
+ st.session_state.x_current -= learning_rate * gradient
68
+ st.session_state.iter_count += 1
69
+ st.session_state.x_points.append(st.session_state.x_current)
70
+ st.session_state.y_points.append(evaluate_function(function_input, st.session_state.x_current))
71
+ except Exception as e:
72
+ st.error(f"Error: {str(e)}")
73
+
74
+ # Gradient Descent Progress Section with Different Style
75
+ st.subheader("Gradient Descent Progress")
76
+ st.markdown(f"**Iteration:** {st.session_state.iter_count}")
77
+ st.markdown(f"**Current x Value:** {st.session_state.x_current:.4f}")
78
+ st.markdown(f"**Current Function Value (f(x)):** {st.session_state.y_points[-1]:.4f}")
79
+ st.markdown("---")
80
+
81
+ # Styling the updates section
82
+ st.markdown(
83
+ f'<div style="background-color:#f0f0f0; padding: 10px; border-radius: 5px;">'
84
+ f'<strong>Iteration: </strong>{st.session_state.iter_count} <br>'
85
+ f'<strong>Current x Value: </strong>{st.session_state.x_current:.4f} <br>'
86
+ f'<strong>Current f(x): </strong>{st.session_state.y_points[-1]:.4f}</div>',
87
+ unsafe_allow_html=True
88
+ )
89
+
90
+ # Generate plot data
91
+ x_vals = np.linspace(-10, 10, 400)
92
+ y_vals = [evaluate_function(function_input, x) for x in x_vals]
93
+
94
+ # Create the plot
95
+ plot = go.Figure()
96
 
97
+ # Add function plot
98
+ plot.add_trace(
99
+ go.Scatter(x=x_vals, y=y_vals, mode="lines", line=dict(color="green", width=3), name="Function Curve")
 
100
  )
101
+
102
+ # Add gradient descent points
103
+ plot.add_trace(
104
+ go.Scatter(
105
+ x=st.session_state.x_points,
106
+ y=st.session_state.y_points,
107
+ mode="markers",
108
+ marker=dict(color="red", size=10, symbol="diamond"),
109
+ name="Descent Steps",
110
+ )
111
  )
112
 
113
+ # Add tangent line
114
+ current_x = st.session_state.x_current
115
+ current_y = evaluate_function(function_input, current_x)
116
+ slope = compute_derivative(function_input, current_x)
117
+ tangent_x = np.linspace(current_x - 2, current_x + 2, 100)
118
+ tangent_y = calculate_tangent(function_input, current_x, tangent_x)
119
+
120
+ plot.add_trace(
121
+ go.Scatter(
122
+ x=tangent_x,
123
+ y=tangent_y,
124
+ mode="lines",
125
+ line=dict(color="blue", width=2, dash="dash"),
126
+ name="Tangent Line",
127
+ )
128
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ # Update plot layout
131
+ plot.update_layout(
132
+ title="Interactive Gradient Descent with Tangent Visualization",
133
+ xaxis_title="x",
134
+ yaxis_title="f(x)",
135
+ template="plotly_dark",
136
+ legend=dict(bgcolor="rgba(255,255,255,0.5)", bordercolor="gray", borderwidth=1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  )
138
 
139
+ # Display the plot in the left side
140
+ st.plotly_chart(plot, use_container_width=True)