File size: 2,092 Bytes
b47c0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from torch import nn
import torch
import torch.nn.functional as F

HIDDEN_SIZE = 128
EMBEDDING_DIM = 128
VOCAB_SIZE = 1980


class Bandanau(nn.Module):
    def __init__(self, HIDDEN_SIZE) -> None:
        super().__init__()
        self.hidden_size = HIDDEN_SIZE
        self.linearwk = nn.Linear(self.hidden_size, self.hidden_size)
        self.linearwa = nn.Linear(self.hidden_size, self.hidden_size)
        self.linearwv = nn.Linear(self.hidden_size, 1)

    def forward(
            self,
            lstm_outputs: torch.Tensor,  # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
            final_hidden: torch.Tensor  # BATCH_SIZE x HIDDEN_SIZE
    ):
        final_hidden = final_hidden.unsqueeze(1)

        wk_out = self.linearwk(lstm_outputs)
        wa_out = self.linearwa(final_hidden)

        plus = F.tanh(wk_out + wa_out)

        wv_out = self.linearwv(plus)

        attention_weights = F.softmax(wv_out, dim=1)

        attention_weights = attention_weights.transpose(1, 2)

        context = torch.bmm(attention_weights, wk_out)

        context = context.squeeze(1)
        attention_weights = attention_weights.squeeze(1)

        return context, attention_weights


# %%
class LSTMConcatAttention(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        # self.embedding = embedding_layer
        self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
        self.attn = Bandanau(HIDDEN_SIZE)
        self.clf = nn.Sequential(
            nn.Linear(HIDDEN_SIZE, 512),
            nn.Dropout(0.3),
            nn.Tanh(),
            nn.Linear(512, 256),
            nn.Dropout(0.3),
            nn.Tanh(),
            nn.Linear(256, 128),
            nn.Dropout(0.3),
            nn.Tanh(),
            nn.Linear(128, 3)
        )

    def forward(self, x):
        embeddings = self.embedding(x)
        outputs, (h_n, _) = self.lstm(embeddings)
        att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
        out = self.clf(att_hidden)
        return out, att_weights