File size: 888 Bytes
ec97649 25b5df9 ec97649 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
DebertaV2Model,
DebertaV2PreTrainedModel,
DebertaV2Config,
)
class DebertaV3SequenceClassifier(DebertaV2PreTrainedModel):
def __init__(self, config: DebertaV2Config):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.d_model = self.deberta.embeddings.LayerNorm.weight.shape[0]
self.head = nn.Linear(self.d_model, 1)
self.post_init()
@property
def device(self):
return self.head.weight.device
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
x = self.deberta(input_ids, attention_mask=attention_mask).last_hidden_state
logits = self.head(x.mean(dim=-2))
probs = F.sigmoid(logits)
return {'logits': logits, 'probs': probs}
|