trohith89 commited on
Commit
db6ba1d
Β·
verified Β·
1 Parent(s): f5deb75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -221
app.py CHANGED
@@ -39,242 +39,245 @@ if "x" not in st.session_state:
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
- # Full-width layout
43
- st.set_page_config(layout="wide")
 
 
44
 
45
- # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
- st.markdown(
47
- """
48
- <style>
49
- * {
50
- font-family: Cambria, Arial, sans-serif !important;
51
- }
52
- h1, h2, h3, h4, h5 {
53
- text-align: center;
54
- margin-top: 0;
55
- }
56
- input, .stButton button, .stDownloadButton button {
57
- border: 2px solid #ea445a;
58
- border-radius: 5px;
59
- padding: 10px;
60
- }
61
- .stInfo, .stSuccess {
62
- border: 2px solid #ea445a;
63
- border-radius: 5px;
64
- padding: 10px;
65
- }
66
- .stButton {
67
- margin-top: 10px;
68
- }
69
- /* Reduced Padding at the top */
70
- .css-1d391kg {
71
- padding-top: 0.5rem;
72
- }
73
- /* Centering the legend in the plot */
74
- .stPlotlyChart {
75
- display: block;
76
- margin: 0 auto;
77
- }
78
- /* Adjusting for full width without scrolling */
79
- .css-1lcbvhc {
80
- padding-left: 0;
81
- padding-right: 0;
82
- }
83
- /* Custom borders for input fields */
84
- .stTextInput input, .stNumberInput input {
85
- border: 2px solid #001A6E;
86
- border-radius: 5px;
87
- padding: 10px;
88
- }
89
- /* Tooltip styling */
90
- .tooltip {
91
- position: relative;
92
- display: inline-block;
93
- cursor: pointer;
94
- }
95
- .tooltip .tooltiptext {
96
- visibility: hidden;
97
- opacity: 0;
98
- width: 300px;
99
- background-color: #001A6E;
100
- color: #fff;
101
- text-align: center;
102
- border-radius: 5px;
103
- padding: 5px;
104
- position: absolute;
105
- z-index: 1;
106
- bottom: 125%; /* Position the tooltip above */
107
- left: 50%;
108
- margin-left: -150px;
109
- transition: opacity 0.3s;
110
- }
111
- .tooltip:hover .tooltiptext {
112
- visibility: visible;
113
- opacity: 1;
114
- }
115
- </style>
116
- """,
117
- unsafe_allow_html=True,
118
- )
119
-
120
- # Page Layout
121
- st.title("🌟 Gradient Descent Interactive Tool 🌟")
122
-
123
- col1, col2 = st.columns([1, 2])
124
-
125
- # Left Section: User Input
126
- with col1:
127
- st.subheader("πŸ”§ Define Your Function")
128
-
129
- # Tooltip with instructions when hovering over the function input label
130
  st.markdown(
131
  """
132
- <div class="tooltip">
133
- <label for="func_input">Enter a function of 'x':</label>
134
- <span class="tooltiptext">
135
- **How to input your function:**
136
- - Please give the inputs as mentioned below
137
- - x^n as x**n,
138
- - sin(x) as np.sin(x)
139
- - log(x) as np.log(x),
140
- - e^x or exp(x) as np.exp(x).
141
- </span>
142
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  """,
144
- unsafe_allow_html=True
145
- )
146
-
147
- # Use text input for the user to define a function, but no `value` argument
148
- func_input = st.text_input(
149
- "πŸ‘‡",
150
- key="func_input",
151
- on_change=reset_state
152
  )
153
 
154
- st.subheader("βš™οΈ Gradient Descent Parameters")
155
- starting_point = st.number_input(
156
- "Starting Point (Xβ‚€)",
157
- value=4.0,
158
- step=0.1,
159
- format="%.2f",
160
- key="starting_point",
161
- on_change=reset_state
162
- )
163
- learning_rate = st.number_input(
164
- "Learning Rate (Ε‹)",
165
- value=0.25,
166
- step=0.01,
167
- format="%.2f",
168
- key="learning_rate",
169
- on_change=reset_state
170
- )
171
 
