Mpavan45 commited on
Commit
1580b42
ยท
verified ยท
1 Parent(s): f6819c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -48
app.py CHANGED
@@ -5,30 +5,23 @@ import plotly.graph_objects as go
5
  # Set up the page title and layout
6
  st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
  st.title("๐ŸŒŸ Gradient Descent Visualizer")
8
- st.markdown("## Visualizing Gradient Descent with Tangent Lines")
9
  st.markdown("---")
10
 
11
  # Safe function evaluation
12
  def safe_eval(func_str, x_val):
13
- """
14
- Safely evaluates the function at a given x value.
15
- Only allows numpy operations and 'x' as the variable.
16
- """
17
  allowed_names = {"x": x_val, "np": np}
18
  return eval(func_str, {"_builtins_": None}, allowed_names)
19
 
20
  # Derivative using finite difference method
21
  def derivative(func_str, x_val, h=1e-5):
22
- """
23
- Calculates the derivative of the function at a point x using numerical methods.
24
- """
25
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
26
 
27
  # Compute tangent line
28
  def tangent_line(func_str, x_val, x_range):
29
- """
30
- Computes the tangent line at a given x value over a specified x range.
31
- """
32
  y_val = safe_eval(func_str, x_val)
33
  slope = derivative(func_str, x_val)
34
  return slope * (x_range - x_val) + y_val
@@ -40,9 +33,9 @@ def reset_state():
40
  st.session_state.x_vals = [st.session_state.starting_point]
41
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
42
 
43
- # Sidebar for user input with customized background and font color
44
  st.sidebar.header("๐Ÿ”ง Function and Parameters")
45
- st.sidebar.markdown("<p style='color:#FF5733; font-size:16px;'>Enter a mathematical function for gradient descent:</p>", unsafe_allow_html=True)
46
 
47
  # Function input
