trohith89 commited on
Commit
cfecba1
Β·
verified Β·
1 Parent(s): 4c2d271

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -224
app.py CHANGED
@@ -39,245 +39,242 @@ if "x" not in st.session_state:
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
- # Check for query parameters to navigate from home.py
43
- if st.experimental_get_query_params().get("page") == ["gradient-descent"]:
44
- # Full-width layout
45
- st.set_page_config(layout="wide")
46
 
47
- # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
48
- st.markdown(
49
- """
50
- <style>
51
- * {
52
- font-family: Cambria, Arial, sans-serif !important;
53
- }
54
- h1, h2, h3, h4, h5 {
55
- text-align: center;
56
- margin-top: 0;
57
- }
58
- input, .stButton button, .stDownloadButton button {
59
- border: 2px solid #ea445a;
60
- border-radius: 5px;
61
- padding: 10px;
62
- }
63
- .stInfo, .stSuccess {
64
- border: 2px solid #ea445a;
65
- border-radius: 5px;
66
- padding: 10px;
67
- }
68
- .stButton {
69
- margin-top: 10px;
70
- }
71
- /* Reduced Padding at the top */
72
- .css-1d391kg {
73
- padding-top: 0.5rem;
74
- }
75
- /* Centering the legend in the plot */
76
- .stPlotlyChart {
77
- display: block;
78
- margin: 0 auto;
79
- }
80
- /* Adjusting for full width without scrolling */
81
- .css-1lcbvhc {
82
- padding-left: 0;
83
- padding-right: 0;
84
- }
85
- /* Custom borders for input fields */
86
- .stTextInput input, .stNumberInput input {
87
- border: 2px solid #001A6E;
88
- border-radius: 5px;
89
- padding: 10px;
90
- }
91
- /* Tooltip styling */
92
- .tooltip {
93
- position: relative;
94
- display: inline-block;
95
- cursor: pointer;
96
- }
97
- .tooltip .tooltiptext {
98
- visibility: hidden;
99
- opacity: 0;
100
- width: 300px;
101
- background-color: #001A6E;
102
- color: #fff;
103
- text-align: center;
104
- border-radius: 5px;
105
- padding: 5px;
106
- position: absolute;
107
- z-index: 1;
108
- bottom: 125%; /* Position the tooltip above */
109
- left: 50%;
110
- margin-left: -150px;
111
- transition: opacity 0.3s;
112
- }
113
- .tooltip:hover .tooltiptext {
114
- visibility: visible;
115
- opacity: 1;
116
- }
117
- </style>
118
- """,
119
- unsafe_allow_html=True,
120
- )
121
-
122
- # Page Layout
123
- st.title("🌟 Gradient Descent Interactive Tool 🌟")
124
 
125
- col1, col2 = st.columns([1, 2])
 
126
 
127
- # Left Section: User Input
128
- with col1:
129
- st.subheader("πŸ”§ Define Your Function")
130
 
131
- # Tooltip with instructions when hovering over the function input label
132
- st.markdown(
133
- """
134
- <div class="tooltip">
135
- <label for="func_input">Enter a function of 'x':</label>
136
- <span class="tooltiptext">
137
- **How to input your function:**
138
- - Please give the inputs as mentioned below
139
- - x^n as x**n,
140
- - sin(x) as np.sin(x)
141
- - log(x) as np.log(x),
142
- - e^x or exp(x) as np.exp(x).
143
- </span>
144
- </div>
145
- """,
146
- unsafe_allow_html=True
147
- )
148
 
