File size: 4,273 Bytes
f2716e0
 
 
ce87e89
cd6adf8
78cdd90
f2716e0
e82e970
7f1acfb
 
 
 
 
a2b1856
78cdd90
 
 
7f1acfb
 
f2716e0
7f1acfb
e82e970
 
 
 
f2716e0
78cdd90
ce87e89
 
7f1acfb
a2b1856
 
 
 
 
ce87e89
 
7f1acfb
ce87e89
 
7f1acfb
ce87e89
3bd112b
7f1acfb
 
ce87e89
7f1acfb
 
78cdd90
3bd112b
78cdd90
 
 
 
 
ce87e89
78cdd90
3bd112b
08a68c4
cd6adf8
 
 
 
78cdd90
3bd112b
 
08a68c4
78cdd90
 
 
ce87e89
78cdd90
ce87e89
78cdd90
 
 
 
 
ce87e89
78cdd90
7f1acfb
78cdd90
 
7f1acfb
ce87e89
78cdd90
cd6adf8
78cdd90
 
7f1acfb
78cdd90
3062fb3
3bd112b
 
 
 
78cdd90
 
 
3062fb3
78cdd90
3bd112b
 
 
7f1acfb
3bd112b
7f1acfb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import streamlit as st
import time
import numpy as np

# Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
class MemristorLSTM(nn.Module):
    def __init__(self, memory_size=4):
        super(MemristorLSTM, self).__init__()
        self.memory_size = memory_size
        self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
        self.fc = nn.Linear(50, self.memory_size * self.memory_size)  # Output 16 values (4x4 grid)
        self.gate = nn.Sigmoid()  # Gate to modulate the output
        self.filter = nn.Tanh()   # Non-linearity filter
        self.memory = torch.zeros(memory_size, memory_size)  # 4x4 memory grid initialized to zeros

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        output = self.fc(lstm_out[:, -1, :])
        gated_output = self.gate(output)
        filtered_output = self.filter(gated_output)
        self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
        return filtered_output, self.memory

# Load Pretrained Model
def load_model(model_path="memristor_lstm.pth"):
    model = MemristorLSTM(memory_size=4)
    try:
        pretrained_dict = torch.load(model_path, map_location=torch.device('cpu'))
        model_dict = model.state_dict()
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'fc' not in k}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model.eval()
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Visualize the 4x4 Memory Grid
def visualize_memory_grid(memory_grid):
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.imshow(memory_grid.detach().numpy(), cmap='viridis', interpolation='nearest')
    ax.set_title("4x4 Memory Grid - VRAM")
    ax.set_xlabel("Memory Cells (Columns)")
    ax.set_ylabel("Memory Cells (Rows)")
    
    return fig  # Return the figure

# Generate Spike Train
def generate_spike_train(length=50, threshold=0.5):
    spike_train = np.random.rand(length) > threshold
    return spike_train.astype(int)

# Visualize the Spike Train
def visualize_spike_train(spike_train):
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(spike_train, marker='o', linestyle='-', color='r', markersize=5)
    ax.set_title("Spike Train")
    ax.set_xlabel("Time Steps")
    ax.set_ylabel("Spike Event (1 = Spike, 0 = No Spike)")
    ax.set_yticks([0, 1])
    
    return fig  # Return the figure

# Generate spikes and VRAM updates
def generate_spikes_for_pressure(pressure):
    pressure_normalized = (pressure - 0.1) / (1.0 - 0.1)  
    pressure_input = torch.tensor([[pressure_normalized]], dtype=torch.float32).view(1, 1, 1)
    model = load_model()

    if model is not None:
        output, memory_grid = model(pressure_input)
        spike_train = generate_spike_train(length=50)
        return output, memory_grid, spike_train
    return None, None, None

# Streamlit UI
def app():
    st.title("Memristor VRAM & Spike Train Live Visualization")
    st.write("Watch how VRAM and spike trains change over time with pressure input.")

    pressure = st.slider("Select Pressure (MPa)", 0.1, 1.0, 0.5, 0.1)
    duration = st.radio("Select Duration", [1, 2], index=0)

    if st.button("Start Live Visualization"):
        st.info(f"Running live update for {duration} minute(s)...")

        plot_col1, plot_col2 = st.columns(2)  # Side-by-side layout

        # Create empty containers to update plots in real-time
        plot_vram = plot_col1.empty()
        plot_spike_train = plot_col2.empty()

        end_time = time.time() + duration * 60  # Run for selected time
        while time.time() < end_time:
            output, memory_grid, spike_train = generate_spikes_for_pressure(pressure)

            if output is not None and memory_grid is not None:
                # Update plots dynamically in existing containers
                plot_vram.pyplot(visualize_memory_grid(memory_grid))
                plot_spike_train.pyplot(visualize_spike_train(spike_train))

            time.sleep(0.2)  # Update every 0.2 seconds (FAST UPDATES)

if __name__ == "__main__":
    app()