trohith89 commited on
Commit
2f3f580
·
verified ·
1 Parent(s): 3d7465a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -240
app.py CHANGED
@@ -1,34 +1,10 @@
1
- import streamlit as st
2
- import numpy as np
3
- import plotly.graph_objects as go
4
-
5
- # Safe function evaluation
6
- def safe_eval(func_str, x_val):
7
- """ Safely evaluates the function at a given x value. """
8
- allowed_names = {"x": x_val, "np": np}
9
- try:
10
- return eval(func_str, {"__builtins__": None}, allowed_names)
11
- except Exception as e:
12
- raise ValueError(f"Error evaluating the function: {e}")
13
-
14
- # Function derivative using finite difference method
15
- def derivative(func_str, x_val, h=1e-5):
16
- """ Numerically compute the derivative of the function at x using finite differences. """
17
- return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
18
-
19
- # Tangent line equation
20
- def tangent_line(func_str, x_val, x_range):
21
- """ Compute the tangent line at a given x value. """
22
- y_val = safe_eval(func_str, x_val)
23
- slope = derivative(func_str, x_val)
24
- return slope * (x_range - x_val) + y_val
25
-
26
- # Callback to reset session state
27
- def reset_state():
28
- st.session_state.x = st.session_state.starting_point
29
- st.session_state.iteration = 0
30
- st.session_state.x_vals = [st.session_state.starting_point]
31
- st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
32
 
33
  # Predefined function buttons
