trohith89 commited on
Commit
925f165
Β·
verified Β·
1 Parent(s): c3ffa81

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -4
app.py CHANGED
@@ -1,6 +1,186 @@
1
- # Gradient Descent Visualization with Updated Axes and Color Scheme
 
 
2
 
3
- # Visualization Section in Streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  with col2:
5
  st.subheader("πŸ“Š Gradient Descent Visualization")
6
  try:
@@ -57,13 +237,12 @@ with col2:
57
  zerolinewidth=2,
58
  showgrid=True,
59
  gridcolor="lightgray",
60
- range=[min(y_plot) - 10, max(y_plot) + 10], # Adjust to show negative y-axis
61
  color="white"
62
  ),
63
  plot_bgcolor="black",
64
  paper_bgcolor="black",
65
  font=dict(color="white"),
66
- title="Gradient Descent Visualization",
67
  width=800,
68
  height=400,
69
  showlegend=True
@@ -73,3 +252,7 @@ with col2:
73
 
74
  except Exception as e:
75
  st.error(f"⚠️ Error in visualization: {str(e)}")
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
 
5
+ # Safe function evaluation
6
+ def safe_eval(func_str, x_val):
7
+ """ Safely evaluates the function at a given x value. """
8
+ allowed_names = {"x": x_val, "np": np}
9
+ try:
10
+ return eval(func_str, {"__builtins__": None}, allowed_names)
11
+ except Exception as e:
12
+ raise ValueError(f"Error evaluating the function: {e}")
13
+
14
+ # Function derivative using finite difference method
15
+ def derivative(func_str, x_val, h=1e-5):
16
+ """ Numerically compute the derivative of the function at x using finite differences. """
17
+ return (safe_eval(func_str, x_val + h) - safe_eval(func_str, x_val - h)) / (2 * h)
18
+
19
+ # Tangent line equation
20
+ def tangent_line(func_str, x_val, x_range):
21
+ """ Compute the tangent line at a given x value. """
22
+ y_val = safe_eval(func_str, x_val)
23
+ slope = derivative(func_str, x_val)
24
+ return slope * (x_range - x_val) + y_val
25
+
26
+ # Callback to reset session state
27
+ def reset_state():
28
+ st.session_state.x = st.session_state.starting_point
29
+ st.session_state.iteration = 0
30
+ st.session_state.x_vals = [st.session_state.starting_point]
31
+ st.session_state.y_vals = [safe_eval(st.session_state.func_input, st.session_state.starting_point)]
32
+
33
+ # Initialize session state variables
34
+ if "func_input" not in st.session_state:
35
+ st.session_state.func_input = "x**2 + x"
36
+ if "x" not in st.session_state:
37
+ st.session_state.x = 4.0
38
+ st.session_state.iteration = 0
39
+ st.session_state.x_vals = [4.0]
40
+ st.session_state.y_vals = [safe_eval(st.session_state.func_input, 4.0)]
41
+
42
+ # Full-width layout
43
+ st.set_page_config(layout="wide")
44
+
45
+ # CSS Styles for Borders, Font, Reduced Padding, and Custom Border Color
46
+ st.markdown(
47
+ """
48
+ <style>
49
+ * {
50
+ font-family: Cambria, Arial, sans-serif !important;
51
+ }
52
+ h1, h2, h3, h4, h5 {
53
+ text-align: center;
54
+ margin-top: 0;
55
+ }
56
+ input, .stButton button, .stDownloadButton button {
57
+ border: 2px solid #ea445a;
58
+ border-radius: 5px;
59
+ padding: 10px;
60
+ }
61
+ .stInfo, .stSuccess {
62
+ border: 2px solid #ea445a;
63
+ border-radius: 5px;
64
+ padding: 10px;
65
+ }
66
+ .stButton {
67
+ margin-top: 10px;
68
+ }
69
+ /* Reduced Padding at the top */
70
+ .css-1d391kg {
71
+ padding-top: 0.5rem;
72
+ }
73
+ /* Centering the legend in the plot */
74
+ .stPlotlyChart {
75
+ display: block;
76
+ margin: 0 auto;
77
+ }
78
+ /* Adjusting for full width without scrolling */
79
+ .css-1lcbvhc {
80
+ padding-left: 0;
81
+ padding-right: 0;
82
+ }
83
+ /* Custom borders for input fields */
84
+ .stTextInput input, .stNumberInput input {
85
+ border: 2px solid #001A6E;
86
+ border-radius: 5px;
87
+ padding: 10px;
88
+ }
89
+ /* Tooltip styling */
90
+ .tooltip {
91
+ position: relative;
92
+ display: inline-block;
93
+ cursor: pointer;
94
+ }
95
+ .tooltip .tooltiptext {
96
+ visibility: hidden;
97
+ opacity: 0;
98
+ width: 300px;
99
+ background-color: #001A6E;
100
+ color: #fff;
101
+ text-align: center;
102
+ border-radius: 5px;
103
+ padding: 5px;
104
+ position: absolute;
105
+ z-index: 1;
106
+ bottom: 125%; /* Position the tooltip above */
107
+ left: 50%;
108
+ margin-left: -150px;
109
+ transition: opacity 0.3s;
110
+ }
111
+ .tooltip:hover .tooltiptext {
112
+ visibility: visible;
113
+ opacity: 1;
114
+ }
115
+ </style>
116
+ """,
117
+ unsafe_allow_html=True,
118
+ )
119
+
120
+ # Page Layout
121
+ st.title("🌟 Gradient Descent Interactive Tool 🌟")
122
+
123
+ col1, col2 = st.columns([1, 2])
124
+
125
+ # Left Section: User Input
126
+ with col1:
127
+ st.subheader("πŸ”§ Define Your Function")
128
+
129
+ st.markdown(
130
+ """
131
+ <div class="tooltip">
132
+ <label for="func_input">Enter a function of 'x':</label>
133
+ <span class="tooltiptext">
134
+ **How to input your function:**
135
+ - x^n as x**n,
136
+ - sin(x) as np.sin(x),
137
+ - log(x) as np.log(x),
138
+ - e^x or exp(x) as np.exp(x).
139
+ </span>
140
+ </div>
141
+ """,
142
+ unsafe_allow_html=True
143
+ )
144
+
145
+ func_input = st.text_input(
146
+ "πŸ‘‡",
147
+ key="func_input",
148
+ on_change=reset_state
149
+ )
150
+
151
+ st.subheader("βš™οΈ Gradient Descent Parameters")
152
+ starting_point = st.number_input(
153
+ "Starting Point (Xβ‚€)",
154
+ value=4.0,
155
+ step=0.1,
156
+ format="%.2f",
157
+ key="starting_point",
158
+ on_change=reset_state
159
+ )
160
+ learning_rate = st.number_input(
161
+ "Learning Rate (Ε‹)",
162
+ value=0.25,
163
+ step=0.01,
164
+ format="%.2f",
165
+ key="learning_rate"
166
+ )
167
+
168
+ col3, col4 = st.columns(2)
169
+ with col3:
170
+ if st.button("πŸ”„ Set Up Function"):
171
+ reset_state()
172
+ with col4:
173
+ if st.button("▢️ Next Iteration"):
174
+ try:
175
+ grad = derivative(st.session_state.func_input, st.session_state.x)
176
+ st.session_state.x = st.session_state.x - learning_rate * grad
177
+ st.session_state.iteration += 1
178
+ st.session_state.x_vals.append(st.session_state.x)
179
+ st.session_state.y_vals.append(safe_eval(st.session_state.func_input, st.session_state.x))
180
+ except Exception as e:
181
+ st.error(f"⚠️ Error: {str(e)}")
182
+
183
+ # Right Section: Visualization
184
  with col2:
185
  st.subheader("πŸ“Š Gradient Descent Visualization")
186
  try:
 
237
  zerolinewidth=2,
238
  showgrid=True,
239
  gridcolor="lightgray",
240
+ range=[min(y_plot) - 10, max(y_plot) + 10],
241
  color="white"
242
  ),
243
  plot_bgcolor="black",
244
  paper_bgcolor="black",
245
  font=dict(color="white"),
 
246
  width=800,
247
  height=400,
248
  showlegend=True
 
252
 
253
  except Exception as e:
254
  st.error(f"⚠️ Error in visualization: {str(e)}")
255
+
256
+ col5, col6 = st.columns(2)
257
+ col5.info(f"πŸ§‘β€πŸ’» Iteration: {st.session_state.iteration}")
258
+ col6.success(f"βœ… Current x: {st.session_state.x:.4f}, Current f(x): {st.session_state.y_vals[-1]:.4f}")