YashsharmaPhD commited on
Commit
78cdd90
·
verified ·
1 Parent(s): 3062fb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -71
app.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  import matplotlib.pyplot as plt
4
  import streamlit as st
5
  import time
 
6
 
7
  # Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
8
  class MemristorLSTM(nn.Module):
@@ -11,46 +12,27 @@ class MemristorLSTM(nn.Module):
11
  self.memory_size = memory_size
12
  self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
13
  self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid)
14
- # Gate to modulate the output
15
- self.gate = nn.Sigmoid()
16
- # Tanh filter to apply non-linearity to the output
17
- self.filter = nn.Tanh()
18
-
19
- # 4x4 memory grid initialized to zeros
20
- self.memory = torch.zeros(memory_size, memory_size)
21
 
22
  def forward(self, x):
23
- # Forward pass through LSTM
24
  lstm_out, _ = self.lstm(x)
25
  output = self.fc(lstm_out[:, -1, :])
26
-
27
- # Apply sigmoid gate to control the flow of information
28
  gated_output = self.gate(output)
29
-
30
- # Apply tanh filter for non-linearity
31
  filtered_output = self.filter(gated_output)
32
-
33
- # Update memory (4x4 grid), simulating synaptic memory flow
34
  self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
35
  return filtered_output, self.memory
36
 
37
- # Load Pretrained Model with Weights Compatibility
38
  def load_model(model_path="memristor_lstm.pth"):
39
  model = MemristorLSTM(memory_size=4)
40
  try:
41
- # Load the pretrained weights
42
  pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
43
-
44
- # Extract the state_dict of the LSTM layers only
45
  model_dict = model.state_dict()
46
-
47
- # Only update the LSTM layers with the pretrained weights, leaving fc layer to be reinitialized
48
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
49
-
50
- # Update the model's state_dict
51
  model_dict.update(pretrained_dict)
52
  model.load_state_dict(model_dict)
53
-
54
  model.eval()
55
  return model
56
  except Exception as e:
@@ -64,75 +46,64 @@ def visualize_memory_grid(memory_grid, plot_container):
64
  ax.set_title("4x4 Memory Grid - VRAM")
65
  ax.set_xlabel("Memory Cells (Columns)")
66
  ax.set_ylabel("Memory Cells (Rows)")
67
- plot_container.pyplot(fig) # Render plot in Streamlit using st.pyplot
 
 
 
 
 
 
 
68
 
69
- # Visualize the Spike Train Output
70
  def visualize_spike_train(spike_train, plot_container):
71
- # Plot the spike train (1's and 0's for spike events)
72
  fig, ax = plt.subplots(figsize=(6, 4))
73
  ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
74
  ax.set_title("Spike Train")
75
  ax.set_xlabel("Time Steps")
76
  ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
77
- ax.set_yticks([0, 1]) # Show only 0 and 1 on the y-axis
 
78
  plot_container.pyplot(fig)
 
79
 
80
- # Generate spikes for the given pressure value
81
- def generate_spikes_for_pressure(pressure, memory_plot_container, spike_plot_container, duration=1):
82
- # Simulate the pressure as the input (scale to appropriate format for LSTM)
83
- # Assuming the pressure range is from 0.1 to 1.0 MPa, and normalize the value
84
- pressure_normalized = (pressure - 0.1) / (1.0 - 0.1) # Normalize between 0 and 1
85
-
86
- # Reshape the pressure value to match LSTM input format: (batch_size, seq_len, input_size)
87
  pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1)
88
-
89
- # Load the pre-trained model
90
- model = load_model() # Ensure you have 'memristor_lstm.pth' in the correct path
91
-
92
- if model is not None:
93
- # Loop for live plotting over the given duration
94
- start_time = time.time()
95
- while time.time() - start_time < duration * 60: # duration in minutes
96
- # Forward pass through the model
97
- output, memory_grid = model(pressure_input)
98
-
99
- # Visualize the memory grid (4x4 grid) on the left plot container
100
- visualize_memory_grid(memory_grid, memory_plot_container)
101
 
102
- # Convert the output into a spike train by thresholding the filtered output
103
- spike_train = (output.detach().numpy() > 0).astype(int) # 1 for spike, 0 for no spike
104
-
105
- # Visualize the spike train on the right plot container
106
- visualize_spike_train(spike_train, spike_plot_container)
107
-
108
- # Delay to simulate real-time plotting (e.g., update every second)
109
- time.sleep(1)
110
 
111
- # Streamlit UI for interacting with pressure input
112
  def app():
113
- st.title("Memristor Response Spike Generator with Pressure Input")
114
- st.write("Generate spike patterns and visualize memory passing in a 4x4 memory grid.")
115
 
116
  pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
