Gowthamvemula commited on
Commit
647d9b5
ยท
verified ยท
1 Parent(s): 0dc3960

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -99
app.py CHANGED
@@ -2,51 +2,51 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Title and App Configuration
6
- st.set_page_config(page_title="Gradient Descent Visualizer", layout="wide")
7
- st.markdown(
8
- """
 
 
9
  <style>
10
  body {
11
- background-color: #F9F9F9; /* Light background */
 
12
  }
13
  .stButton>button {
14
- background: linear-gradient(90deg, #00C6FF, #0072FF); /* Blue gradient */
15
- color: white;
16
- font-size: 16px;
17
- padding: 10px;
18
  border-radius: 5px;
 
 
19
  }
20
  .stButton>button:hover {
21
- background: linear-gradient(90deg, #0072FF, #00C6FF); /* Reversed gradient on hover */
22
- }
23
- .st-tabs [data-baseweb="tab"] {
24
- font-size: 18px;
25
- padding: 10px 20px;
26
  }
27
  </style>
28
- """,
29
- unsafe_allow_html=True
30
- )
31
- st.markdown("<h1 style='text-align: center;'>๐ŸŒŸ Gradient Descent Visualizer</h1>", unsafe_allow_html=True)
32
 
33
  # Safe function evaluation
34
  def evaluate_function(expression, x_value):
 
35
  allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
36
  return eval(expression, {"_builtins_": None}, allowed_names)
37
 
38
  # Compute derivative using finite difference
39
  def compute_derivative(expression, x_value, h=1e-5):
 
40
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
41
 
42
  # Tangent line calculation
43
  def calculate_tangent(expression, x_value, x_range):
 
44
  y_value = evaluate_function(expression, x_value)
45
  slope = compute_derivative(expression, x_value)
46
  return slope * (x_range - x_value) + y_value
47
 
48
- # Reset session state
49
  def reset_session_state():
 
50
  st.session_state.x_current = st.session_state.initial_point
51
  st.session_state.iter_count = 0
52
  st.session_state.history = [
@@ -56,30 +56,31 @@ def reset_session_state():
56
 
57
  # Initialize session state variables
58
  if "x_current" not in st.session_state:
59
- st.session_state.x_current = 0.0
60
  if "iter_count" not in st.session_state:
61
  st.session_state.iter_count = 0
62
  if "history" not in st.session_state:
63
- st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))]
64
  if "current_index" not in st.session_state:
65
  st.session_state.current_index = 0
66
  if "learning_rate" not in st.session_state:
67
  st.session_state.learning_rate = 0.1
68
 
69
- # Layout Configuration
70
- left_col, right_col = st.columns([1, 2])
71
 
72
- # Left Column: Inputs
73
  with left_col:
74
- st.markdown("### Input Function and Parameters")
75
  function_input = st.text_input(
76
- "Function (e.g., 'x**2 + x'):",
77
  "x**2 + x",
78
  key="math_function",
79
  on_change=reset_session_state
80
  )
81
- st.number_input(
82
- "Initial Value of x:",
 
83
  value=4.0,
84
  step=0.1,
85
  format="%.2f",
@@ -87,14 +88,16 @@ with left_col:
87
  on_change=reset_session_state
88
  )
89
  st.number_input(
90
- "Learning Rate:",
91
  value=st.session_state.learning_rate,
92
  step=0.01,
93
  format="%.2f",
94
  key="learning_rate"
95
- )
96
- st.markdown("### Controls")
97
- if st.button("๐Ÿ”„ Run Descent Step"):
 
 
98
  try:
99
  gradient = compute_derivative(function_input, st.session_state.x_current)
100
  st.session_state.x_current -= st.session_state.learning_rate * gradient
@@ -108,72 +111,69 @@ with left_col:
108
  if st.button("๐Ÿ”„ Reset"):
109
  reset_session_state()
110
 
111
- # Right Column: Visualization and Details
112
  with right_col:
113
- tabs = st.tabs(["๐Ÿ“ˆ Visualization", "๐Ÿ“‹ Iteration Details"])
114
 
