from typing import Optional, Union import torch from torch import nn from transformers import ( BertModel, BertPreTrainedModel, ) from transformers.modeling_outputs import SequenceClassifierOutput from transformers.models.bert.modeling_bert import BertOnlyMLMHead from .configuration_bert import BertMultiTaskConfig class BertForMultiTaskClassification(BertPreTrainedModel): config_class = BertMultiTaskConfig _tied_weights_keys = ["cls.predictions.decoder.weight"] def __init__(self, config): super().__init__(config) self.tasks = config.tasks self.config = config self.bert = BertModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) task_layers = {} for task_name, num_labels in self.tasks.items(): if task_name.upper() == "MLM": self.cls = BertOnlyMLMHead(config) else: task_layers[task_name.upper()] = nn.Linear(config.hidden_size, num_labels) self.task_classifiers = nn.ModuleDict(task_layers) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): # This method tells the PreTrainedModel that self.cls.predictions.decoder is the output layer to be tied if hasattr(self, "cls"): return self.cls.predictions.decoder return None def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, task: str | None = None, # For now the model will use single task per batch ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if task is None: raise ValueError(f"Task must be specified and one of {self.task_classifiers.keys()}") if task.upper() == "MLM": if not hasattr(self, "cls"): raise ValueError("Model was not initialized with an MLM head.") outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) loss = None logits = None num_labels = self.config.vocab_size if task.upper() == "MLM" else self.tasks[task] if task.upper() == "MLM": sequence_output = outputs[0] logits = self.cls(sequence_output) elif task.upper() in self.task_classifiers: pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.task_classifiers[task.upper()](pooled_output) else: raise ValueError(f"Invalid task: {task}") if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) BertMultiTaskConfig.register_for_auto_class() BertForMultiTaskClassification.register_for_auto_class("AutoModelForSequenceClassification")