trohith89 commited on
Commit
9d48ec4
Β·
verified Β·
1 Parent(s): 1a9168f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -130
app.py CHANGED
@@ -6,7 +6,10 @@ import plotly.graph_objects as go
6
  def safe_eval(func_str, x_val):
7
  """ Safely evaluates the function at a given x value. """
8
  allowed_names = {"x": x_val, "np": np}
9
- return eval(func_str, {"__builtins__": None}, allowed_names)
 
 
 
10
 
11
  # Function derivative using finite difference method
12
  def derivative(func_str, x_val, h=1e-5):
@@ -47,59 +50,43 @@ if "x" not in st.session_state:
47
  # Full-width layout
48
  st.set_page_config(layout="wide")
49
 
50
- # CSS Styles for Premium Design
51
  st.markdown(
52
  """
53
  <style>
54
- /* Header Styling */
 
 
55
  h1, h2, h3, h4, h5 {
56
- font-family: 'Arial', sans-serif;
57
- color: #333333;
58
  text-align: center;
 
59
  }
60
- /* Button Styling */
61
- .stButton button {
62
- background-color: #ffffff;
63
- color: #FF00FF; /* Pink text */
64
- border-radius: 8px;
65
- border: 2px solid #000000; /* Black border */
66
- padding: 10px 20px;
67
- font-size: 16px;
68
- font-weight: bold;
69
- transition: 0.3s;
70
- width: 100%;
71
- height: 50px;
72
- box-sizing: border-box;
73
- display: flex;
74
- align-items: center;
75
- justify-content: center;
76
- white-space: nowrap; /* Ensures text stays in one line */
77
- }
78
- .stButton button:hover {
79
- background-color: #000000; /* Black background on hover */
80
- color: #FF00FF; /* Pink text on hover */
81
- transform: scale(1.05);
82
- }
83
- /* Input Box Styling */
84
- input, select {
85
- border: 2px solid #ccc;
86
- padding: 8px;
87
  border-radius: 5px;
88
- font-size: 16px;
89
- font-family: 'Arial', sans-serif;
90
  }
91
- /* Tooltip Styling */
92
- div[role="tooltip"] {
93
- background-color: #0078D7;
94
- color: white;
95
- border-radius: 8px;
96
  padding: 10px;
97
  }
98
- /* Columns for buttons */
99
  .stButton {
100
- display: flex;
101
- flex-direction: column;
102
- gap: 15px;
 
 
 
 
 
 
 
 
 
 
 
 
103
  }
104
  </style>
105
  """,
@@ -107,49 +94,70 @@ st.markdown(
107
  )
108
 
109
  # Page Layout
110
- with st.container():
111
- st.header("Gradient Descent Interactive Tool")
 
 
 
 
 
112
 
113
- col1, col2 = st.columns([1, 2])
114
-
115
- # Left Section: User Input
116
- with col1:
117
- st.subheader("Define Your Function")
118
- func_input = st.text_input("Enter a function of 'x':", st.session_state.func_input, key="func_input", on_change=reset_state)
119
-
120
- st.write("Or choose a predefined function:")
121
- cols = st.columns(4) # Adjusted to 4 columns to increase button width
122
- for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
123
- with cols[i]:
124
- if st.button(btn_label):
125
- st.session_state.func_input = func_value
126
- reset_state()
127
- st.rerun() # Re-run to ensure everything updates properly
128
-
129
- st.subheader("Gradient Descent Parameters")
130
- starting_point = st.number_input("Starting Point (Xβ‚’)", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state)
131
- learning_rate = st.number_input("Learning Rate (Ε‹)", value=0.25, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state)
132
-
133
- col3, col4 = st.columns(2)
134
- with col3:
135
- if st.button("Set Up Function"):
136
- reset_state()
137
- with col4:
138
- if st.button("Next Iteration"):
139
- try:
140
- grad = derivative(st.session_state.func_input, st.session_state.x)
141
- st.session_state.x = st.session_state.x - learning_rate * grad
142
- st.session_state.iteration += 1
143
- st.session_state.x_vals.append(st.session_state.x)
144
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
145
- except Exception as e:
146
- st.error(f"Error: {str(e)}")
147
-
148
- # Right Section: Visualization
149
- with col2:
150
- st.subheader("Gradient Descent Visualization")
151
-
152
- # Plot the function, gradient descent points, and tangent line
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  x_plot = np.linspace(-10, 10, 400)
154
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
155
 
@@ -159,26 +167,25 @@ with st.container():
159
  fig.add_trace(go.Scatter(
160
  x=x_plot,
161
  y=y_plot,
162
- mode="lines",
163
  line=dict(color="blue", width=2),
 
164
  name="Function"
165
  ))
166
 
167
- # Gradient descent point (only show the current point)
168
  fig.add_trace(go.Scatter(
169
- x=[st.session_state.x],
170
- y=[st.session_state.y_vals[-1]],
171
  mode="markers",
172
- marker=dict(color="red", size=8),
173
- name="Current Gradient Descent Point",
174
- hovertemplate="x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>"
175
  ))
176
 
177
- # Tangent line at current point
178
  current_x = st.session_state.x
179
- tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range
180
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
181
-
182
  fig.add_trace(go.Scatter(
183
  x=tangent_x,
184
  y=tangent_y,
@@ -187,55 +194,56 @@ with st.container():
187
  name="Tangent Line"
188
  ))
189
 
190
- # Add axis lines
191
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="black", width=2)) # x-axis
192
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="black", width=2)) # y-axis
193
-
194
- # Update layout
195
  fig.update_layout(
196
- plot_bgcolor="black",
197
  xaxis=dict(
198
- title="x",
199
- titlefont=dict(color="white"),
200
- tickfont=dict(color="white"),
201
- ticks="outside",
202
- tickcolor="white",
203
  range=[-10, 10],
204
  showline=True,
205
  linecolor="white",
206
- mirror=True
207
- ),
208
- yaxis=dict(
209
- title="y",
210
- titlefont=dict(color="white"),
211
  tickfont=dict(color="white"),
212
  ticks="outside",
213
- tickcolor="white",
214
- range=[-100, 100],
 
 
215
  showline=True,
216
  linecolor="white",
217
- mirror=True
218
- ),
219
- title=dict(
220
- text="Gradient Descent Visualization with Tangent Line",
221
- font=dict(color="white")
222
  ),
 
 
 
 
 
 
 
223
  legend=dict(
224
- font=dict(color="white"),
 
 
 
 
 
225
  bgcolor="black",
226
- bordercolor="black",
227
- borderwidth=1,
228
- orientation="h",
229
- xanchor="center",
230
- x=0.5,
231
- yanchor="bottom",
232
- y=1.1
233
  )
234
  )
