| import torch |
| import torch.nn as nn |
| import matplotlib.pyplot as plt |
| import streamlit as st |
| import time |
| import numpy as np |
|
|
| |
| class MemristorLSTM(nn.Module): |
| def __init__(self, memory_size=12, grid_count=3): |
| super(MemristorLSTM, self).__init__() |
| self.memory_size = memory_size |
| self.grid_count = grid_count |
| self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True) |
| self.fc = nn.Linear(50, self.memory_size * self.memory_size * self.grid_count) |
| self.gate = nn.Sigmoid() |
| self.filter = nn.Tanh() |
| |
| |
| self.memory_grids = [torch.zeros(memory_size, memory_size) for _ in range(grid_count)] |
|
|
| def forward(self, x): |
| lstm_out, _ = self.lstm(x) |
| output = self.fc(lstm_out[:, -1, :]) |
| gated_output = self.gate(output) |
| filtered_output = self.filter(gated_output) |
| |
| |
| split_outputs = torch.chunk(filtered_output, self.grid_count, dim=-1) |
| for i in range(self.grid_count): |
| self.memory_grids[i] += split_outputs[i].view(self.memory_size, self.memory_size) |
|
|
| return split_outputs, self.memory_grids |
|
|
| |
| class LIFNeuron: |
| def __init__(self, tau=0.02, v_th=1.0, v_reset=0.0): |
| self.tau = tau |
| self.v_th = v_th |
| self.v_reset = v_reset |
| self.v = 0 |
|
|
| def step(self, input_current): |
| self.v += input_current - (self.v / self.tau) |
| if self.v >= self.v_th: |
| self.v = self.v_reset |
| return 1 |
| return 0 |
|
|
| |
| def load_model(model_path="memristor_lstm.pth", grid_count=3): |
| model = MemristorLSTM(memory_size=12, grid_count=grid_count) |
| try: |
| pretrained_dict = torch.load(model_path, map_location=torch.device('cpu')) |
| model_dict = model.state_dict() |
| pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k} |
| model_dict.update(pretrained_dict) |
| model.load_state_dict(model_dict) |
| model.eval() |
| return model |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| return None |
|
|
| |
| def visualize_multi_grid_vram(memory_grids): |
| fig, axes = plt.subplots(1, len(memory_grids), figsize=(15, 6)) |
| |
| for i, grid in enumerate(memory_grids): |
| ax = axes[i] if len(memory_grids) > 1 else axes |
| ax.imshow(grid.detach().numpy(), cmap='viridis', interpolation='nearest') |
| ax.set_title(f"VRAM Grid {i+1}") |
| ax.set_xlabel("Columns") |
| ax.set_ylabel("Rows") |
| |
| return fig |
|
|
| |
| def generate_lif_spike_train(length=50, input_current=0.5): |
| lif_neuron = LIFNeuron() |
| spike_train = [] |
| membrane_potentials = [] |
| for t in range(length): |
| spike = lif_neuron.step(input_current) |
| spike_train.append(spike) |
| membrane_potentials.append(lif_neuron.v) |
| return spike_train, membrane_potentials |
|
|
| |
| def visualize_lif_spike_train(spike_train, membrane_potentials): |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8)) |
|
|
| |
| ax1.plot(spike_train, marker='o', linestyle='-', color='b', markersize=5) |
| ax1.set_title("LIF Neuron Spike Train") |
| ax1.set_xlabel("Time Steps") |
| ax1.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)") |
| ax1.set_yticks([0, 1]) |
| ax1.grid(True) |
|
|
| |
| ax2.plot(membrane_potentials, color='r', marker='x', linestyle='-', markersize=5) |
| ax2.set_title("Membrane Potential Over Time") |
| ax2.set_xlabel("Time Steps") |
| ax2.set_ylabel("Membrane Potential (mV)") |
| ax2.grid(True) |
|
|
| return fig |
|
|
| |
| def generate_spikes_for_pressure(pressure, grid_count=3): |
| pressure_normalized = (pressure - 0.1) / (1.0 - 0.1) |
| pressure_input = torch.tensor([[[pressure_normalized]]], dtype=torch.float32) |
|
|
| model = load_model(grid_count=grid_count) |
| if model is not None: |
| outputs, memory_grids = model(pressure_input) |
| spike_train, membrane_potentials = generate_lif_spike_train(length=50, input_current=pressure_normalized) |
| return outputs, memory_grids, spike_train, membrane_potentials |
| return None, None, None, None |
|
|
| |
| def app(): |
| st.title("Neuromorphic Multi-Grid VRAM & LIF Spiking Network") |
| st.write("Observe how pressure input affects the VRAM grids, LIF spikes, and membrane potential in real-time.") |
|
|
| pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1) |
| grid_count = st.slider("Number of VRAM Grids", 1, 5, 3) |
| duration = st.radio("Select Duration", [1, 2], index=0) |
|
|
| if st.button("Start Live Visualization"): |
| st.info(f"Running live update for {duration} minute(s)...") |
|
|
| |
| plot_vram = st.empty() |
| plot_spike_train = st.empty() |
|
|
| |
| st.write("### VRAM Grids (Above)") |
| st.write("### LIF Spike Train & Membrane Potential (Below)") |
|
|
| end_time = time.time() + duration * 60 |
| while time.time() < end_time: |
| outputs, memory_grids, spike_train, membrane_potentials = generate_spikes_for_pressure(pressure, grid_count) |
|
|
| if outputs is not None and memory_grids is not None: |
| |
| plot_vram.pyplot(visualize_multi_grid_vram(memory_grids)) |
|
|
| |
| plot_spike_train.pyplot(visualize_lif_spike_train(spike_train, membrane_potentials)) |
|
|
| time.sleep(0.2) |
|
|
| if __name__ == "__main__": |
| app() |
|
|