Spaces:
Sleeping
Sleeping
File size: 4,273 Bytes
f2716e0 ce87e89 cd6adf8 78cdd90 f2716e0 e82e970 7f1acfb a2b1856 78cdd90 7f1acfb f2716e0 7f1acfb e82e970 f2716e0 78cdd90 ce87e89 7f1acfb a2b1856 ce87e89 7f1acfb ce87e89 7f1acfb ce87e89 3bd112b 7f1acfb ce87e89 7f1acfb 78cdd90 3bd112b 78cdd90 ce87e89 78cdd90 3bd112b 08a68c4 cd6adf8 78cdd90 3bd112b 08a68c4 78cdd90 ce87e89 78cdd90 ce87e89 78cdd90 ce87e89 78cdd90 7f1acfb 78cdd90 7f1acfb ce87e89 78cdd90 cd6adf8 78cdd90 7f1acfb 78cdd90 3062fb3 3bd112b 78cdd90 3062fb3 78cdd90 3bd112b 7f1acfb 3bd112b 7f1acfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import streamlit as st
import time
import numpy as np
# Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
class MemristorLSTM(nn.Module):
def __init__(self, memory_size=4):
super(MemristorLSTM, self).__init__()
self.memory_size = memory_size
self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values (4x4 grid)
self.gate = nn.Sigmoid() # Gate to modulate the output
self.filter = nn.Tanh() # Non-linearity filter
self.memory = torch.zeros(memory_size, memory_size) # 4x4 memory grid initialized to zeros
def forward(self, x):
lstm_out, _ = self.lstm(x)
output = self.fc(lstm_out[:, -1, :])
gated_output = self.gate(output)
filtered_output = self.filter(gated_output)
self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
return filtered_output, self.memory
# Load Pretrained Model
def load_model(model_path="memristor_lstm.pth"):
model = MemristorLSTM(memory_size=4)
try:
pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.eval()
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Visualize the 4x4 Memory Grid
def visualize_memory_grid(memory_grid):
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest')
ax.set_title("4x4 Memory Grid - VRAM")
ax.set_xlabel("Memory Cells (Columns)")
ax.set_ylabel("Memory Cells (Rows)")
return fig # Return the figure
# Generate Spike Train
def generate_spike_train(length=50, threshold=0.5):
spike_train = np.random.rand(length) > threshold
return spike_train.astype(int)
# Visualize the Spike Train
def visualize_spike_train(spike_train):
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
ax.set_title("Spike Train")
ax.set_xlabel("Time Steps")
ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
ax.set_yticks([0, 1])
return fig # Return the figure
# Generate spikes and VRAM updates
def generate_spikes_for_pressure(pressure):
pressure_normalized = (pressure - 0.1) / (1.0 - 0.1)
pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1)
model = load_model()
if model is not None:
output, memory_grid = model(pressure_input)
spike_train = generate_spike_train(length=50)
return output, memory_grid, spike_train
return None, None, None
# Streamlit UI
def app():
st.title("Memristor VRAM & Spike Train Live Visualization")
st.write("Watch how VRAM and spike trains change over time with pressure input.")
pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
duration = st.radio("Select Duration", [1, 2], index=0)
if st.button("Start Live Visualization"):
st.info(f"Running live update for {duration} minute(s)...")
plot_col1, plot_col2 = st.columns(2) # Side-by-side layout
# Create empty containers to update plots in real-time
plot_vram = plot_col1.empty()
plot_spike_train = plot_col2.empty()
end_time = time.time() + duration * 60 # Run for selected time
while time.time() < end_time:
output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)
if output is not None and memory_grid is not None:
# Update plots dynamically in existing containers
plot_vram.pyplot(visualize_memory_grid(memory_grid))
plot_spike_train.pyplot(visualize_spike_train(spike_train))
time.sleep(0.2) # Update every 0.2 seconds (FAST UPDATES)
if __name__ == "__main__":
app()
|