YashsharmaPhD commited on
Commit
f2716e0
·
verified ·
1 Parent(s): 60c3083

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import streamlit as st
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Define the MemoryAugmentedLSTM model
8
+ class MemoryAugmentedLSTM(nn.Module):
9
+ def __init__(self, input_size, hidden_size, output_size, num_memories=10):
10
+ super(MemoryAugmentedLSTM, self).__init__()
11
+ self.lstm = nn.LSTM(input_size, hidden_size)
12
+ self.fc = nn.Linear(hidden_size, output_size)
13
+
14
+ # Memory matrix (example: 10 slots, each with hidden_size dimensions)
15
+ self.memory_bank = torch.zeros(num_memories, hidden_size)
16
+
17
+ self.num_memories = num_memories
18
+ self.hidden_size = hidden_size
19
+
20
+ def forward(self, x, pressure):
21
+ lstm_out, _ = self.lstm(x)
22
+
23
+ # Memory allocation based on pressure
24
+ memory_idx = self._get_memory_idx(pressure)
25
+
26
+ # Read from memory (get the most relevant memory slot)
27
+ memory_out = self.memory_bank[memory_idx]
28
+
29
+ # Combine the LSTM output and the relevant memory
30
+ combined_output = lstm_out[-1] + memory_out
31
+
32
+ # Final prediction
33
+ prediction = self.fc(combined_output)
34
+ return prediction
35
+
36
+ def _get_memory_idx(self, pressure):
37
+ # Normalize the pressure to determine the most relevant memory slot
38
+ normalized_pressure = (pressure - 0.1) / (1.0 - 0.1) # Normalize between 0 and 1
39
+ memory_idx = int(normalized_pressure * (self.num_memories - 1))
40
+ return memory_idx
41
+
42
+ # Load the trained model
43
+ model = MemoryAugmentedLSTM(input_size=10, hidden_size=50, output_size=1, num_memories=10)
44
+ model.load_state_dict(torch.load('memristor_lstm.pth'))
45
+
46
+ # Streamlit UI components
47
+ st.title("Pressure to Spike Prediction Model")
48
+ st.write("Select a pressure value between 0.1 MPa and 1 MPa:")
49
+
50
+ # Slider for selecting pressure
51
+ pressure_input = st.slider("Pressure (MPa)", 0.1, 1.0, 0.2, step=0.1)
52
+
53
+ # Example input data (adjust according to your model input)
54
+ input_data = torch.randn(5, 1, 10) # Example random input data for the LSTM
55
+
56
+ # Make prediction using the selected pressure
57
+ prediction = model(input_data, pressure_input).detach().numpy()
58
+
59
+ # Display the prediction outcome (spike generation)
60
+ st.write("Predicted spike outcome:")
61
+
62
+ # Plot the spikes (for demonstration, we use a simple plot with noise for now)
63
+ spike_time = np.linspace(0, 10, 100) # Time from 0 to 10 (for example)
64
+ spike_amplitude = prediction * np.sin(spike_time) # Example of a spike pattern
65
+
66
+ # Plot the spike generation
67
+ plt.figure(figsize=(10, 6))
68
+ plt.plot(spike_time, spike_amplitude, label=f"Spike Pattern for {pressure_input} MPa")
69
+ plt.title(f"Spike Generation for Pressure = {pressure_input} MPa")
70
+ plt.xlabel("Time (s)")
71
+ plt.ylabel("Spike Amplitude")
72
+ plt.legend()
73
+ plt.grid(True)
74
+ st.pyplot(plt)
75
+