117
- duration = st.selectbox("Select Duration", [1, 2], help="Choose duration for the live plot (minutes)")
118
 
119
- if st.button(f"Generate Spikes & Visualize Memory for {pressure} MPa"):
120
- st.info(f"Generating spikes for {pressure} MPa...")
121
 
122
- # Create columns for both plots (left for memory grid, right for spike train)
123
- col1, col2 = st.columns(2)
124
 
125
- # Create plot containers for each column
126
- with col1:
127
- memory_plot_container = st.empty()
128
 
129
- with col2:
130
- spike_plot_container = st.empty()
 
131
 
132
- # Generate spikes and memory updates for the given pressure
133
- generate_spikes_for_pressure(pressure, memory_plot_container, spike_plot_container, duration)
134
 
135
- st.success(f"Spike generation and memory passing complete for {pressure} MPa!")
136
 
137
  if __name__ == "__main__":
138
  app()
 
3
  import matplotlib.pyplot as plt
4
  import streamlit as st
5
  import time
6
+ import numpy as np
7
 
8
  # Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
9
  class MemristorLSTM(nn.Module):
 
12
  self.memory_size = memory_size
13
  self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
14
  self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid)
15
+ self.gate = nn.Sigmoid() # Gate to modulate the output
16
+ self.filter = nn.Tanh() # Non-linearity filter
17
+ self.memory = torch.zeros(memory_size, memory_size) # 4x4 memory grid initialized to zeros
 
 
 
 
18
 
19
  def forward(self, x):
 
20
  lstm_out, _ = self.lstm(x)
21
  output = self.fc(lstm_out[:, -1, :])
 
 
22
  gated_output = self.gate(output)
 
 
23
  filtered_output = self.filter(gated_output)
 
 
24
  self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
25
  return filtered_output, self.memory
26
 
27
+ # Load Pretrained Model
28
  def load_model(model_path="memristor_lstm.pth"):
29
  model = MemristorLSTM(memory_size=4)
30
  try:
 
31
  pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
 
 
32
  model_dict = model.state_dict()
 
 
33
  pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
 
 
34
  model_dict.update(pretrained_dict)
35
  model.load_state_dict(model_dict)
 
36
  model.eval()
37
  return model
38
  except Exception as e:
 
46
  ax.set_title("4x4 Memory Grid - VRAM")
47
  ax.set_xlabel("Memory Cells (Columns)")
48
  ax.set_ylabel("Memory Cells (Rows)")
49
+
50
+ plot_container.pyplot(fig) # Streamlit plot
51
+ plt.close(fig) # Close the figure to save memory
52
+
53
+ # Generate Spike Train
54
+ def generate_spike_train(length=50, threshold=0.5):
55
+ spike_train = np.random.rand(length) > threshold
56
+ return spike_train.astype(int)
57
 
58
+ # Visualize the Spike Train
59
  def visualize_spike_train(spike_train, plot_container):
 
60
  fig, ax = plt.subplots(figsize=(6, 4))
61
  ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
62
  ax.set_title("Spike Train")
63
  ax.set_xlabel("Time Steps")
64
  ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
65
+ ax.set_yticks([0, 1])
66
+
67
  plot_container.pyplot(fig)
68
+ plt.close(fig)
69
 
70
+ # Generate spikes and VRAM updates
71
+ def generate_spikes_for_pressure(pressure):
72
+ pressure_normalized = (pressure - 0.1) / (1.0 - 0.1)
 
 
 
 
73
  pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1)
74
+ model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ if model is not None:
77
+ output, memory_grid = model(pressure_input)
78
+ spike_train = generate_spike_train(length=50)
79
+ return output, memory_grid, spike_train
80
+ return None, None, None
 
 
 
81
 
82
+ # Streamlit UI
83
  def app():
84
+ st.title("Memristor VRAM & Spike Train Live Visualization")
85
+ st.write("Watch how VRAM and spike trains change over time with pressure input.")
86
 
87
  pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
88
+ duration = st.radio("Select Duration", [1, 2], index=0)
89
 
90
+ if st.button("Start Live Visualization"):
91
+ st.info(f"Running live update for {duration} minute(s)...")
92
 
93
+ plot_col1, plot_col2 = st.columns(2) # Side-by-side layout
 
94
 
95
+ end_time = time.time() + duration * 60 # Run for selected time
96
+ while time.time() < end_time:
97
+ output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)
98
 
99
+ if output is not None and memory_grid is not None:
100
+ with plot_col1:
101
+ visualize_memory_grid(memory_grid, st)
102
 
103
+ with plot_col2:
104
+ visualize_spike_train(spike_train, st)
105
 
106
+ time.sleep(1) # Update every second
107
 
108
  if __name__ == "__main__":
109
  app()