trohith89 commited on
Commit
1a9168f
·
verified ·
1 Parent(s): 2f3f580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -16
app.py CHANGED
@@ -1,10 +1,31 @@
1
- # Use text input for the user to define a function
2
- func_input = st.text_input(
3
- "Enter a function of 'x':",
4
- st.session_state.func_input,
5
- key="func_input",
6
- on_change=reset_state
7
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  # Predefined function buttons
10
  predefined_functions = {
@@ -14,12 +35,207 @@ predefined_functions = {
14
  "log(x)": "np.log(x)"
15
  }
16
 
17
- # Display the predefined function buttons
18
- st.write("Or choose a predefined function:")
19
- cols = st.columns(4)
20
- for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
21
- with cols[i]:
22
- if st.button(btn_label):
23
- st.session_state.func_input = func_value
24
- reset_state() # Ensure that the reset function is called to update the state
25
- st.rerun() # Re-run to ensure everything updates properly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+
5
+ # Safe function evaluation
6
+ def safe_eval(func_str, x_val):
7
+ """ Safely evaluates the function at a given x value. """
8
+ allowed_names = {"x": x_val, "np": np}
9
+ return eval(func_str, {"__builtins__": None}, allowed_names)
10
+
11
+ # Function derivative using finite difference method
12
+ def derivative(func_str, x_val, h=1e-5):
13
+ """ Numerically compute the derivative of the function at x using finite differences. """
14
+ return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
15
+
16
+ # Tangent line equation
17
+ def tangent_line(func_str, x_val, x_range):
18
+ """ Compute the tangent line at a given x value. """
19
+ y_val = safe_eval(func_str, x_val)
20
+ slope = derivative(func_str, x_val)
21
+ return slope * (x_range - x_val) + y_val
22
+
23
+ # Callback to reset session state
24
+ def reset_state():
25
+ st.session_state.x = st.session_state.starting_point
26
+ st.session_state.iteration = 0
27
+ st.session_state.x_vals = [st.session_state.starting_point]
28
+ st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
29
 
30
  # Predefined function buttons
31
  predefined_functions = {
 
35
  "log(x)": "np.log(x)"
36
  }
37
 
38
+ # Initialize session state variables
39
+ if "func_input" not in st.session_state:
40
+ st.session_state.func_input = "x**2 + x"
41
+ if "x" not in st.session_state:
42
+ st.session_state.x = 4.0
43
+ st.session_state.iteration = 0
44
+ st.session_state.x_vals = [4.0]
45
+ st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
46
+
47
+ # Full-width layout
48
+ st.set_page_config(layout="wide")
49
+
50
+ # CSS Styles for Premium Design
51
+ st.markdown(
52
+ """
53
+ <style>
54
+ /* Header Styling */
55
+ h1, h2, h3, h4, h5 {
56
+ font-family: 'Arial', sans-serif;
57
+ color: #333333;
58
+ text-align: center;
59
+ }
60
+ /* Button Styling */
61
+ .stButton button {
62
+ background-color: #ffffff;
63
+ color: #FF00FF; /* Pink text */
64
+ border-radius: 8px;
65
+ border: 2px solid #000000; /* Black border */
66
+ padding: 10px 20px;
67
+ font-size: 16px;
68
+ font-weight: bold;
69
+ transition: 0.3s;
70
+ width: 100%;
71
+ height: 50px;
72
+ box-sizing: border-box;
73
+ display: flex;
74
+ align-items: center;
75
+ justify-content: center;
76
+ white-space: nowrap; /* Ensures text stays in one line */
77
+ }
78
+ .stButton button:hover {
79
+ background-color: #000000; /* Black background on hover */
80
+ color: #FF00FF; /* Pink text on hover */
81
+ transform: scale(1.05);
82
+ }
83
+ /* Input Box Styling */
84
+ input, select {
85
+ border: 2px solid #ccc;
86
+ padding: 8px;
87
+ border-radius: 5px;
88
+ font-size: 16px;
89
+ font-family: 'Arial', sans-serif;
90
+ }
91
+ /* Tooltip Styling */
92
+ div[role="tooltip"] {
93
+ background-color: #0078D7;
94
+ color: white;
95
+ border-radius: 8px;
96
+ padding: 10px;
97
+ }
98
+ /* Columns for buttons */
99
+ .stButton {
100
+ display: flex;
101
+ flex-direction: column;
102
+ gap: 15px;
103
+ }
104
+ </style>
105
+ """,
106
+ unsafe_allow_html=True,
107
+ )
108
+
109
+ # Page Layout
110
+ with st.container():
111
+ st.header("Gradient Descent Interactive Tool")
112
+
113
+ col1, col2 = st.columns([1, 2])
114
+
115
+ # Left Section: User Input
116
+ with col1:
117
+ st.subheader("Define Your Function")
118
+ func_input = st.text_input("Enter a function of 'x':", st.session_state.func_input, key="func_input", on_change=reset_state)
119
+
120
+ st.write("Or choose a predefined function:")
121
+ cols = st.columns(4) # Adjusted to 4 columns to increase button width
122
+ for i, (btn_label, func_value) in enumerate(predefined_functions.items()):
123
+ with cols[i]:
124
+ if st.button(btn_label):
125
+ st.session_state.func_input = func_value
126
+ reset_state()
127
+ st.rerun() # Re-run to ensure everything updates properly
128
+
129
+ st.subheader("Gradient Descent Parameters")
130
+ starting_point = st.number_input("Starting Point (Xₒ)", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state)
131
+ learning_rate = st.number_input("Learning Rate (ŋ)", value=0.25, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state)
132
+
133
+ col3, col4 = st.columns(2)
134
+ with col3:
135
+ if st.button("Set Up Function"):
136
+ reset_state()
137
+ with col4:
138
+ if st.button("Next Iteration"):
139
+ try:
140
+ grad = derivative(st.session_state.func_input, st.session_state.x)
141
+ st.session_state.x = st.session_state.x - learning_rate * grad
142
+ st.session_state.iteration += 1
143
+ st.session_state.x_vals.append(st.session_state.x)
144
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
145
+ except Exception as e:
146
+ st.error(f"Error: {str(e)}")
147
+
148
+ # Right Section: Visualization
149
+ with col2:
150
+ st.subheader("Gradient Descent Visualization")
151
+
152
+ # Plot the function, gradient descent points, and tangent line
153
+ x_plot = np.linspace(-10, 10, 400)
154
+ y_plot = [safe_eval(st.session_state.func_input, x) for x in x_plot]
155
+
156
+ fig = go.Figure()
157
+
158
+ # Function curve
159
+ fig.add_trace(go.Scatter(
160
+ x=x_plot,
161
+ y=y_plot,
162
+ mode="lines",
163
+ line=dict(color="blue", width=2),
164
+ name="Function"
165
+ ))
166
+
167
+ # Gradient descent point (only show the current point)
168
+ fig.add_trace(go.Scatter(
169
+ x=[st.session_state.x],
170
+ y=[st.session_state.y_vals[-1]],
171
+ mode="markers",
172
+ marker=dict(color="red", size=8),
173
+ name="Current Gradient Descent Point",
174
+ hovertemplate="x: %{x:.2f}<br>y: %{y:.2f}<extra></extra>"
175
+ ))
176
+
177
+ # Tangent line at current point
178
+ current_x = st.session_state.x
179
+ tangent_x = np.linspace(current_x - 5, current_x + 5, 200) # Extended range
180
+ tangent_y = tangent_line(st.session_state.func_input, current_x, tangent_x)
181
+
182
+ fig.add_trace(go.Scatter(
183
+ x=tangent_x,
184
+ y=tangent_y,
185
+ mode="lines",
186
+ line=dict(color="orange", width=3),
187
+ name="Tangent Line"
188
+ ))
189
+
190
+ # Add axis lines
191
+ fig.add_shape(type="line", x0=-10, x1=10, y0=0, y1=0, line=dict(color="black", width=2)) # x-axis
192
+ fig.add_shape(type="line", x0=0, x1=0, y0=-100, y1=100, line=dict(color="black", width=2)) # y-axis
193
+
194
+ # Update layout
195
+ fig.update_layout(
196
+ plot_bgcolor="black",
197
+ xaxis=dict(
198
+ title="x",
199
+ titlefont=dict(color="white"),
200
+ tickfont=dict(color="white"),
201
+ ticks="outside",
202
+ tickcolor="white",
203
+ range=[-10, 10],
204
+ showline=True,
205
+ linecolor="white",
206
+ mirror=True
207
+ ),
208
+ yaxis=dict(
209
+ title="y",
210
+ titlefont=dict(color="white"),
211
+ tickfont=dict(color="white"),
212
+ ticks="outside",
213
+ tickcolor="white",
214
+ range=[-100, 100],
215
+ showline=True,
216
+ linecolor="white",
217
+ mirror=True
218
+ ),
219
+ title=dict(
220
+ text="Gradient Descent Visualization with Tangent Line",
221
+ font=dict(color="white")
222
+ ),
223
+ legend=dict(
224
+ font=dict(color="white"),
225
+ bgcolor="black",
226
+ bordercolor="black",
227
+ borderwidth=1,
228
+ orientation="h",
229
+ xanchor="center",
230
+ x=0.5,
231
+ yanchor="bottom",
232
+ y=1.1
233
+ )
234
+ )
235
+
236
+ st.plotly_chart(fig, use_container_width=True)
237
+
238
+ # Iteration stats
239
+ st.subheader("Iteration Counter")
240
+ st.info(f"Iteration: {st.session_state.iteration}")
241
+ st.success(f"Current x: {st.session_state.x:.4f}\nCurrent f(x): {st.session_state.y_vals[-1]:.4f}")