Spaces:
Runtime error
Runtime error
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| class RNNEnoder(nn.Module): | |
| def __init__(self, cfg): | |
| super(RNNEnoder, self).__init__() | |
| self.cfg = cfg | |
| self.rnn_type = cfg.MODEL.LANGUAGE_BACKBONE.RNN_TYPE | |
| self.variable_length = cfg.MODEL.LANGUAGE_BACKBONE.VARIABLE_LENGTH | |
| self.word_embedding_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_EMBEDDING_SIZE | |
| self.word_vec_size = cfg.MODEL.LANGUAGE_BACKBONE.WORD_VEC_SIZE | |
| self.hidden_size = cfg.MODEL.LANGUAGE_BACKBONE.HIDDEN_SIZE | |
| self.bidirectional = cfg.MODEL.LANGUAGE_BACKBONE.BIDIRECTIONAL | |
| self.input_dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.INPUT_DROPOUT_P | |
| self.dropout_p = cfg.MODEL.LANGUAGE_BACKBONE.DROPOUT_P | |
| self.n_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS | |
| self.corpus_path = cfg.MODEL.LANGUAGE_BACKBONE.CORPUS_PATH | |
| self.vocab_size = cfg.MODEL.LANGUAGE_BACKBONE.VOCAB_SIZE | |
| # language encoder | |
| self.embedding = nn.Embedding(self.vocab_size, self.word_embedding_size) | |
| self.input_dropout = nn.Dropout(self.input_dropout_p) | |
| self.mlp = nn.Sequential(nn.Linear(self.word_embedding_size, self.word_vec_size), nn.ReLU()) | |
| self.rnn = getattr(nn, self.rnn_type.upper())(self.word_vec_size, | |
| self.hidden_size, | |
| self.n_layers, | |
| batch_first=True, | |
| bidirectional=self.bidirectional, | |
| dropout=self.dropout_p) | |
| self.num_dirs = 2 if self.bidirectional else 1 | |
| def forward(self, input, mask=None): | |
| word_id = input | |
| max_len = (word_id != 0).sum(1).max().item() | |
| word_id = word_id[:, :max_len] # mask zero | |
| # embedding | |
| output, hidden, embedded, final_output = self.RNNEncode(word_id) | |
| return { | |
| 'hidden': hidden, | |
| 'output': output, | |
| 'embedded': embedded, | |
| 'final_output': final_output, | |
| } | |
| def encode(self, input_labels): | |
| """ | |
| Inputs: | |
| - input_labels: Variable long (batch, seq_len) | |
| Outputs: | |
| - output : Variable float (batch, max_len, hidden_size * num_dirs) | |
| - hidden : Variable float (batch, num_layers * num_dirs * hidden_size) | |
| - embedded: Variable float (batch, max_len, word_vec_size) | |
| """ | |
| device = input_labels.device | |
| if self.variable_length: | |
| input_lengths_list, sorted_lengths_list, sort_idxs, recover_idxs = self.sort_inputs(input_labels) | |
| input_labels = input_labels[sort_idxs] | |
| embedded = self.embedding(input_labels) # (n, seq_len, word_embedding_size) | |
| embedded = self.input_dropout(embedded) # (n, seq_len, word_embedding_size) | |
| embedded = self.mlp(embedded) # (n, seq_len, word_vec_size) | |
| if self.variable_length: | |
| if self.variable_length: | |
| embedded = nn.utils.rnn.pack_padded_sequence(embedded, \ | |
| sorted_lengths_list, \ | |
| batch_first=True) | |
| # forward rnn | |
| self.rnn.flatten_parameters() | |
| output, hidden = self.rnn(embedded) | |
| # recover | |
| if self.variable_length: | |
| # recover embedded | |
| embedded, _ = nn.utils.rnn.pad_packed_sequence(embedded, | |
| batch_first=True) # (batch, max_len, word_vec_size) | |
| embedded = embedded[recover_idxs] | |
| # recover output | |
| output, _ = nn.utils.rnn.pad_packed_sequence(output, | |
| batch_first=True) # (batch, max_len, hidden_size * num_dir) | |
| output = output[recover_idxs] | |
| # recover hidden | |
| if self.rnn_type == 'lstm': | |
| hidden = hidden[0] # hidden state | |
| hidden = hidden[:, recover_idxs, :] # (num_layers * num_dirs, batch, hidden_size) | |
| hidden = hidden.transpose(0, 1).contiguous() # (batch, num_layers * num_dirs, hidden_size) | |
| hidden = hidden.view(hidden.size(0), -1) # (batch, num_layers * num_dirs * hidden_size) | |
| # final output | |
| finnal_output = [] | |
| for ii in range(output.shape[0]): | |
| finnal_output.append(output[ii, int(input_lengths_list[ii] - 1), :]) | |
| finnal_output = torch.stack(finnal_output, dim=0) # (batch, number_dirs * hidden_size) | |
| return output, hidden, embedded, finnal_output | |
| def sort_inputs(self, input_labels): # sort input labels by descending | |
| device = input_labels.device | |
| input_lengths = (input_labels != 0).sum(1) | |
| input_lengths_list = input_lengths.data.cpu().numpy().tolist() | |
| sorted_input_lengths_list = np.sort(input_lengths_list)[::-1].tolist() # list of sorted input_lengths | |
| sort_idxs = np.argsort(input_lengths_list)[::-1].tolist() | |
| s2r = {s: r for r, s in enumerate(sort_idxs)} | |
| recover_idxs = [s2r[s] for s in range(len(input_lengths_list))] | |
| assert max(input_lengths_list) == input_labels.size(1) | |
| # move to long tensor | |
| sort_idxs = input_labels.data.new(sort_idxs).long().to(device) # Variable long | |
| recover_idxs = input_labels.data.new(recover_idxs).long().to(device) # Variable long | |
| return input_lengths_list, sorted_input_lengths_list, sort_idxs, recover_idxs | |