Spaces:
Sleeping
Sleeping
| from torch import nn | |
| import torch | |
| import torch.nn.functional as F | |
| HIDDEN_SIZE = 128 | |
| EMBEDDING_DIM = 128 | |
| VOCAB_SIZE = 1980 | |
| class Bandanau(nn.Module): | |
| def __init__(self, HIDDEN_SIZE) -> None: | |
| super().__init__() | |
| self.hidden_size = HIDDEN_SIZE | |
| self.linearwk = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.linearwa = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.linearwv = nn.Linear(self.hidden_size, 1) | |
| def forward( | |
| self, | |
| lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE | |
| final_hidden: torch.Tensor # BATCH_SIZE x HIDDEN_SIZE | |
| ): | |
| final_hidden = final_hidden.unsqueeze(1) | |
| wk_out = self.linearwk(lstm_outputs) | |
| wa_out = self.linearwa(final_hidden) | |
| plus = F.tanh(wk_out + wa_out) | |
| wv_out = self.linearwv(plus) | |
| attention_weights = F.softmax(wv_out, dim=1) | |
| attention_weights = attention_weights.transpose(1, 2) | |
| context = torch.bmm(attention_weights, wk_out) | |
| context = context.squeeze(1) | |
| attention_weights = attention_weights.squeeze(1) | |
| return context, attention_weights | |
| # %% | |
| class LSTMConcatAttention(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) | |
| # self.embedding = embedding_layer | |
| self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) | |
| self.attn = Bandanau(HIDDEN_SIZE) | |
| self.clf = nn.Sequential( | |
| nn.Linear(HIDDEN_SIZE, 512), | |
| nn.Dropout(0.3), | |
| nn.Tanh(), | |
| nn.Linear(512, 256), | |
| nn.Dropout(0.3), | |
| nn.Tanh(), | |
| nn.Linear(256, 128), | |
| nn.Dropout(0.3), | |
| nn.Tanh(), | |
| nn.Linear(128, 3) | |
| ) | |
| def forward(self, x): | |
| embeddings = self.embedding(x) | |
| outputs, (h_n, _) = self.lstm(embeddings) | |
| att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) | |
| out = self.clf(att_hidden) | |
| return out, att_weights |