34
  predefined_functions = {
@@ -38,212 +14,12 @@ predefined_functions = {
38
  "log(x)": "np.log(x)"
39
  }
40
 
41
- # Initialize session state variables
42
- if "func_input" not in st.session_state:
43
- st.session_state.func_input = "x**2 + x"
44
- if "x" not in st.session_state:
45
- st.session_state.x = 4.0
46
- st.session_state.iteration = 0
47
- st.session_state.x_vals = [4.0]
48
- st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
49
-
50
- # Full-width layout
51
- st.set_page_config(layout="wide")
52
-
53
- # CSS Styles for Borders, Font and Reduced Padding
54
- st.markdown(
55
- """
56
- <style>
57
- * {
58
- font-family: Cambria, Arial, sans-serif !important;
59
- }
60
- h1, h2, h3, h4, h5 {
61
- text-align: center;
62
- margin-top: 0;
63
- }
64
- input, .stButton button, .stDownloadButton button {
65
- border: 2px solid #ea445a;
66
- border-radius: 5px;
67
- padding: 10px;
68
- }
69
- .stInfo, .stSuccess {
70
- border: 2px solid #ea445a;
71
- border-radius: 5px;
72
- padding: 10px;
73
- }
74
- .stButton {
75
- margin-top: 10px;
76
- }
77
- /* Reduced Padding at the top */
78
- .css-1d391kg {
79
- padding-top: 0.5rem;
80
- }
81
- /* Centering the legend in the plot */
82
- .stPlotlyChart {
83
- display: block;
84
- margin: 0 auto;
85
- }
86
- /* Adjusting for full width without scrolling */
87
- .css-1lcbvhc {
88
- padding-left: 0;
89
- padding-right: 0;
90
- }
91
- </style>
92
- """,
93
- unsafe_allow_html=True,
94
- )
95
-
96
- # Page Layout
97
- st.title("🌟 Gradient Descent Interactive Tool 🌟")
98
-
99
- col1, col2 = st.columns([1, 2])
100
-
101
- # Left Section: User Input
102
- with col1:
103
- st.subheader("🔧 Define Your Function")
104
-
105
- # Display the predefined function buttons
106
- st.write("Or choose a predefined function:")
107
- cols = st.columns(4)
108
- for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
109
- with cols[i]:
110
- if st.button(btn_label):
111
- st.session_state.func_input = func_value
112
- reset_state() # Ensure that the reset function is called to update the state
113
- st.experimental_rerun() # Re-run to ensure everything updates properly
114
-
115
- # Use text input for the user to define a function
116
- func_input = st.text_input(
117
- "Enter a function of 'x':",
118
- st.session_state.func_input,
119
- key="func_input",
120
- on_change=reset_state
121
- )
122
-
123
- st.subheader("⚙️ Gradient Descent Parameters")
124
- starting_point = st.number_input(
125
- "Starting Point (X₀)",
126
- value=4.0,
127
- step=0.1,
128
- format="%.2f",
129
- key="starting_point",
130
- on_change=reset_state
131
- )
132
- learning_rate = st.number_input(
133
- "Learning Rate (ŋ)",
134
- value=0.25,
135
- step=0.01,
136
- format="%.2f",
137
- key="learning_rate",
138
- on_change=reset_state
139
- )
140
-
141
- col3, col4 = st.columns(2)
142
- with col3:
143
- if st.button("🔄 Set Up Function"):
144
- reset_state()
145
- with col4:
146
- if st.button("▶️ Next Iteration"):
147
- try:
148
- grad = derivative(st.session_state.func_input, st.session_state.x)
149
- st.session_state.x = st.session_state.x - learning_rate * grad
150
- st.session_state.iteration += 1
151
- st.session_state.x_vals.append(st.session_state.x)
152
- st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
153
- except Exception as e:
154
- st.error(f"⚠️ Error: {str(e)}")
155
-
156
- # Right Section: Visualization
157
- with col2:
158
- st.subheader("📊 Gradient Descent Visualization")
159
- try:
160
- # Plot the function and all current and previous gradient descent points
161
- x_plot = np.linspace(-10, 10, 400)
162
- y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
163
-
164
- fig = go.Figure()
165
-
166
- # Function curve
167
- fig.add_trace(go.Scatter(
168
- x=x_plot,
169
- y=y_plot,
170
- mode="lines+markers",
171
- line=dict(color="blue", width=2),
172
- marker=dict(size=4, color="blue", symbol="circle"),
173
- name="Function"
174
- ))
175
-
176
- # All gradient descent points (red points without coordinates)
177
- fig.add_trace(go.Scatter(
178
- x=st.session_state.x_vals,
179
- y=st.session_state.y_vals,
180
- mode="markers",
181
- marker=dict(color="red", size=10),
182
- name="Gradient Descent Points"
183
- ))
184
-
185
- # Tangent line at the current gradient descent point
186
- current_x = st.session_state.x
187
- tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range for tangent line
188
- tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
189
- fig.add_trace(go.Scatter(
190
- x=tangent_x,
191
- y=tangent_y,
192
- mode="lines",
193
- line=dict(color="orange", width=3),
194
- name="Tangent Line"
195
- ))
196
-
197
- # Dynamic zoom for better visibility
198
- fig.update_layout(
199
- xaxis=dict(
200
- title="x-axis",
201
- range=[-10, 10],
202
- showline=True,
203
- linecolor="white",
204
- tickcolor="white",
205
- tickfont=dict(color="white"),
206
- ticks="outside",
207
- ),
208
- yaxis=dict(
209
- title="y-axis",
210
- range=[min(y_plot) - 5, max(y_plot) + 5],
211
- showline=True,
212
- linecolor="white",
213
- tickcolor="white",
214
- tickfont=dict(color="white"),
215
- ticks="outside",
216
- ),
217
- plot_bgcolor="black",
218
- paper_bgcolor="black",
219
- title="",
220
- margin=dict(l=10, r=10, t=10, b=10),
221
- width=800,
222
- height=400,
223
- showlegend=True,
224
- legend=dict(
225
- x=1.1,
226
- y=0.5,
227
- xanchor="left",
228
- yanchor="middle",
229
- orientation="v",
230
- font=dict(size=12, color="white"),
231
- bgcolor="black",
232
- bordercolor="white",
233
- borderwidth=2,
234
- )
235
- )
236
-
237
- # Axis lines for quadrants
238
- fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="white", width=2)) # x-axis
239
- fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="white", width=2)) # y-axis
240
-
241
- st.plotly_chart(fig, use_container_width=True)
242
-
243
- except Exception as e:
244
- st.error(f"⚠️ Error in visualization: {str(e)}")
245
-
246
- # Iteration stats and download
247
- col5, col6 = st.columns(2)
248
- col5.info(f"🧑‍💻 Iteration: {st.session_state.iteration}")
249
- col6.success(f"✅ Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")
 
1
+ # Use text input for the user to define a function
2
+ func_input = st.text_input(
3
+ "Enter a function of 'x':",
4
+ st.session_state.func_input,
5
+ key="func_input",
6
+ on_change=reset_state
7
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Predefined function buttons
10
  predefined_functions = {
 
14
  "log(x)": "np.log(x)"
15
  }
16
 
17
+ # Display the predefined function buttons
18
+ st.write("Or choose a predefined function:")
19
+ cols = st.columns(4)
20
+ for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
21
+ with cols[i]:
22
+ if st.button(btn_label):
23
+ st.session_state.func_input = func_value
24
+ reset_state() # Ensure that the reset function is called to update the state
25
+ st.rerun() # Re-run to ensure everything updates properly