trohith89 commited on
Commit
afba365
Β·
verified Β·
1 Parent(s): 703a15c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -108
app.py CHANGED
@@ -4,6 +4,7 @@ import plotly.graph_objects as go
4
 
5
  # Safe function evaluation
6
  def safe_eval(func_str, x_val):
 
7
  allowed_names = {"x": x_val, "np": np}
8
  try:
9
  return eval(func_str, {"__builtins__": None}, allowed_names)
@@ -12,22 +13,24 @@ def safe_eval(func_str, x_val):
12
 
13
  # Function derivative using finite difference method
14
  def derivative(func_str, x_val, h=1e-5):
 
15
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
16
 
17
  # Tangent line equation
18
  def tangent_line(func_str, x_val, x_range):
 
19
  y_val = safe_eval(func_str, x_val)
20
  slope = derivative(func_str, x_val)
21
  return slope * (x_range - x_val) + y_val
22
 
23
- # Reset session state
24
  def reset_state():
25
  st.session_state.x = st.session_state.starting_point
26
  st.session_state.iteration = 0
27
  st.session_state.x_vals = [st.session_state.starting_point]
28
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
29
 
30
- # Initialize session state
31
  if "func_input" not in st.session_state:
32
  st.session_state.func_input = "x**2 + x"
33
  if "x" not in st.session_state:
@@ -36,117 +39,172 @@ if "x" not in st.session_state:
36
  st.session_state.x_vals = [4.0]
37
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
38
 
 
39
  st.set_page_config(layout="wide")
40
 
41
- # CSS for borders and font
42
  st.markdown(
43
  """
44
  <style>
45
  * {
46
  font-family: Cambria, Arial, sans-serif !important;
47
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  .stPlotlyChart {
49
- border: 5px solid #001A6E; /* Plot border */
50
- border-radius: 10px;
51
- padding: 5px;
52
  }
53
- .info-box {
54
- display: inline-block;
55
- background-color:#FFB6C1;
56
- border: 2px solid black;
57
- border-radius: 8px;
58
- padding: 10px 15px;
59
- margin: 5px 10px;
60
- text-align: center;
61
- font-size: 18px;
62
- color: black;
63
  }
64
- .info-container {
65
- display: flex;
66
- justify-content: flex-end;
67
- margin-top: 20px;
 
68
  }
69
- .center-title {
70
- text-align: center;
 
 
 
71
  }
72
- .button-style {
73
- background-color: black;
74
- color: #FF1493;
75
- border: 2px solid #FF1493;
76
- border-radius: 8px;
77
- padding: 10px 20px;
78
- font-size: 18px;
 
 
 
 
 
 
 
 
79
  }
80
- .button-style:active, .button-style:hover {
81
- background-color: black;
82
- color: #FF1493;
83
  }
84
  </style>
85
  """,
86
  unsafe_allow_html=True,
87
  )
88
 
89
- st.markdown("""<h1 class="center-title">🌟 Gradient Descent Interactive Tool 🌟</h1>""", unsafe_allow_html=True)
 
90
 
91
  col1, col2 = st.columns([1, 2])
92
 
93
- # Left Section
94
  with col1:
95
  st.subheader("πŸ”§ Define Your Function")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  func_input = st.text_input(
97
- "Enter a function of x (e.g., x**2 + x):",
98
- key="func_input",
99
  on_change=reset_state
100
  )
 
 
101
  starting_point = st.number_input(
102
- "Starting Point (Xβ‚€):",
103
- value=4.0,
104
- step=0.1,
105
- key="starting_point",
 
106
  on_change=reset_state
107
  )
108
  learning_rate = st.number_input(
109
- "Learning Rate (Ε‹):",
110
- value=0.25,
111
- step=0.01,
112
- key="learning_rate"
 
 
113
  )
