Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| import streamlit as st | |
| import time | |
| import numpy as np | |
| # Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters | |
| class MemristorLSTM(nn.Module): | |
| def __init__(self, memory_size=4): | |
| super(MemristorLSTM, self).__init__() | |
| self.memory_size = memory_size | |
| self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True) | |
| self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid) | |
| self.gate = nn.Sigmoid() # Gate to modulate the output | |
| self.filter = nn.Tanh() # Non-linearity filter | |
| self.memory = torch.zeros(memory_size, memory_size) # 4x4 memory grid initialized to zeros | |
| def forward(self, x): | |
| lstm_out, _ = self.lstm(x) | |
| output = self.fc(lstm_out[:, -1, :]) | |
| gated_output = self.gate(output) | |
| filtered_output = self.filter(gated_output) | |
| self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size) | |
| return filtered_output, self.memory | |
| # Load Pretrained Model | |
| def load_model(model_path="memristor_lstm.pth"): | |
| model = MemristorLSTM(memory_size=4) | |
| try: | |
| pretrained_dict = torch.load(model_path, map_location=torch.device('cpu')) | |
| model_dict = model.state_dict() | |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k} | |
| model_dict.update(pretrained_dict) | |
| model.load_state_dict(model_dict) | |
| model.eval() | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| # Visualize the 4x4 Memory Grid | |
| def visualize_memory_grid(memory_grid): | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest') | |
| ax.set_title("4x4 Memory Grid - VRAM") | |
| ax.set_xlabel("Memory Cells (Columns)") | |
| ax.set_ylabel("Memory Cells (Rows)") | |
| return fig # Return the figure | |
| # Generate Spike Train | |
| def generate_spike_train(length=50, threshold=0.5): | |
| spike_train = np.random.rand(length) > threshold | |
| return spike_train.astype(int) | |
| # Visualize the Spike Train | |
| def visualize_spike_train(spike_train): | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5) | |
| ax.set_title("Spike Train") | |
| ax.set_xlabel("Time Steps") | |
| ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)") | |
| ax.set_yticks([0, 1]) | |
| return fig # Return the figure | |
| # Generate spikes and VRAM updates | |
| def generate_spikes_for_pressure(pressure): | |
| pressure_normalized = (pressure - 0.1) / (1.0 - 0.1) | |
| pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1) | |
| model = load_model() | |
| if model is not None: | |
| output, memory_grid = model(pressure_input) | |
| spike_train = generate_spike_train(length=50) | |
| return output, memory_grid, spike_train | |
| return None, None, None | |
| # Streamlit UI | |
| def app(): | |
| st.title("Memristor VRAM & Spike Train Live Visualization") | |
| st.write("Watch how VRAM and spike trains change over time with pressure input.") | |
| pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1) | |
| duration = st.radio("Select Duration", [1, 2], index=0) | |
| if st.button("Start Live Visualization"): | |
| st.info(f"Running live update for {duration} minute(s)...") | |
| plot_col1, plot_col2 = st.columns(2) # Side-by-side layout | |
| # Create empty containers to update plots in real-time | |
| plot_vram = plot_col1.empty() | |
| plot_spike_train = plot_col2.empty() | |
| end_time = time.time() + duration * 60 # Run for selected time | |
| while time.time() < end_time: | |
| output, memory_grid, spike_train = generate_spikes_for_pressure(pressure) | |
| if output is not None and memory_grid is not None: | |
| # Update plots dynamically in existing containers | |
| plot_vram.pyplot(visualize_memory_grid(memory_grid)) | |
| plot_spike_train.pyplot(visualize_spike_train(spike_train)) | |
| time.sleep(0.2) # Update every 0.2 seconds (FAST UPDATES) | |
| if __name__ == "__main__": | |
| app() | |