Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch | |
| from models.attention import Attention | |
| from utils.config import config | |
| class Decoder(nn.Module): | |
| def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout): | |
| super().__init__() | |
| self.output_dim = output_dim | |
| self.attention = Attention(hidden_dim) | |
| self.embedding = nn.Embedding(output_dim, embedding_dim) | |
| self.rnn = nn.GRU( | |
| embedding_dim + hidden_dim, | |
| hidden_dim, | |
| num_layers=n_layers, | |
| dropout=dropout if n_layers > 1 else 0 | |
| ) | |
| self.fc_out = nn.Linear(hidden_dim * 2, output_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, input, hidden, encoder_outputs): | |
| # input: [batch_size] | |
| # hidden: [n_layers, batch_size, hidden_dim] | |
| # encoder_outputs: [src_len, batch_size, hidden_dim] | |
| input = input.unsqueeze(0) | |
| # input: [1, batch_size] | |
| embedded = self.dropout(self.embedding(input)) | |
| # embedded: [1, batch_size, embedding_dim] | |
| a = self.attention(hidden[-1], encoder_outputs) | |
| # a: [src_len, batch_size] | |
| a = a.permute(1, 0).unsqueeze(1) | |
| # a: [batch_size, 1, src_len] | |
| encoder_outputs = encoder_outputs.permute(1, 0, 2) | |
| # encoder_outputs: [batch_size, src_len, hidden_dim] | |
| weighted = torch.bmm(a, encoder_outputs) | |
| weighted = weighted.permute(1, 0, 2) | |
| # weighted: [1, batch_size, hidden_dim] | |
| rnn_input = torch.cat((embedded, weighted), dim=2) | |
| # rnn_input: [1, batch_size, embedding_dim + hidden_dim] | |
| output, hidden = self.rnn(rnn_input, hidden) | |
| # output: [1, batch_size, hidden_dim] | |
| # hidden: [n_layers, batch_size, hidden_dim] | |
| embedded = embedded.squeeze(0) | |
| output = output.squeeze(0) | |
| weighted = weighted.squeeze(0) | |
| prediction = self.fc_out(torch.cat((output, weighted), dim=1)) | |
| # prediction: [batch_size, output_dim] | |
| return prediction, hidden |