| | from transformers.modeling_outputs import ( |
| | TokenClassifierOutput, |
| | SequenceClassifierOutput, |
| | ) |
| | from transformers.modeling_outputs import TokenClassifierOutput |
| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig |
| | from torch.nn import CrossEntropyLoss |
| | from typing import Optional, Tuple, Union |
| | import logging, json, os |
| | from torch.nn import MSELoss, BCEWithLogitsLoss |
| | import floret |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def get_info(label_map): |
| | num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} |
| | return num_token_labels_dict |
| |
|
| |
|
| | class ModelForSequenceAndTokenClassification(PreTrainedModel): |
| | """ |
| | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| | models. |
| | """ |
| |
|
| | config_class = BertConfig |
| |
|
| | _keys_to_ignore_on_load_missing = [r"position_ids"] |
| |
|
| | def __init__( |
| | self, config, num_sequence_labels=None, num_token_labels=None, do_classif=False |
| | ): |
| | super().__init__(config) |
| |
|
| | if num_sequence_labels is None: |
| | self.num_token_labels = len(config.id2label) |
| | self.num_sequence_labels = 2 |
| | else: |
| | self.num_token_labels = num_token_labels |
| | self.num_sequence_labels = num_sequence_labels |
| |
|
| | self.config = config |
| | self.do_classif = do_classif |
| |
|
| | self.model = floret.load_model(self.config.filename) |
| |
|
| | self.bert = AutoModel.from_config(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout |
| | if config.classifier_dropout is not None |
| | else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| |
|
| | |
| | self.token_classifier = nn.Linear(config.hidden_size, self.num_token_labels) |
| |
|
| | if do_classif: |
| | |
| | self.sequence_classifier = nn.Linear( |
| | config.hidden_size, self.num_sequence_labels |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def do_classif(self): |
| | return self.do_classif |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | |
| | |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | token_labels: Optional[torch.Tensor] = None, |
| | sequence_labels: Optional[torch.Tensor] = None, |
| | offset_mapping: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ) -> Union[ |
| | Union[Tuple[torch.Tensor], SequenceClassifierOutput], |
| | Union[Tuple[torch.Tensor], TokenClassifierOutput], |
| | ]: |
| | r""" |
| | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
| | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
| | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | |
| | token_output = outputs[0] |
| |
|
| | token_output = self.dropout(token_output) |
| | token_logits = self.token_classifier(token_output) |
| |
|
| | if self.do_classif: |
| | |
| | pooled_output = outputs[1] |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | sequence_logits = self.sequence_classifier(pooled_output) |
| |
|
| | |
| | loss = None |
| | if token_labels is not None: |
| | loss_fct = CrossEntropyLoss() |
| | |
| | loss_tokens = loss_fct( |
| | token_logits.view(-1, self.num_token_labels), token_labels.view(-1) |
| | ) |
| |
|
| | if self.do_classif: |
| | if self.config.problem_type == "regression": |
| | loss_fct = MSELoss() |
| | if self.num_sequence_labels == 1: |
| | loss_sequence = loss_fct( |
| | sequence_logits.squeeze(), sequence_labels.squeeze() |
| | ) |
| | else: |
| | loss_sequence = loss_fct(sequence_logits, sequence_labels) |
| | if self.config.problem_type == "single_label_classification": |
| | loss_fct = CrossEntropyLoss() |
| | loss_sequence = loss_fct( |
| | sequence_logits.view(-1, self.num_sequence_labels), |
| | sequence_labels.view(-1), |
| | ) |
| | elif self.config.problem_type == "multi_label_classification": |
| | loss_fct = BCEWithLogitsLoss() |
| | loss_sequence = loss_fct(sequence_logits, sequence_labels) |
| |
|
| | loss = loss_tokens + loss_sequence |
| | else: |
| | loss = loss_tokens |
| |
|
| | if not return_dict: |
| | if self.do_classif: |
| | output = ( |
| | sequence_logits, |
| | token_logits, |
| | ) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| | else: |
| | output = (token_logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | if self.do_classif: |
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=sequence_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ), TokenClassifierOutput( |
| | loss=loss, |
| | logits=token_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| | else: |
| | return TokenClassifierOutput( |
| | loss=loss, |
| | logits=token_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|