File size: 1,283 Bytes
1712f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from transformers import AutoModel, PreTrainedModel

class ElectraForMultiLabel(PreTrainedModel):
    def __init__(self, config, pos_weight=None):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.enc = AutoModel.from_pretrained(config._name_or_path)
        hid = self.enc.config.hidden_size
        self.dropout = torch.nn.Dropout(0.2)
        self.head = torch.nn.Linear(hid, self.num_labels)
        # pos_weight ๋“ฑ๋ก (GPU ์ด๋™ ์ž๋™ ์ง€์›)
        if pos_weight is not None:
            self.register_buffer("pos_weight", pos_weight)
        else:
            self.pos_weight = None
        self.post_init()

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        # [CLS] ํ† ํฐ ์ž„๋ฒ ๋”ฉ ์ถ”์ถœ
        h = self.enc(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
        h = self.dropout(h)
        logits = self.head(h)

        loss = None
        if labels is not None:
            if self.pos_weight is not None:
                loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=self.pos_weight)
            else:
                loss_fn = torch.nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}