Update README.md
Browse filesФайл модели (model.py) :
from transformers import BertConfig, BertModel
import torch.nn as nn
class EBertConfig(BertConfig):
model_type = "ebert"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.adapter_size = kwargs.pop('adapter_size', None)
class EBertModel(BertModel):
config_class = EBertConfig
def __init__(self, config: EBertConfig):
super().__init__(config)
if config.adapter_size:
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.adapter_size),
nn.ReLU(),
nn.Linear(config.adapter_size, config.hidden_size),
)
for _ in range(config.num_hidden_layers)
])
else:
self.adapters = None
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
sequence_output = outputs.last_hidden_state
if self.adapters is not None:
for adapter in self.adapters:
sequence_output = sequence_output + adapter(sequence_output)
return outputs.__class__(
last_hidden_state=sequence_output,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Загрузка модели :
from transformers import BertTokenizerFast
from model import EBertConfig, EBertModel
import torch
checkpoint = "your-username/ebert-ru-facts"
config = EBertConfig.from_pretrained(checkpoint)
model = BertForMaskedLM(config)
model.bert = EBertModel(config)
model.load_state_dict(torch.load(f"{checkpoint}/model.safetensors"))
model.eval()
tokenizer = BertTokenizerFast.from_pretrained(checkpoint)