| from transformers import PreTrainedModel, AutoModel | |
| import torch.nn as nn | |
| import torch | |
| from .configuration_deberta_arg_classifier import DebertaConfig | |
| class DebertaArgClassifier(PreTrainedModel): | |
| config_class = DebertaConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = AutoModel.from_pretrained("microsoft/deberta-large") | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, config.num_labels) | |
| self.criterion = nn.BCEWithLogitsLoss() | |
| def forward(self, input_ids, attention_mask, labels=None): | |
| output = self.bert(input_ids, attention_mask=attention_mask) | |
| output = self._cls_embeddings(output) | |
| output_cls = self.classifier(output) | |
| output = torch.sigmoid(output_cls) | |
| loss = None | |
| if labels is not None: | |
| loss = self.cirterion(output_cls, labels) | |
| return {"loss": loss, "output": output} | |
| return {"loss": loss, "output": output} | |
| def _cls_embeddings(self, output): | |
| '''Returns the embeddings corresponding to the <CLS> token of each text. ''' | |
| last_hidden_state = output[0] | |
| cls_embeddings = last_hidden_state[:, 0] | |
| return cls_embeddings |