| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from collections import OrderedDict | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from typing import List, Optional, Tuple, Union | |
| from .configuration import MultiLabelClassifierConfig | |
| class MultiLabelClassifierModel(PreTrainedModel): | |
| config_class = MultiLabelClassifierConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.nlp_model = torch.hub.load('huggingface/pytorch-transformers', 'model', config.transformer_name) | |
| self.rnn = nn.GRU(config.embedding_dim, | |
| config.hidden_dim, | |
| num_layers = config.num_layers, | |
| bidirectional = config.bidirectional, | |
| batch_first = True, | |
| dropout = 0 if config.num_layers < 2 else config.dropout) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.out = nn.Linear(config.hidden_dim * 2 if config.bidirectional else config.hidden_dim, config.num_classes) | |
| def forward(self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| token_type_ids: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| head_mask: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| )-> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: | |
| output = self.nlp_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| _, hidden = self.rnn(output['last_hidden_state']) | |
| if self.rnn.bidirectional: | |
| hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)) | |
| else: | |
| hidden = self.dropout(hidden[-1,:,:]) | |
| logits = self.out(hidden) | |
| return SequenceClassifierOutput( | |
| logits=logits, | |
| hidden_states=output.hidden_states, | |
| attentions=output.attentions, | |
| ) |