YashsharmaPhD's picture
Update app.py
3bd112b verified
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import streamlit as st
import time
import numpy as np
# Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
class MemristorLSTM(nn.Module):
def __init__(self, memory_size=4):
super(MemristorLSTM, self).__init__()
self.memory_size = memory_size
self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid)
self.gate = nn.Sigmoid() # Gate to modulate the output
self.filter = nn.Tanh() # Non-linearity filter
self.memory = torch.zeros(memory_size, memory_size) # 4x4 memory grid initialized to zeros
def forward(self, x):
lstm_out, _ = self.lstm(x)
output = self.fc(lstm_out[:, -1, :])
gated_output = self.gate(output)
filtered_output = self.filter(gated_output)
self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
return filtered_output, self.memory
# Load Pretrained Model
def load_model(model_path="memristor_lstm.pth"):
model = MemristorLSTM(memory_size=4)
try:
pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Visualize the 4x4 Memory Grid
def visualize_memory_grid(memory_grid):
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest')
ax.set_title("4x4 Memory Grid - VRAM")
ax.set_xlabel("Memory Cells (Columns)")
ax.set_ylabel("Memory Cells (Rows)")
return fig # Return the figure
# Generate Spike Train
def generate_spike_train(length=50, threshold=0.5):
spike_train = np.random.rand(length) > threshold
return spike_train.astype(int)
# Visualize the Spike Train
def visualize_spike_train(spike_train):
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
ax.set_title("Spike Train")
ax.set_xlabel("Time Steps")
ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
ax.set_yticks([0, 1])
return fig # Return the figure
# Generate spikes and VRAM updates
def generate_spikes_for_pressure(pressure):
pressure_normalized = (pressure - 0.1) / (1.0 - 0.1)
pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1)
model = load_model()
if model is not None:
output, memory_grid = model(pressure_input)
spike_train = generate_spike_train(length=50)
return output, memory_grid, spike_train
return None, None, None
# Streamlit UI
def app():
st.title("Memristor VRAM & Spike Train Live Visualization")
st.write("Watch how VRAM and spike trains change over time with pressure input.")
pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
duration = st.radio("Select Duration", [1, 2], index=0)
if st.button("Start Live Visualization"):
st.info(f"Running live update for {duration} minute(s)...")
plot_col1, plot_col2 = st.columns(2) # Side-by-side layout
# Create empty containers to update plots in real-time
plot_vram = plot_col1.empty()
plot_spike_train = plot_col2.empty()
end_time = time.time() + duration * 60 # Run for selected time
while time.time() < end_time:
output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)
if output is not None and memory_grid is not None:
# Update plots dynamically in existing containers
plot_vram.pyplot(visualize_memory_grid(memory_grid))
plot_spike_train.pyplot(visualize_spike_train(spike_train))
time.sleep(0.2) # Update every 0.2 seconds (FAST UPDATES)
if __name__ == "__main__":
app()