trohith89 commited on
Commit
8e83ddf
·
verified ·
1 Parent(s): 2d0a6c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -280
app.py CHANGED
@@ -1,280 +0,0 @@
1
- import streamlit as st
2
- import numpy as np
3
- import plotly.graph_objects as go
4
-
5
- # Safe function evaluation
6
- def safe_eval(func_str, x_val):
7
- """ Safely evaluates the function at a given x value. """
8
- allowed_names = {"x": x_val, "np": np}
9
- try:
10
- return eval(func_str, {"__builtins__": None}, allowed_names)
11
- except Exception as e:
12
- raise ValueError(f"Error evaluating the function: {e}")
13
-
14
- # Function derivative using finite difference method
15
- def derivative(func_str, x_val, h=1e-5):
16
- """ Numerically compute the derivative of the function at x using finite differences. """
17
- return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
18
-
19
- # Tangent line equation
20
- def tangent_line(func_str, x_val, x_range):
21
- """ Compute the tangent line at a given x value. """
22
- y_val = safe_eval(func_str, x_val)
23
- slope = derivative(func_str, x_val)
24
- return slope * (x_range - x_val) + y_val
25
-
26
- # Callback to reset session state
27
- def reset_state():
28
- st.session_state.x = st.session_state.starting_point
29
- st.session_state.iteration = 0
30
- st.session_state.x_vals = [st.session_state.starting_point]
31
- st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
32
-
33
- # Initialize session state variables
34
- if "func_input" not in st.session_state:
35
- st.session_state.func_input = "x**2 + x"
36
- if "x" not in st.session_state:
37
- st.session_state.x = 4.0
38
- st.session_state.iteration = 0
39
- st.session_state.x_vals = [4.0]
40
- st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
-
42
- # Full-width layout
43
- st.set_page_config(layout="wide")
44
-
45
- # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
- st.markdown(
47
- """
48
- <style>
49
- * {
50
- font-family: Cambria, Arial, sans-serif !important;
51
- }
52
- h1, h2, h3, h4, h5 {
53
- text-align: center;
54
- margin-top: 0;
55
- }
56
- input, .stButton button, .stDownloadButton button {
57
- border: 2px solid #ea445a;
58
- border-radius: 5px;
59
- padding: 10px;
60
- }
61
- .stInfo, .stSuccess {
62
- border: 2px solid #ea445a;
63
- border-radius: 5px;
64
- padding: 10px;
65
- }
66
- .stButton {
67
- margin-top: 10px;
68
- }
69
- /* Reduced Padding at the top */
70
- .css-1d391kg {
71
- padding-top: 0.5rem;
72
- }
73
- /* Centering the legend in the plot */
74
- .stPlotlyChart {
75
- display: block;
76
- margin: 0 auto;
77
- }
78
- /* Adjusting for full width without scrolling */
79
- .css-1lcbvhc {
80
- padding-left: 0;
81
- padding-right: 0;
82
- }
83
- /* Custom borders for input fields */
84
- .stTextInput input, .stNumberInput input {
85
- border: 2px solid #001A6E;
86
- border-radius: 5px;
87
- padding: 10px;
88
- }
89
- /* Hoverable tooltip styling */
90
- .tooltip {
91
- position: relative;
92
- display: inline-block;
93
- }
94
- .tooltip .tooltiptext {
95
- visibility: hidden;
96
- opacity: 0;
97
- width: 300px;
98
- background-color: #001A6E;
99
- color: #fff;
100
- text-align: center;
101
- border-radius: 5px;
102
- padding: 5px;
103
- position: absolute;
104
- z-index: 1;
105
- bottom: 125%; /* Position the tooltip above */
106
- left: 50%;
107
- margin-left: -150px;
108
- transition: opacity 0.3s;
109
- }
110
- .tooltip:hover .tooltiptext {
111
- visibility: visible;
112
- opacity: 1;
113
- }
114
- </style>
115
- """,
116
- unsafe_allow_html=True,
117
- )
118
-
119
- # Page Layout
120
- st.title("🌟 Gradient Descent Interactive Tool 🌟")
121
-
122
- col1, col2 = st.columns([1, 2])
123
-
124
- # Left Section: User Input
125
- with col1:
126
- st.subheader("🔧 Define Your Function")
127
-
128
- # Tooltip with instructions when hovering over the function input
129
- st.markdown(
130
- """
131
- <div class="tooltip">
132
- <input type="text" readonly style="border: none; color: transparent; background-color: transparent;">
133
- <span class="tooltiptext">
134
- **How to input your function:**
135
- - Use `x` for the variable.
136
- - Example: For \(x^2\), input `x**2`.
137
- - For sine function, input `np.sin(x)`.
138
- - For logarithm, input `np.log(x)`.
139
- - For other mathematical operations, use `np` (e.g., `np.exp(x)`).
140
- </span>
141
- </div>
142
- """,
143
- unsafe_allow_html=True
144
- )
145
-
146
- # Use text input for the user to define a function
147
- func_input = st.text_input(
148
- "Enter a function of 'x':",
149
- st.session_state.func_input,
150
- key="func_input",
151
- on_change=reset_state
152
- )
153
-
154
- st.subheader("⚙️ Gradient Descent Parameters")
155
- starting_point = st.number_input(
156
- "Starting Point (X₀)",
157
- value=4.0,
158
- step=0.1,
159
- format="%.2f",
160
- key="starting_point",
161
- on_change=reset_state
162
- )
163
- learning_rate = st.number_input(
164
- "Learning Rate (ŋ)",
165
- value=0.25,
166
- step=0.01,
167
- format="%.2f",
168
- key="learning_rate",
169
- on_change=reset_state
170
- )
171
-
172
- col3, col4 = st.columns(2)
173
- with col3:
174
- if st.button("🔄 Set Up Function"):
175
- reset_state()
176
- with col4:
177
- if st.button("▶️ Next Iteration"):
178
- try:
179
- grad = derivative(st.session_state.func_input, st.session_state.x)
180
- st.session_state.x = st.session_state.x - learning_rate * grad
181
- st.session_state.iteration += 1
182
- st.session_state.x_vals.append(st.session_state.x)
183
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
184
- except Exception as e:
185
- st.error(f"⚠️ Error: {str(e)}")
186
-
187
- # Right Section: Visualization
188
- with col2:
189
- st.subheader("📊 Gradient Descent Visualization")
190
- try:
191
- # Plot the function and all current and previous gradient descent points
192
- x_plot = np.linspace(-10, 10, 400)
193
- y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
194
-
195
- fig = go.Figure()
196
-
197
- # Function curve
198
- fig.add_trace(go.Scatter(
199
- x=x_plot,
200
- y=y_plot,
201
- mode="lines+markers",
202
- line=dict(color="blue", width=2),
203
- marker=dict(size=4, color="blue", symbol="circle"),
204
- name="Function"
205
- ))
206
-
207
- # All gradient descent points (red points without coordinates)
208
- fig.add_trace(go.Scatter(
209
- x=st.session_state.x_vals,
210
- y=st.session_state.y_vals,
211
- mode="markers",
212
- marker=dict(color="red", size=10),
213
- name="Gradient Descent Points"
214
- ))
215
-
216
- # Tangent line at the current gradient descent point
217
- current_x = st.session_state.x
218
- tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range for tangent line
219
- tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
- fig.add_trace(go.Scatter(
221
- x=tangent_x,
222
- y=tangent_y,
223
- mode="lines",
224
- line=dict(color="orange", width=3),
225
- name="Tangent Line"
226
- ))
227
-
228
- # Dynamic zoom for better visibility
229
- fig.update_layout(
230
- xaxis=dict(
231
- title="x-axis",
232
- range=[-10, 10],
233
- showline=True,
234
- linecolor="white",
235
- tickcolor="white",
236
- tickfont=dict(color="white"),
237
- ticks="outside",
238
- ),
239
- yaxis=dict(
240
- title="y-axis",
241
- range=[min(y_plot) - 5, max(y_plot) + 5],
242
- showline=True,
243
- linecolor="white",
244
- tickcolor="white",
245
- tickfont=dict(color="white"),
246
- ticks="outside",
247
- ),
248
- plot_bgcolor="black",
249
- paper_bgcolor="black",
250
- title="",
251
- margin=dict(l=10, r=10, t=10, b=10),
252
- width=800,
253
- height=400,
254
- showlegend=True,
255
- legend=dict(
256
- x=1.1,
257
- y=0.5,
258
- xanchor="left",
259
- yanchor="middle",
260
- orientation="v",
261
- font=dict(size=12, color="white"),
262
- bgcolor="black",
263
- bordercolor="white",
264
- borderwidth=2,
265
- )
266
- )
267
-
268
- # Axis lines for quadrants
269
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
271
-
272
- st.plotly_chart(fig, use_container_width=True)
273
-
274
- except Exception as e:
275
- st.error(f"⚠️ Error in visualization: {str(e)}")
276
-
277
- # Iteration stats and download
278
- col5, col6 = st.columns(2)
279
- col5.info(f"🧑‍💻 Iteration: {st.session_state.iteration}")
280
- col6.success(f"✅ Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")