YashsharmaPhD's picture
Update app.py
5a16c55 verified
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import streamlit as st
import time
import numpy as np
# LSTM Model with 12x12 Multi-Grid VRAM Simulation
class MemristorLSTM(nn.Module):
def __init__(self, memory_size=12, grid_count=3):
super(MemristorLSTM, self).__init__()
self.memory_size = memory_size
self.grid_count = grid_count # Number of VRAM grids
self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
self.fc = nn.Linear(50, self.memory_size * self.memory_size * self.grid_count) # Multi-grid output
self.gate = nn.Sigmoid()
self.filter = nn.Tanh()
# Create multiple 12x12 VRAM grids
self.memory_grids = [torch.zeros(memory_size, memory_size) for _ in range(grid_count)]
def forward(self, x):
lstm_out, _ = self.lstm(x)
output = self.fc(lstm_out[:, -1, :])
gated_output = self.gate(output)
filtered_output = self.filter(gated_output)
# Update multiple VRAM grids
split_outputs = torch.chunk(filtered_output, self.grid_count, dim=-1)
for i in range(self.grid_count):
self.memory_grids[i] += split_outputs[i].view(self.memory_size, self.memory_size)
return split_outputs, self.memory_grids
# LIF Neuron Model (for advanced spiking)
class LIFNeuron:
def __init__(self, tau=0.02, v_th=1.0, v_reset=0.0):
self.tau = tau # Decay rate
self.v_th = v_th # Spiking threshold
self.v_reset = v_reset # Reset voltage after spike
self.v = 0 # Membrane potential
def step(self, input_current):
self.v += input_current - (self.v / self.tau) # Decay equation
if self.v >= self.v_th:
self.v = self.v_reset # Spike and reset
return 1 # Spike occurs
return 0 # No spike
# Load Pretrained Model
def load_model(model_path="memristor_lstm.pth", grid_count=3):
model = MemristorLSTM(memory_size=12, grid_count=grid_count)
try:
pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Visualize Multi-Grid VRAM Network
def visualize_multi_grid_vram(memory_grids):
fig, axes = plt.subplots(1, len(memory_grids), figsize=(15, 6))
for i, grid in enumerate(memory_grids):
ax = axes[i] if len(memory_grids) > 1 else axes
ax.imshow(grid.detach().numpy(), cmap='viridis', interpolation='nearest')
ax.set_title(f"VRAM Grid {i+1}")
ax.set_xlabel("Columns")
ax.set_ylabel("Rows")
return fig # Return the whole figure object
# Generate Spike Train using LIF Neurons (Dynamic to Pressure)
def generate_lif_spike_train(length=50, input_current=0.5):
lif_neuron = LIFNeuron()
spike_train = []
membrane_potentials = []
for t in range(length):
spike = lif_neuron.step(input_current)
spike_train.append(spike)
membrane_potentials.append(lif_neuron.v)
return spike_train, membrane_potentials
# Visualize LIF Spike Train (Real-Time Update)
def visualize_lif_spike_train(spike_train, membrane_potentials):
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))
# Plot Spike Train
ax1.plot(spike_train, marker='o', linestyle='-', color='b', markersize=5)
ax1.set_title("LIF Neuron Spike Train")
ax1.set_xlabel("Time Steps")
ax1.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
ax1.set_yticks([0, 1])
ax1.grid(True)
# Plot Membrane Potentials
ax2.plot(membrane_potentials, color='r', marker='x', linestyle='-', markersize=5)
ax2.set_title("Membrane Potential Over Time")
ax2.set_xlabel("Time Steps")
ax2.set_ylabel("Membrane Potential (mV)")
ax2.grid(True)
return fig # Return the whole figure object
# Generate Multi-Grid Spikes and VRAM Updates
def generate_spikes_for_pressure(pressure, grid_count=3):
pressure_normalized = (pressure - 0.1) / (1.0 - 0.1)
pressure_input = torch.tensor([[[pressure_normalized]]], dtype=torch.float32) # Shape (1,1,1)
model = load_model(grid_count=grid_count)
if model is not None:
outputs, memory_grids = model(pressure_input)
spike_train, membrane_potentials = generate_lif_spike_train(length=50, input_current=pressure_normalized)
return outputs, memory_grids, spike_train, membrane_potentials
return None, None, None, None
# Streamlit UI for Real-time Visualization
def app():
st.title("Neuromorphic Multi-Grid VRAM & LIF Spiking Network")
st.write("Observe how pressure input affects the VRAM grids, LIF spikes, and membrane potential in real-time.")
pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
grid_count = st.slider("Number of VRAM Grids", 1, 5, 3)
duration = st.radio("Select Duration", [1, 2], index=0)
if st.button("Start Live Visualization"):
st.info(f"Running live update for {duration} minute(s)...")
# Create empty containers for real-time plotting
plot_vram = st.empty()
plot_spike_train = st.empty()
# Layout: VRAM plots on top, Spike and Membrane Potentials below
st.write("### VRAM Grids (Above)")
st.write("### LIF Spike Train & Membrane Potential (Below)")
end_time = time.time() + duration * 60
while time.time() < end_time:
outputs, memory_grids, spike_train, membrane_potentials = generate_spikes_for_pressure(pressure, grid_count)
if outputs is not None and memory_grids is not None:
# Update VRAM plots dynamically
plot_vram.pyplot(visualize_multi_grid_vram(memory_grids))
# Update Spike Train and Membrane Potential plot
plot_spike_train.pyplot(visualize_lif_spike_train(spike_train, membrane_potentials))
time.sleep(0.2) # Update every 0.2 seconds
if __name__ == "__main__":
app()