Mpavan45 commited on
Commit
62e27a7
·
verified ·
1 Parent(s): 7f37c44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -53
app.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Title of the app
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
8
  st.markdown("---") # Horizontal separator for cleaner layout
@@ -47,7 +47,7 @@ if "current_index" not in st.session_state:
47
  if "learning_rate" not in st.session_state:
48
  st.session_state.learning_rate = 0.1
49
 
50
- # Create two-column grid layout for the left side (more space for the right graph)
51
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
52
 
53
  # Left side content (Function Input and Gradient Descent Parameters)
@@ -65,17 +65,14 @@ with left_col:
65
  st.number_input(
66
  "Learning Rate", value=st.session_state.learning_rate, step=0.01, format="%.2f",
67
  key="learning_rate"
68
- ) # Updates session state directly without reset
69
  st.markdown("---")
70
 
71
- # Reset Button
72
  if st.button("🔄 Reset"):
73
  reset_session_state()
74
-
75
- # Right side content (Gradient Descent Updates and Progress)
76
- with right_col:
77
- st.header("Gradient Descent Updates")
78
- if st.button("🔄 Run Descent Step", type="primary"):
79
  try:
80
  gradient = compute_derivative(function_input, st.session_state.x_current)
81
  st.session_state.x_current -= st.session_state.learning_rate * gradient
@@ -86,8 +83,8 @@ with right_col:
86
  st.session_state.current_index = st.session_state.iter_count
87
  except Exception as e:
88
  st.error(f"Error: {str(e)}")
89
-
90
- # Navigation Buttons
91
  col1, col2 = st.columns(2)
92
  with col1:
93
  if st.button("⬅️ Previous") and st.session_state.current_index > 0:
@@ -96,7 +93,11 @@ with right_col:
96
  if st.button("➡️ Next") and st.session_state.current_index < st.session_state.iter_count:
97
  st.session_state.current_index += 1
98
 
99
- # Display selected iteration details
 
 
 
 
100
  try:
101
  selected_x, selected_y = st.session_state.history[st.session_state.current_index]
102
  st.subheader("Iteration Details")
@@ -106,53 +107,30 @@ with right_col:
106
  st.markdown("---")
107
  except IndexError:
108
  st.warning("No iteration data available. Please run a descent step first.")
109
-
110
- # Generate plot data
111
- x_vals = np.linspace(-10, 10, 400)
112
- y_vals = [evaluate_function(function_input, x) for x in x_vals]
113
 
114
- # Create the plot
115
- plot = go.Figure()
 
116
 
117
- # Add function plot
118
- plot.add_trace(
119
- go.Scatter(x=x_vals, y=y_vals, mode="lines", line=dict(color="green", width=3), name="Function Curve")
120
- )
121
 
122
- # Add gradient descent points up to the current index
123
- x_points, y_points = zip(*st.session_state.history[:st.session_state.current_index + 1])
124
- plot.add_trace(
125
- go.Scatter(
126
- x=x_points,
127
- y=y_points,
128
- mode="markers",
129
- marker=dict(color="red", size=10, symbol="diamond"),
130
- name="Descent Steps",
131
- )
132
- )
133
 
134
- # Add tangent line at the selected point
135
- tangent_x = np.linspace(selected_x - 2, selected_x + 2, 100)
136
- tangent_y = calculate_tangent(function_input, selected_x, tangent_x)
137
- plot.add_trace(
138
- go.Scatter(
139
- x=tangent_x,
140
- y=tangent_y,
141
- mode="lines",
142
- line=dict(color="blue", width=2, dash="dash"),
143
- name="Tangent Line",
144
- )
145
- )
146
 
147
- # Update plot layout
148
- plot.update_layout(
149
- title="Interactive Gradient Descent with Tangent Visualization",
150
  xaxis_title="x",
151
  yaxis_title="f(x)",
152
- template="plotly_dark",
153
- legend=dict(bgcolor="rgba(255,255,255,0.5)", bordercolor="gray", borderwidth=1),
154
- height=500, # Reduce the graph height for better fitting
155
  )
156
 
157
- # Display the plot in the right side
158
- st.plotly_chart(plot, use_container_width=True)
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Configure the page
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
8
  st.markdown("---") # Horizontal separator for cleaner layout
 
47
  if "learning_rate" not in st.session_state:
48
  st.session_state.learning_rate = 0.1
49
 
50
+ # Layout configuration
51
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
52
 
53
  # Left side content (Function Input and Gradient Descent Parameters)
 
65
  st.number_input(
66
  "Learning Rate", value=st.session_state.learning_rate, step=0.01, format="%.2f",
67
  key="learning_rate"
68
+ )
69
  st.markdown("---")
70
 
71
+ # Buttons for controlling steps
72
  if st.button("🔄 Reset"):
73
  reset_session_state()
74
+
75
+ if st.button("▶️ Run Descent Step"):
 
 
 
76
  try:
77
  gradient = compute_derivative(function_input, st.session_state.x_current)
78
  st.session_state.x_current -= st.session_state.learning_rate * gradient
 
83
  st.session_state.current_index = st.session_state.iter_count
84
  except Exception as e:
85
  st.error(f"Error: {str(e)}")
86
+
87
+ # Navigation buttons
88
  col1, col2 = st.columns(2)
89
  with col1:
90
  if st.button("⬅️ Previous") and st.session_state.current_index > 0:
 
93
  if st.button("➡️ Next") and st.session_state.current_index < st.session_state.iter_count:
94
  st.session_state.current_index += 1
95
 
96
+ # Right side content (Interactive Gradient Descent Visualization)
97
+ with right_col:
98
+ st.header("Gradient Descent Visualization")
99
+
100
+ # Display iteration details at the top of the graph
101
  try:
102
  selected_x, selected_y = st.session_state.history[st.session_state.current_index]
103
  st.subheader("Iteration Details")
 
107
  st.markdown("---")
108
  except IndexError:
109
  st.warning("No iteration data available. Please run a descent step first.")
 
 
 
 
110
 
111
+ # Prepare data for visualization
112
+ x_range = np.linspace(-10, 10, 500) # Define range for x
113
+ y_range = [evaluate_function(function_input, x) for x in x_range]
114
 
115
+ # Plot function and gradient descent steps
116
+ fig = go.Figure()
117
+ fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Function'))
 
118
 
119
+ # Add current point
120
+ x_current, y_current = st.session_state.history[st.session_state.current_index]
121
+ fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Point', marker=dict(size=10, color='red')))
 
 
 
 
 
 
 
 
122
 
123
+ # Add tangent line
124
+ tangent_y = calculate_tangent(function_input, x_current, x_range)
125
+ fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent Line', line=dict(dash='dash')))
 
 
 
 
 
 
 
 
 
126
 
127
+ # Layout adjustments
128
+ fig.update_layout(
129
+ title="Gradient Descent Progress",
130
  xaxis_title="x",
131
  yaxis_title="f(x)",
132
+ template="plotly_white",
133
+ height=600
 
134
  )
135
 
136
+ st.plotly_chart(fig, use_container_width=True)