File size: 4,029 Bytes
5559201
 
05c97e7
8900576
 
 
5559201
8900576
 
 
05c97e7
 
8900576
 
5559201
8900576
5559201
8900576
 
 
 
 
 
 
5559201
 
 
 
 
 
 
 
8900576
 
 
 
 
 
 
 
5559201
 
 
 
8900576
5559201
 
 
8900576
5559201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8900576
5559201
 
 
 
 
 
 
05c97e7
5559201
 
 
 
 
 
 
 
 
 
 
 
 
 
05c97e7
5559201
 
05c97e7
5559201
 
05c97e7
5559201
 
 
 
 
 
 
8900576
5559201
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import time
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer
from datasets import load_dataset

# Путь к модели и данным
model_path = "./unzipped_model"
validation_file = "mt5_validation_data-1.jsonl"

# Загрузка модели и токенизатора
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()

# Используем GPU, если доступно
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Загрузка валидационной выборки
dataset = load_dataset("json", data_files={"validation": validation_file})
val_data = dataset["validation"]

# Функция предсказания
def predict(text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256
    ).to(device)
    
    outputs = model.generate(
        **inputs,
        max_length=64,
        num_beams=5,
        early_stopping=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Счётчики
country_correct = 0
city_correct = 0
full_correct = 0

# Массив для хранения ошибочных предсказаний
# (где хотя бы что-то не совпало между таргетом и предиктом)
incorrect_samples = []

start_time = time.time()

for example in val_data:
    text = example["text"]
    target = example["target"]
    
    # Получаем предсказание
    prediction = predict(text)
    
    # Разбиваем на "Страна:Город" - если нет двоеточия, ставим "unknown" для второй части
    pred_parts = prediction.split(":", 1)
    if len(pred_parts) == 2:
        pred_country, pred_city = pred_parts[0].strip(), pred_parts[1].strip()
    else:
        pred_country = pred_parts[0].strip()
        pred_city = "unknown"

    # Аналогично таргет
    target_parts = target.split(":", 1)
    if len(target_parts) == 2:
        true_country, true_city = target_parts[0].strip(), target_parts[1].strip()
    else:
        true_country = target_parts[0].strip()
        true_city = "unknown"

    # Сравниваем
    if pred_country == true_country:
        country_correct += 1
    if pred_city == true_city:
        city_correct += 1
    if (pred_country == true_country) and (pred_city == true_city):
        full_correct += 1
    else:
        # Если хотя бы что-то не совпало, сохраним этот пример
        incorrect_samples.append({
            "text": text,
            "target": f"{true_country}:{true_city}",
            "prediction": f"{pred_country}:{pred_city}"
        })

end_time = time.time()
total_time = end_time - start_time

num_examples = len(val_data)
time_per_example = total_time / num_examples if num_examples > 0 else 0

# Выводим первые 80 «ошибочных» примеров (если их меньше, то все)
print("Примеры, где хотя бы что-то не совпало (макс. 80):")
for i, item in enumerate(incorrect_samples[:80]):
    print(f"\nПример {i+1}:")
    print(f"Текст:        {item['text']}")
    print(f"Таргет:       {item['target']}")
    print(f"Предсказание: {item['prediction']}")

# После этого выводим статистику
print("\nРезультаты:")
print(f"Всего примеров валидации: {num_examples}")
print(f"Совпало стран: {country_correct}")
print(f"Совпало городов: {city_correct}")
print(f"Полных совпадений (страна и город): {full_correct}")
print(f"Общее время выполнения скрипта: {total_time:.4f} сек.")
print(f"Время на одно предсказание: {time_per_example:.6f} сек.")