trohith89 commited on
Commit
d6a0711
Β·
verified Β·
1 Parent(s): 8fa6273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -146
app.py CHANGED
@@ -38,85 +38,12 @@ if "x" not in st.session_state:
38
  st.session_state.iteration = 0
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
 
 
41
 
42
  # Full-width layout
43
  st.set_page_config(layout="wide")
44
 
45
- # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
- st.markdown(
47
- """
48
- <style>
49
- * {
50
- font-family: Cambria, Arial, sans-serif !important;
51
- }
52
- h1, h2, h3, h4, h5 {
53
- text-align: center;
54
- margin-top: 0;
55
- }
56
- input, .stButton button, .stDownloadButton button {
57
- border: 2px solid #ea445a;
58
- border-radius: 5px;
59
- padding: 10px;
60
- }
61
- .stInfo, .stSuccess {
62
- border: 2px solid #ea445a;
63
- border-radius: 5px;
64
- padding: 10px;
65
- }
66
- .stButton {
67
- margin-top: 10px;
68
- }
69
- /* Reduced Padding at the top */
70
- .css-1d391kg {
71
- padding-top: 0.5rem;
72
- }
73
- /* Centering the legend in the plot */
74
- .stPlotlyChart {
75
- display: block;
76
- margin: 0 auto;
77
- }
78
- /* Adjusting for full width without scrolling */
79
- .css-1lcbvhc {
80
- padding-left: 0;
81
- padding-right: 0;
82
- }
83
- /* Custom borders for input fields */
84
- .stTextInput input, .stNumberInput input {
85
- border: 2px solid #001A6E;
86
- border-radius: 5px;
87
- padding: 10px;
88
- }
89
- /* Tooltip styling */
90
- .tooltip {
91
- position: relative;
92
- display: inline-block;
93
- cursor: pointer;
94
- }
95
- .tooltip .tooltiptext {
96
- visibility: hidden;
97
- opacity: 0;
98
- width: 300px;
99
- background-color: #001A6E;
100
- color: #fff;
101
- text-align: center;
102
- border-radius: 5px;
103
- padding: 5px;
104
- position: absolute;
105
- z-index: 1;
106
- bottom: 125%; /* Position the tooltip above */
107
- left: 50%;
108
- margin-left: -150px;
109
- transition: opacity 0.3s;
110
- }
111
- .tooltip:hover .tooltiptext {
112
- visibility: visible;
113
- opacity: 1;
114
- }
115
- </style>
116
- """,
117
- unsafe_allow_html=True,
118
- )
119
-
120
  # Page Layout
121
  st.title("🌟 Gradient Descent Interactive Tool 🌟")
122
 
@@ -126,27 +53,8 @@ col1, col2 = st.columns([1, 2])
126
  with col1:
127
  st.subheader("πŸ”§ Define Your Function")
128
 
129
- # Tooltip with instructions when hovering over the function input label
130
- st.markdown(
131
- """
132
- <div class="tooltip">
133
- <label for="func_input">Enter a function of 'x':</label>
134
- <span class="tooltiptext">
135
- **How to input your function:**
136
- - Please give the inputs as mentioned below
137
- - x^n as x**n,
138
- - sin(x) as np.sin(x)
139
- - log(x) as np.log(x),
140
- - e^x or exp(x) as np.exp(x).
141
- </span>
142
- </div>
143
- """,
144
- unsafe_allow_html=True
145
- )
146
-
147
- # Use text input for the user to define a function, but no `value` argument
148
  func_input = st.text_input(
149
- "πŸ‘‡",
150
  key="func_input",
151
  on_change=reset_state
152
  )
@@ -162,13 +70,16 @@ with col1:
162
  )
163
  learning_rate = st.number_input(
164
  "Learning Rate (Ε‹)",
165
- value=0.25,
166
  step=0.01,
167
  format="%.2f",
168
- key="learning_rate",
169
- on_change=reset_state
170
  )
171
 
 
 
 
 
172
  col3, col4 = st.columns(2)
173
  with col3:
174
  if st.button("πŸ”„ Set Up Function"):
@@ -177,7 +88,7 @@ with col1:
177
  if st.button("▢️ Next Iteration"):
178
  try:
179
  grad = derivative(st.session_state.func_input, st.session_state.x)
180
- st.session_state.x = st.session_state.x - learning_rate * grad
181
  st.session_state.iteration += 1
182
  st.session_state.x_vals.append(st.session_state.x)
183
  st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
@@ -188,7 +99,6 @@ with col1:
188
  with col2:
189
  st.subheader("πŸ“Š Gradient Descent Visualization")
190
  try:
191
- # Plot the function and all current and previous gradient descent points
192
  x_plot = np.linspace(-10, 10, 400)
193
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
194
 
@@ -198,13 +108,12 @@ with col2:
198
  fig.add_trace(go.Scatter(
199
  x=x_plot,
200
  y=y_plot,
201
- mode="lines+markers",
202
- line=dict(color="blue", width=2),
203
- marker=dict(size=4, color="blue", symbol="circle"),
204
  name="Function"
205
  ))
206
 
207
- # All gradient descent points (red points without coordinates)
208
  fig.add_trace(go.Scatter(
209
  x=st.session_state.x_vals,
210
  y=st.session_state.y_vals,
@@ -213,9 +122,9 @@ with col2:
213
  name="Gradient Descent Points"
214
  ))
