File size: 2,619 Bytes
b61b1b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from transformers import BertTokenizerFast, BertConfig
import torch
from datasets import load_dataset
from ebert_model import EBertConfig, EBertModel
from transformers import BertForMaskedLM
from safetensors.torch import load_file
import os

model_path = "./ebert"
tokenizer = BertTokenizerFast.from_pretrained(model_path)

config = EBertConfig.from_pretrained(model_path)
model = BertForMaskedLM(config)
model.bert = EBertModel(config)

weights_path = f"{model_path}/model.safetensors"
if os.path.exists(weights_path):
    state_dict = load_file(weights_path)
    model.load_state_dict(state_dict, strict=False)
else:
    raise FileNotFoundError(f"Файл {weights_path} не найден.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
dataset = load_dataset("Expotion/russian-facts-qa", split="train")

def predict_masked_text(example):
    text = f"{example['q'].strip()} {example['a'].strip()}"
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    input_ids = inputs["input_ids"].clone()
    labels = input_ids.clone()
    mask_token_index = 2
    if input_ids.size(1) <= mask_token_index:
        return {
            "original_text": text,
            "masked_text": "Слишком короткий текст",
            "predicted_tokens": [],
            "true_token": ""
        }
    input_ids[0, mask_token_index] = tokenizer.mask_token_id
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = outputs.logits
    predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist()
    predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids]

    return {
        "original_text": text,
        "masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True),
        "predicted_tokens": predicted_tokens,
        "true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True)
    }

total_params = sum(p.numel() for p in model.parameters())
print(f"Общее количество параметров: {total_params}")

num_examples = 1
for i in range(num_examples):
    result = predict_masked_text(dataset[i])
    print(f"Оригинал: {result['original_text']}")
    print(f"Замаскированный: {result['masked_text']}")
    print(f"Предсказание: {result['predicted_tokens']}")
    print(f"Истина: {result['true_token']}")