trohith89 commited on
Commit
d90f498
Β·
verified Β·
1 Parent(s): ecb0b93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -107
app.py CHANGED
@@ -4,7 +4,6 @@ import plotly.graph_objects as go
4
 
5
  # Safe function evaluation
6
  def safe_eval(func_str, x_val):
7
- """ Safely evaluates the function at a given x value. """
8
  allowed_names = {"x": x_val, "np": np}
9
  try:
10
  return eval(func_str, {"__builtins__": None}, allowed_names)
@@ -13,24 +12,22 @@ def safe_eval(func_str, x_val):
13
 
14
  # Function derivative using finite difference method
15
  def derivative(func_str, x_val, h=1e-5):
16
- """ Numerically compute the derivative of the function at x using finite differences. """
17
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
18
 
19
  # Tangent line equation
20
  def tangent_line(func_str, x_val, x_range):
21
- """ Compute the tangent line at a given x value. """
22
  y_val = safe_eval(func_str, x_val)
23
  slope = derivative(func_str, x_val)
24
  return slope * (x_range - x_val) + y_val
25
 
26
- # Callback to reset session state
27
  def reset_state():
28
  st.session_state.x = st.session_state.starting_point
29
  st.session_state.iteration = 0
30
  st.session_state.x_vals = [st.session_state.starting_point]
31
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
32
 
33
- # Initialize session state variables
34
  if "func_input" not in st.session_state:
35
  st.session_state.func_input = "x**2 + x"
36
  if "x" not in st.session_state:
@@ -39,134 +36,72 @@ if "x" not in st.session_state:
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
- # Full-width layout
43
  st.set_page_config(layout="wide")
44
 
45
- # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
  st.markdown(
47
  """
48
  <style>
49
  * {
50
  font-family: Cambria, Arial, sans-serif !important;
51
  }
52
- h1, h2, h3, h4, h5 {
53
- text-align: center;
54
- margin-top: 0;
55
- }
56
- input, .stButton button, .stDownloadButton button {
57
- border: 2px solid #ea445a;
58
- border-radius: 5px;
59
- padding: 10px;
60
- }
61
- .stInfo, .stSuccess {
62
- border: 2px solid #ea445a;
63
- border-radius: 5px;
64
- padding: 10px;
65
- }
66
- .stButton {
67
- margin-top: 10px;
68
- }
69
- /* Reduced Padding at the top */
70
- .css-1d391kg {
71
- padding-top: 0.5rem;
72
- }
73
- /* Centering the legend in the plot */
74
  .stPlotlyChart {
75
- display: block;
76
- margin: 0 auto;
77
- border: 5px solid #001A6E; /* Border color for the plot */
78
- border-radius: 10px; /* Rounded corners for the border */
79
  padding: 5px;
80
  }
81
- /* Adjusting for full width without scrolling */
82
- .css-1lcbvhc {
83
- padding-left: 0;
84
- padding-right: 0;
85
- }
86
- /* Custom borders for input fields */
87
- .stTextInput input, .stNumberInput input {
88
- border: 2px solid #001A6E;
89
- border-radius: 5px;
90
- padding: 10px;
91
- }
92
  </style>
93
  """,
94
  unsafe_allow_html=True,
95
  )
96
 
97
- # Page Layout
98
  st.title("🌟 Gradient Descent Interactive Tool 🌟")
99
 
100
  col1, col2 = st.columns([1, 2])
101
 
102
- # Left Section: User Input
103
  with col1:
104
  st.subheader("πŸ”§ Define Your Function")
105
-
106
- st.markdown(
107
- """
108
- <div class="tooltip">
109
- <label for="func_input">Enter a function of 'x':</label>
110
- <span class="tooltiptext">
111
- **How to input your function:**
112
- - x^n as x**n,
113
- - sin(x) as np.sin(x),
114
- - log(x) as np.log(x),
115
- - e^x or exp(x) as np.exp(x).
116
- </span>
117
- </div>
118
- """,
119
- unsafe_allow_html=True
120
- )
121
-
122
  func_input = st.text_input(
123
- "πŸ‘‡",
124
- key="func_input",
125
  on_change=reset_state
126
  )
127
-
128
- st.subheader("βš™οΈ Gradient Descent Parameters")
129
  starting_point = st.number_input(
130
- "Starting Point (Xβ‚€)",
131
- value=4.0,
132
- step=0.1,
133
- format="%.2f",
134
- key="starting_point",
135
  on_change=reset_state
136
  )