215
 
216
- # Tangent line at the current gradient descent point
217
  current_x = st.session_state.x
218
- tangent_x = np.linspace(-10, 10, 200) # Adjusting range to cover entire plot
219
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
220
  fig.add_trace(go.Scatter(
221
  x=tangent_x,
@@ -225,56 +134,20 @@ with col2:
225
  name="Tangent Line"
226
  ))
227
 
228
- # Dynamic zoom for better visibility
229
  fig.update_layout(
230
- xaxis=dict(
231
- title="x-axis",
232
- range=[-10, 10],
233
- showline=True,
234
- linecolor="white",
235
- tickcolor="white",
236
- tickfont=dict(color="white"),
237
- ticks="outside",
238
- ),
239
- yaxis=dict(
240
- title="y-axis",
241
- range=[min(y_plot) - 5, min(max(y_plot) + 5, 1000)], # Limiting the max y to 1000
242
- showline=True,
243
- linecolor="white",
244
- tickcolor="white",
245
- tickfont=dict(color="white"),
246
- ticks="outside",
247
- ),
248
- plot_bgcolor="black",
249
- paper_bgcolor="black",
250
- title="",
251
- margin=dict(l=10, r=10, t=10, b=10),
252
  width=800,
253
  height=400,
254
- showlegend=True,
255
- legend=dict(
256
- x=1.1,
257
- y=0.5,
258
- xanchor="left",
259
- yanchor="middle",
260
- orientation="v",
261
- font=dict(size=12, color="white"),
262
- bgcolor="black",
263
- bordercolor="white",
264
- borderwidth=2,
265
- )
266
  )
267
 
268
- # Axis lines for quadrants
269
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
270
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
271
-
272
  st.plotly_chart(fig, use_container_width=True)
273
 
274
  except Exception as e:
275
  st.error(f"⚠️ Error in visualization: {str(e)}")
276
 
277
- # Iteration stats and download
278
  col5, col6 = st.columns(2)
279
  col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
280
  col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")
 
38
  st.session_state.iteration = 0
39
  st.session_state.x_vals = [4.0]
40
  st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
+ if "learning_rate" not in st.session_state:
42
+ st.session_state.learning_rate = 0.25
43
 
44
  # Full-width layout
45
  st.set_page_config(layout="wide")
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Page Layout
48
  st.title("🌟 Gradient Descent Interactive Tool 🌟")
49
 
 
53
  with col1:
54
  st.subheader("πŸ”§ Define Your Function")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  func_input = st.text_input(
57
+ "Enter a function of 'x':",
58
  key="func_input",
59
  on_change=reset_state
60
  )
 
70
  )
71
  learning_rate = st.number_input(
72
  "Learning Rate (Ε‹)",
73
+ value=st.session_state.learning_rate,
74
  step=0.01,
75
  format="%.2f",
76
+ key="new_learning_rate"
 
77
  )
78
 
79
+ # Update learning rate without resetting the state
80
+ if learning_rate != st.session_state.learning_rate:
81
+ st.session_state.learning_rate = learning_rate
82
+
83
  col3, col4 = st.columns(2)
84
  with col3:
85
  if st.button("πŸ”„ Set Up Function"):
 
88
  if st.button("▢️ Next Iteration"):
89
  try:
90
  grad = derivative(st.session_state.func_input, st.session_state.x)
91
+ st.session_state.x = st.session_state.x - st.session_state.learning_rate * grad
92
  st.session_state.iteration += 1
93
  st.session_state.x_vals.append(st.session_state.x)
94
  st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
 
99
  with col2:
100
  st.subheader("πŸ“Š Gradient Descent Visualization")
101
  try:
 
102
  x_plot = np.linspace(-10, 10, 400)
103
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
104
 
 
108
  fig.add_trace(go.Scatter(
109
  x=x_plot,
110
  y=y_plot,
111
+ mode="lines",
112
+ line=dict(color="blue", width=2),
 
113
  name="Function"
114
  ))
115
 
116
+ # Gradient descent points
117
  fig.add_trace(go.Scatter(
118
  x=st.session_state.x_vals,
119
  y=st.session_state.y_vals,
 
122
  name="Gradient Descent Points"
123
  ))
124
 
125
+ # Tangent line at the current point
126
  current_x = st.session_state.x
127
+ tangent_x = np.linspace(-10, 10, 200)
128
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
129
  fig.add_trace(go.Scatter(
130
  x=tangent_x,
 
134
  name="Tangent Line"
135
  ))
136
 
 
137
  fig.update_layout(
138
+ xaxis=dict(title="x-axis"),
139
+ yaxis=dict(title="y-axis"),
140
+ title="Gradient Descent Visualization",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  width=800,
142
  height=400,
143
+ showlegend=True
 
 
 
 
 
 
 
 
 
 
 
144
  )
145
 
 
 
 
 
146
  st.plotly_chart(fig, use_container_width=True)
147
 
148
  except Exception as e:
149
  st.error(f"⚠️ Error in visualization: {str(e)}")
150
 
 
151
  col5, col6 = st.columns(2)
152
  col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
153
  col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")