import torch import transformers from torch import nn from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.bert.modeling_bert import ( BertPreTrainedModel, BERT_INPUTS_DOCSTRING, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, BERT_START_DOCSTRING, _CONFIG_FOR_DOC, _SEQ_CLASS_EXPECTED_OUTPUT, _SEQ_CLASS_EXPECTED_LOSS, BertModel, ) from transformers.file_utils import ( add_code_sample_docstrings, add_start_docstrings_to_model_forward, add_start_docstrings ) @add_start_docstrings( """ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, BERT_START_DOCSTRING, ) class BertForSequenceClassification(BertPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(transformers.PretrainedConfig()) #task_labels_map={"binary_classification": 2, "label_classification": 5} self.tasks = kwargs.get("tasks_map", {}) self.config = config self.bert = BertModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) ## add task specific output heads self.classifier1 = nn.Linear( config.hidden_size, self.tasks[0].num_labels ) self.classifier2 = nn.Linear( config.hidden_size, self.tasks[1].num_labels ) self.init_weights() @add_start_docstrings_to_model_forward( BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") ) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, expected_output=_SEQ_CLASS_EXPECTED_OUTPUT, expected_loss=_SEQ_CLASS_EXPECTED_LOSS, ) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, task_ids=None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) unique_task_ids_list = torch.unique(task_ids).tolist() loss_list = [] logits = None for unique_task_id in unique_task_ids_list: loss = None task_id_filter = task_ids == unique_task_id if unique_task_id == 0: logits = self.classifier1(pooled_output[task_id_filter]) elif unique_task_id == 1: logits = self.classifier2(pooled_output[task_id_filter]) if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.tasks[unique_task_id].num_labels), labels[task_id_filter].view(-1)) loss_list.append(loss) # logits are only used for eval. and in case of eval the batch is not multi task # For training only the loss is used if loss_list: loss = torch.stack(loss_list).mean() if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )