File size: 1,390 Bytes
b61b1b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from transformers import BertConfig, BertModel
import torch.nn as nn

class EBertConfig(BertConfig):
    model_type = "ebert"
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.adapter_size = kwargs.pop('adapter_size', None)

class EBertModel(BertModel):
    config_class = EBertConfig

    def __init__(self, config: EBertConfig):
        super().__init__(config)
        if config.adapter_size:
            self.adapters = nn.ModuleList([
                nn.Sequential(
                    nn.Linear(config.hidden_size, config.adapter_size),
                    nn.ReLU(),
                    nn.Linear(config.adapter_size, config.hidden_size),
                )
                for _ in range(config.num_hidden_layers)
            ])
        else:
            self.adapters = None

    def forward(self, *args, **kwargs):
        outputs = super().forward(*args, **kwargs)
        sequence_output = outputs.last_hidden_state

        if self.adapters is not None:
            for adapter in self.adapters:
                sequence_output = sequence_output + adapter(sequence_output)

        return outputs.__class__(
            last_hidden_state=sequence_output,
            pooler_output=outputs.pooler_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )