File size: 3,625 Bytes
855a350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import time
import torch
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer
from datasets import load_dataset

# Пути к моделям
mt5_path = "./mt5"              # Локальная MT5 модель
byt5_path = "./unzipped_model"            # Локальная или скачанная ByT5 модель

# Путь к данным
validation_file = "mt5_validation_data-1.jsonl"

# Загрузка моделей и токенизаторов
mt5_tokenizer = MT5Tokenizer.from_pretrained(mt5_path)
mt5_model = MT5ForConditionalGeneration.from_pretrained(mt5_path).eval()

byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_path)
byt5_model = T5ForConditionalGeneration.from_pretrained(byt5_path).eval()

# Выбор устройства
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mt5_model.to(device)
byt5_model.to(device)

# Загрузка валидационной выборки
dataset = load_dataset("json", data_files={"validation": validation_file})
val_data = dataset["validation"]

# Функция предсказания
def predict(model, tokenizer, text):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=256
    ).to(device)

    outputs = model.generate(
        **inputs,
        max_length=64,
        num_beams=5,
        early_stopping=True
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Статистика
country_match = 0
city_match = 0
full_match = 0
mismatch_samples = []

start_time = time.time()

for example in val_data:
    text = example["text"]

    # Предсказания от MT5 и ByT5
    mt5_pred = predict(mt5_model, mt5_tokenizer, text)
    byt5_pred = predict(byt5_model, byt5_tokenizer, text)

    def split_prediction(pred):
        parts = pred.split(":", 1)
        if len(parts) == 2:
            return parts[0].strip(), parts[1].strip()
        else:
            return parts[0].strip(), "unknown"

    mt5_country, mt5_city = split_prediction(mt5_pred)
    byt5_country, byt5_city = split_prediction(byt5_pred)

    if mt5_country == byt5_country:
        country_match += 1
    if mt5_city == byt5_city:
        city_match += 1
    if mt5_country == byt5_country and mt5_city == byt5_city:
        full_match += 1
    else:
        mismatch_samples.append({
            "text": text,
            "mt5_prediction": f"{mt5_country}:{mt5_city}",
            "byt5_prediction": f"{byt5_country}:{byt5_city}"
        })

end_time = time.time()
total_time = end_time - start_time
num_examples = len(val_data)
time_per_example = total_time / num_examples if num_examples > 0 else 0

# Вывод различий
print("Примеры, где хотя бы что-то не совпало между MT5 и ByT5 (макс. 80):")
for i, item in enumerate(mismatch_samples[:80]):
    print(f"\nПример {i+1}:")
    print(f"Текст:         {item['text']}")
    print(f"MT5 предсказал:  {item['mt5_prediction']}")
    print(f"ByT5 предсказал: {item['byt5_prediction']}")

# Итоги
print("\nРезультаты сравнения MT5 vs ByT5:")
print(f"Всего примеров: {num_examples}")
print(f"Совпало стран: {country_match}")
print(f"Совпало городов: {city_match}")
print(f"Полных совпадений: {full_match}")
print(f"Общее время выполнения: {total_time:.4f} сек.")
print(f"Время на одно сравнение: {time_per_example:.6f} сек.")