| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import ( | |
| DebertaV2Model, | |
| DebertaV2PreTrainedModel, | |
| DebertaV2Config, | |
| ) | |
| class DebertaV3SequenceClassifier(DebertaV2PreTrainedModel): | |
| def __init__(self, config: DebertaV2Config): | |
| super().__init__(config) | |
| self.deberta = DebertaV2Model(config) | |
| self.d_model = self.deberta.embeddings.LayerNorm.weight.shape[0] | |
| self.head = nn.Linear(self.d_model, 1) | |
| self.post_init() | |
| def device(self): | |
| return self.head.weight.device | |
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): | |
| x = self.deberta(input_ids, attention_mask=attention_mask).last_hidden_state | |
| logits = self.head(x.mean(dim=-2)) | |
| probs = F.sigmoid(logits) | |
| return {'logits': logits, 'probs': probs} | |