137
  learning_rate = st.number_input(
138
- "Learning Rate (Ε‹)",
139
- value=0.25,
140
- step=0.01,
141
- format="%.2f",
142
  key="learning_rate"
143
  )
144
-
145
- col3, col4 = st.columns(2)
146
- with col3:
147
- if st.button("πŸ”„ Set Up Function"):
148
- reset_state()
149
- with col4:
150
- if st.button("▢️ Next Iteration"):
151
- try:
152
- grad = derivative(st.session_state.func_input, st.session_state.x)
153
- st.session_state.x = st.session_state.x - learning_rate * grad
154
- st.session_state.iteration += 1
155
- st.session_state.x_vals.append(st.session_state.x)
156
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
157
- except Exception as e:
158
- st.error(f"⚠️ Error: {str(e)}")
159
-
160
- # Right Section: Visualization
161
  with col2:
162
- st.subheader("πŸ“Š Gradient Descent Visualization")
163
  try:
164
  x_plot = np.linspace(-10, 10, 400)
165
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
166
 
167
  fig = go.Figure()
168
 
169
- # Function curve
170
  fig.add_trace(go.Scatter(
171
  x=x_plot,
172
  y=y_plot,
@@ -184,7 +119,7 @@ with col2:
184
  name="Gradient Descent Points"
185
  ))
186
 
187
- # Tangent line at the current point
188
  current_x = st.session_state.x
189
  tangent_x = np.linspace(-10, 10, 200)
190
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
@@ -196,7 +131,7 @@ with col2:
196
  name="Tangent Line"
197
  ))
198
 
199
- # Update layout to include full quadrants and reposition the legend
200
  fig.update_layout(
201
  xaxis=dict(
202
  title="x-axis",
@@ -205,7 +140,6 @@ with col2:
205
  zerolinewidth=2,
206
  showgrid=True,
207
  gridcolor="lightgray",
208
- range=[-10, 10], # Adjust x-axis range to show all quadrants
209
  color="white"
210
  ),
211
  yaxis=dict(
@@ -215,34 +149,32 @@ with col2:
215
  zerolinewidth=2,
216
  showgrid=True,
217
  gridcolor="lightgray",
218
- range=[-100, 100], # Adjust y-axis range to show all quadrants
219
  color="white"
220
  ),
221
  plot_bgcolor="black",
222
  paper_bgcolor="black",
223
  font=dict(color="white"),
224
  legend=dict(
225
- x=0.5, # Center the legend horizontally
226
- y=1.15, # Position the legend above the plot
227
- xanchor="center", # Align legend horizontally by its center
228
- yanchor="bottom", # Align legend vertically by its bottom
229
  bgcolor="black",
230
  bordercolor="#001A6E",
231
  borderwidth=2
232
  ),
233
- margin=dict(l=10, r=10, t=40, b=10), # Adjust margins to accommodate legend at the top
234
  width=800,
235
  height=400,
236
  showlegend=True
237
  )
238
 
239
-
240
  st.plotly_chart(fig, use_container_width=True)
241
 
242
  except Exception as e:
243
  st.error(f"⚠️ Error in visualization: {str(e)}")
244
 
245
- col5, col6 = st.columns(2)
246
- col5.info(f"πŸ§‘β€πŸ’» Iteration {st.session_state.iteration}")
 
247
  col6.success(f"βœ… Current x: {st.session_state.x:.4f}")
248
  col7.warning(f"πŸ“ Current Point: ({st.session_state.x:.4f}, {st.session_state.y_vals[-1]:.4f})")
 
4
 
5
  # Safe function evaluation
6
  def safe_eval(func_str, x_val):
 
7
  allowed_names = {"x": x_val, "np": np}
8
  try:
9
  return eval(func_str, {"__builtins__": None}, allowed_names)
 
12
 
13
  # Function derivative using finite difference method
14
  def derivative(func_str, x_val, h=1e-5):
 
15
  return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
16
 
17
  # Tangent line equation
18
  def tangent_line(func_str, x_val, x_range):
 
19
  y_val = safe_eval(func_str, x_val)
20
  slope = derivative(func_str, x_val)
21
  return slope * (x_range - x_val) + y_val
22
 
23
+ # Reset session state
24
  def reset_state():
25
  st.session_state.x = st.session_state.starting_point
26
  st.session_state.iteration = 0
27
  st.session_state.x_vals = [st.session_state.starting_point]
