from transformers import BertTokenizerFast, BertConfig import torch from datasets import load_dataset from ebert_model import EBertConfig, EBertModel from transformers import BertForMaskedLM from safetensors.torch import load_file import os model_path = "./ebert" tokenizer = BertTokenizerFast.from_pretrained(model_path) config = EBertConfig.from_pretrained(model_path) model = BertForMaskedLM(config) model.bert = EBertModel(config) weights_path = f"{model_path}/model.safetensors" if os.path.exists(weights_path): state_dict = load_file(weights_path) model.load_state_dict(state_dict, strict=False) else: raise FileNotFoundError(f"Файл {weights_path} не найден.") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() dataset = load_dataset("Expotion/russian-facts-qa", split="train") def predict_masked_text(example): text = f"{example['q'].strip()} {example['a'].strip()}" inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512, padding="max_length") inputs = {k: v.to(device) for k, v in inputs.items()} input_ids = inputs["input_ids"].clone() labels = input_ids.clone() mask_token_index = 2 if input_ids.size(1) <= mask_token_index: return { "original_text": text, "masked_text": "Слишком короткий текст", "predicted_tokens": [], "true_token": "" } input_ids[0, mask_token_index] = tokenizer.mask_token_id with torch.no_grad(): outputs = model(**inputs) predictions = outputs.logits predicted_token_ids = torch.topk(predictions[0, mask_token_index], 5).indices.tolist() predicted_tokens = [tokenizer.decode([id]).strip() for id in predicted_token_ids] return { "original_text": text, "masked_text": tokenizer.decode(input_ids[0], skip_special_tokens=True), "predicted_tokens": predicted_tokens, "true_token": tokenizer.decode([labels[0, mask_token_index]], skip_special_tokens=True) } total_params = sum(p.numel() for p in model.parameters()) print(f"Общее количество параметров: {total_params}") num_examples = 1 for i in range(num_examples): result = predict_masked_text(dataset[i]) print(f"Оригинал: {result['original_text']}") print(f"Замаскированный: {result['masked_text']}") print(f"Предсказание: {result['predicted_tokens']}") print(f"Истина: {result['true_token']}")