#model, info = ryann.ParaLSTM(system_dim=256, embed_dim=512, num_heads=8, nclass=2, vocab_size=230, typelstm='standart') #OR class AI(nn.Module): def __init__(self, vocab_size=230, nclass=2, embed_dim=512, num_heads=8, system_dim=256): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.system_dim = system_dim self.embed_dim = embed_dim self.body = nn.ParameterList([ ]) for i in range(num_heads): self.body.append(nn.Parameter(torch.randn(embed_dim, system_dim))) self.forgets_x = nn.ParameterList([ ]) for i in range(num_heads): self.forgets_x.append(nn.Parameter(torch.randn(embed_dim, system_dim))) self.matrix = nn.ParameterList([ ]) for i in range(num_heads): self.matrix.append(nn.Parameter(torch.randn(system_dim, system_dim))) self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() self.normalization = nn.LayerNorm(system_dim) self.linear = nn.Linear(system_dim, nclass) def forward(self, x): x = self.embedding(x) memory = torch.zeros(x.size(0), x.size(1), self.system_dim) if x.dim() == 2: x = x.unsqueeze(0) headox = torch.zeros(x.size(0), x.size(1), self.system_dim) k_del = 0 for i in self.body: headx = x @ i headox = headox + headx k_del = k_del+ 1 headox = headox / k_del headox = self.tanh(headox) memory = memory + headox headoxn = torch.zeros(x.size(0), x.size(1), self.system_dim) k_del = 0 for i in self.forgets_x: headoxn = headoxn + (x @ i) k_del = k_del + 1 head_add = headoxn/k_del head_add = self.sigmoid(head_add) memory = memory * head_add memory = memory.mean(dim=1) memn = torch.randn(memory.size(0), self.system_dim) k_del = 0 for i in self.matrix: memn = memn + (memory@i) k_del = k_del + 1 memn = memn / k_del memn = self.tanh(memn) memory = memn x = self.linear(memory) return x