| from transformers import BertPreTrainedModel, BertModel, AutoConfig, AutoModelForTokenClassification |
| import torch |
| import torch.nn as nn |
|
|
| from transformers.modeling_outputs import TokenClassifierOutput |
| from transformers.utils import TransformersKwargs, can_return_tuple |
| from transformers.processing_utils import Unpack |
|
|
| from .configuration_multilabelbert import MultiLabelBertConfig |
|
|
| from typing import Optional |
|
|
|
|
| class BertForMultiLabelTokenClassification(BertPreTrainedModel): |
| config_class = MultiLabelBertConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.bert = BertModel(config, add_pooling_layer=False) |
| classifier_dropout = ( |
| config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
| ) |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| token_type_ids: torch.Tensor | None = None, |
| position_ids: torch.Tensor | None = None, |
| inputs_embeds: torch.Tensor | None = None, |
| labels: torch.Tensor | None = None, |
| special_tokens_mask: Optional[torch.Tensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor] | TokenClassifierOutput: |
| outputs = self.bert( |
| input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| sequence_output = outputs[0] |
|
|
| sequence_output = self.dropout(sequence_output) |
| logits = self.classifier(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.BCEWithLogitsLoss(reduction = 'none') |
| loss = loss_fct(logits, labels) |
|
|
| if special_tokens_mask is not None: |
| loss = loss[special_tokens_mask != 1].mean() |
| else: |
| loss = loss.mean() |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| AutoModelForTokenClassification.register(MultiLabelBertConfig, BertForMultiLabelTokenClassification) |
| BertForMultiLabelTokenClassification.register_for_auto_class('AutoModelForTokenClassification') |
|
|