149
- # Use text input for the user to define a function
150
- func_input = st.text_input(
151
- "πŸ‘‡",
152
- key="func_input",
153
- on_change=reset_state
154
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- st.subheader("βš™οΈ Gradient Descent Parameters")
157
- starting_point = st.number_input(
158
- "Starting Point (Xβ‚€)",
159
- value=4.0,
160
- step=0.1,
161
- format="%.2f",
162
- key="starting_point",
163
- on_change=reset_state
164
- )
165
- learning_rate = st.number_input(
166
- "Learning Rate (Ε‹)",
167
- value=0.25,
168
- step=0.01,
169
- format="%.2f",
170
- key="learning_rate",
171
- on_change=reset_state
172
- )
173
 
174
- col3, col4 = st.columns(2)
175
- with col3:
176
- if st.button("πŸ”„ Set Up Function"):
177
- reset_state()
178
- with col4:
179
- if st.button("▢️ Next Iteration"):
180
- try:
181
- grad = derivative(st.session_state.func_input, st.session_state.x)
182
- st.session_state.x = st.session_state.x - learning_rate * grad
183
- st.session_state.iteration += 1
184
- st.session_state.x_vals.append(st.session_state.x)
185
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
186
- except Exception as e:
187
- st.error(f"⚠️ Error: {str(e)}")
188
 
189
- # Right Section: Visualization
190
- with col2:
191
- st.subheader("πŸ“Š Gradient Descent Visualization")
192
- try:
193
- # Plot the function and all current and previous gradient descent points
194
- x_plot = np.linspace(-10, 10, 400)
195
- y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
196
 
197
- fig = go.Figure()
198
 
199
- # Function curve
200
- fig.add_trace(go.Scatter(
201
- x=x_plot,
202
- y=y_plot,
203
- mode="lines+markers",
204
- line=dict(color="blue", width=2),
205
- marker=dict(size=4, color="blue", symbol="circle"),
206
- name="Function"
207
- ))
208
 
209
- # All gradient descent points (red points without coordinates)
210
- fig.add_trace(go.Scatter(
211
- x=st.session_state.x_vals,
212
- y=st.session_state.y_vals,
213
- mode="markers",
214
- marker=dict(color="red", size=10),
215
- name="Gradient Descent Points"
216
- ))
217
 
218
- # Tangent line at the current gradient descent point
219
- current_x = st.session_state.x
220
- tangent_x = np.linspace(-10, 10, 200)
221
- tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
222
- fig.add_trace(go.Scatter(
223
- x=tangent_x,
224
- y=tangent_y,
225
- mode="lines",
226
- line=dict(color="orange", width=3),
227
- name="Tangent Line"
228
- ))
229
 
230
- # Dynamic zoom for better visibility
231
- fig.update_layout(
232
- xaxis=dict(
233
- title="x-axis",
234
- range=[-10, 10],
235
- showline=True,
236
- linecolor="white",
237
- tickcolor="white",
238
- tickfont=dict(color="white"),
239
- ticks="outside",
240
- ),
241
- yaxis=dict(
242
- title="y-axis",
243
- range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)],
244
- showline=True,
245
- linecolor="white",
246
- tickcolor="white",
247
- tickfont=dict(color="white"),
248
- ticks="outside",
249
- ),
250
- plot_bgcolor="black",
251
- paper_bgcolor="black",
252
- title="",
253
- margin=dict(l=10, r=10, t=10, b=10),
254
- width=800,
255
- height=400,
256
- showlegend=True,
257
- legend=dict(
258
- x=1.1,
259
- y=0.5,
260
- xanchor="left",
261
- yanchor="middle",
262
- orientation="v",
263
- font=dict(size=12, color="white"),
264
- bgcolor="black",
265
- bordercolor="white",
266
- borderwidth=2,
267
- )
268
  )
 
269
 
270
- # Axis lines for quadrants
271
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2))
272
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2))
273
-
274
- st.plotly_chart(fig, use_container_width=True)
275
 
276
- except Exception as e:
277
- st.error(f"⚠️ Error in visualization: {str(e)}")
278
 
279
- # Iteration stats and download
280
- col5, col6 = st.columns(2)
281
- col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
282
- col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")
283
 
 
 
 
 
 
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
+ # Full-width layout
43
+ st.set_page_config(layout="wide")
 
 
44
 
