| |
| |
| class AI(nn.Module): |
| def __init__(self, vocab_size=230, nclass=2, embed_dim=512, num_heads=8, system_dim=256): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, embed_dim) |
| self.system_dim = system_dim |
| self.embed_dim = embed_dim |
| self.body = nn.ParameterList([ |
|
|
| ]) |
| for i in range(num_heads): |
| self.body.append(nn.Parameter(torch.randn(embed_dim, system_dim))) |
|
|
| self.forgets_x = nn.ParameterList([ |
|
|
| ]) |
| for i in range(num_heads): |
| self.forgets_x.append(nn.Parameter(torch.randn(embed_dim, system_dim))) |
| self.matrix = nn.ParameterList([ |
|
|
| ]) |
| for i in range(num_heads): |
| self.matrix.append(nn.Parameter(torch.randn(system_dim, system_dim))) |
| self.tanh = nn.Tanh() |
| self.sigmoid = nn.Sigmoid() |
| self.normalization = nn.LayerNorm(system_dim) |
| self.linear = nn.Linear(system_dim, nclass) |
| def forward(self, x): |
| x = self.embedding(x) |
| memory = torch.zeros(x.size(0), x.size(1), self.system_dim) |
| if x.dim() == 2: |
| x = x.unsqueeze(0) |
| headox = torch.zeros(x.size(0), x.size(1), self.system_dim) |
| k_del = 0 |
| for i in self.body: |
| headx = x @ i |
| headox = headox + headx |
| k_del = k_del+ 1 |
| headox = headox / k_del |
| headox = self.tanh(headox) |
| memory = memory + headox |
| headoxn = torch.zeros(x.size(0), x.size(1), self.system_dim) |
| k_del = 0 |
| for i in self.forgets_x: |
| headoxn = headoxn + (x @ i) |
| k_del = k_del + 1 |
| head_add = headoxn/k_del |
| head_add = self.sigmoid(head_add) |
| memory = memory * head_add |
| memory = memory.mean(dim=1) |
| memn = torch.randn(memory.size(0), self.system_dim) |
| k_del = 0 |
| for i in self.matrix: |
| memn = memn + (memory@i) |
| k_del = k_del + 1 |
| memn = memn / k_del |
| memn = self.tanh(memn) |
| memory = memn |
| x = self.linear(memory) |
| return x |