import torch import torch.nn as nn import torch.nn.functional as F from transformers import ( DebertaV2Model, DebertaV2PreTrainedModel, DebertaV2Config, ) class DebertaV3SequenceClassifier(DebertaV2PreTrainedModel): def __init__(self, config: DebertaV2Config): super().__init__(config) self.deberta = DebertaV2Model(config) self.d_model = self.deberta.embeddings.LayerNorm.weight.shape[0] self.head = nn.Linear(self.d_model, 1) self.post_init() @property def device(self): return self.head.weight.device def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): x = self.deberta(input_ids, attention_mask=attention_mask).last_hidden_state logits = self.head(x.mean(dim=-2)) probs = F.sigmoid(logits) return {'logits': logits, 'probs': probs}