File size: 1,390 Bytes
b61b1b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from transformers import BertConfig, BertModel
import torch.nn as nn
class EBertConfig(BertConfig):
model_type = "ebert"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.adapter_size = kwargs.pop('adapter_size', None)
class EBertModel(BertModel):
config_class = EBertConfig
def __init__(self, config: EBertConfig):
super().__init__(config)
if config.adapter_size:
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.adapter_size),
nn.ReLU(),
nn.Linear(config.adapter_size, config.hidden_size),
)
for _ in range(config.num_hidden_layers)
])
else:
self.adapters = None
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
sequence_output = outputs.last_hidden_state
if self.adapters is not None:
for adapter in self.adapters:
sequence_output = sequence_output + adapter(sequence_output)
return outputs.__class__(
last_hidden_state=sequence_output,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
|