Mpavan45 commited on
Commit
f216914
·
verified ·
1 Parent(s): 4194c27

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -35
app.py CHANGED
@@ -2,10 +2,9 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Configure the page
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
- st.title("🌟 Gradient Descent Visualizer")
8
- st.markdown("---") # Horizontal separator for cleaner layout
9
 
10
  # Safe function evaluation
11
  def evaluate_function(expression, x_value):
@@ -47,32 +46,41 @@ if "current_index" not in st.session_state:
47
  if "learning_rate" not in st.session_state:
48
  st.session_state.learning_rate = 0.1
49
 
50
- # Layout configuration
51
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
52
 
53
  # Left side content (Function Input and Gradient Descent Parameters)
54
  with left_col:
55
- st.header("Input Your Function")
56
- st.markdown("Define a mathematical function (e.g., `x**2`, `np.sin(x)`, `x**3 - 3*x + 2`):")
57
- function_input = st.text_input("Enter Function:", "x**2 + x", key="math_function", on_change=reset_session_state)
58
- st.markdown("---")
 
 
 
59
 
60
- st.header("Set Parameters for Gradient Descent")
61
- st.markdown("Configure the starting point and learning rate:")
62
  initial_point = st.number_input(
63
- "Initial Value of x", value=4.0, step=0.1, format="%.2f", key="initial_point", on_change=reset_session_state
 
 
 
 
 
64
  )
65
  st.number_input(
66
- "Learning Rate", value=st.session_state.learning_rate, step=0.01, format="%.2f",
 
 
 
67
  key="learning_rate"
68
- )
69
- st.markdown("---")
70
 
71
- # Buttons for controlling steps
72
  if st.button("🔄 Reset"):
73
  reset_session_state()
74
-
75
- if st.button("▶️ Run Descent Step"):
76
  try:
77
  gradient = compute_derivative(function_input, st.session_state.x_current)
78
  st.session_state.x_current -= st.session_state.learning_rate * gradient
@@ -83,28 +91,26 @@ with left_col:
83
  st.session_state.current_index = st.session_state.iter_count
84
  except Exception as e:
85
  st.error(f"Error: {str(e)}")
 
 
 
 
86
 
87
- # Navigation buttons
88
- col1, col2 = st.columns(2)
89
  with col1:
90
- if st.button("⬅️ Previous") and st.session_state.current_index > 0:
91
  st.session_state.current_index -= 1
92
  with col2:
93
- if st.button("➡️ Next") and st.session_state.current_index < st.session_state.iter_count:
 
 
94
  st.session_state.current_index += 1
95
-
96
- # Right side content (Interactive Gradient Descent Visualization)
97
- with right_col:
98
- st.header("Gradient Descent Visualization")
99
 
100
- # Display iteration details at the top of the graph
101
  try:
102
  selected_x, selected_y = st.session_state.history[st.session_state.current_index]
103
- st.subheader("Iteration Details")
104
- st.markdown(f"**Iteration:** {st.session_state.current_index}")
105
- st.markdown(f"**x Value:** {selected_x:.4f}")
106
- st.markdown(f"**f(x):** {selected_y:.4f}")
107
- st.markdown("---")
108
  except IndexError:
109
  st.warning("No iteration data available. Please run a descent step first.")
110
 
@@ -112,17 +118,35 @@ with right_col:
112
  x_range = np.linspace(-10, 10, 500) # Define range for x
113
  y_range = [evaluate_function(function_input, x) for x in x_range]
114
 
115
- # Plot function and gradient descent steps
116
  fig = go.Figure()
117
- fig.add_trace(go.Scatter(x=x_range, y=y_range, mode='lines', name='Function',line=dict(color='orange')))
 
 
 
 
 
 
118
 
119
  # Add current point
120
  x_current, y_current = st.session_state.history[st.session_state.current_index]
121
- fig.add_trace(go.Scatter(x=[x_current], y=[y_current], mode='markers', name='Current Point', marker=dict(size=10, color='red')))
 
 
 
 
 
 
122
 
123
  # Add tangent line
124
  tangent_y = calculate_tangent(function_input, x_current, x_range)
125
- fig.add_trace(go.Scatter(x=x_range, y=tangent_y, mode='lines', name='Tangent Line', line=dict(dash='dash',color='green')))
 
 
 
 
 
 
126
 