45
+ # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
+ st.markdown(
47
+ """
48
+ <style>
49
+ * {
50
+ font-family: Cambria, Arial, sans-serif !important;
51
+ }
52
+ h1, h2, h3, h4, h5 {
53
+ text-align: center;
54
+ margin-top: 0;
55
+ }
56
+ input, .stButton button, .stDownloadButton button {
57
+ border: 2px solid #ea445a;
58
+ border-radius: 5px;
59
+ padding: 10px;
60
+ }
61
+ .stInfo, .stSuccess {
62
+ border: 2px solid #ea445a;
63
+ border-radius: 5px;
64
+ padding: 10px;
65
+ }
66
+ .stButton {
67
+ margin-top: 10px;
68
+ }
69
+ /* Reduced Padding at the top */
70
+ .css-1d391kg {
71
+ padding-top: 0.5rem;
72
+ }
73
+ /* Centering the legend in the plot */
74
+ .stPlotlyChart {
75
+ display: block;
76
+ margin: 0 auto;
77
+ }
78
+ /* Adjusting for full width without scrolling */
79
+ .css-1lcbvhc {
80
+ padding-left: 0;
81
+ padding-right: 0;
82
+ }
83
+ /* Custom borders for input fields */
84
+ .stTextInput input, .stNumberInput input {
85
+ border: 2px solid #001A6E;
86
+ border-radius: 5px;
87
+ padding: 10px;
88
+ }
89
+ /* Tooltip styling */
90
+ .tooltip {
91
+ position: relative;
92
+ display: inline-block;
93
+ cursor: pointer;
94
+ }
95
+ .tooltip .tooltiptext {
96
+ visibility: hidden;
97
+ opacity: 0;
98
+ width: 300px;
99
+ background-color: #001A6E;
100
+ color: #fff;
101
+ text-align: center;
102
+ border-radius: 5px;
103
+ padding: 5px;
104
+ position: absolute;
105
+ z-index: 1;
106
+ bottom: 125%; /* Position the tooltip above */
107
+ left: 50%;
108
+ margin-left: -150px;
109
+ transition: opacity 0.3s;
110
+ }
111
+ .tooltip:hover .tooltiptext {
112
+ visibility: visible;
113
+ opacity: 1;
114
+ }
115
+ </style>
116
+ """,
117
+ unsafe_allow_html=True,
118
+ )
 
 
 
119
 
120
+ # Page Layout
121
+ st.title("🌟 Gradient Descent Interactive Tool 🌟")
122
 
123
+ col1, col2 = st.columns([1, 2])
 
 
124
 
125
+ # Left Section: User Input
126
+ with col1:
127
+ st.subheader("πŸ”§ Define Your Function")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
+ # Tooltip with instructions when hovering over the function input label
130
+ st.markdown(
131
+ """
132
+ <div class="tooltip">
133
+ <label for="func_input">Enter a function of 'x':</label>
134
+ <span class="tooltiptext">
135
+ **How to input your function:**
136
+ - Please give the inputs as mentioned below
137
+ - x^n as x**n,
138
+ - sin(x) as np.sin(x)
139
+ - log(x) as np.log(x),
140
+ - e^x or exp(x) as np.exp(x).
141
+ </span>
142
+ </div>
143
+ """,
144
+ unsafe_allow_html=True
145
+ )
146
+
147
+ # Use text input for the user to define a function, but no `value` argument
148
+ func_input = st.text_input(
149
+ "πŸ‘‡",
150
+ key="func_input",
151
+ on_change=reset_state
152
+ )
153
 
154
+ st.subheader("βš™οΈ Gradient Descent Parameters")
155
+ starting_point = st.number_input(
156
+ "Starting Point (Xβ‚€)",
157
+ value=4.0,
158
+ step=0.1,
159
+ format="%.2f",
160
+ key="starting_point",
161
+ on_change=reset_state
162
+ )
163
+ learning_rate = st.number_input(
164
+ "Learning Rate (Ε‹)",
165
+ value=0.25,
166
+ step=0.01,
167
+ format="%.2f",
168
+ key="learning_rate",
169
+ on_change=reset_state
170
+ )
171
 
