import torch import torch.nn as nn import matplotlib.pyplot as plt import streamlit as st import time import numpy as np # Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters class MemristorLSTM(nn.Module): def __init__(self, memory_size=4): super(MemristorLSTM, self).__init__() self.memory_size = memory_size self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True) self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid) self.gate = nn.Sigmoid() # Gate to modulate the output self.filter = nn.Tanh() # Non-linearity filter self.memory = torch.zeros(memory_size, memory_size) # 4x4 memory grid initialized to zeros def forward(self, x): lstm_out, _ = self.lstm(x) output = self.fc(lstm_out[:, -1, :]) gated_output = self.gate(output) filtered_output = self.filter(gated_output) self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size) return filtered_output, self.memory # Load Pretrained Model def load_model(model_path="memristor_lstm.pth"): model = MemristorLSTM(memory_size=4) try: pretrained_dict = torch.load(model_path, map_location=torch.device('cpu')) model_dict = model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k} model_dict.update(pretrained_dict) model.load_state_dict(model_dict) model.eval() return model except Exception as e: print(f"Error loading model: {e}") return None # Visualize the 4x4 Memory Grid def visualize_memory_grid(memory_grid): fig, ax = plt.subplots(figsize=(6, 6)) ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest') ax.set_title("4x4 Memory Grid - VRAM") ax.set_xlabel("Memory Cells (Columns)") ax.set_ylabel("Memory Cells (Rows)") return fig # Return the figure # Generate Spike Train def generate_spike_train(length=50, threshold=0.5): spike_train = np.random.rand(length) > threshold return spike_train.astype(int) # Visualize the Spike Train def visualize_spike_train(spike_train): fig, ax = plt.subplots(figsize=(6, 4)) ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5) ax.set_title("Spike Train") ax.set_xlabel("Time Steps") ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)") ax.set_yticks([0, 1]) return fig # Return the figure # Generate spikes and VRAM updates def generate_spikes_for_pressure(pressure): pressure_normalized = (pressure - 0.1) / (1.0 - 0.1) pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1) model = load_model() if model is not None: output, memory_grid = model(pressure_input) spike_train = generate_spike_train(length=50) return output, memory_grid, spike_train return None, None, None # Streamlit UI def app(): st.title("Memristor VRAM & Spike Train Live Visualization") st.write("Watch how VRAM and spike trains change over time with pressure input.") pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1) duration = st.radio("Select Duration", [1, 2], index=0) if st.button("Start Live Visualization"): st.info(f"Running live update for {duration} minute(s)...") plot_col1, plot_col2 = st.columns(2) # Side-by-side layout # Create empty containers to update plots in real-time plot_vram = plot_col1.empty() plot_spike_train = plot_col2.empty() end_time = time.time() + duration * 60 # Run for selected time while time.time() < end_time: output, memory_grid, spike_train = generate_spikes_for_pressure(pressure) if output is not None and memory_grid is not None: # Update plots dynamically in existing containers plot_vram.pyplot(visualize_memory_grid(memory_grid)) plot_spike_train.pyplot(visualize_spike_train(spike_train)) time.sleep(0.2) # Update every 0.2 seconds (FAST UPDATES) if __name__ == "__main__": app()