172
- col3, col4 = st.columns(2)
173
- with col3:
174
- if st.button("πŸ”„ Set Up Function"):
175
- reset_state()
176
- with col4:
177
- if st.button("▢️ Next Iteration"):
178
- try:
179
- grad = derivative(st.session_state.func_input, st.session_state.x)
180
- st.session_state.x = st.session_state.x - learning_rate * grad
181
- st.session_state.iteration += 1
182
- st.session_state.x_vals.append(st.session_state.x)
183
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
184
- except Exception as e:
185
- st.error(f"⚠️ Error: {str(e)}")
186
 
187
- # Right Section: Visualization
188
- with col2:
189
- st.subheader("πŸ“Š Gradient Descent Visualization")
190
- try:
191
- # Plot the function and all current and previous gradient descent points
192
- x_plot = np.linspace(-10, 10, 400)
193
- y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- fig = go.Figure()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Function curve
198
- fig.add_trace(go.Scatter(
199
- x=x_plot,
200
- y=y_plot,
201
- mode="lines+markers",
202
- line=dict(color="blue", width=2),
203
- marker=dict(size=4, color="blue", symbol="circle"),
204
- name="Function"
205
- ))
206
 
207
- # All gradient descent points (red points without coordinates)
208
- fig.add_trace(go.Scatter(
209
- x=st.session_state.x_vals,
210
- y=st.session_state.y_vals,
211
- mode="markers",
212
- marker=dict(color="red", size=10),
213
- name="Gradient Descent Points"
214
- ))
215
 
216
- # Tangent line at the current gradient descent point
217
- current_x = st.session_state.x
218
- tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot
219
- tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
- fig.add_trace(go.Scatter(
221
- x=tangent_x,
222
- y=tangent_y,
223
- mode="lines",
224
- line=dict(color="orange", width=3),
225
- name="Tangent Line"
226
- ))
227
 
228
- # Dynamic zoom for better visibility
229
- fig.update_layout(
230
- xaxis=dict(
231
- title="x-axis",
232
- range=[-10, 10],
233
- showline=True,
234
- linecolor="white",
235
- tickcolor="white",
236
- tickfont=dict(color="white"),
237
- ticks="outside",
238
- ),
239
- yaxis=dict(
240
- title="y-axis",
241
- range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000
242
- showline=True,
243
- linecolor="white",
244
- tickcolor="white",
245
- tickfont=dict(color="white"),
246
- ticks="outside",
247
- ),
248
- plot_bgcolor="black",
249
- paper_bgcolor="black",
250
- title="",
251
- margin=dict(l=10, r=10, t=10, b=10),
252
- width=800,
253
- height=400,
254
- showlegend=True,
255
- legend=dict(
256
- x=1.1,
257
- y=0.5,
258
- xanchor="left",
259
- yanchor="middle",
260
- orientation="v",
261
- font=dict(size=12, color="white"),
262
- bgcolor="black",
263
- bordercolor="white",
264
- borderwidth=2,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  )
266
- )
267
 
268
- # Axis lines for quadrants
269
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
271
 
272
- st.plotly_chart(fig, use_container_width=True)
273
 
274
- except Exception as e:
275
- st.error(f"⚠️ Error in visualization: {str(e)}")
 
 
 
 
 
276
 
277
- # Iteration stats and download
278
- col5, col6 = st.columns(2)
279
- col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
280
- col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")
 
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
 
42
+ # Check for query parameters to navigate from home.py
43
+ if st.experimental_get_query_params().get("page") == ["gradient-descent"]:
44
+ # Full-width layout
45
+ st.set_page_config(layout="wide")
46
 