114
- if st.button("πŸ› οΈ Setup", key="setup_button", help="Click to reset the function and starting point",
115
- use_container_width=True,
116
- on_click=reset_state,
117
- args=(),
118
- kwargs={}):
119
- pass
120
- if st.button("πŸ”„ Next Iteration", key="next_iteration_button", help="Click to perform the next gradient descent iteration",
121
- use_container_width=True):
122
- try:
123
- grad = derivative(st.session_state.func_input, st.session_state.x)
124
- st.session_state.x = st.session_state.x - learning_rate * grad
125
- st.session_state.iteration += 1
126
- st.session_state.x_vals.append(st.session_state.x)
127
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
128
- except Exception as e:
129
- st.error(f"⚠️ Error: {str(e)}")
130
-
131
- # Right Section - Visualization
132
  with col2:
133
- st.subheader("πŸ“Š Visualization")
134
  try:
 
135
  x_plot = np.linspace(-10, 10, 400)
136
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
137
 
138
  fig = go.Figure()
139
 
140
- # Function plot
141
  fig.add_trace(go.Scatter(
142
- x=x_plot,
143
- y=y_plot,
144
- mode="lines",
145
- line=dict(color="blue", width=2),
 
146
  name="Function"
147
  ))
148
 
149
- # Gradient descent points
150
  fig.add_trace(go.Scatter(
151
  x=st.session_state.x_vals,
152
  y=st.session_state.y_vals,
@@ -155,9 +213,9 @@ with col2:
155
  name="Gradient Descent Points"
156
  ))
157
 
158
- # Tangent line
159
  current_x = st.session_state.x
160
- tangent_x = np.linspace(-10, 10, 200)
161
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
162
  fig.add_trace(go.Scatter(
163
  x=tangent_x,
@@ -167,58 +225,56 @@ with col2:
167
  name="Tangent Line"
168
  ))
169
 
170
- # Plot layout
171
  fig.update_layout(
172
  xaxis=dict(
173
  title="x-axis",
174
- zeroline=True,
175
- zerolinecolor="white",
176
- zerolinewidth=2,
177
- showgrid=True,
178
- gridcolor="lightgray",
179
- color="white"
180
  ),
181
  yaxis=dict(
182
  title="y-axis",
183
- zeroline=True,
184
- zerolinecolor="white",
185
- zerolinewidth=2,
186
- showgrid=True,
187
- gridcolor="lightgray",
188
- range=[-120, 120], # Adjusted y-axis range
189
- color="white"
190
  ),
191
  plot_bgcolor="black",
192
  paper_bgcolor="black",
193
- font=dict(color="white"),
194
- legend=dict(
195
- x=0.6,
196
- y=1.0,
197
- bgcolor="black",
198
- bordercolor="#001A6E",
199
- borderwidth=2
200
- ),
201
- margin=dict(l=10, r=100, t=10, b=10), # Increase right margin here
202
  width=800,
203
  height=400,
204
- showlegend=True
 
 
 
 
 
 
 
 
 
 
 
205
  )
206
 
 
 
 
 
207
  st.plotly_chart(fig, use_container_width=True)
208
 
209
  except Exception as e:
210
  st.error(f"⚠️ Error in visualization: {str(e)}")
211
 
212
- st.markdown(
213
- f"""
214
- <div class="info-container">
215
- <div class="info-box">
216
- πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}
217
- </div>
218
- <div class="info-box">
219
- πŸ“ Current Point: ({st.session_state.x:.4f}, {st.session_state.y_vals[-1]:.4f})
220
- </div>
221
- </div>
222
- """,
223
- unsafe_allow_html=True,
224
- )
 
4
 
5
  # Safe function evaluation
6
  def safe_eval(func_str, x_val):
7
+ """ Safely evaluates the function at a given x value. """
8
  allowed_names = {"x": x_val, "np": np}
9
  try:
10
  return eval(func_str, {"__builtins__": None}, allowed_names)
 
13
 
14
  # Function derivative using finite difference method
15
  def derivative(func_str, x_val, h=1e-5):
16
+ """ Numerically compute the derivative of the function at x using finite differences. """
17
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
18
 
19
  # Tangent line equation
20
  def tangent_line(func_str, x_val, x_range):
21
+ """ Compute the tangent line at a given x value. """
22
  y_val = safe_eval(func_str, x_val)
