Mpavan45 commited on
Commit
ce2a035
·
verified ·
1 Parent(s): 8b14f8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -122
app.py CHANGED
@@ -2,164 +2,141 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Page configuration
6
- st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
-
8
- # Add custom CSS styles
9
- st.markdown(
10
- """
11
- <style>
12
- body {
13
- background-color: #f0f8ff;
14
- font-family: Arial, sans-serif;
15
- }
16
- .stButton>button {
17
- font-size: 16px;
18
- padding: 10px 20px;
19
- border-radius: 8px;
20
- border: none;
21
- width: 100%;
22
- }
23
- .stButton .blue button {
24
- background-color: #1e90ff;
25
- color: white;
26
- }
27
- .stButton .green button {
28
- background-color: #4CAF50;
29
- color: white;
30
- }
31
- .stButton .orange button {
32
- background-color: #FF9800;
33
- color: white;
34
- }
35
- .stButton button:hover {
36
- opacity: 0.9;
37
- }
38
- </style>
39
- """,
40
- unsafe_allow_html=True,
41
- )
42
-
43
- # Helper functions
44
  def evaluate_function(expression, x_value):
45
- """Safely evaluates the function at a given x value."""
46
- allowed_names = {"x": x_value, "np": np}
47
  return eval(expression, {"_builtins_": None}, allowed_names)
48
 
 
49
  def compute_derivative(expression, x_value, h=1e-5):
50
- """Numerically computes the derivative using finite differences."""
51
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
52
 
 
53
  def calculate_tangent(expression, x_value, x_range):
54
- """Calculates the tangent line at the current x value."""
55
  y_value = evaluate_function(expression, x_value)
56
  slope = compute_derivative(expression, x_value)
57
  return slope * (x_range - x_value) + y_value
58
 
59
- def reset_session():
60
- """Resets the session state."""
61
- if "start_point" not in st.session_state:
62
- st.session_state.start_point = 4.0 # Default value if not initialized
63
-
64
- if "func_input" in st.session_state:
65
- st.session_state.x_current = st.session_state.start_point
66
- st.session_state.iter_count = 0
67
- st.session_state.iter_data = [{"Iteration": 0, "x": st.session_state.start_point, "f(x)": evaluate_function(st.session_state.func_input, st.session_state.start_point)}]
68
- else:
69
- st.session_state.x_current = st.session_state.start_point
70
- st.session_state.iter_count = 0
71
- st.session_state.iter_data = [{"Iteration": 0, "x": st.session_state.start_point, "f(x)": evaluate_function("x**2 + x", st.session_state.start_point)}]
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  # Initialize session state
74
  if "x_current" not in st.session_state:
75
- st.session_state.x_current = 4.0
76
  st.session_state.iter_count = 0
77
- st.session_state.iter_data = [{"Iteration": 0, "x": 4.0, "f(x)": evaluate_function("x**2 + x", 4.0)}]
78
- st.session_state.func_input = "x**2 + x" # Default function
79
 
80
- # Layout with two columns
81
- col1, col2 = st.columns([1, 2])
82
-
83
- # Left column for function input and settings
84
- with col1:
85
- st.subheader("📋 Select or Define a Function")
86
- predefined_func = st.radio(
87
- "Choose a predefined function or select 'Custom':",
88
- options=["x**2 + x", "np.sin(x)", "x**3 - 3*x + 2", "Custom"],
89
- )
90
- if predefined_func == "Custom":
91
- func_input = st.text_input("Enter your custom function:", value="x**2 + x", key="func_input")
92
- else:
93
- func_input = predefined_func
94
 
95
- st.session_state.func_input = func_input # Store the selected/custom function in session state
 
 
 
 
 
96
 
97
- start_point = st.number_input("Starting Point (x):", value=4.0, step=0.1, format="%.2f", key="start_point", on_change=reset_session)
98
- learning_rate = st.number_input("Learning Rate:", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_session)
99
 
100
- # Buttons for actions with styling (change key to make it unique)
101
- col1.button("🔄 Next Step", key="next_step", use_container_width=True, help="Perform one step of gradient descent", on_click=None)
102
- col1.button("🔁 Reset", key="reset", use_container_width=True, help="Reset the gradient descent", on_click=reset_session)
103
- iteration_data_button = col1.button("📊 Iteration Data", key="iteration_data_button", use_container_width=True)
 
 
 
104
 
105
- # Right column for visualization
106
  with col2:
107
- st.subheader("📈 Visualization")
108
  x_vals = np.linspace(-10, 10, 400)
109
- y_vals = [evaluate_function(st.session_state.func_input, x) for x in x_vals]
110
 
 
111
  plot = go.Figure()
112
 
113
- # Function curve
114
- plot.add_trace(go.Scatter(x=x_vals, y=y_vals, mode="lines", line=dict(color="green", width=3), name="Function Curve"))
 
 
115
 
116
- # Gradient Descent Points
117
- x_points = [entry["x"] for entry in st.session_state.iter_data]
118
- y_points = [entry["f(x)"] for entry in st.session_state.iter_data]
119
  plot.add_trace(
120
  go.Scatter(
121
- x=x_points, y=y_points, mode="markers", marker=dict(color="blue", size=10), name="Descent Points"
 
 
 
 
122
  )
123
  )
124
 
125
- # Tangent Line
126
- if st.session_state.iter_count > 0:
127
- current_x = st.session_state.x_current
128
- tangent_x = np.linspace(current_x - 2, current_x + 2, 100)
129
- tangent_y = calculate_tangent(st.session_state.func_input, current_x, tangent_x)
130
- plot.add_trace(
131
- go.Scatter(
132
- x=tangent_x, y=tangent_y, mode="lines", line=dict(color="orange", dash="dash"), name="Tangent Line"
133
- )
 
 
 
 
 
134
  )
 
135
 
 
136
  plot.update_layout(
137
- title="Gradient Descent Visualization",
138
  xaxis_title="x",
139
  yaxis_title="f(x)",
140
- template="plotly_white",
141
- legend=dict(bgcolor="rgba(255,255,255,0.8)", bordercolor="gray"),
142
  )
143
 
 
144
  st.plotly_chart(plot)
145
-
146
- # Perform gradient descent operation when button is clicked
147
- if st.button("🔄 Next Step", key="next_step", use_container_width=True):
148
- try:
149
- grad = compute_derivative(st.session_state.func_input, st.session_state.x_current)
150
- st.session_state.x_current = st.session_state.x_current - st.session_state.learning_rate * grad
151
- st.session_state.iter_count += 1
152
-
153
- # Add current iteration data
154
- st.session_state.iter_data.append({
155
- "Iteration": st.session_state.iter_count,
156
- "x": st.session_state.x_current,
157
- "f(x)": evaluate_function(st.session_state.func_input, st.session_state.x_current),
158
- })
159
- except Exception as e:
160
- st.error(f"Error: {e}")
161
-
162
- # Show iteration data in a new section when the button is clicked
163
- if iteration_data_button:
164
- st.subheader("📊 Iteration Data")
165
- st.table(st.session_state.iter_data)
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Title of the app
6
+ st.title("Interactive Gradient Descent Visualizer")
7
+ st.markdown("---") # Horizontal separator for cleaner layout
8
+
9
+ # Safe function evaluation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def evaluate_function(expression, x_value):
11
+ """Safely evaluates the mathematical function."""
12
+ allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
13
  return eval(expression, {"_builtins_": None}, allowed_names)
14
 
15
+ # Compute derivative using finite difference
16
  def compute_derivative(expression, x_value, h=1e-5):
17
+ """Numerically calculates the derivative at a given point."""
18
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
19
 
20
+ # Tangent line calculation
21
  def calculate_tangent(expression, x_value, x_range):
22
+ """Generates the tangent line for a given point."""
23
  y_value = evaluate_function(expression, x_value)
24
  slope = compute_derivative(expression, x_value)
25
  return slope * (x_range - x_value) + y_value
26
 
27
+ # Reset state
28
+ def reset_session_state():
29
+ st.session_state.x_current = st.session_state.initial_point
30
+ st.session_state.iter_count = 0
31
+ st.session_state.x_points = [st.session_state.initial_point]
32
+ st.session_state.y_points = [evaluate_function(st.session_state.math_function, st.session_state.initial_point)]
33
+
34
+ # Input section for function
35
+ st.header("Input Your Function")
36
+ st.markdown("Define a mathematical function (e.g., `x**2`, `np.sin(x)`, `x**3 - 3*x + 2`):")
37
+ function_input = st.text_input("Enter Function:", "x**2 + x", key="math_function", on_change=reset_session_state)
38
+ st.markdown("---")
39
+
40
+ # Gradient descent parameters
41
+ st.header("Set Parameters for Gradient Descent")
42
+ st.markdown("Configure the starting point and learning rate:")
43
+ col1, col2 = st.columns(2)
44
+ with col1:
45
+ initial_point = st.number_input(
46
+ "Initial Value of x", value=4.0, step=0.1, format="%.2f", key="initial_point", on_change=reset_session_state
47
+ )
48
+ with col2:
49
+ learning_rate = st.number_input(
50
+ "Learning Rate", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_session_state
51
+ )
52
+ st.markdown("---")
53
 