47
+ # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  st.markdown(
49
  """
50
+ <style>
51
+ * {
52
+ font-family: Cambria, Arial, sans-serif !important;
53
+ }
54
+ h1, h2, h3, h4, h5 {
55
+ text-align: center;
56
+ margin-top: 0;
57
+ }
58
+ input, .stButton button, .stDownloadButton button {
59
+ border: 2px solid #ea445a;
60
+ border-radius: 5px;
61
+ padding: 10px;
62
+ }
63
+ .stInfo, .stSuccess {
64
+ border: 2px solid #ea445a;
65
+ border-radius: 5px;
66
+ padding: 10px;
67
+ }
68
+ .stButton {
69
+ margin-top: 10px;
70
+ }
71
+ /* Reduced Padding at the top */
72
+ .css-1d391kg {
73
+ padding-top: 0.5rem;
74
+ }
75
+ /* Centering the legend in the plot */
76
+ .stPlotlyChart {
77
+ display: block;
78
+ margin: 0 auto;
79
+ }
80
+ /* Adjusting for full width without scrolling */
81
+ .css-1lcbvhc {
82
+ padding-left: 0;
83
+ padding-right: 0;
84
+ }
85
+ /* Custom borders for input fields */
86
+ .stTextInput input, .stNumberInput input {
87
+ border: 2px solid #001A6E;
88
+ border-radius: 5px;
89
+ padding: 10px;
90
+ }
91
+ /* Tooltip styling */
92
+ .tooltip {
93
+ position: relative;
94
+ display: inline-block;
95
+ cursor: pointer;
96
+ }
97
+ .tooltip .tooltiptext {
98
+ visibility: hidden;
99
+ opacity: 0;
100
+ width: 300px;
101
+ background-color: #001A6E;
102
+ color: #fff;
103
+ text-align: center;
104
+ border-radius: 5px;
105
+ padding: 5px;
106
+ position: absolute;
107
+ z-index: 1;
108
+ bottom: 125%; /* Position the tooltip above */
109
+ left: 50%;
110
+ margin-left: -150px;
111
+ transition: opacity 0.3s;
112
+ }
113
+ .tooltip:hover .tooltiptext {
114
+ visibility: visible;
115
+ opacity: 1;
116
+ }
117
+ </style>
118
  """,
119
+ unsafe_allow_html=True,
 
 
 
 
 
 
 
120
  )
121
 
122
+ # Page Layout
123
+ st.title("🌟 Gradient Descent Interactive Tool 🌟")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ col1, col2 = st.columns([1, 2])
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # Left Section: User Input
128
+ with col1:
129
+ st.subheader("πŸ”§ Define Your Function")
130
+
131
+ # Tooltip with instructions when hovering over the function input label
132
+ st.markdown(
133
+ """
134
+ <div class="tooltip">
135
+ <label for="func_input">Enter a function of 'x':</label>
136
+ <span class="tooltiptext">
137
+ **How to input your function:**
138
+ - Please give the inputs as mentioned below
139
+ - x^n as x**n,
140
+ - sin(x) as np.sin(x)
141
+ - log(x) as np.log(x),
142
+ - e^x or exp(x) as np.exp(x).
143
+ </span>
144
+ </div>
145
+ """,
146
+ unsafe_allow_html=True
147
+ )
148
 
