| from transformers import BertTokenizerFast, BertConfig | |
| import torch | |
| from datasets import load_dataset | |
| from ebert_model import EBertConfig, EBertModel | |
| from transformers import BertForMaskedLM | |
| from safetensors.torch import load_file | |
| import os | |
| model_path = "./ebert" | |
| tokenizer = BertTokenizerFast.from_pretrained(model_path) | |
| config = EBertConfig.from_pretrained(model_path) | |
| model = BertForMaskedLM(config) | |
| model.bert = EBertModel(config) | |
| weights_path = f"{model_path}/model.safetensors" | |
| if os.path.exists(weights_path): | |
| state_dict = load_file(weights_path) | |
| model.load_state_dict(state_dict, strict=False) | |
| else: | |
| raise FileNotFoundError(f"Файл {weights_path} не найден.") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| dataset = load_dataset("Expotion/russian-facts-qa", split="train") | |
| def predict_masked_text(example): | |
| text = f"{example['q'].strip()} {example['a'].strip()}" | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| input_ids = inputs["input_ids"].clone() | |
| labels = input_ids.clone() | |
| mask_token_index = 2 | |
| if input_ids.size(1) <= mask_token_index: | |
| return { | |
| "original_text": text, | |
| "masked_text": "Слишком короткий текст", | |
| "predicted_tokens": [], | |
| "true_token": "" | |
| } | |
| input_ids[0, mask_token_index] = tokenizer.mask_token_id | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = outputs.logits | |
| predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist() | |
| predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids] | |
| return { | |
| "original_text": text, | |
| "masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True), | |
| "predicted_tokens": predicted_tokens, | |
| "true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True) | |
| } | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| print(f"Общее количество параметров: {total_params}") | |
| num_examples = 1 | |
| for i in range(num_examples): | |
| result = predict_masked_text(dataset[i]) | |
| print(f"Оригинал: {result['original_text']}") | |
| print(f"Замаскированный: {result['masked_text']}") | |
| print(f"Предсказание: {result['predicted_tokens']}") | |
| print(f"Истина: {result['true_token']}") |