File size: 2,033 Bytes
1ad47ee
 
 
 
 
 
 
 
 
 
 
 
 
 
42c6c29
3c8a80e
1ad47ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9289834
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from transformers import PreTrainedModel, AutoModel, AutoConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from .configuration_embedder_with_mlp import EmbedderWithMLPConfig



class EmbedderWithMLP(PreTrainedModel):
    config_class = EmbedderWithMLPConfig

    def __init__(self, config, embedder=None):
        super().__init__(config)
        if embedder is None:
            backbone_config = AutoConfig.from_pretrained(config.model_name)
            self.embedder = AutoModel.from_config(backbone_config)
        else:
            self.embedder = embedder  # HF backbone (BERT, RoBERTa, etc.)

        hidden_layer_list = [config.input_size] + config.hidden_layer_list

        layers = nn.ModuleList()
        for in_dim, out_dim in zip(hidden_layer_list[:-1], hidden_layer_list[1:]):
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(config.dropout_ratio))
        layers.append(nn.Linear(hidden_layer_list[-1], 1))
        self.mlp = nn.Sequential(*layers)

        self.post_init()  # PreTrainedModel 필수

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        emb = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
        emb = emb.pooler_output
        # emb = F.normalize(emb, p=2, dim=1)
        logit = self.mlp(emb)

        loss = None
        if labels is not None:
            loss = nn.BCELoss()(logit.view(-1), labels.float())

        return {"loss": loss, "logits": logit}

    def inference(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        emb = self.embedder(input_ids=input_ids, attention_mask=attention_mask)
        emb = emb.pooler_output
        # emb = F.normalize(emb, p=2, dim=1)
        logit = self.mlp(emb)
        prob = torch.sigmoid(logit)
        score = prob.view(-1) * 24
        score = score.detach().cpu().numpy()
        score = score.item()
        score = round(score)
        return {"score": score}