235
 
 
 
 
 
236
  st.plotly_chart(fig, use_container_width=True)
237
 
238
- # Iteration stats
239
- st.subheader("Iteration Counter")
240
- st.info(f"Iteration: {st.session_state.iteration}")
241
- st.success(f"Current x: {st.session_state.x:.4f}\nCurrent f(x): {st.session_state.y_vals[-1]:.4f}")
 
 
 
 
6
  def safe_eval(func_str, x_val):
7
  """ Safely evaluates the function at a given x value. """
8
  allowed_names = {"x": x_val, "np": np}
9
+ try:
10
+ return eval(func_str, {"__builtins__": None}, allowed_names)
11
+ except Exception as e:
12
+ raise ValueError(f"Error evaluating the function: {e}")
13
 
14
  # Function derivative using finite difference method
15
  def derivative(func_str, x_val, h=1e-5):
 
50
  # Full-width layout
51
  st.set_page_config(layout="wide")
52
 
53
+ # CSS Styles for Borders, Font and Reduced Padding
54
  st.markdown(
55
  """
56
  <style>
57
+ * {
58
+ font-family: Cambria, Arial, sans-serif !important;
59
+ }
60
  h1, h2, h3, h4, h5 {
 
 
61
  text-align: center;
62
+ margin-top: 0;
63
  }
64
+ input, .stButton button, .stDownloadButton button {
65
+ border: 2px solid #ea445a;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  border-radius: 5px;
67
+ padding: 10px;
 
68
  }
69
+ .stInfo, .stSuccess {
70
+ border: 2px solid #ea445a;
71
+ border-radius: 5px;
 
 
72
  padding: 10px;
73
  }
 
74
  .stButton {
75
+ margin-top: 10px;
76
+ }
77
+ /* Reduced Padding at the top */
78
+ .css-1d391kg {
79
+ padding-top: 0.5rem;
80
+ }
81
+ /* Centering the legend in the plot */
82
+ .stPlotlyChart {
83
+ display: block;
84
+ margin: 0 auto;
85
+ }
86
+ /* Adjusting for full width without scrolling */
87
+ .css-1lcbvhc {
88
+ padding-left: 0;
89
+ padding-right: 0;
90
  }
91
  </style>
92
  """,
 
94
  )
95
 
96
  # Page Layout
97
+ st.title("🌟 Gradient Descent Interactive Tool 🌟")
98
+
99
+ col1, col2 = st.columns([1, 2])
100
+
101
+ # Left Section: User Input
102
+ with col1:
103
+ st.subheader("πŸ”§ Define Your Function")
104
 