23
  slope = derivative(func_str, x_val)
24
  return slope * (x_range - x_val) + y_val
25
 
26
+ # Callback to reset session state
27
  def reset_state():
28
  st.session_state.x = st.session_state.starting_point
29
  st.session_state.iteration = 0
30
  st.session_state.x_vals = [st.session_state.starting_point]
31
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
32
 
33
+ # Initialize session state variables
34
  if "func_input" not in st.session_state:
35
  st.session_state.func_input = "x**2 + x"
36
  if "x" not in st.session_state:
 
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
+ # Full-width layout
43
  st.set_page_config(layout="wide")
44
 
45
+ # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
  st.markdown(
47
  """
48
  <style>
49
  * {
50
  font-family: Cambria, Arial, sans-serif !important;
51
  }
52
+ h1, h2, h3, h4, h5 {
53
+ text-align: center;
54
+ margin-top: 0;
55
+ }
56
+ input, .stButton button, .stDownloadButton button {
57
+ border: 2px solid #ea445a;
58
+ border-radius: 5px;
59
+ padding: 10px;
60
+ }
61
+ .stInfo, .stSuccess {
62
+ border: 2px solid #ea445a;
63
+ border-radius: 5px;
64
+ padding: 10px;
65
+ }
66
+ .stButton {
67
+ margin-top: 10px;
68
+ }
69
+ /* Reduced Padding at the top */
70
+ .css-1d391kg {
71
+ padding-top: 0.5rem;
72
+ }
73
+ /* Centering the legend in the plot */
74
  .stPlotlyChart {
75
+ display: block;
76
+ margin: 0 auto;
 
77
  }
78
+ /* Adjusting for full width without scrolling */
79
+ .css-1lcbvhc {
80
+ padding-left: 0;
81
+ padding-right: 0;
 
 
 
 
 
 
82
  }
83
+ /* Custom borders for input fields */
84
+ .stTextInput input, .stNumberInput input {
85
+ border: 2px solid #001A6E;
86
+ border-radius: 5px;
87
+ padding: 10px;
88
  }
89
+ /* Tooltip styling */
90
+ .tooltip {
91
+ position: relative;
92
+ display: inline-block;
93
+ cursor: pointer;
94
  }
95
+ .tooltip .tooltiptext {
96
+ visibility: hidden;
97
+ opacity: 0;
98
+ width: 300px;
99
+ background-color: #001A6E;
100
+ color: #fff;
101
+ text-align: center;
102
+ border-radius: 5px;
103
+ padding: 5px;
104
+ position: absolute;
105
+ z-index: 1;
106
+ bottom: 125%; /* Position the tooltip above */
107
+ left: 50%;
108
+ margin-left: -150px;
109
+ transition: opacity 0.3s;
110
  }
111
+ .tooltip:hover .tooltiptext {
112
+ visibility: visible;
113
+ opacity: 1;
114
  }
115
  </style>
116
  """,
117
  unsafe_allow_html=True,
118
  )
119
 
120
+ # Page Layout
121
+ st.title("🌟 Gradient Descent Interactive Tool 🌟")
122
 
123
  col1, col2 = st.columns([1, 2])
124
 
125
+ # Left Section: User Input
126
  with col1:
127
  st.subheader("πŸ”§ Define Your Function")
128
+
129
+ # Tooltip with instructions when hovering over the function input label
130
+ st.markdown(
131
+ """
132
+ <div class="tooltip">
133
+ <label for="func_input">Enter a function of 'x':</label>
134
+ <span class="tooltiptext">
135
+ **How to input your function:**
136
+ - Please give the inputs as mentioned below
137
+ - x^n as x**n,
138
+ - sin(x) as np.sin(x)
139
+ - log(x) as np.log(x),
140
+ - e^x or exp(x) as np.exp(x).
141
+ </span>
142
+ </div>
143
+ """,
144
+ unsafe_allow_html=True
145
+ )
146
+
147
+ # Use text input for the user to define a function, but no value argument
148
  func_input = st.text_input(
149
+ "πŸ‘‡",
150
+ key="func_input",
151
  on_change=reset_state
152
  )
