| """ Implementation of ONMT RNN for Input Feeding Decoding """ |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class StackedLSTM(nn.Module): |
| """ |
| Our own implementation of stacked LSTM. |
| Needed for the decoder, because we do input feeding. |
| """ |
|
|
| def __init__(self, num_layers, input_size, hidden_size, dropout): |
| super(StackedLSTM, self).__init__() |
| self.dropout = nn.Dropout(dropout) |
| self.num_layers = num_layers |
| self.layers = nn.ModuleList() |
|
|
| for _ in range(num_layers): |
| self.layers.append(nn.LSTMCell(input_size, hidden_size)) |
| input_size = hidden_size |
|
|
| def forward(self, input_feed, hidden): |
| h_0, c_0 = hidden |
| h_1, c_1 = [], [] |
| for i, layer in enumerate(self.layers): |
| h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) |
| input_feed = h_1_i |
| if i + 1 != self.num_layers: |
| input_feed = self.dropout(input_feed) |
| h_1 += [h_1_i] |
| c_1 += [c_1_i] |
|
|
| h_1 = torch.stack(h_1) |
| c_1 = torch.stack(c_1) |
|
|
| return input_feed, (h_1, c_1) |
|
|
|
|
| class StackedGRU(nn.Module): |
| """ |
| Our own implementation of stacked GRU. |
| Needed for the decoder, because we do input feeding. |
| """ |
|
|
| def __init__(self, num_layers, input_size, hidden_size, dropout): |
| super(StackedGRU, self).__init__() |
| self.dropout = nn.Dropout(dropout) |
| self.num_layers = num_layers |
| self.layers = nn.ModuleList() |
|
|
| for _ in range(num_layers): |
| self.layers.append(nn.GRUCell(input_size, hidden_size)) |
| input_size = hidden_size |
|
|
| def forward(self, input_feed, hidden): |
| h_1 = [] |
| for i, layer in enumerate(self.layers): |
| h_1_i = layer(input_feed, hidden[0][i]) |
| input_feed = h_1_i |
| if i + 1 != self.num_layers: |
| input_feed = self.dropout(input_feed) |
| h_1 += [h_1_i] |
|
|
| h_1 = torch.stack(h_1) |
| return input_feed, (h_1,) |
|
|