28
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
29
 
30
+ # Initialize session state
31
  if "func_input" not in st.session_state:
32
  st.session_state.func_input = "x**2 + x"
33
  if "x" not in st.session_state:
 
36
  st.session_state.x_vals = [4.0]
37
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
38
 
 
39
  st.set_page_config(layout="wide")
40
 
41
+ # CSS for borders and font
42
  st.markdown(
43
  """
44
  <style>
45
  * {
46
  font-family: Cambria, Arial, sans-serif !important;
47
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  .stPlotlyChart {
49
+ border: 5px solid #001A6E; /* Plot border */
50
+ border-radius: 10px;
 
 
51
  padding: 5px;
52
  }
 
 
 
 
 
 
 
 
 
 
 
53
  </style>
54
  """,
55
  unsafe_allow_html=True,
56
  )
57
 
 
58
  st.title("🌟 Gradient Descent Interactive Tool 🌟")
59
 
60
  col1, col2 = st.columns([1, 2])
61
 
62
+ # Left Section
63
  with col1:
64
  st.subheader("πŸ”§ Define Your Function")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  func_input = st.text_input(
66
+ "Enter a function of x (e.g., x**2 + x):",
67
+ key="func_input",
68
  on_change=reset_state
69
  )
 
 
70
  starting_point = st.number_input(
71
+ "Starting Point (Xβ‚€):",
72
+ value=4.0,
73
+ step=0.1,
74
+ key="starting_point",
 
75
  on_change=reset_state
76
  )
77
  learning_rate = st.number_input(
78
+ "Learning Rate (Ε‹):",
79
+ value=0.25,
80
+ step=0.01,
 
81
  key="learning_rate"
82
  )
83
+ if st.button("Reset"):
84
+ reset_state()
85
+ if st.button("Next Iteration"):
86
+ try:
87
+ grad = derivative(st.session_state.func_input, st.session_state.x)
88
+ st.session_state.x = st.session_state.x - learning_rate * grad
89
+ st.session_state.iteration += 1
90
+ st.session_state.x_vals.append(st.session_state.x)
91
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
92
+ except Exception as e:
93
+ st.error(f"⚠️ Error: {str(e)}")
94
+
95
+ # Right Section - Visualization
 
 
 
 
96
  with col2:
97
+ st.subheader("πŸ“Š Visualization")
98
  try:
99
  x_plot = np.linspace(-10, 10, 400)
100
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
101
 
102
  fig = go.Figure()
103
 
104
+ # Function plot
105
  fig.add_trace(go.Scatter(
106
  x=x_plot,
107
  y=y_plot,
 
119
  name="Gradient Descent Points"
120
  ))
121
 
122
+ # Tangent line
123
  current_x = st.session_state.x
124
  tangent_x = np.linspace(-10, 10, 200)
125
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
 
131
  name="Tangent Line"
132
  ))
133
 
134
+ # Plot layout
135
  fig.update_layout(
136
  xaxis=dict(
137
  title="x-axis",
 
140
  zerolinewidth=2,
141
  showgrid=True,
142
  gridcolor="lightgray",
 
143
  color="white"
144
  ),
145
  yaxis=dict(
 
149
  zerolinewidth=2,
150
  showgrid=True,
151
  gridcolor="lightgray",
152
+ range=[0, max(y_plot) + 10], # Show non-negative y-axis only
153
  color="white"
154
  ),
155
  plot_bgcolor="black",
156
  paper_bgcolor="black",
157
  font=dict(color="white"),
158
  legend=dict(
159
+ x=0.6, # Legend slightly left for border visibility
160
+ y=1.0,
 
 
161
  bgcolor="black",
162
  bordercolor="#001A6E",
163
  borderwidth=2
164
  ),
165
+ margin=dict(l=10, r=80, t=10, b=10), # Expand right border
166
  width=800,
167
  height=400,
168
  showlegend=True
169
  )
170
 
 
171
  st.plotly_chart(fig, use_container_width=True)
172
 
173
  except Exception as e:
174
  st.error(f"⚠️ Error in visualization: {str(e)}")
175
 
176
+ # Display iteration and current point info
177
+ col5, col6, col7 = st.columns(3)
178
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
179
  col6.success(f"βœ… Current x: {st.session_state.x:.4f}")
180
  col7.warning(f"πŸ“ Current Point: ({st.session_state.x:.4f}, {st.session_state.y_vals[-1]:.4f})")