noanabeshima's picture
Upload folder using huggingface_hub
f737748 verified
raw
history blame contribute delete
888 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
DebertaV2Model,
DebertaV2PreTrainedModel,
DebertaV2Config,
)
class DebertaV3SequenceClassifier(DebertaV2PreTrainedModel):
def __init__(self, config: DebertaV2Config):
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.d_model = self.deberta.embeddings.LayerNorm.weight.shape[0]
self.head = nn.Linear(self.d_model, 1)
self.post_init()
@property
def device(self):
return self.head.weight.device
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
x = self.deberta(input_ids, attention_mask=attention_mask).last_hidden_state
logits = self.head(x.mean(dim=-2))
probs = F.sigmoid(logits)
return {'logits': logits, 'probs': probs}