Gowthamvemula commited on
Commit
1aeffd1
·
verified ·
1 Parent(s): 559203b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -96
app.py CHANGED
@@ -2,32 +2,51 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # App Configuration
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
-
8
- # Initialize session state variables
9
- if "x_current" not in st.session_state:
10
- st.session_state.x_current = 0.0
11
- if "iter_count" not in st.session_state:
12
- st.session_state.iter_count = 0
13
- if "history" not in st.session_state:
14
- st.session_state.history = []
15
- if "current_index" not in st.session_state:
16
- st.session_state.current_index = 0
17
- if "learning_rate" not in st.session_state:
18
- st.session_state.learning_rate = 0.1
19
-
20
- # Function evaluation
 
 
 
 
 
 
 
 
 
21
  def evaluate_function(expression, x_value):
22
- allowed_names = {"x": x_value, "np": np}
 
23
  return eval(expression, {"_builtins_": None}, allowed_names)
24
 
25
- # Compute derivative
26
  def compute_derivative(expression, x_value, h=1e-5):
 
27
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
28
 
29
- # Reset session state
 
 
 
 
 
 
 
30
  def reset_session_state():
 
31
  st.session_state.x_current = st.session_state.initial_point
32
  st.session_state.iter_count = 0
33
  st.session_state.history = [
@@ -35,79 +54,120 @@ def reset_session_state():
35
  ]
36
  st.session_state.current_index = 0
37
 
38
- # Tangent line calculation
39
- def calculate_tangent(expression, x_value, x_range):
40
- y_value = evaluate_function(expression, x_value)
41
- slope = compute_derivative(expression, x_value)
42
- return slope * (x_range - x_value) + y_value
43
-
44
- # App Title
45
- st.markdown("<h1 style='text-align: center;'>🌟 Gradient Descent Visualizer</h1>", unsafe_allow_html=True)
46
-
47
- # Input section
48
- st.markdown("### Input Parameters")
49
- col1, col2, col3, col4 = st.columns([2, 2, 2, 1])
50
 
