YashsharmaPhD commited on
Commit
3bd112b
·
verified ·
1 Parent(s): 78cdd90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -40,15 +40,14 @@ def load_model(model_path="memristor_lstm.pth"):
40
  return None
41
 
42
  # Visualize the 4x4 Memory Grid
43
- def visualize_memory_grid(memory_grid, plot_container):
44
  fig, ax = plt.subplots(figsize=(6, 6))
45
  ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest')
46
  ax.set_title("4x4 Memory Grid - VRAM")
47
  ax.set_xlabel("Memory Cells (Columns)")
48
  ax.set_ylabel("Memory Cells (Rows)")
49
 
50
- plot_container.pyplot(fig) # Streamlit plot
51
- plt.close(fig) # Close the figure to save memory
52
 
53
  # Generate Spike Train
54
  def generate_spike_train(length=50, threshold=0.5):
@@ -56,16 +55,15 @@ def generate_spike_train(length=50, threshold=0.5):
56
  return spike_train.astype(int)
57
 
58
  # Visualize the Spike Train
59
- def visualize_spike_train(spike_train, plot_container):
60
  fig, ax = plt.subplots(figsize=(6, 4))
61
  ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
62
  ax.set_title("Spike Train")
63
  ax.set_xlabel("Time Steps")
64
  ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
65
  ax.set_yticks([0, 1])
66
-
67
- plot_container.pyplot(fig)
68
- plt.close(fig)
69
 
70
  # Generate spikes and VRAM updates
71
  def generate_spikes_for_pressure(pressure):
@@ -92,18 +90,20 @@ def app():
92
 
93
  plot_col1, plot_col2 = st.columns(2) # Side-by-side layout
94
 
 
 
 
 
95
  end_time = time.time() + duration * 60 # Run for selected time
96
  while time.time() < end_time:
97
  output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)
98
 
99
  if output is not None and memory_grid is not None:
100
- with plot_col1:
101
- visualize_memory_grid(memory_grid, st)
102
-
103
- with plot_col2:
104
- visualize_spike_train(spike_train, st)
105
 
106
- time.sleep(1) # Update every second
107
 
108
  if __name__ == "__main__":
109
  app()
 
40
  return None
41
 
42
  # Visualize the 4x4 Memory Grid
43
+ def visualize_memory_grid(memory_grid):
44
  fig, ax = plt.subplots(figsize=(6, 6))
45
  ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest')
46
  ax.set_title("4x4 Memory Grid - VRAM")
47
  ax.set_xlabel("Memory Cells (Columns)")
48
  ax.set_ylabel("Memory Cells (Rows)")
49
 
50
+ return fig # Return the figure
 
51
 
52
  # Generate Spike Train
53
  def generate_spike_train(length=50, threshold=0.5):
 
55
  return spike_train.astype(int)
56
 
57
  # Visualize the Spike Train
58
+ def visualize_spike_train(spike_train):
59
  fig, ax = plt.subplots(figsize=(6, 4))
60
  ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
61
  ax.set_title("Spike Train")
62
  ax.set_xlabel("Time Steps")
63
  ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
64
  ax.set_yticks([0, 1])
65
+
66
+ return fig # Return the figure
 
67
 
68
  # Generate spikes and VRAM updates
69
  def generate_spikes_for_pressure(pressure):
 
90
 
91
  plot_col1, plot_col2 = st.columns(2) # Side-by-side layout
92
 
93
+ # Create empty containers to update plots in real-time
94
+ plot_vram = plot_col1.empty()
95
+ plot_spike_train = plot_col2.empty()
96
+
97
  end_time = time.time() + duration * 60 # Run for selected time
98
  while time.time() < end_time:
99
  output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)
100
 
101
  if output is not None and memory_grid is not None:
102
+ # Update plots dynamically in existing containers
103
+ plot_vram.pyplot(visualize_memory_grid(memory_grid))
104
+ plot_spike_train.pyplot(visualize_spike_train(spike_train))
 
 
105
 
106
+ time.sleep(0.2) # Update every 0.2 seconds (FAST UPDATES)
107
 
108
  if __name__ == "__main__":
109
  app()