YashsharmaPhD commited on
Commit
e82e970
·
verified ·
1 Parent(s): ce87e89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -3,13 +3,18 @@ import torch.nn as nn
3
  import matplotlib.pyplot as plt
4
  import streamlit as st
5
 
6
- # Simple LSTM with a VRAM-like 4x4 Memory Grid
7
  class MemristorLSTM(nn.Module):
8
  def __init__(self, memory_size=4):
9
  super(MemristorLSTM, self).__init__()
10
  self.memory_size = memory_size
11
  self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
12
- self.fc = nn.Linear(50, 1)
 
 
 
 
 
13
  # 4x4 memory grid initialized to zeros
14
  self.memory = torch.zeros(memory_size, memory_size)
15
 
@@ -18,9 +23,15 @@ class MemristorLSTM(nn.Module):
18
  lstm_out, _ = self.lstm(x)
19
  output = self.fc(lstm_out[:, -1, :])
20
 
 
 
 
 
 
 
21
  # Update memory (4x4 grid), simulating synaptic memory flow
22
- self.memory = self.memory + output.view(self.memory_size, self.memory_size)
23
- return output, self.memory
24
 
25
  # Load Pretrained Model
26
  def load_model(model_path="memristor_lstm.pth"):
@@ -62,7 +73,7 @@ def generate_spikes_for_pressure(pressure):
62
  visualize_memory_grid(memory_grid)
63
 
64
  # Optionally, print the output
65
- print(f"Output: {output}")
66
  print(f"Memory Grid after update:\n{memory_grid}")
67
 
68
  return output, memory_grid
 
3
  import matplotlib.pyplot as plt
4
  import streamlit as st
5
 
6
+ # Simple LSTM with a VRAM-like 4x4 Memory Grid, including gates and filters
7
  class MemristorLSTM(nn.Module):
8
  def __init__(self, memory_size=4):
9
  super(MemristorLSTM, self).__init__()
10
  self.memory_size = memory_size
11
  self.lstm = nn.LSTM(input_size=1, hidden_size=50, num_layers=2, batch_first=True)
12
+ self.fc = nn.Linear(50, self.memory_size * self.memory_size) # Output 16 values
13
+ # Gate to modulate the output
14
+ self.gate = nn.Sigmoid()
15
+ # Tanh filter to apply non-linearity to the output
16
+ self.filter = nn.Tanh()
17
+
18
  # 4x4 memory grid initialized to zeros
19
  self.memory = torch.zeros(memory_size, memory_size)
20
 
 
23
  lstm_out, _ = self.lstm(x)
24
  output = self.fc(lstm_out[:, -1, :])
25
 
26
+ # Apply sigmoid gate to control the flow of information
27
+ gated_output = self.gate(output)
28
+
29
+ # Apply tanh filter for non-linearity
30
+ filtered_output = self.filter(gated_output)
31
+
32
  # Update memory (4x4 grid), simulating synaptic memory flow
33
+ self.memory = self.memory + filtered_output.view(self.memory_size, self.memory_size)
34
+ return filtered_output, self.memory
35
 
36
  # Load Pretrained Model
37
  def load_model(model_path="memristor_lstm.pth"):
 
73
  visualize_memory_grid(memory_grid)
74
 
75
  # Optionally, print the output
76
+ print(f"Filtered Output: {output}")
77
  print(f"Memory Grid after update:\n{memory_grid}")
78
 
79
  return output, memory_grid