Gowthamvemula commited on
Commit
8bc07a1
·
verified ·
1 Parent(s): fd9d533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -114
app.py CHANGED
@@ -4,60 +4,24 @@ import plotly.graph_objects as go
4
 
5
  # Title and Header
6
  st.title("Interactive Gradient Descent Visualizer")
7
- st.markdown(
8
- "<h1 style='text-align: center; color: #00FA9A;'>✨ Gradient Descent Visualizer ✨</h1>",
9
- unsafe_allow_html=True,
10
- )
11
-
12
- # Custom CSS for Enhanced UI
13
- st.markdown(
14
- """
15
- <style>
16
- body {
17
- background: linear-gradient(to right, #141E30, #243B55);
18
- color: #E0FFFF;
19
- }
20
- .stButton>button {
21
- background: linear-gradient(to right, #00C6FF, #0072FF);
22
- color: white;
23
- border: none;
24
- border-radius: 10px;
25
- padding: 10px 15px;
26
- font-size: 16px;
27
- font-weight: bold;
28
- }
29
- .stButton>button:hover {
30
- background: linear-gradient(to right, #0072FF, #00C6FF);
31
- }
32
- .block-container {
33
- padding-top: 10px;
34
- }
35
- </style>
36
- """,
37
- unsafe_allow_html=True,
38
- )
39
 
40
  # Safe Function Evaluation
41
  def evaluate_function(expression, x_value):
42
- """Safely evaluates the mathematical function."""
43
- allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
44
  return eval(expression, {"__builtins__": None}, allowed_names)
45
 
46
  # Compute Derivative
47
  def compute_derivative(expression, x_value, h=1e-5):
48
- """Numerically calculates the derivative at a given point."""
49
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
50
 
51
  # Tangent Line Calculation
52
  def calculate_tangent(expression, x_value, x_range):
53
- """Generates the tangent line for a given point."""
54
  y_value = evaluate_function(expression, x_value)
55
  slope = compute_derivative(expression, x_value)
56
  return slope * (x_range - x_value) + y_value
57
 
58
  # Reset Session State
59
  def reset_session_state():
60
- """Resets the session state for a fresh start."""
61
  st.session_state.x_current = st.session_state.initial_point
62
  st.session_state.iter_count = 0
