File size: 888 Bytes
ec97649
 
 
 
25b5df9
 
ec97649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import (
    DebertaV2Model,
    DebertaV2PreTrainedModel,
    DebertaV2Config,
)

class DebertaV3SequenceClassifier(DebertaV2PreTrainedModel):
    def __init__(self, config: DebertaV2Config):
        super().__init__(config)
        self.deberta = DebertaV2Model(config)
        self.d_model = self.deberta.embeddings.LayerNorm.weight.shape[0]
        self.head = nn.Linear(self.d_model, 1)
        self.post_init()
    
    @property
    def device(self):
        return self.head.weight.device

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        x = self.deberta(input_ids, attention_mask=attention_mask).last_hidden_state
        
        logits = self.head(x.mean(dim=-2))
        probs = F.sigmoid(logits)
        return {'logits': logits, 'probs': probs}