nlp_project / models /kdnv_models.py
Kdnv's picture
movie kdnv update
b47c0e4
from torch import nn
import torch
import torch.nn.functional as F
HIDDEN_SIZE = 128
EMBEDDING_DIM = 128
VOCAB_SIZE = 1980
class Bandanau(nn.Module):
def __init__(self, HIDDEN_SIZE) -> None:
super().__init__()
self.hidden_size = HIDDEN_SIZE
self.linearwk = nn.Linear(self.hidden_size, self.hidden_size)
self.linearwa = nn.Linear(self.hidden_size, self.hidden_size)
self.linearwv = nn.Linear(self.hidden_size, 1)
def forward(
self,
lstm_outputs: torch.Tensor, # BATCH_SIZE x SEQ_LEN x HIDDEN_SIZE
final_hidden: torch.Tensor # BATCH_SIZE x HIDDEN_SIZE
):
final_hidden = final_hidden.unsqueeze(1)
wk_out = self.linearwk(lstm_outputs)
wa_out = self.linearwa(final_hidden)
plus = F.tanh(wk_out + wa_out)
wv_out = self.linearwv(plus)
attention_weights = F.softmax(wv_out, dim=1)
attention_weights = attention_weights.transpose(1, 2)
context = torch.bmm(attention_weights, wk_out)
context = context.squeeze(1)
attention_weights = attention_weights.squeeze(1)
return context, attention_weights
# %%
class LSTMConcatAttention(nn.Module):
def __init__(self) -> None:
super().__init__()
self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
# self.embedding = embedding_layer
self.lstm = nn.LSTM(EMBEDDING_DIM, HIDDEN_SIZE, batch_first=True)
self.attn = Bandanau(HIDDEN_SIZE)
self.clf = nn.Sequential(
nn.Linear(HIDDEN_SIZE, 512),
nn.Dropout(0.3),
nn.Tanh(),
nn.Linear(512, 256),
nn.Dropout(0.3),
nn.Tanh(),
nn.Linear(256, 128),
nn.Dropout(0.3),
nn.Tanh(),
nn.Linear(128, 3)
)
def forward(self, x):
embeddings = self.embedding(x)
outputs, (h_n, _) = self.lstm(embeddings)
att_hidden, att_weights = self.attn(outputs, h_n.squeeze(0))
out = self.clf(att_hidden)
return out, att_weights