63
  st.session_state.history = [
@@ -77,38 +41,35 @@ if "current_index" not in st.session_state:
77
  if "learning_rate" not in st.session_state:
78
  st.session_state.learning_rate = 0.1
79
 
80
- # Layout: Left (Inputs) and Right (Visualization)
81
- left_col, right_col = st.columns(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Left Column: Inputs and Controls
84
- with left_col:
85
- st.markdown("### 🧮 Input Your Function")
86
- function_input = st.text_input(
87
- "Enter Function (e.g., x**2, np.sin(x))",
88
- "x**2 + x",
89
- key="math_function",
90
- on_change=reset_session_state
91
- )
92
-
93
- st.markdown("### ⚙️ Parameters")
94
- initial_point = st.number_input(
95
- "🔢 Initial Value of x",
96
- value=4.0,
97
- step=0.1,
98
- format="%.2f",
99
- key="initial_point",
100
- on_change=reset_session_state
101
- )
102
- st.number_input(
103
- "📏 Learning Rate",
104
- value=st.session_state.learning_rate,
105
- step=0.01,
106
- format="%.2f",
107
- key="learning_rate"
108
- )
109
-
110
- st.markdown("### 🎛️ Controls")
111
- if st.button("🚀 Run Descent Step"):
112
  try:
113
  gradient = compute_derivative(function_input, st.session_state.x_current)
114
  st.session_state.x_current -= st.session_state.learning_rate * gradient
@@ -119,52 +80,43 @@ with left_col:
119
  st.session_state.current_index = st.session_state.iter_count
120
  except Exception as e:
121
  st.error(f"Error: {str(e)}")
122
- if st.button("🔄 Reset"):
 
123
  reset_session_state()
 
 
 
 
 
 
124
 
125
- # Right Column: Visualization
126
- with right_col:
127
- st.markdown("### 📉 Gradient Descent Visualization")
128
- col1, col2, col3 = st.columns([1, 1, 1])
129
-
130
- with col1:
131
- if st.button("⏮️ Previous") and st.session_state.current_index > 0:
132
- st.session_state.current_index -= 1
133
- with col2:
134
- st.markdown(
135
- f"<p style='text-align: center;'>Iteration: <strong>{st.session_state.current_index}</strong></p>",
136
- unsafe_allow_html=True,
137
- )
138
- with col3:
139
- if st.button("⏭️ Next") and st.session_state.current_index < st.session_state.iter_count:
140
- st.session_state.current_index += 1
141
 
142
- # Visualization
143
- try:
144
- selected_x, selected_y = st.session_state.history[st.session_state.current_index]
145
- st.markdown(f"🧾 **x Value:** `{selected_x:.4f}`")
146
- st.markdown(f"📊 **f(x):** `{selected_y:.4f}`")
147
- except IndexError:
148
- st.warning("No iteration data available. Please run a descent step first.")
149
-
150
- x_range = np.linspace(-10, 10, 500)
151
- y_range = [evaluate_function(function_input, x) for x in x_range]
152
-
153
- fig = go.Figure()
154
- fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Function', line=dict(color='blue')))
155
-
156
- x_current, y_current = st.session_state.history[st.session_state.current_index]
157
- fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Point', marker=dict(size=12, color='red')))
158
-
159
- tangent_y = calculate_tangent(function_input, x_current, x_range)
160
- fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent Line', line=dict(dash='dash', color='yellow')))
161
-
162
- fig.update_layout(
163
- title="Gradient Descent Progress 🌟",
164
- xaxis_title="x",
165
- yaxis_title="f(x)",
166
- template="plotly_dark",
167
- height=500, # Adjusted height for better visibility
168
- width=700, # Adjusted width for better visibility
169
- )
170
- st.plotly_chart(fig, use_container_width=False)
 
4
 
5
  # Title and Header
6
  st.title("Interactive Gradient Descent Visualizer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Safe Function Evaluation
9
  def evaluate_function(expression, x_value):
10
+ allowed_names = {"x": x_value, "np": np}
 
11
  return eval(expression, {"__builtins__": None}, allowed_names)
12
 
13
  # Compute Derivative
14
  def compute_derivative(expression, x_value, h=1e-5):
 
15
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
16
 
17
  # Tangent Line Calculation
18
  def calculate_tangent(expression, x_value, x_range):
 
19
  y_value = evaluate_function(expression, x_value)
20
  slope = compute_derivative(expression, x_value)
21
  return slope * (x_range - x_value) + y_value
22
 
23
  # Reset Session State
24
  def reset_session_state():
 
25
  st.session_state.x_current = st.session_state.initial_point
26
  st.session_state.iter_count = 0
27
  st.session_state.history = [
 
41
  if "learning_rate" not in st.session_state:
42
  st.session_state.learning_rate = 0.1
43
 
44
+ # Layout: Inputs and Visualization
45
+ st.markdown("### Input Your Function")
46
+ function_input = st.text_input(
47
+ "Enter Function (e.g., x**2, np.sin(x))",
48
+ "x**2 + x",
49
+ key="math_function",
50
+ on_change=reset_session_state
51
+ )
52
+ st.markdown("### Parameters")
53
+ initial_point = st.number_input(
54
+ "Initial Value of x",
55
+ value=4.0,
56
+ step=0.1,
57
+ format="%.2f",
58
+ key="initial_point",
59
+ on_change=reset_session_state
60
+ )
61
+ st.number_input(
62
+ "Learning Rate",
63
+ value=st.session_state.learning_rate,
64
+ step=0.01,
65
+ format="%.2f",
66
+ key="learning_rate"
67
+ )
68
 
69
+ st.markdown("### Controls")
70
+ col1, col2, col3, col4 = st.columns(4)
71
+ with col1:
72
+ if st.button("Run Step 🚀"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  try:
74
  gradient = compute_derivative(function_input, st.session_state.x_current)
75
  st.session_state.x_current -= st.session_state.learning_rate * gradient
 
80
  st.session_state.current_index = st.session_state.iter_count
81
  except Exception as e:
82
  st.error(f"Error: {str(e)}")
83
+ with col2:
84
+ if st.button("Reset 🔄"):
85
  reset_session_state()
86
+ with col3:
87
+ if st.button("⏮️ Previous") and st.session_state.current_index > 0:
88
+ st.session_state.current_index -= 1
89
+ with col4:
90
+ if st.button("⏭️ Next") and st.session_state.current_index < st.session_state.iter_count:
91
+ st.session_state.current_index += 1
92
 
93
+ # Visualization
94
+ st.markdown("### Visualization")
95
+ try:
96
+ selected_x, selected_y = st.session_state.history[st.session_state.current_index]
97
+ st.markdown(f"**x Value:** `{selected_x:.4f}`, **f(x):** `{selected_y:.4f}`")
98
+ except IndexError:
99
+ st.warning("No iteration data available. Please run a descent step first.")
 
 
 
 
 
 
 
 
 
100
 
101
+ x_range = np.linspace(-10, 10, 500)
102
+ y_range = [evaluate_function(function_input, x) for x in x_range]
103
+
104
+ fig = go.Figure()
105
+ fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Function', line=dict(color='blue')))
106
+
107
+ x_current, y_current = st.session_state.history[st.session_state.current_index]
108
+ fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Point', marker=dict(size=8, color='red')))
109
+
110
+ tangent_y = calculate_tangent(function_input, x_current, x_range)
111
+ fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent Line', line=dict(dash='dash', color='yellow')))
112
+
113
+ fig.update_layout(
114
+ title="Gradient Descent Progress",
115
+ xaxis_title="x",
116
+ yaxis_title="f(x)",
117
+ template="plotly_dark",
118
+ height=400, # Reduced height
119
+ width=600, # Reduced width
120
+ margin=dict(l=20, r=20, t=40, b=20), # Compact margins
121
+ )
122
+ st.plotly_chart(fig, use_container_width=False)