File size: 4,029 Bytes
5559201 05c97e7 8900576 5559201 8900576 05c97e7 8900576 5559201 8900576 5559201 8900576 5559201 8900576 5559201 8900576 5559201 8900576 5559201 8900576 5559201 05c97e7 5559201 05c97e7 5559201 05c97e7 5559201 05c97e7 5559201 8900576 5559201 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import time
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer
from datasets import load_dataset
# Путь к модели и данным
model_path = "./unzipped_model"
validation_file = "mt5_validation_data-1.jsonl"
# Загрузка модели и токенизатора
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()
# Используем GPU, если доступно
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Загрузка валидационной выборки
dataset = load_dataset("json", data_files={"validation": validation_file})
val_data = dataset["validation"]
# Функция предсказания
def predict(text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
).to(device)
outputs = model.generate(
**inputs,
max_length=64,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Счётчики
country_correct = 0
city_correct = 0
full_correct = 0
# Массив для хранения ошибочных предсказаний
# (где хотя бы что-то не совпало между таргетом и предиктом)
incorrect_samples = []
start_time = time.time()
for example in val_data:
text = example["text"]
target = example["target"]
# Получаем предсказание
prediction = predict(text)
# Разбиваем на "Страна:Город" - если нет двоеточия, ставим "unknown" для второй части
pred_parts = prediction.split(":", 1)
if len(pred_parts) == 2:
pred_country, pred_city = pred_parts[0].strip(), pred_parts[1].strip()
else:
pred_country = pred_parts[0].strip()
pred_city = "unknown"
# Аналогично таргет
target_parts = target.split(":", 1)
if len(target_parts) == 2:
true_country, true_city = target_parts[0].strip(), target_parts[1].strip()
else:
true_country = target_parts[0].strip()
true_city = "unknown"
# Сравниваем
if pred_country == true_country:
country_correct += 1
if pred_city == true_city:
city_correct += 1
if (pred_country == true_country) and (pred_city == true_city):
full_correct += 1
else:
# Если хотя бы что-то не совпало, сохраним этот пример
incorrect_samples.append({
"text": text,
"target": f"{true_country}:{true_city}",
"prediction": f"{pred_country}:{pred_city}"
})
end_time = time.time()
total_time = end_time - start_time
num_examples = len(val_data)
time_per_example = total_time / num_examples if num_examples > 0 else 0
# Выводим первые 80 «ошибочных» примеров (если их меньше, то все)
print("Примеры, где хотя бы что-то не совпало (макс. 80):")
for i, item in enumerate(incorrect_samples[:80]):
print(f"\nПример {i+1}:")
print(f"Текст: {item['text']}")
print(f"Таргет: {item['target']}")
print(f"Предсказание: {item['prediction']}")
# После этого выводим статистику
print("\nРезультаты:")
print(f"Всего примеров валидации: {num_examples}")
print(f"Совпало стран: {country_correct}")
print(f"Совпало городов: {city_correct}")
print(f"Полных совпадений (страна и город): {full_correct}")
print(f"Общее время выполнения скрипта: {total_time:.4f} сек.")
print(f"Время на одно предсказание: {time_per_example:.6f} сек.")
|