172
+ col3, col4 = st.columns(2)
173
+ with col3:
174
+ if st.button("πŸ”„ Set Up Function"):
175
+ reset_state()
176
+ with col4:
177
+ if st.button("▢️ Next Iteration"):
178
+ try:
179
+ grad = derivative(st.session_state.func_input, st.session_state.x)
180
+ st.session_state.x = st.session_state.x - learning_rate * grad
181
+ st.session_state.iteration += 1
182
+ st.session_state.x_vals.append(st.session_state.x)
183
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
184
+ except Exception as e:
185
+ st.error(f"⚠️ Error: {str(e)}")
186
 
187
+ # Right Section: Visualization
188
+ with col2:
189
+ st.subheader("πŸ“Š Gradient Descent Visualization")
190
+ try:
191
+ # Plot the function and all current and previous gradient descent points
192
+ x_plot = np.linspace(-10, 10, 400)
193
+ y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
194
 
195
+ fig = go.Figure()
196
 
197
+ # Function curve
198
+ fig.add_trace(go.Scatter(
199
+ x=x_plot,
200
+ y=y_plot,
201
+ mode="lines+markers",
202
+ line=dict(color="blue", width=2),
203
+ marker=dict(size=4, color="blue", symbol="circle"),
204
+ name="Function"
205
+ ))
206
 
207
+ # All gradient descent points (red points without coordinates)
208
+ fig.add_trace(go.Scatter(
209
+ x=st.session_state.x_vals,
210
+ y=st.session_state.y_vals,
211
+ mode="markers",
212
+ marker=dict(color="red", size=10),
213
+ name="Gradient Descent Points"
214
+ ))
215
 
216
+ # Tangent line at the current gradient descent point
217
+ current_x = st.session_state.x
218
+ tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot
219
+ tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
+ fig.add_trace(go.Scatter(
221
+ x=tangent_x,
222
+ y=tangent_y,
223
+ mode="lines",
224
+ line=dict(color="orange", width=3),
225
+ name="Tangent Line"
226
+ ))
227
 
228
+ # Dynamic zoom for better visibility
229
+ fig.update_layout(
230
+ xaxis=dict(
231
+ title="x-axis",
232
+ range=[-10, 10],
233
+ showline=True,
234
+ linecolor="white",
235
+ tickcolor="white",
236
+ tickfont=dict(color="white"),
237
+ ticks="outside",
238
+ ),
239
+ yaxis=dict(
240
+ title="y-axis",
241
+ range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000
242
+ showline=True,
243
+ linecolor="white",
244
+ tickcolor="white",
245
+ tickfont=dict(color="white"),
246
+ ticks="outside",
247
+ ),
248
+ plot_bgcolor="black",
249
+ paper_bgcolor="black",
250
+ title="",
251
+ margin=dict(l=10, r=10, t=10, b=10),
252
+ width=800,
253
+ height=400,
254
+ showlegend=True,
255
+ legend=dict(
256
+ x=1.1,
257
+ y=0.5,
258
+ xanchor="left",
259
+ yanchor="middle",
260
+ orientation="v",
261
+ font=dict(size=12, color="white"),
262
+ bgcolor="black",
263
+ bordercolor="white",
264
+ borderwidth=2,
 
265
  )
266
+ )
267
 
268
+ # Axis lines for quadrants
269
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
 
 
271
 
272
+ st.plotly_chart(fig, use_container_width=True)
 
273
 
274
+ except Exception as e:
275
+ st.error(f"⚠️ Error in visualization: {str(e)}")
 
 
276
 
277
+ # Iteration stats and download
278
+ col5, col6 = st.columns(2)
279
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
280
+ col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")