115
- # Tab 1: Visualization
116
- with tabs[0]:
117
- st.markdown("### Iteration Details")
118
-
119
- # Check if there are enough iterations for a slider
120
- max_iter = len(st.session_state.history) - 1
121
-
122
- if max_iter > 0:
123
- # Slider for dynamic iteration selection
124
- selected_iteration = st.slider(
125
- "Select Iteration",
126
- min_value=0,
127
- max_value=max_iter,
128
- value=st.session_state.current_index,
129
- step=1
130
- )
131
- st.session_state.current_index = selected_iteration # Update the current index based on slider selection
132
-
133
- # Display selected iteration details dynamically
134
- x_current, y_current = st.session_state.history[selected_iteration]
135
- st.markdown(f"**Iteration:** {selected_iteration}")
136
- st.markdown(f"**x Value:** {x_current:.4f}")
137
- st.markdown(f"**f(x):** {y_current:.4f}")
138
-
139
- st.markdown("### Gradient Descent Visualization")
140
-
141
- # Prepare data for visualization
142
- x_range = np.linspace(-10, 10, 500)
143
- y_range = [evaluate_function(function_input, x) for x in x_range]
144
-
145
- # Plot function curve
146
- fig = go.Figure()
147
- fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Function', line=dict(color='orange')))
148
-
149
- # Add current point
150
- fig.add_trace(go.Scatter(
151
- x=[x_current],
152
- y=[y_current],
153
- mode='markers',
154
- name='Current Point',
155
- marker=dict(size=10, color='red')
156
- ))
157
-
158
- # Add tangent line
159
- tangent_y = calculate_tangent(function_input, x_current, x_range)
160
- fig.add_trace(go.Scatter(
161
- x=x_range,
162
- y=tangent_y,
163
- mode='lines',
164
- name='Tangent Line',
165
- line=dict(dash='dash', color='blue')
166
- ))
167
-
168
- # Update layout
169
- fig.update_layout(
170
- title=f"Gradient Descent Progress: Iteration {selected_iteration}",
171
- xaxis_title="x",
172
- yaxis_title="f(x)",
173
- height=500,
174
- template="plotly_white"
175
- )
176
-
177
- st.plotly_chart(fig, use_container_width=True)
178
- else:
179
- st.warning("Not enough iterations to display. Run more steps to visualize gradient descent.")
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Title of the app
6
+ st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
+ st.markdown("<h1 style='text-align: center;'> โœจ Gradient Descent Visualizer โœจ</h1>", unsafe_allow_html=True)
8
+
9
+ # Custom CSS for background and button color
10
+ st.markdown("""
11
  <style>
12
  body {
13
+ background-color: black; /* Set background color to black */
14
+ color: white; /* Set text color to white for visibility */
15
  }
16
  .stButton>button {
17
+ background-color: #00FFFF; /* Light Cyan color */
18
+ color: black;
 
 
19
  border-radius: 5px;
20
+ padding: 10px 20px;
21
+ font-size: 16px;
22
  }
23
  .stButton>button:hover {
24
+ background-color: #00CED1; /* Darker cyan on hover */
 
 
 
 
25
  }
26
  </style>
27
+ """, unsafe_allow_html=True)
 
 
 
28
 
29
  # Safe function evaluation
30
  def evaluate_function(expression, x_value):
31
+ """Safely evaluates the mathematical function."""
32
  allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
33
  return eval(expression, {"_builtins_": None}, allowed_names)
34
 
35
  # Compute derivative using finite difference
36
  def compute_derivative(expression, x_value, h=1e-5):
37
+ """Numerically calculates the derivative at a given point."""
38
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
39
 
40
  # Tangent line calculation
41
  def calculate_tangent(expression, x_value, x_range):
42
+ """Generates the tangent line for a given point."""
43
  y_value = evaluate_function(expression, x_value)
44
  slope = compute_derivative(expression, x_value)
45
  return slope * (x_range - x_value) + y_value
46
 
47
+ # Reset state
48
  def reset_session_state():
49
+ """Resets the session state for a fresh start."""
50
  st.session_state.x_current = st.session_state.initial_point
51
  st.session_state.iter_count = 0
