trohith89 commited on
Commit
aee509d
·
verified ·
1 Parent(s): d130e9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -293
app.py CHANGED
@@ -1,293 +1,170 @@
1
- # Import necessary libraries
2
- import streamlit as st
3
- import numpy as np
4
- import plotly.graph_objects as go
5
- import math
6
-
7
- # Full-page layout
8
- st.set_page_config(layout="wide", page_title="Gradient Descent Visualizer")
9
-
10
- # Main Title
11
- st.title("")
12
- st.title("Gradient Descent Visualizer")
13
-
14
- # CSS for full-page layout and styling (no scrollbars)
15
- st.markdown("""
16
- <style>
17
- body {
18
- font-family: 'serif'; /* Serif font for a mathematical feel */
19
- background-color: #161748; /* Dark background */
20
- color: white;
21
- width:100%:
22
- height:100%;
23
- overflow: hidden; /* Hide scrollbars */
24
- }
25
- .block-container {
26
- padding: 1rem; /* Padding for page container */
27
- margin: 0; /* Remove margin */
28
- max-width: 100%; /* Full page width */
29
- }
30
- .stButton>button {
31
- background-color: #000000;
32
- color: #ff5e6c;
33
- border-radius: 8px;
34
- border: 2px solid #dbb6ee;
35
- }
36
- .stTextInput>div>div>input {
37
- color: white;
38
- background-color: #161748;
39
- # border: 2px solid #dbb6ee;
40
- border-radius: 8px;
41
- }
42
- .stNumberInput>div>div>input {
43
- color: white;
44
- background-color: #161748;
45
- border: 2px solid #dbb6ee;
46
- border-radius: 8px;
47
- }
48
- .stPlotlyChart {
49
- border: 2px solid #dbb6ee;
50
- border-radius: 15px;
51
- margin: 0;
52
- padding: 0;
53
- }
54
- .iteration-info {
55
- color: black;
56
- font-size: 18px;
57
- font-weight: bold;
58
- background-color: #39a0ca;
59
- padding: 6px;
60
- border-radius: 8px;
61
- display: inline-block;
62
- }
63
- </style>
64
- """, unsafe_allow_html=True)
65
-
66
- # Divide the layout into two columns
67
- left_col, right_col = st.columns(2)
68
-
69
- # Left column for inputs and buttons
70
- with left_col:
71
- st.markdown("<div class='component-container'></div>", unsafe_allow_html=True) # Border for input section
72
- st.markdown("## Function")
73
-
74
- if 'text_input_value' not in st.session_state:
75
- st.session_state.text_input_value = "x**2 + 3*x + 5"
76
-
77
- # Function buttons
78
- st.write("Functions you should try (click to auto format):")
79
- col1, col2, col3, col4, col5 = st.columns(5)
80
- with col1:
81
- if st.button("x^2", key="x2"):
82
- st.session_state.text_input_value = "x**2"
83
- with col2:
84
- if st.button("x^3", key="x3"):
85
- st.session_state.text_input_value = "x**3"
86
- with col3:
87
- if st.button("sin(x)", key="sinx"):
88
- st.session_state.text_input_value = "math.sin(x)"
89
- with col4:
90
- if st.button("sin(1/x)", key="sin1x"):
91
- st.session_state.text_input_value = "math.sin(1/x)"
92
- with col5:
93
- if st.button("log(x)", key="logx"):
94
- st.session_state.text_input_value = "math.log(x)"
95
-
96
- # Custom function input
97
- st.text_input("## Enter a function of your choice :", value=st.session_state.text_input_value, key="text_input")
98
-
99
- # Starting point input
100
- start_point = st.number_input("## Start point :", value=2)
101
-
102
- # Learning rate input
103
- learn_rate = st.number_input("## Learning Rate (η) :", value=0.25)
104
-
105
- # Setup button
106
- if st.button("Set Up"):
107
- st.session_state.iteration = 0
108
- st.session_state.theta_history = [start_point]
109
- st.session_state.current_fn = st.session_state.text_input_value
110
- st.write("Setup complete! Click 'Next Iteration' to start.")
111
-
112
- # Gradient descent function with error handling
113
- def gradient_descent(fn, start_point, learning_rate, num_iterations):
114
- theta = start_point
115
- theta_history = [theta]
116
-
117
- # Define function gradients manually
118
- def get_gradient(fn, x):
119
- epsilon = 1e-6
120
- try:
121
- if "x**2" in fn:
122
- return 2 * x # derivative of x^2
123
- elif "x**3" in fn:
124
- return 3 * x**2 # derivative of x^3
125
- elif "sin(x)" in fn:
126
- return math.cos(x) # derivative of sin(x)
127
- elif "sin(1/x)" in fn:
128
- return -math.cos(1/x) / (x**2) # derivative of sin(1/x)
129
- elif "log(x)" in fn:
130
- return 1 / x # derivative of log(x)
131
- else:
132
- return 0 # default to 0 if function is unsupported
133
- except:
134
- return 0 # Handle undefined behavior
135
-
136
- for _ in range(num_iterations):
137
- gradient = get_gradient(fn, theta)
138
- theta = theta - learning_rate * gradient
139
- if abs(theta) > 1e10:
140
- theta = np.sign(theta) * 1e10
141
- theta_history.append(theta)
142
-
143
- return theta_history
144
-
145
- def plot(fn, theta_history, iteration):
146
- # Convert history to float values
147
- theta_history = [float(theta) for theta in theta_history]
148
- if not theta_history:
149
- st.write("No iterations yet. Please click 'Next Iteration'.")
150
- return
151
-
152
- x = np.linspace(-10, 10, 100)
153
- y = []
154
-
155
- # Handle edge cases for invalid function evaluations
156
- for i in x:
157
- try:
158
- if "x**2" in fn:
159
- y.append(i**2)
160
- elif "x**3" in fn:
161
- y.append(i**3)
162
- elif "sin(x)" in fn:
163
- y.append(math.sin(i))
164
- elif "sin(1/x)" in fn:
165
- if i != 0:
166
- y.append(math.sin(1/i))
167
- else:
168
- y.append(np.nan)
169
- elif "log(x)" in fn:
170
- if i > 0:
171
- y.append(math.log(i))
172
- else:
173
- y.append(np.nan)
174
- else:
175
- y.append(np.nan)
176
- except:
177
- y.append(np.nan)
178
-
179
- # Remove NaN values from x and y
180
- x_valid = x[~np.isnan(y)]
181
- y_valid = np.array(y)[~np.isnan(y)]
182
-
183
- last_theta = theta_history[-1]
184
- meeting_y = None
185
- try:
186
- meeting_y = eval(fn.replace('x', str(last_theta))) if 'x' in fn else 0
187
- except:
188
- pass
189
-
190
- # Numerical derivative using central difference
191
- epsilon = 1e-6
192
- try:
193
- derivative = (eval(fn.replace('x', str(last_theta + epsilon))) - eval(fn.replace('x', str(last_theta - epsilon)))) / (2 * epsilon)
194
- except:
195
- derivative = 0
196
- slope = derivative
197
- intercept = meeting_y - slope * last_theta if meeting_y is not None else 0
198
- tangent_y = slope * x_valid + intercept
199
-
200
- fig = go.Figure(data=[
201
- # Function Line
202
- go.Scatter(x=x_valid, y=y_valid, mode='lines', name='Function',
203
- line=dict(color='blue')),
204
- # Gradient Descent Points
205
- go.Scatter(x=theta_history,
206
- y=[eval(fn.replace('x', str(theta))) for theta in theta_history],
207
- mode='markers', name='Gradient Descent',
208
- marker=dict(color='red', size=10)), # All points are red
209
- # Tangent Line
210
- go.Scatter(x=x_valid, y=tangent_y, mode='lines', name='Tangent',
211
- line=dict(color='orange')),
212
- # Tangent Point (Red)
213
- go.Scatter(x=[last_theta], y=[meeting_y], mode='markers', name='Tangent Point',
214
- marker=dict(color='red', size=12))
215
- ])
216
-
217
- # Update layout for styling
218
- fig.update_layout(
219
- annotations=[
220
- dict(
221
- xref='paper', yref='paper', x=0.05, y=0.1,
222
- xanchor='left', yanchor='bottom',
223
- text=f"<b>Next Iteration: {iteration}</b>",
224
- showarrow=False,
225
- font=dict(size=20, color='black'),
226
- bgcolor="#f95d9b", borderpad=5, bordercolor="black", borderwidth=2
227
- ),
228
- dict(
229
- xref='paper', yref='paper', x=1, y=0,
230
- xanchor='right', yanchor='bottom',
231
- text=f"Current Point: ({last_theta:.6f}, {meeting_y if meeting_y is not None else 'N/A'})",
232
- showarrow=False,
233
- font=dict(size=14, color='black'),
234
- bgcolor="#39a0ca", borderpad=5, bordercolor="black", borderwidth=2
235
- )
236
- ],
237
- xaxis_title='x-axis',
238
- yaxis_title='y-axis',
239
- hovermode='x unified',
240
- xaxis=dict(
241
- range=[-10, 10],
242
- showgrid=True, gridcolor='black',
243
- titlefont=dict(color='black'),
244
- tickfont=dict(color='black') # Make x-axis numbers black
245
- ),
246
- yaxis=dict(
247
- range=[-10, 10],
248
- showgrid=True, gridcolor='black',
249
- titlefont=dict(color='black'),
250
- tickfont=dict(color='black') # Make y-axis numbers black
251
- ),
252
- paper_bgcolor='white', # White background
253
- plot_bgcolor='white', # White plot background
254
- legend=dict(
255
- yanchor='top', xanchor='right', x=1, y=0.99,
256
- font=dict(color='black')
257
- ),
258
- title="Gradient Descent Visualization", titlefont=dict(color='black')
259
- )
260
-
261
- # Display the plot
262
- st.plotly_chart(fig, use_container_width=True, config={'displayModeBar': False})
263
-
264
- return last_theta, meeting_y
265
-
266
- def main():
267
- with right_col:
268
- if 'iteration' not in st.session_state:
269
- st.session_state.iteration = 0
270
- st.session_state.theta_history = [start_point]
271
- st.session_state.current_fn = st.session_state.text_input_value
272
-
273
- theta_history = st.session_state.theta_history
274
- iteration = st.session_state.iteration
275
- current_fn = st.session_state.current_fn
276
-
277
- if st.button("Next Iteration", key="next_iter"):
278
- iteration += 1
279
- theta_history = gradient_descent(current_fn, start_point, learn_rate, iteration)
280
- st.session_state.iteration = iteration
281
- st.session_state.theta_history = theta_history
282
-
283
- # Plot the function and gradient descent
284
- last_theta, meeting_y = plot(current_fn, theta_history, iteration)
285
-
286
- # Display iteration and point details
287
- st.markdown(f"## Iteration: {int(iteration)}")
288
- st.markdown(f"The tangent is meeting the plot at point **({last_theta}, {meeting_y if meeting_y is not None else 'N/A'})**")
289
-
290
- # Run the app
291
- if __name__ == "__main__":
292
- main()
293
-
 