127
  # Layout adjustments
128
  fig.update_layout(
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
+ # Title of the app
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
+ st.markdown("## 🌟 Gradient Descent Visualizer")
 
8
 
9
  # Safe function evaluation
10
  def evaluate_function(expression, x_value):
 
46
  if "learning_rate" not in st.session_state:
47
  st.session_state.learning_rate = 0.1
48
 
49
+ # Create two-column grid layout for the left side (more space for the right graph)
50
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
51
 
52
  # Left side content (Function Input and Gradient Descent Parameters)
53
  with left_col:
54
+ st.markdown("### Input Your Function")
55
+ function_input = st.text_input(
56
+ "Enter Function:",
57
+ "x**2 + x",
58
+ key="math_function",
59
+ on_change=reset_session_state
60
+ )
61
 
62
+ st.markdown("### Set Parameters")
 
63
  initial_point = st.number_input(
64
+ "Initial Value of x",
65
+ value=4.0,
66
+ step=0.1,
67
+ format="%.2f",
68
+ key="initial_point",
69
+ on_change=reset_session_state
70
  )
71
  st.number_input(
72
+ "Learning Rate",
73
+ value=st.session_state.learning_rate,
74
+ step=0.01,
75
+ format="%.2f",
76
  key="learning_rate"
77
+ ) # Updates session state directly without reset
 
78
 
79
+ st.markdown("### Controls")
80
  if st.button("🔄 Reset"):
81
  reset_session_state()
82
+
83
+ if st.button("🔄 Run Descent Step", type="primary"):
84
  try:
85
  gradient = compute_derivative(function_input, st.session_state.x_current)
86
  st.session_state.x_current -= st.session_state.learning_rate * gradient
 
91
  st.session_state.current_index = st.session_state.iter_count
92
  except Exception as e:
93
  st.error(f"Error: {str(e)}")
94
+
95
+ # Right side content (Visualization and Iteration Details)
96
+ with right_col:
97
+ st.markdown("### Gradient Descent Visualization")
98
 
99
+ # Display iteration details using buttons
100
+ col1, col2, col3 = st.columns(3)
101
  with col1:
102
+ if st.button("⬅️ Previous Iteration") and st.session_state.current_index > 0:
103
  st.session_state.current_index -= 1
104
  with col2:
105
+ st.markdown(f"**Iteration:** {st.session_state.current_index}", unsafe_allow_html=True)
106
+ with col3:
107
+ if st.button("➡️ Next Iteration") and st.session_state.current_index < st.session_state.iter_count:
108
  st.session_state.current_index += 1
 
 
 
 
109
 
 
110
  try:
111
  selected_x, selected_y = st.session_state.history[st.session_state.current_index]
112
+ st.markdown(f"x Value: `{selected_x:.4f}`")
113
+ st.markdown(f"f(x): `{selected_y:.4f}`")
 
 
 
114
  except IndexError:
115
  st.warning("No iteration data available. Please run a descent step first.")
116
 
 
118
  x_range = np.linspace(-10, 10, 500) # Define range for x
119
  y_range = [evaluate_function(function_input, x) for x in x_range]
120
 
121
+ # Plot function curve with orange color
122
  fig = go.Figure()
123
+ fig.add_trace(go.Scatter(
124
+ x=x_range,
125
+ y=y_range,
126
+ mode='lines',
127
+ name='Function',
128
+ line=dict(color='orange') # Curve color set to orange
129
+ ))
130
 
131
  # Add current point
132
  x_current, y_current = st.session_state.history[st.session_state.current_index]
133
+ fig.add_trace(go.Scatter(
134
+ x=[x_current],
135
+ y=[y_current],
136
+ mode='markers',
137
+ name='Current Point',
138
+ marker=dict(size=10, color='red')
139
+ ))
140
 
141
  # Add tangent line
142
  tangent_y = calculate_tangent(function_input, x_current, x_range)
143
+ fig.add_trace(go.Scatter(
144
+ x=x_range,
145
+ y=tangent_y,
146
+ mode='lines',
147
+ name='Tangent Line',
148
+ line=dict(dash='dash', color='blue') # Tangent line in blue
149
+ ))
150
 
151
  # Layout adjustments
152
  fig.update_layout(