Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence | |
| import flair | |
| class LSTM(torch.nn.Module): | |
| """ | |
| Simple LSTM Implementation that returns the features used for (1)CRF and (2)Span Classifier | |
| """ | |
| def __init__(self, rnn_layers: int, hidden_size: int, bidirectional: bool, rnn_input_dim: int,): | |
| """ | |
| :param rnn_layers: number of rnn layers to be used, default 1 | |
| :param hidden_size: hidden size of the LSTM layer | |
| :param bidirectional: whether we use biderectional lstm or not, default True | |
| :param rnn_input_dim: the shape of our max sentence token and embeddings | |
| """ | |
| super(LSTM, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.rnn_input_dim = rnn_input_dim | |
| self.num_layers = rnn_layers | |
| self.dropout = 0.0 if rnn_layers == 1 else 0.5 | |
| self.bidirectional = bidirectional | |
| self.batch_first = True | |
| self.lstm = torch.nn.LSTM( | |
| self.rnn_input_dim, | |
| self.hidden_size, | |
| num_layers=self.num_layers, | |
| dropout=self.dropout, | |
| bidirectional=self.bidirectional, | |
| batch_first=self.batch_first, | |
| ) | |
| self.to(flair.device) | |
| def forward(self, sentence_tensor: torch.Tensor, sorted_lengths: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward propagation of LSTM Model by packing the tensors. | |
| :param features: output from RNN / Linear layer in shape (batch size, seq len, hidden size) | |
| :return: CRF scores (emission scores for each token + transitions prob from previous state) in | |
| shape (batch_size, seq len, tagset size, tagset size) | |
| """ | |
| packed = pack_padded_sequence(sentence_tensor, sorted_lengths, batch_first=True, enforce_sorted=False) | |
| rnn_output, hidden = self.lstm(packed) | |
| sentence_tensor, output_lengths = pad_packed_sequence(rnn_output, batch_first=True) | |
| return sentence_tensor, output_lengths | |