Darkester commited on
Commit
87169f1
·
verified ·
1 Parent(s): 8b535c8

Update README.md

Browse files

Файл модели (model.py) :

from transformers import BertConfig, BertModel
import torch.nn as nn

class EBertConfig(BertConfig):
model_type = "ebert"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.adapter_size = kwargs.pop('adapter_size', None)

class EBertModel(BertModel):
config_class = EBertConfig

def __init__(self, config: EBertConfig):
super().__init__(config)
if config.adapter_size:
self.adapters = nn.ModuleList([
nn.Sequential(
nn.Linear(config.hidden_size, config.adapter_size),
nn.ReLU(),
nn.Linear(config.adapter_size, config.hidden_size),
)
for _ in range(config.num_hidden_layers)
])
else:
self.adapters = None

def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
sequence_output = outputs.last_hidden_state

if self.adapters is not None:
for adapter in self.adapters:
sequence_output = sequence_output + adapter(sequence_output)

return outputs.__class__(
last_hidden_state=sequence_output,
pooler_output=outputs.pooler_output,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)


Загрузка модели :

from transformers import BertTokenizerFast
from model import EBertConfig, EBertModel
import torch

checkpoint = "your-username/ebert-ru-facts"
config = EBertConfig.from_pretrained(checkpoint)
model = BertForMaskedLM(config)
model.bert = EBertModel(config)
model.load_state_dict(torch.load(f"{checkpoint}/model.safetensors"))
model.eval()

tokenizer = BertTokenizerFast.from_pretrained(checkpoint)

Files changed (1) hide show
  1. README.md +7 -3
README.md CHANGED
@@ -1,3 +1,7 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - ru
5
+ base_model:
6
+ - google-bert/bert-base-uncased
7
+ ---