149
+ # Use text input for the user to define a function
150
+ func_input = st.text_input(
151
+ "πŸ‘‡",
152
+ key="func_input",
153
+ on_change=reset_state
154
+ )
155
+
156
+ st.subheader("βš™οΈ Gradient Descent Parameters")
157
+ starting_point = st.number_input(
158
+ "Starting Point (Xβ‚€)",
159
+ value=4.0,
160
+ step=0.1,
161
+ format="%.2f",
162
+ key="starting_point",
163
+ on_change=reset_state
164
+ )
165
+ learning_rate = st.number_input(
166
+ "Learning Rate (Ε‹)",
167
+ value=0.25,
168
+ step=0.01,
169
+ format="%.2f",
170
+ key="learning_rate",
171
+ on_change=reset_state
172
+ )
173
+
174
+ col3, col4 = st.columns(2)
175
+ with col3:
176
+ if st.button("πŸ”„ Set Up Function"):
177
+ reset_state()
178
+ with col4:
179
+ if st.button("▢️ Next Iteration"):
180
+ try:
181
+ grad = derivative(st.session_state.func_input, st.session_state.x)
182
+ st.session_state.x = st.session_state.x - learning_rate * grad
183
+ st.session_state.iteration += 1
184
+ st.session_state.x_vals.append(st.session_state.x)
185
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
186
+ except Exception as e:
187
+ st.error(f"⚠️ Error: {str(e)}")
188
 
189
+ # Right Section: Visualization
190
+ with col2:
191
+ st.subheader("πŸ“Š Gradient Descent Visualization")
192
+ try:
193
+ # Plot the function and all current and previous gradient descent points
194
+ x_plot = np.linspace(-10, 10, 400)
195
+ y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
 
 
196
 
197
+ fig = go.Figure()
 
 
 
 
 
 
 
198
 
199
+ # Function curve
200
+ fig.add_trace(go.Scatter(
201
+ x=x_plot,
202
+ y=y_plot,
203
+ mode="lines+markers",
204
+ line=dict(color="blue", width=2),
205
+ marker=dict(size=4, color="blue", symbol="circle"),
206
+ name="Function"
207
+ ))
 
 
208
 
209
+ # All gradient descent points (red points without coordinates)
210
+ fig.add_trace(go.Scatter(
211
+ x=st.session_state.x_vals,
212
+ y=st.session_state.y_vals,
213
+ mode="markers",
214
+ marker=dict(color="red", size=10),
215
+ name="Gradient Descent Points"
216
+ ))
217
+
218
+ # Tangent line at the current gradient descent point
219
+ current_x = st.session_state.x
220
+ tangent_x = np.linspace(-10, 10, 200)
221
+ tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
222
+ fig.add_trace(go.Scatter(
223
+ x=tangent_x,
224
+ y=tangent_y,
225
+ mode="lines",
226
+ line=dict(color="orange", width=3),
227
+ name="Tangent Line"
228
+ ))
229
+
230
+ # Dynamic zoom for better visibility
231
+ fig.update_layout(
232
+ xaxis=dict(
233
+ title="x-axis",
234
+ range=[-10, 10],
235
+ showline=True,
236
+ linecolor="white",
237
+ tickcolor="white",
238
+ tickfont=dict(color="white"),
239
+ ticks="outside",
240
+ ),
241
+ yaxis=dict(
242
+ title="y-axis",
243
+ range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)],
244
+ showline=True,
245
+ linecolor="white",
246
+ tickcolor="white",
247
+ tickfont=dict(color="white"),
248
+ ticks="outside",
249
+ ),
250
+ plot_bgcolor="black",
251
+ paper_bgcolor="black",
252
+ title="",
253
+ margin=dict(l=10, r=10, t=10, b=10),
254
+ width=800,
255
+ height=400,
256
+ showlegend=True,
257
+ legend=dict(
258
+ x=1.1,
259
+ y=0.5,
260
+ xanchor="left",
261
+ yanchor="middle",
262
+ orientation="v",
263
+ font=dict(size=12, color="white"),
264
+ bgcolor="black",
265
+ bordercolor="white",
266
+ borderwidth=2,
267
+ )
268
  )
 
269
 
270
+ # Axis lines for quadrants
271
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2))
272
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2))
273
 
274
+ st.plotly_chart(fig, use_container_width=True)
275
 
276
+ except Exception as e:
277
+ st.error(f"⚠️ Error in visualization: {str(e)}")
278
+
279
+ # Iteration stats and download
280
+ col5, col6 = st.columns(2)
281
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
282
+ col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")
283