1
+ Rohith Ramdass
2
+ rohith.r18
3
+ Idle
4
+
5
+ Kande Chandrika_IN1241221 — Today at 14:47
6
+ hi
7
+ Rohith Ramdass — Today at 14:48
8
+ hi
9
+ Kande Chandrika_IN1241221 — Today at 14:48
10
+ wr r u
11
+ Rohith Ramdass — Today at 14:48
12
+ i can se u
13
+ Kande Chandrika_IN1241221 — Today at 14:48
14
+ can u send ur space with code also
15
+ Rohith Ramdass — Today at 14:48
16
+ there u have place?
17
+ Kande Chandrika_IN1241221 — Today at 14:48
18
+ haa
19
+ Rohith Ramdass Today at 14:49
20
+ cmg
21
+ Kande Chandrika_IN1241221 — Today at 14:49
22
+ ok
23
+ Rohith Ramdass Today at 15:18
24
+ import cv2
25
+ import numpy as np
26
+
27
+ # Create a window and set its callback function
28
+ cv2.namedWindow('Painting')
29
+ cv2.setMouseCallback('Painting', lambda event, x, y, flags, param: mouse_event(event, x, y, flags, param))
30
+ Expand
31
+ message.txt
32
+ 4 KB
33
+ Rohith Ramdass — Today at 15:33
34
+ import streamlit as st
35
+
36
+ # Title
37
+ st.title("Machine Learning Project")
38
+
39
+ # header
40
+ Expand
41
+ app1.py
42
+ 4 KB
43
+ Kande Chandrika_IN1241221 — Today at 20:29
44
+ import streamlit as st
45
+ import numpy as np
46
+ import plotly.graph_objects as go
47
+
48
+ # Title of the app
49
+ st.title("Gradient Descent Visualizer with Tangent Lines")
50
+ Expand
51
+ message.txt
52
+ 5 KB
53
+ Rohith Ramdass — Today at 20:43
54
+ Thanks
55
+ 
56
+ Kande Chandrika_IN1241221
57
+ k.chandrika.
58
+ import streamlit as st
59
+ import numpy as np
60
+ import plotly.graph_objects as go
61
+
62
+ # Title of the app
63
+ st.title("Gradient Descent Visualizer with Tangent Lines")
64
+
65
+ # Safe function evaluation
66
+ def safe_eval(func_str, x_val):
67
+ """ Safely evaluates the function at a given x value. """
68
+ allowed_names = {"x": x_val, "np": np} # Only allow x and numpy
69
+ return eval(func_str, {"__builtins__": None}, allowed_names)
70
+
71
+ # Function derivative using finite difference method
72
+ def derivative(func_str, x_val, h=1e-5):
73
+ """ Numerically compute the derivative of the function at x using finite differences. """
74
+ return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
75
+
76
+ # Tangent Line Equation
77
+ def tangent_line(func_str, x_val, x_range):
78
+ """ Compute the tangent line at a given x value. """
79
+ y_val = safe_eval(func_str, x_val)
80
+ slope = derivative(func_str, x_val)
81
+ return slope * (x_range - x_val) + y_val
82
+
83
+ # Callback to reset session state
84
+ def reset_state():
85
+ st.session_state.x = st.session_state.starting_point
86
+ st.session_state.iteration = 0
87
+ st.session_state.x_vals = [st.session_state.starting_point]
88
+ st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
89
+
90
+ # Function input
91
+ st.header("Define Your Function")
92
+ func_input = st.text_input("Enter a function of 'x' (e.g., x**2 + x, sin(x), x**3 - 3*x + 2):", "x**2 + x", key="func_input", on_change=reset_state)
93
+
94
+ # Starting Point and Learning Rate
95
+ st.header("Gradient Descent Parameters")
96
+ starting_point = st.number_input("Starting Point", value=4.0, step=0.1, format="%.2f", key="starting_point", on_change=reset_state)
97
+ learning_rate = st.number_input("Learning Rate", value=0.1, step=0.01, format="%.2f", key="learning_rate", on_change=reset_state)
98
+
99
+ # Initialize session state variables if they don't exist
100
+ if "x" not in st.session_state:
101
+ st.session_state.x = starting_point
102
+ st.session_state.iteration = 0
103
+ st.session_state.x_vals = [starting_point]
104
+ st.session_state.y_vals = [safe_eval(func_input, starting_point)]
105
+
106
+ # "Next Iteration" button logic
107
+ if st.button("Next Iteration"):
108
+ try:
109
+ # Perform one iteration of gradient descent
110
+ grad = derivative(func_input, st.session_state.x)
111
+ st.session_state.x = st.session_state.x - learning_rate * grad
112
+ st.session_state.iteration += 1
113
+
114
+ # Save the new x and y values
115
+ st.session_state.x_vals.append(st.session_state.x)
116
+ st.session_state.y_vals.append(safe_eval(func_input, st.session_state.x))
117
+ except Exception as e:
118
+ st.error(f"Error: {str(e)}")
119
+
120
+ # Display iteration results
121
+ st.subheader("Gradient Descent Progress")
122
+ st.write(f"Iteration: {st.session_state.iteration}")
123
+ st.write(f"Current x: {st.session_state.x:.4f}")
124
+ st.write(f"Current f(x): {st.session_state.y_vals[-1]:.4f}")
125
+
126
+ # Plot the function, gradient descent points, and tangent line
127
+ x_plot = np.linspace(-10, 10, 400)
128
+ y_plot = [safe_eval(func_input, x) for x in x_plot]
129
+
130
+ fig = go.Figure()
131
+
132
+ # Add the function curve
133
+ fig.add_trace(go.Scatter(x=x_plot, y=y_plot, mode="lines", name="Function"))
134
+
135
+ # Add gradient descent points in red
136
+ fig.add_trace(go.Scatter(
137
+ x=st.session_state.x_vals,
138
+ y=st.session_state.y_vals,
139
+ mode="markers",
140
+ marker=dict(color="red", size=8),
141
+ name="Gradient Descent Points"
142
+ ))
143
+
144
+ # Add the tangent line at the current point
145
+ current_x = st.session_state.x
146
+ current_y = safe_eval(func_input, current_x)
147
+ slope = derivative(func_input, current_x)
148
+
149
+ # Generate tangent line range
150
+ tangent_x = np.linspace(current_x - 2, current_x + 2, 100)
151
+ tangent_y = tangent_line(func_input, current_x, tangent_x)
152
+
153
+ # Plot the tangent line as a straight solid line
154
+ fig.add_trace(go.Scatter(
155
+ x=tangent_x,
156
+ y=tangent_y,
157
+ mode="lines",
158
+ line=dict(color="orange", width=3),
159
+ name="Tangent Line"
160
+ ))
161
+
162
+ # Update layout
163
+ fig.update_layout(
164
+ xaxis_title="x",
165
+ yaxis_title="f(x)",
166
+ title="Gradient Descent Visualization with Tangent Line"
167
+ )
168
+
169
+ # Display the plot
170
+ st.plotly_chart(fig)