EBERT_ru_MLM / proverka.py
Darkester's picture
Upload 2 files
b61b1b5 verified
from transformers import BertTokenizerFast, BertConfig
import torch
from datasets import load_dataset
from ebert_model import EBertConfig, EBertModel
from transformers import BertForMaskedLM
from safetensors.torch import load_file
import os
model_path = "./ebert"
tokenizer = BertTokenizerFast.from_pretrained(model_path)
config = EBertConfig.from_pretrained(model_path)
model = BertForMaskedLM(config)
model.bert = EBertModel(config)
weights_path = f"{model_path}/model.safetensors"
if os.path.exists(weights_path):
state_dict = load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
else:
raise FileNotFoundError(f"Файл {weights_path} не найден.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
dataset = load_dataset("Expotion/russian-facts-qa", split="train")
def predict_masked_text(example):
text = f"{example['q'].strip()} {example['a'].strip()}"
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
inputs = {k: v.to(device) for k, v in inputs.items()}
input_ids = inputs["input_ids"].clone()
labels = input_ids.clone()
mask_token_index = 2
if input_ids.size(1) <= mask_token_index:
return {
"original_text": text,
"masked_text": "Слишком короткий текст",
"predicted_tokens": [],
"true_token": ""
}
input_ids[0, mask_token_index] = tokenizer.mask_token_id
with torch.no_grad():
outputs = model(**inputs)
predictions = outputs.logits
predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist()
predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids]
return {
"original_text": text,
"masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True),
"predicted_tokens": predicted_tokens,
"true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True)
}
total_params = sum(p.numel() for p in model.parameters())
print(f"Общее количество параметров: {total_params}")
num_examples = 1
for i in range(num_examples):
result = predict_masked_text(dataset[i])
print(f"Оригинал: {result['original_text']}")
print(f"Замаскированный: {result['masked_text']}")
print(f"Предсказание: {result['predicted_tokens']}")
print(f"Истина: {result['true_token']}")