Mpavan45 commited on
Commit
a43c583
·
verified ·
1 Parent(s): c9d0458

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -2,16 +2,6 @@ import streamlit as st
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
5
- # Initialize session state variables if not already present
6
- if "x_current" not in st.session_state:
7
- st.session_state.x_current = 0.0 # Default starting point
8
- if "iter_count" not in st.session_state:
9
- st.session_state.iter_count = 0
10
- if "history" not in st.session_state:
11
- st.session_state.history = [] # Store (x, f(x)) for each iteration
12
- if "current_index" not in st.session_state:
13
- st.session_state.current_index = 0
14
-
15
  # Title of the app
16
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
17
  st.title("🌟 Gradient Descent Visualizer")
@@ -42,6 +32,16 @@ def reset_session_state():
42
  st.session_state.history = [(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))]
43
  st.session_state.current_index = 0
44
 
 
 
 
 
 
 
 
 
 
 
45
  # Create two-column grid layout for the left side (more space for the right graph)
46
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
47
 
@@ -74,8 +74,9 @@ with right_col:
74
  gradient = compute_derivative(function_input, st.session_state.x_current)
75
  st.session_state.x_current -= learning_rate * gradient
76
  st.session_state.iter_count += 1
77
- current_y = evaluate_function(function_input, st.session_state.x_current)
78
- st.session_state.history.append((st.session_state.x_current, current_y))
 
79
  st.session_state.current_index = st.session_state.iter_count
80
  except Exception as e:
81
  st.error(f"Error: {str(e)}")
@@ -90,12 +91,15 @@ with right_col:
90
  st.session_state.current_index += 1
91
 
92
  # Display selected iteration details
93
- selected_x, selected_y = st.session_state.history[st.session_state.current_index]
94
- st.subheader("Iteration Details")
95
- st.markdown(f"**Iteration:** {st.session_state.current_index}")
96
- st.markdown(f"**x Value:** {selected_x:.4f}")
97
- st.markdown(f"**f(x):** {selected_y:.4f}")
98
- st.markdown("---")
 
 
 
99
 
100
  # Generate plot data
101
  x_vals = np.linspace(-10, 10, 400)
@@ -110,7 +114,8 @@ with right_col:
110
  )
111
 
112
  # Add gradient descent points
113
- x_points, y_points = zip(*st.session_state.history)
 
114
  plot.add_trace(
115
  go.Scatter(
116
  x=x_points,
 
2
  import numpy as np
3
  import plotly.graph_objects as go
4
 
 
 
 
 
 
 
 
 
 
 
5
  # Title of the app
6
  st.set_page_config(page_title="Interactive Gradient Descent Visualizer", layout="wide")
7
  st.title("🌟 Gradient Descent Visualizer")
 
32
  st.session_state.history = [(st.session_state.initial_point, evaluate_function(st.session_state.math_function, st.session_state.initial_point))]
33
  st.session_state.current_index = 0
34
 
35
+ # Initialize session state variables
36
+ if "x_current" not in st.session_state:
37
+ st.session_state.x_current = 0.0 # Default starting point
38
+ if "iter_count" not in st.session_state:
39
+ st.session_state.iter_count = 0
40
+ if "history" not in st.session_state:
41
+ st.session_state.history = [(0.0, evaluate_function("x**2 + x", 0.0))] # Default function example
42
+ if "current_index" not in st.session_state:
43
+ st.session_state.current_index = 0
44
+
45
  # Create two-column grid layout for the left side (more space for the right graph)
46
  left_col, right_col = st.columns([1, 2]) # 1 for left, 2 for right grid proportion
47
 
 
74
  gradient = compute_derivative(function_input, st.session_state.x_current)
75
  st.session_state.x_current -= learning_rate * gradient
76
  st.session_state.iter_count += 1
77
+ st.session_state.history.append(
78
+ (st.session_state.x_current, evaluate_function(function_input, st.session_state.x_current))
79
+ )
80
  st.session_state.current_index = st.session_state.iter_count
81
  except Exception as e:
82
  st.error(f"Error: {str(e)}")
 
91
  st.session_state.current_index += 1
92
 
93
  # Display selected iteration details
94
+ try:
95
+ selected_x, selected_y = st.session_state.history[st.session_state.current_index]
96
+ st.subheader("Iteration Details")
97
+ st.markdown(f"**Iteration:** {st.session_state.current_index}")
98
+ st.markdown(f"**x Value:** {selected_x:.4f}")
99
+ st.markdown(f"**f(x):** {selected_y:.4f}")
100
+ st.markdown("---")
101
+ except IndexError:
102
+ st.warning("No iteration data available. Please run a descent step first.")
103
 
104
  # Generate plot data
105
  x_vals = np.linspace(-10, 10, 400)
 
114
  )
115
 
116
  # Add gradient descent points
117
+ x_points = [point[0] for point in st.session_state.history]
118
+ y_points = [point[1] for point in st.session_state.history]
119
  plot.add_trace(
120
  go.Scatter(
121
  x=x_points,