52
  st.session_state.history = [
 
56
 
57
  # Initialize session state variables
58
  if "x_current" not in st.session_state:
59
+ st.session_state.x_current = 0.0 # Default starting point
60
  if "iter_count" not in st.session_state:
61
  st.session_state.iter_count = 0
62
  if "history" not in st.session_state:
63
+ st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example
64
  if "current_index" not in st.session_state:
65
  st.session_state.current_index = 0
66
  if "learning_rate" not in st.session_state:
67
  st.session_state.learning_rate = 0.1
68
 
69
+ # Create two-column grid layout for the left side (more space for the right graph)
70
+ left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
71
 
72
+ # Left side content (Function Input and Gradient Descent Parameters)
73
  with left_col:
74
+ st.markdown("### ๐Ÿงฎ Input Your Function")
75
  function_input = st.text_input(
76
+ "Enter Function: Example: `x**2`, `np.sin(x)`",
77
  "x**2 + x",
78
  key="math_function",
79
  on_change=reset_session_state
80
  )
81
+ st.markdown("### โš™๏ธ Set Parameters")
82
+ initial_point = st.number_input(
83
+ "๐Ÿ”ข Initial Value of x",
84
  value=4.0,
85
  step=0.1,
86
  format="%.2f",
 
88
  on_change=reset_session_state
89
  )
90
  st.number_input(
91
+ "๐Ÿ“ Learning Rate",
92
  value=st.session_state.learning_rate,
93
  step=0.01,
94
  format="%.2f",
95
  key="learning_rate"
96
+ ) # Updates session state directly without reset
97
+
98
+ st.markdown("### ๐ŸŽ›๏ธ Controls")
99
+
100
+ if st.button("๐Ÿš€ Run Descent Step", type="primary"):
101
  try:
102
  gradient = compute_derivative(function_input, st.session_state.x_current)
103
  st.session_state.x_current -= st.session_state.learning_rate * gradient
 
111
  if st.button("๐Ÿ”„ Reset"):
112
  reset_session_state()
113
 
114
+ # Right side content (Visualization and Iteration Details)
115
  with right_col:
116
+ st.markdown("### ๐Ÿ“‰ Gradient Descent Visualization")
117
 
118
+ # Display iteration details using buttons
119
+ col1, col2, col3 = st.columns(3)
120
+ with col1:
121
+ if st.button("โฌ…๏ธ Previous Iteration") and st.session_state.current_index > 0:
122
+ st.session_state.current_index -= 1
123
+ with col2:
124
+ st.markdown(f"**๐Ÿ”„ Iteration:** {st.session_state.current_index}", unsafe_allow_html=True)
125
+ with col3:
126
+ if st.button("โžก๏ธ Next Iteration") and st.session_state.current_index < st.session_state.iter_count:
127
+ st.session_state.current_index += 1
128
+
129
+ try:
130
+ selected_x, selected_y = st.session_state.history[st.session_state.current_index]
131
+ st.markdown(f"๐Ÿงพ **x Value:** `{selected_x:.4f}`")
132
+ st.markdown(f"๐Ÿ“Š **f(x):** `{selected_y:.4f}`")
133
+ except IndexError:
134
+ st.warning("No iteration data available. Please run a descent step first.")
135
+
136
+ # Prepare data for visualization
137
+ x_range = np.linspace(-10, 10, 500) # Define range for x
138
+ y_range = [evaluate_function(function_input, x) for x in x_range]
139
+
140
+ # Plot function curve with orange color
141
+ fig = go.Figure()
142
+ fig.add_trace(go.Scatter(
143
+ x=x_range,
144
+ y=y_range,
145
+ mode='lines',
146
+ name='Function',
147
+ line=dict(color='orange') # Curve color set to orange
148
+ ))
149
+
150
+ # Add current point
151
+ x_current, y_current = st.session_state.history[st.session_state.current_index]
152
+ fig.add_trace(go.Scatter(
153
+ x=[x_current],
154
+ y=[y_current],
155
+ mode='markers',
156
+ name='Current Point',
157
+ marker=dict(size=10, color='red')
158
+ ))
159
+
160
+ # Add tangent line
161
+ tangent_y = calculate_tangent(function_input, x_current, x_range)
162
+ fig.add_trace(go.Scatter(
163
+ x=x_range,
164
+ y=tangent_y,
165
+ mode='lines',
166
+ name='Tangent Line',
167
+ line=dict(dash='dash', color='blue') # Tangent line in blue
168
+ ))
169
+
170
+ # Layout adjustments
171
+ fig.update_layout(
172
+ title="Gradient Descent Progress ๐ŸŒŸ",
173
+ xaxis_title="x",
174
+ yaxis_title="f(x)",
175
+ template="plotly_white",
176
+ height=600
177
+ )
178
+
179
+ st.plotly_chart(fig, use_container_width=True)