File size: 1,117 Bytes
a849de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, PretrainedConfig
import transformers


class DistilBertClassifier(PreTrainedModel):

    def __init__(self, bert_config, model_name='distilbert-base-uncased', tokenizer_len=30528, freeze_bert=False):


        super().__init__(bert_config)
        D_in, H, D_out = 256, 50, 91

        self.bert = AutoModel.from_pretrained(model_name)
        self.bert.resize_token_embeddings(tokenizer_len)
        self.classifier = nn.Sequential(
            nn.GELU(),
            nn.Linear(self.bert.config.hidden_size, 300),
            nn.GELU(),
            nn.Dropout(0.05),
            nn.Linear(300, 91)
        )

        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids=input_ids,
                            attention_mask=attention_mask)

        last_hidden_state_cls = outputs[0][:, 0, :]
        logits = self.classifier(last_hidden_state_cls)
        return logits