from torch import nn import torch import torch.nn.functional as F HIDDEN_SIZE = 128 EMBEDDING_DIM = 128 VOCAB_SIZE = 1980 class Bandanau(nn.Module): def __init__(self, HIDDEN_SIZE) -> None: super().__init__() self.hidden_size = HIDDEN_SIZE self.linearwk = nn.Linear(self.hidden_size, self.hidden_size) self.linearwa = nn.Linear(self.hidden_size, self.hidden_size) self.linearwv = nn.Linear(self.hidden_size, 1) def forward( self, lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE final_hidden: torch.Tensor # BATCH_SIZE x HIDDEN_SIZE ): final_hidden = final_hidden.unsqueeze(1) wk_out = self.linearwk(lstm_outputs) wa_out = self.linearwa(final_hidden) plus = F.tanh(wk_out + wa_out) wv_out = self.linearwv(plus) attention_weights = F.softmax(wv_out, dim=1) attention_weights = attention_weights.transpose(1, 2) context = torch.bmm(attention_weights, wk_out) context = context.squeeze(1) attention_weights = attention_weights.squeeze(1) return context, attention_weights # %% class LSTMConcatAttention(nn.Module): def __init__(self) -> None: super().__init__() self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM) # self.embedding = embedding_layer self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True) self.attn = Bandanau(HIDDEN_SIZE) self.clf = nn.Sequential( nn.Linear(HIDDEN_SIZE, 512), nn.Dropout(0.3), nn.Tanh(), nn.Linear(512, 256), nn.Dropout(0.3), nn.Tanh(), nn.Linear(256, 128), nn.Dropout(0.3), nn.Tanh(), nn.Linear(128, 3) ) def forward(self, x): embeddings = self.embedding(x) outputs, (h_n, _) = self.lstm(embeddings) att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0)) out = self.clf(att_hidden) return out, att_weights