105
+ # Display the predefined function buttons
106
+ st.write("Or choose a predefined function:")
107
+ cols = st.columns(4)
108
+ for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
109
+ with cols[i]:
110
+ if st.button(btn_label):
111
+ st.session_state.func_input = func_value
112
+ reset_state() # Ensure that the reset function is called to update the state
113
+ st.experimental_rerun() # Re-run to ensure everything updates properly
114
+
115
+ # Use text input for the user to define a function
116
+ func_input = st.text_input(
117
+ "Enter a function of 'x':",
118
+ st.session_state.func_input,
119
+ key="func_input",
120
+ on_change=reset_state
121
+ )
122
+
123
+ st.subheader("βš™οΈ Gradient Descent Parameters")
124
+ starting_point = st.number_input(
125
+ "Starting Point (Xβ‚€)",
126
+ value=4.0,
127
+ step=0.1,
128
+ format="%.2f",
129
+ key="starting_point",
130
+ on_change=reset_state
131
+ )
132
+ learning_rate = st.number_input(
133
+ "Learning Rate (Ε‹)",
134
+ value=0.25,
135
+ step=0.01,
136
+ format="%.2f",
137
+ key="learning_rate",
138
+ on_change=reset_state
139
+ )
140
+
141
+ col3, col4 = st.columns(2)
142
+ with col3:
143
+ if st.button("πŸ”„ Set Up Function"):
144
+ reset_state()
145
+ with col4:
146
+ if st.button("▢️ Next Iteration"):
147
+ try:
148
+ grad = derivative(st.session_state.func_input, st.session_state.x)
149
+ st.session_state.x = st.session_state.x - learning_rate * grad
150
+ st.session_state.iteration += 1
151
+ st.session_state.x_vals.append(st.session_state.x)
152
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
153
+ except Exception as e:
154
+ st.error(f"⚠️ Error: {str(e)}")
155
+
156
+ # Right Section: Visualization
157
+ with col2:
158
+ st.subheader("πŸ“Š Gradient Descent Visualization")
159
+ try:
160
+ # Plot the function and all current and previous gradient descent points
161
  x_plot = np.linspace(-10, 10, 400)
162
  y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
163
 
 
167
  fig.add_trace(go.Scatter(
168
  x=x_plot,
169
  y=y_plot,
170
+ mode="lines+markers",
171
  line=dict(color="blue", width=2),
172
+ marker=dict(size=4, color="blue", symbol="circle"),
173
  name="Function"
174
  ))
175
 
176
+ # All gradient descent points (red points without coordinates)
177
  fig.add_trace(go.Scatter(
178
+ x=st.session_state.x_vals,
179
+ y=st.session_state.y_vals,
180
  mode="markers",
181
+ marker=dict(color="red", size=10),
182
+ name="Gradient Descent Points"
 
183
  ))
184
 
185
+ # Tangent line at the current gradient descent point
186
  current_x = st.session_state.x
187
+ tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range for tangent line
188
  tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
 
189
  fig.add_trace(go.Scatter(
190
  x=tangent_x,
191
  y=tangent_y,
 
194
  name="Tangent Line"
195
  ))
196
 
197
+ # Dynamic zoom for better visibility
 
 
 
 
198
  fig.update_layout(
 
199
  xaxis=dict(
200
+ title="x-axis",
 
 
 
 
201
  range=[-10, 10],
202
  showline=True,
203
  linecolor="white",
204
+ tickcolor="white",
 
 
 
 
205
  tickfont=dict(color="white"),
206
  ticks="outside",
207
+ ),
208
+ yaxis=dict(
209
+ title="y-axis",
210
+ range=[min(y_plot) - 5, max(y_plot) + 5],
211
  showline=True,
212
  linecolor="white",
213
+ tickcolor="white",
214
+ tickfont=dict(color="white"),
215
+ ticks="outside",
 
 
216
  ),
217
+ plot_bgcolor="black",
218
+ paper_bgcolor="black",
219
+ title="",
220
+ margin=dict(l=10, r=10, t=10, b=10),
221
+ width=800,
222
+ height=400,
223
+ showlegend=True,
224
  legend=dict(
225
+ x=1.1,
226
+ y=0.5,
227
+ xanchor="left",
228
+ yanchor="middle",
229
+ orientation="v",
230
+ font=dict(size=12, color="white"),
231
  bgcolor="black",
232
+ bordercolor="white",
233
+ borderwidth=2,
 
 
 
 
 
234
  )
235
  )
236
 
237
+ # Axis lines for quadrants
238
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
239
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
240
+
241
  st.plotly_chart(fig, use_container_width=True)
242
 
243
+ except Exception as e:
244
+ st.error(f"⚠️ Error in visualization: {str(e)}")
245
+
246
+ # Iteration stats and download
247
+ col5, col6 = st.columns(2)
248
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
249
+ col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")