51
- with col1:
52
- function_input = st.text_input("Function (`x**2`, `np.sin(x)`) :", "x**2 + x", key="math_function")
53
- with col2:
54
- st.session_state.initial_point = st.number_input("Initial x Value:", value=4.0, step=0.1, format="%.2f")
55
- with col3:
56
- st.session_state.learning_rate = st.number_input("Learning Rate:", value=0.1, step=0.01, format="%.2f")
57
- with col4:
58
- if st.button("Reset"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  reset_session_state()
60
 
61
- # Run Gradient Descent Step
62
- if st.button("Run Descent Step"):
63
- try:
64
- gradient = compute_derivative(function_input, st.session_state.x_current)
65
- st.session_state.x_current -= st.session_state.learning_rate * gradient
66
- st.session_state.iter_count += 1
67
- st.session_state.history.append(
68
- (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
69
- )
70
- st.session_state.current_index = st.session_state.iter_count
71
- except Exception as e:
72
- st.error(f"Error: {str(e)}")
73
-
74
- # Tabs for content
75
- tab1, tab2 = st.tabs(["📈 Graph", "ℹ️ Iteration Details"])
76
-
77
- # Tab 1: Visualization
78
- with tab1:
79
  st.markdown("### Gradient Descent Visualization")
80
- x_range = np.linspace(-10, 10, 500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  y_range = [evaluate_function(function_input, x) for x in x_range]
82
-
 
83
  fig = go.Figure()
84
  fig.add_trace(go.Scatter(
85
- x=x_range,
86
- y=y_range,
87
- mode="lines",
88
- name="Function",
89
- line=dict(color="orange")
90
  ))
91
-
92
- if st.session_state.history:
93
- x_current, y_current = st.session_state.history[st.session_state.current_index]
94
- fig.add_trace(go.Scatter(
95
- x=[x_current],
96
- y=[y_current],
97
- mode="markers",
98
- name="Current Point",
99
- marker=dict(size=10, color="red")
100
- ))
101
-
102
- tangent_y = calculate_tangent(function_input, x_current, x_range)
103
- fig.add_trace(go.Scatter(
104
- x=x_range,
105
- y=tangent_y,
106
- mode="lines",
107
- name="Tangent Line",
108
- line=dict(dash="dash", color="blue")
109
- ))
110
-
 
 
111
  fig.update_layout(
112
  title="Gradient Descent Progress",
113
  xaxis_title="x",
@@ -115,16 +175,5 @@ with tab1:
115
  template="plotly_white",
116
  height=600
117
  )
118
- st.plotly_chart(fig, use_container_width=True)
119
-
120
- # Tab 2: Iteration Details
121
- with tab2:
122
- st.markdown("### Iteration Details")
123
- if st.session_state.history:
124
- st.progress(st.session_state.current_index / max(1, st.session_state.iter_count))
125
- x_value, y_value = st.session_state.history[st.session_state.current_index]
126
- st.markdown(f"- **Iteration:** {st.session_state.current_index}")
127
- st.markdown(f"- **x Value:** `{x_value:.4f}`")
128
- st.markdown(f"- **f(x):** `{y_value:.4f}`")
129
- else:
130
- st.warning("No iteration data available. Please run a descent step first.")
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Title of the app
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
+ st.markdown("<h1 style='text-align: center;'> 🌟 Gradient Descent Visualizer</h1>", unsafe_allow_html=True)
8
+
9
+ # Custom CSS for background and button color
10
+ st.markdown("""
11
+ <style>
12
+ body {
13
+ background-color: black; /* Set background color to black */
14
+ color: white; /* Set text color to white for visibility */
15
+ }
16
+ .stButton>button {
17
+ background-color: #00FFFF; /* Light Cyan color */
18
+ color: black;
19
+ border-radius: 5px;
20
+ padding: 10px 20px;
21
+ font-size: 16px;
22
+ }
23
+ .stButton>button:hover {
24
+ background-color: #00CED1; /* Darker cyan on hover */
25
+ }
26
+ </style>
27
+ """, unsafe_allow_html=True)
28
+
29
+ # Safe function evaluation
30
  def evaluate_function(expression, x_value):
31
+ """Safely evaluates the mathematical function."""
32
+ allowed_names = {"x": x_value, "np": np} # Allow only x and numpy
33
  return eval(expression, {"_builtins_": None}, allowed_names)
34
 
35
+ # Compute derivative using finite difference
36
  def compute_derivative(expression, x_value, h=1e-5):
37
+ """Numerically calculates the derivative at a given point."""
38
  return (evaluate_function(expression, x_value + h) - evaluate_function(expression, x_value - h)) / (2 * h)
39
 
40
+ # Tangent line calculation
41
+ def calculate_tangent(expression, x_value, x_range):
42
+ """Generates the tangent line for a given point."""
43
+ y_value = evaluate_function(expression, x_value)
44
+ slope = compute_derivative(expression, x_value)
45
+ return slope * (x_range - x_value) + y_value
46
+
47
+ # Reset state
48
  def reset_session_state():
49
+ """Resets the session state for a fresh start."""
50
  st.session_state.x_current = st.session_state.initial_point
51
  st.session_state.iter_count = 0
52
  st.session_state.history = [
 
54
  ]
55
  st.session_state.current_index = 0
56
 
57
+ # Initialize session state variables
58
+ if "x_current" not in st.session_state:
59
+ st.session_state.x_current = 0.0 # Default starting point
60
+ if "iter_count" not in st.session_state:
61
+ st.session_state.iter_count = 0
62
+ if "history" not in st.session_state:
63
+ st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example
64
+ if "current_index" not in st.session_state:
65
+ st.session_state.current_index = 0
66
+ if "learning_rate" not in st.session_state:
67
+ st.session_state.learning_rate = 0.1
 
68
 
69
+ # Create two-column grid layout for the left side (more space for the right graph)
70
+ left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
71
+
72
+ # Left side content (Function Input and Gradient Descent Parameters)
73
+ with left_col:
74
+ st.markdown("### Input Your Function")
75
+ function_input = st.text_input(
76
+ "Enter Function:`Ex:'x**2`,`np.sin(x)`,",
77
+ "x**2 + x",
78
+ key="math_function",
79
+ on_change=reset_session_state
80
+ )
81
+ st.markdown("### Set Parameters")
82
+ initial_point = st.number_input(
83
+ "Initial Value of x",
84
+ value=4.0,
85
+ step=0.1,
86
+ format="%.2f",
87
+ key="initial_point",
88
+ on_change=reset_session_state
89
+ )
90
+ st.number_input(
91
+ "Learning Rate",
92
+ value=st.session_state.learning_rate,
93
+ step=0.01,
94
+ format="%.2f",
95
+ key="learning_rate"
96
+ ) # Updates session state directly without reset
97
+
98
+ st.markdown("### Controls")
99
+
100
+ if st.button("🔄 Run Descent Step", type="primary"):
101
+ try:
102
+ gradient = compute_derivative(function_input, st.session_state.x_current)
103
+ st.session_state.x_current -= st.session_state.learning_rate * gradient
104
+ st.session_state.iter_count += 1
105
+ st.session_state.history.append(
106
+ (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
107
+ )
108
+ st.session_state.current_index = st.session_state.iter_count
109
+ except Exception as e:
110
+ st.error(f"Error: {str(e)}")
111
+ if st.button("🔄 Reset"):
112
  reset_session_state()
113
 
114
+ # Right side content (Visualization and Iteration Details)
115
+ with right_col:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  st.markdown("### Gradient Descent Visualization")
117
+
118
+ # Display iteration details using buttons
119
+ col1, col2, col3 = st.columns(3)
120
+ with col1:
121
+ if st.button("⬅️ Previous Iteration") and st.session_state.current_index > 0:
122
+ st.session_state.current_index -= 1
123
+ with col2:
124
+ st.markdown(f"**Iteration:** {st.session_state.current_index}", unsafe_allow_html=True)
125
+ with col3:
126
+ if st.button("➡️ Next Iteration") and st.session_state.current_index < st.session_state.iter_count:
127
+ st.session_state.current_index += 1
128
+
129
+ try:
130
+ selected_x, selected_y = st.session_state.history[st.session_state.current_index]
131
+ st.markdown(f"x Value: `{selected_x:.4f}`")
132
+ st.markdown(f"f(x): `{selected_y:.4f}`")
133
+ except IndexError:
134
+ st.warning("No iteration data available. Please run a descent step first.")
135
+
136
+ # Prepare data for visualization
137
+ x_range = np.linspace(-10, 10, 500) # Define range for x
138
  y_range = [evaluate_function(function_input, x) for x in x_range]
139
+
140
+ # Plot function curve with orange color
141
  fig = go.Figure()
142
  fig.add_trace(go.Scatter(
143
+ x=x_range,
144
+ y=y_range,
145
+ mode='lines',
146
+ name='Function',
147
+ line=dict(color='orange') # Curve color set to orange
148
  ))
149
+
150
+ # Add current point
151
+ x_current, y_current = st.session_state.history[st.session_state.current_index]
152
+ fig.add_trace(go.Scatter(
153
+ x=[x_current],
154
+ y=[y_current],
155
+ mode='markers',
156
+ name='Current Point',
157
+ marker=dict(size=10, color='red')
158
+ ))
159
+
160
+ # Add tangent line
161
+ tangent_y = calculate_tangent(function_input, x_current, x_range)
162
+ fig.add_trace(go.Scatter(
163
+ x=x_range,
164
+ y=tangent_y,
165
+ mode='lines',
166
+ name='Tangent Line',
167
+ line=dict(dash='dash', color='blue') # Tangent line in blue
168
+ ))
169
+
170
+ # Layout adjustments
171
  fig.update_layout(
172
  title="Gradient Descent Progress",
173
  xaxis_title="x",
 
175
  template="plotly_white",
176
  height=600
177
  )
178
+
179
+ st.plotly_chart(fig, use_container_width=True)