File size: 4,593 Bytes
7575c08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Optional, Union

import torch

from torch import nn
from transformers import (
    BertModel,
    BertPreTrainedModel,
)
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertOnlyMLMHead

from .configuration_bert import BertMultiTaskConfig

class BertForMultiTaskClassification(BertPreTrainedModel):
    config_class = BertMultiTaskConfig
    _tied_weights_keys = ["cls.predictions.decoder.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.tasks = config.tasks
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)

        task_layers = {}
        for task_name, num_labels in self.tasks.items():
            if task_name.upper() == "MLM":
                self.cls = BertOnlyMLMHead(config)
            else:
                task_layers[task_name.upper()] = nn.Linear(config.hidden_size, num_labels)
        self.task_classifiers = nn.ModuleDict(task_layers)

        # Initialize weights and apply final processing
        self.post_init()

    def get_output_embeddings(self):
        # This method tells the PreTrainedModel that self.cls.predictions.decoder is the output layer to be tied
        if hasattr(self, "cls"):
            return self.cls.predictions.decoder
        return None

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        task: str | None = None,  # For now the model will use single task per batch
    ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if task is None:
            raise ValueError(f"Task must be specified and one of {self.task_classifiers.keys()}")
        if task.upper() == "MLM":
            if not hasattr(self, "cls"):
                raise ValueError("Model was not initialized with an MLM head.")

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        loss = None
        logits = None
        num_labels = self.config.vocab_size if task.upper() == "MLM" else self.tasks[task]
        if task.upper() == "MLM":
            sequence_output = outputs[0]
            logits = self.cls(sequence_output)

        elif task.upper() in self.task_classifiers:
            pooled_output = outputs[1]
            pooled_output = self.dropout(pooled_output)
            logits = self.task_classifiers[task.upper()](pooled_output)

        else:
            raise ValueError(f"Invalid task: {task}")

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


BertMultiTaskConfig.register_for_auto_class()
BertForMultiTaskClassification.register_for_auto_class("AutoModelForSequenceClassification")