from transformers import BertPreTrainedModel, BertModel, AutoConfig, AutoModelForTokenClassification import torch import torch.nn as nn from transformers.modeling_outputs import TokenClassifierOutput from transformers.utils import TransformersKwargs, can_return_tuple from transformers.processing_utils import Unpack from .configuration_multilabelbert import MultiLabelBertConfig from typing import Optional class BertForMultiLabelTokenClassification(BertPreTrainedModel): config_class = MultiLabelBertConfig def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config, add_pooling_layer=False) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @can_return_tuple def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, labels: torch.Tensor | None = None, special_tokens_mask: Optional[torch.Tensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor] | TokenClassifierOutput: outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, return_dict=True, **kwargs, ) sequence_output = outputs[0] sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss(reduction = 'none') loss = loss_fct(logits, labels) if special_tokens_mask is not None: loss = loss[special_tokens_mask != 1].mean() else: loss = loss.mean() return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) AutoModelForTokenClassification.register(MultiLabelBertConfig, BertForMultiLabelTokenClassification) BertForMultiLabelTokenClassification.register_for_auto_class('AutoModelForTokenClassification')