54
  # Initialize session state
55
  if "x_current" not in st.session_state:
56
+ st.session_state.x_current = initial_point
57
  st.session_state.iter_count = 0
58
+ st.session_state.x_points = [initial_point]
59
+ st.session_state.y_points = [evaluate_function(function_input, initial_point)]
60
 
61
+ # Gradient Descent Step
62
+ if st.button("Perform Gradient Descent Step", type="primary"):
63
+ try:
64
+ gradient = compute_derivative(function_input, st.session_state.x_current)
65
+ st.session_state.x_current -= learning_rate * gradient
66
+ st.session_state.iter_count += 1
67
+ st.session_state.x_points.append(st.session_state.x_current)
68
+ st.session_state.y_points.append(evaluate_function(function_input, st.session_state.x_current))
69
+ except Exception as e:
70
+ st.error(f"Error: {str(e)}")
 
 
 
 
71
 
72
+ # Display the progress
73
+ st.subheader("Gradient Descent Updates")
74
+ st.markdown(f"**Iteration:** {st.session_state.iter_count}")
75
+ st.markdown(f"**Current x Value:** {st.session_state.x_current:.4f}")
76
+ st.markdown(f"**Current Function Value (f(x)):** {st.session_state.y_points[-1]:.4f}")
77
+ st.markdown("---")
78
 
79
+ # Create two columns: left for inputs and right for visualization
80
+ col1, col2 = st.columns([1, 2]) # Adjust width of the columns as needed
81
 
82
+ # Left column with function inputs and progress
83
+ with col1:
84
+ st.header("Gradient Descent Progress")
85
+ st.markdown(f"**Iteration:** {st.session_state.iter_count}")
86
+ st.markdown(f"**Current x Value:** {st.session_state.x_current:.4f}")
87
+ st.markdown(f"**Current Function Value (f(x)):** {st.session_state.y_points[-1]:.4f}")
88
+ st.markdown("---")
89
 
90
+ # Right column with visualization
91
  with col2:
92
+ # Generate plot data
93
  x_vals = np.linspace(-10, 10, 400)
94
+ y_vals = [evaluate_function(function_input, x) for x in x_vals]
95
 
96
+ # Create the plot
97
  plot = go.Figure()
98
 
99
+ # Add function plot
100
+ plot.add_trace(
101
+ go.Scatter(x=x_vals, y=y_vals, mode="lines", line=dict(color="green", width=3), name="Function Curve")
102
+ )
103
 
104
+ # Add gradient descent points
 
 
105
  plot.add_trace(
106
  go.Scatter(
107
+ x=st.session_state.x_points,
108
+ y=st.session_state.y_points,
109
+ mode="markers",
110
+ marker=dict(color="red", size=10, symbol="diamond"),
111
+ name="Descent Steps",
112
  )
113
  )
114
 
115
+ # Add tangent line
116
+ current_x = st.session_state.x_current
117
+ current_y = evaluate_function(function_input, current_x)
118
+ slope = compute_derivative(function_input, current_x)
119
+ tangent_x = np.linspace(current_x - 2, current_x + 2, 100)
120
+ tangent_y = calculate_tangent(function_input, current_x, tangent_x)
121
+
122
+ plot.add_trace(
123
+ go.Scatter(
124
+ x=tangent_x,
125
+ y=tangent_y,
126
+ mode="lines",
127
+ line=dict(color="blue", width=2, dash="dash"),
128
+ name="Tangent Line",
129
  )
130
+ )
131
 
132
+ # Update plot layout
133
  plot.update_layout(
134
+ title="Interactive Gradient Descent with Tangent Visualization",
135
  xaxis_title="x",
136
  yaxis_title="f(x)",
137
+ template="plotly_dark",
138
+ legend=dict(bgcolor="rgba(255,255,255,0.5)", bordercolor="gray", borderwidth=1),
139
  )
140
 
141
+ # Display the plot
142
  st.plotly_chart(plot)