48
  func_input = st.sidebar.text_input(
@@ -50,7 +43,7 @@ func_input = st.sidebar.text_input(
50
  )
51
 
52
  # Gradient Descent parameters
53
- st.sidebar.markdown("<p style='color:#FF5733; font-size:16px;'>Set the starting point and learning rate:</p>", unsafe_allow_html=True)
54
  starting_point = st.sidebar.number_input(
55
  "Starting Point", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state
56
  )
@@ -65,26 +58,28 @@ if "x" not in st.session_state:
65
  st.session_state.x_vals = [starting_point]
66
  st.session_state.y_vals = [safe_eval(func_input, starting_point)]
67
 
68
- # Perform one iteration when the button is pressed
69
- if st.sidebar.button("๐Ÿ”„ Run Descent Step"):
70
- try:
71
- grad = derivative(func_input, st.session_state.x)
72
- st.session_state.x -= learning_rate * grad
73
- st.session_state.iteration += 1
74
- st.session_state.x_vals.append(st.session_state.x)
75
- st.session_state.y_vals.append(safe_eval(func_input, st.session_state.x))
76
- except Exception as e:
77
- st.sidebar.error(f"Error: {str(e)}")
78
-
79
- # Reset button
80
- if st.sidebar.button("๐Ÿ”„ Reset"):
81
- reset_state()
82
-
83
- # Display gradient descent progress
84
- st.subheader("๐Ÿงฎ Gradient Descent Progress")
85
- st.write(f"*Iteration:* {st.session_state.iteration}")
86
- st.write(f"*Current x:* {st.session_state.x:.4f}")
87
- st.write(f"*Current f(x):* {st.session_state.y_vals[-1]:.4f}")
 
 
88
 
89
  # Plotting
90
  x_range = np.linspace(-2, 6, 400) # Zoomed-in x-axis range for better visualization
@@ -93,16 +88,16 @@ y_range = [safe_eval(func_input, x) for x in x_range]
93
  # Create the plot
94
  fig = go.Figure()
95
 
96
- # Plot the function with a cool gradient color
97
- fig.add_trace(go.Scatter(x=x_range, y=y_range, mode="lines", line=dict(color="mediumslateblue", width=3), name="Function"))
98
 
99
- # Plot gradient descent points with purple color
100
  fig.add_trace(go.Scatter(
101
- x=st.session_state.x_vals, y=st.session_state.y_vals,
102
- mode="markers", marker=dict(color="crimson", size=10, symbol="circle"), name="Gradient Descent Points"
103
  ))
104
 
105
- # Plot tangent line at current point with dotted line
106
  current_x = st.session_state.x
107
  current_y = safe_eval(func_input, current_x)
108
  slope = derivative(func_input, current_x)
@@ -112,19 +107,17 @@ tangent_y = tangent_line(func_input, current_x, tangent_x)
112
 
113
  fig.add_trace(go.Scatter(
114
  x=tangent_x, y=tangent_y, mode="lines",
115
- line=dict(color="gold", dash="dot", width=2), name="Tangent Line"
116
  ))
117
 
118
- # Customize the layout with a soft gradient background
119
  fig.update_layout(
120
  title="๐Ÿ“‰ Gradient Descent Visualization",
121
- xaxis=dict(title="x", range=[-2, 6], showgrid=False),
122
- yaxis=dict(title="f(x)", showgrid=False),
123
- template="plotly",
124
- plot_bgcolor="rgb(250, 250, 250)", # Light background for a soft look
125
- paper_bgcolor="rgb(250, 250, 250)", # Consistent soft background color
126
- font=dict(family="Arial, sans-serif", size=14, color="black"),
127
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
 
128
  )
129
 
130
  # Display the plot
 
5
  # Set up the page title and layout
6
  st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
  st.title("๐ŸŒŸ Gradient Descent Visualizer")
8
+ st.markdown("## Understand Gradient Descent with Visualizations")
9
  st.markdown("---")
10
 
11
  # Safe function evaluation
12
  def safe_eval(func_str, x_val):
13
+ """ Safely evaluates the function at a given x value. """
 
 
 
14
  allowed_names = {"x": x_val, "np": np}
15
  return eval(func_str, {"_builtins_": None}, allowed_names)
16
 
17
  # Derivative using finite difference method
18
  def derivative(func_str, x_val, h=1e-5):
19
+ """ Calculates the derivative of the function at a point x using numerical methods. """
 
 
20
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
21
 
22
  # Compute tangent line
23
  def tangent_line(func_str, x_val, x_range):
24
+ """ Computes the tangent line at a given x value over a specified x range. """
 
 
25
  y_val = safe_eval(func_str, x_val)
26
  slope = derivative(func_str, x_val)
27
  return slope * (x_range - x_val) + y_val
 
33
  st.session_state.x_vals = [st.session_state.starting_point]
34
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
35
 
36
+ # Sidebar for user input
37
  st.sidebar.header("๐Ÿ”ง Function and Parameters")
38
+ st.sidebar.markdown("Enter a mathematical function for gradient descent:")
39
 
40
  # Function input
41
  func_input = st.sidebar.text_input(
 
43
  )
44
 
45
  # Gradient Descent parameters
46
+ st.sidebar.markdown("Set the starting point and learning rate:")
47
  starting_point = st.sidebar.number_input(
48
  "Starting Point", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state
49
  )
 
58
  st.session_state.x_vals = [starting_point]
59
  st.session_state.y_vals = [safe_eval(func_input, starting_point)]
60
 
61
+ # Function to handle iteration movement
62
+ def update_iteration(step):
63
+ if st.session_state.iteration + step >= 0 and st.session_state.iteration + step < len(st.session_state.x_vals):
64
+ st.session_state.iteration += step
65
+ st.session_state.x = st.session_state.x_vals[st.session_state.iteration]
66
+ st.session_state.y_vals = [safe_eval(func_input, st.session_state.x)]
67
+
68
+ # Button handlers for iteration
69
+ col1, col2, col3 = st.columns([1, 3, 1])
70
+ with col1:
71
+ if st.button("โช Previous", key="prev", on_click=update_iteration, args=(-1,)):
72
+ pass
73
+
74
+ with col2:
75
+ if st.button("๐Ÿ”„ Current Iteration", key="current"):
76
+ st.write(f"Iteration: {st.session_state.iteration}")
77
+ st.write(f"Current x: {st.session_state.x:.4f}")
78
+ st.write(f"f(x): {st.session_state.y_vals[-1]:.4f}")
79
+
80
+ with col3:
81
+ if st.button("โฉ Next", key="next", on_click=update_iteration, args=(1,)):
82
+ pass
83
 
84
  # Plotting
85
  x_range = np.linspace(-2, 6, 400) # Zoomed-in x-axis range for better visualization
 
88
  # Create the plot
89
  fig = go.Figure()
90
 
91
+ # Plot the function
92
+ fig.add_trace(go.Scatter(x=x_range, y=y_range, mode="lines", line=dict(color="royalblue"), name="Function"))
93
 
94
+ # Plot gradient descent points
95
  fig.add_trace(go.Scatter(
96
+ x=st.session_state.x_vals[:st.session_state.iteration+1], y=st.session_state.y_vals[:st.session_state.iteration+1],
97
+ mode="markers", marker=dict(color="red", size=8), name="Gradient Descent Points"
98
  ))
99
 
100
+ # Plot tangent line at current point
101
  current_x = st.session_state.x
102
  current_y = safe_eval(func_input, current_x)
103
  slope = derivative(func_input, current_x)
 
107
 
108
  fig.add_trace(go.Scatter(
109
  x=tangent_x, y=tangent_y, mode="lines",
110
+ line=dict(color="orange", dash="dash"), name="Tangent Line"
111
  ))
112
 
113
+ # Customize the layout for clear visibility
114
  fig.update_layout(
115
  title="๐Ÿ“‰ Gradient Descent Visualization",
116
+ xaxis=dict(title="x", range=[-2, 6]), # Zoomed-in range for better visualization
117
+ yaxis=dict(title="f(x)"),
118
+ template="plotly_white", # Light background for better contrast
 
 
 
119
  legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
120
+ margin=dict(l=50, r=50, t=50, b=50) # Adjust margins for better padding
121
  )
122
 
123
  # Display the plot