| | import torch |
| | import transformers |
| | from torch import nn |
| | from torch.nn import CrossEntropyLoss |
| | from typing import Optional, Tuple, Union |
| | from transformers.modeling_outputs import SequenceClassifierOutput |
| | from transformers.models.bert.modeling_bert import ( |
| | BertPreTrainedModel, |
| | BERT_INPUTS_DOCSTRING, |
| | _TOKENIZER_FOR_DOC, |
| | _CHECKPOINT_FOR_DOC, |
| | BERT_START_DOCSTRING, |
| | _CONFIG_FOR_DOC, |
| | _SEQ_CLASS_EXPECTED_OUTPUT, |
| | _SEQ_CLASS_EXPECTED_LOSS, |
| | BertModel, |
| | ) |
| |
|
| | from transformers.file_utils import ( |
| | add_code_sample_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | add_start_docstrings |
| | ) |
| |
|
| | @add_start_docstrings( |
| | """ |
| | Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled |
| | output) e.g. for GLUE tasks. |
| | """, |
| | BERT_START_DOCSTRING, |
| | ) |
| | class BertForSequenceClassification(BertPreTrainedModel): |
| | def __init__(self, config, **kwargs): |
| | super().__init__(transformers.PretrainedConfig()) |
| | |
| | self.tasks = kwargs.get("tasks_map", {}) |
| | self.config = config |
| |
|
| | self.bert = BertModel(config) |
| | classifier_dropout = ( |
| | config.classifier_dropout |
| | if config.classifier_dropout is not None |
| | else config.hidden_dropout_prob |
| | ) |
| | self.dropout = nn.Dropout(classifier_dropout) |
| | |
| | self.classifier1 = nn.Linear( |
| | config.hidden_size, self.tasks[0].num_labels |
| | ) |
| | self.classifier2 = nn.Linear( |
| | config.hidden_size, self.tasks[1].num_labels |
| | ) |
| |
|
| | self.init_weights() |
| |
|
| | @add_start_docstrings_to_model_forward( |
| | BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") |
| | ) |
| | @add_code_sample_docstrings( |
| | processor_class=_TOKENIZER_FOR_DOC, |
| | checkpoint=_CHECKPOINT_FOR_DOC, |
| | output_type=SequenceClassifierOutput, |
| | config_class=_CONFIG_FOR_DOC, |
| | expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, |
| | expected_loss=_SEQ_CLASS_EXPECTED_LOSS, |
| | ) |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | token_type_ids: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | head_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | task_ids=None, |
| | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| | r""" |
| | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
| | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., |
| | config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), |
| | If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| | """ |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | outputs = self.bert( |
| | input_ids, |
| | attention_mask=attention_mask, |
| | token_type_ids=token_type_ids, |
| | position_ids=position_ids, |
| | head_mask=head_mask, |
| | inputs_embeds=inputs_embeds, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | pooled_output = outputs[1] |
| |
|
| | pooled_output = self.dropout(pooled_output) |
| | |
| | unique_task_ids_list = torch.unique(task_ids).tolist() |
| | loss_list = [] |
| | logits = None |
| | for unique_task_id in unique_task_ids_list: |
| | loss = None |
| | task_id_filter = task_ids == unique_task_id |
| |
|
| | if unique_task_id == 0: |
| | logits = self.classifier1(pooled_output[task_id_filter]) |
| | elif unique_task_id == 1: |
| | logits = self.classifier2(pooled_output[task_id_filter]) |
| |
|
| | |
| | if labels is not None: |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.tasks[unique_task_id].num_labels), labels[task_id_filter].view(-1)) |
| | loss_list.append(loss) |
| | |
| | |
| | |
| |
|
| | if loss_list: |
| | loss = torch.stack(loss_list).mean() |
| | if not return_dict: |
| | output = (logits,) + outputs[2:] |
| | return ((loss,) + output) if loss is not None else output |
| | |
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|