153
+
154
+ st.subheader("βš™οΈ Gradient Descent Parameters")
155
  starting_point = st.number_input(
156
+ "Starting Point (Xβ‚€)",
157
+ value=4.0,
158
+ step=0.1,
159
+ format="%.2f",
160
+ key="starting_point",
161
  on_change=reset_state
162
  )
163
  learning_rate = st.number_input(
164
+ "Learning Rate (Ε‹)",
165
+ value=0.25,
166
+ step=0.01,
167
+ format="%.2f",
168
+ key="learning_rate",
169
+ on_change=reset_state
170
  )
171
+
172
+ col3, col4 = st.columns(2)
173
+ with col3:
174
+ if st.button("πŸ”„ Set Up Function"):
175
+ reset_state()
176
+ with col4:
177
+ if st.button("▢️ Next Iteration"):
178
+ try:
179
+ grad = derivative(st.session_state.func_input, st.session_state.x)
180
+ st.session_state.x = st.session_state.x - learning_rate * grad
181
+ st.session_state.iteration += 1
182
+ st.session_state.x_vals.append(st.session_state.x)
183
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
184
+ except Exception as e:
185
+ st.error(f"⚠️ Error: {str(e)}")
186
+
187
+ # Right Section: Visualization
 
188
  with col2:
189
+ st.subheader("πŸ“Š Gradient Descent Visualization")
190
  try:
191
+ # Plot the function and all current and previous gradient descent points
192
  x_plot = np.linspace(-10, 10, 400)
193
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
194
 
195
  fig = go.Figure()
196
 
197
+ # Function curve
198
  fig.add_trace(go.Scatter(
199
+ x=x_plot,
200
+ y=y_plot,
201
+ mode="lines+markers",
202
+ line=dict(color="blue", width=2),
203
+ marker=dict(size=4, color="blue", symbol="circle"),
204
  name="Function"
205
  ))
206
 
207
+ # All gradient descent points (red points without coordinates)
208
  fig.add_trace(go.Scatter(
209
  x=st.session_state.x_vals,
210
  y=st.session_state.y_vals,
 
213
  name="Gradient Descent Points"
214
  ))
215
 
216
+ # Tangent line at the current gradient descent point
217
  current_x = st.session_state.x
218
+ tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot
219
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
  fig.add_trace(go.Scatter(
221
  x=tangent_x,
 
225
  name="Tangent Line"
226
  ))
227
 
228
+ # Dynamic zoom for better visibility
229
  fig.update_layout(
230
  xaxis=dict(
231
  title="x-axis",
232
+ range=[-10, 10],
233
+ showline=True,
234
+ linecolor="white",
235
+ tickcolor="white",
236
+ tickfont=dict(color="white"),
237
+ ticks="outside",
238
  ),
239
  yaxis=dict(
240
  title="y-axis",
241
+ range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000
242
+ showline=True,
243
+ linecolor="white",
244
+ tickcolor="white",
245
+ tickfont=dict(color="white"),
246
+ ticks="outside",
 
247
  ),
248
  plot_bgcolor="black",
249
  paper_bgcolor="black",
250
+ title="",
251
+ margin=dict(l=10, r=10, t=10, b=10),
 
 
 
 
 
 
 
252
  width=800,
253
  height=400,
254
+ showlegend=True,
255
+ legend=dict(
256
+ x=1.1,
257
+ y=0.5,
258
+ xanchor="left",
259
+ yanchor="middle",
260
+ orientation="v",
261
+ font=dict(size=12, color="white"),
262
+ bgcolor="black",
263
+ bordercolor="white",
264
+ borderwidth=2,
265
+ )
266
  )
267
 
268
+ # Axis lines for quadrants
269
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
271
+
272
  st.plotly_chart(fig, use_container_width=True)
273
 
274
  except Exception as e:
275
  st.error(f"⚠️ Error in visualization: {str(e)}")
276
 
277
+ # Iteration stats and download
278
+ col5, col6 = st.columns(2)
279
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
280
+ col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")