ajkndfjsdfasdf commited on
Commit
855a350
·
verified ·
1 Parent(s): 5559201

Create compare.py

Browse files
Files changed (1) hide show
  1. compare.py +105 -0
compare.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer
4
+ from datasets import load_dataset
5
+
6
+ # Пути к моделям
7
+ mt5_path = "./mt5" # Локальная MT5 модель
8
+ byt5_path = "./unzipped_model" # Локальная или скачанная ByT5 модель
9
+
10
+ # Путь к данным
11
+ validation_file = "mt5_validation_data-1.jsonl"
12
+
13
+ # Загрузка моделей и токенизаторов
14
+ mt5_tokenizer = MT5Tokenizer.from_pretrained(mt5_path)
15
+ mt5_model = MT5ForConditionalGeneration.from_pretrained(mt5_path).eval()
16
+
17
+ byt5_tokenizer = AutoTokenizer.from_pretrained(byt5_path)
18
+ byt5_model = T5ForConditionalGeneration.from_pretrained(byt5_path).eval()
19
+
20
+ # Выбор устройства
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ mt5_model.to(device)
23
+ byt5_model.to(device)
24
+
25
+ # Загрузка валидационной выборки
26
+ dataset = load_dataset("json", data_files={"validation": validation_file})
27
+ val_data = dataset["validation"]
28
+
29
+ # Функция предсказания
30
+ def predict(model, tokenizer, text):
31
+ inputs = tokenizer(
32
+ text,
33
+ return_tensors="pt",
34
+ truncation=True,
35
+ padding=True,
36
+ max_length=256
37
+ ).to(device)
38
+
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_length=64,
42
+ num_beams=5,
43
+ early_stopping=True
44
+ )
45
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
46
+
47
+ # Статистика
48
+ country_match = 0
49
+ city_match = 0
50
+ full_match = 0
51
+ mismatch_samples = []
52
+
53
+ start_time = time.time()
54
+
55
+ for example in val_data:
56
+ text = example["text"]
57
+
58
+ # Предсказания от MT5 и ByT5
59
+ mt5_pred = predict(mt5_model, mt5_tokenizer, text)
60
+ byt5_pred = predict(byt5_model, byt5_tokenizer, text)
61
+
62
+ def split_prediction(pred):
63
+ parts = pred.split(":", 1)
64
+ if len(parts) == 2:
65
+ return parts[0].strip(), parts[1].strip()
66
+ else:
67
+ return parts[0].strip(), "unknown"
68
+
69
+ mt5_country, mt5_city = split_prediction(mt5_pred)
70
+ byt5_country, byt5_city = split_prediction(byt5_pred)
71
+
72
+ if mt5_country == byt5_country:
73
+ country_match += 1
74
+ if mt5_city == byt5_city:
75
+ city_match += 1
76
+ if mt5_country == byt5_country and mt5_city == byt5_city:
77
+ full_match += 1
78
+ else:
79
+ mismatch_samples.append({
80
+ "text": text,
81
+ "mt5_prediction": f"{mt5_country}:{mt5_city}",
82
+ "byt5_prediction": f"{byt5_country}:{byt5_city}"
83
+ })
84
+
85
+ end_time = time.time()
86
+ total_time = end_time - start_time
87
+ num_examples = len(val_data)
88
+ time_per_example = total_time / num_examples if num_examples > 0 else 0
89
+
90
+ # Вывод различий
91
+ print("Примеры, где хотя бы что-то не совпало между MT5 и ByT5 (макс. 80):")
92
+ for i, item in enumerate(mismatch_samples[:80]):
93
+ print(f"\nПример {i+1}:")
94
+ print(f"Текст: {item['text']}")
95
+ print(f"MT5 предсказал: {item['mt5_prediction']}")
96
+ print(f"ByT5 предсказал: {item['byt5_prediction']}")
97
+
98
+ # Итоги
99
+ print("\nРезультаты сравнения MT5 vs ByT5:")
100
+ print(f"Всего примеров: {num_examples}")
101
+ print(f"Совпало стран: {country_match}")
102
+ print(f"Совпало городов: {city_match}")
103
+ print(f"Полных совпадений: {full_match}")
104
+ print(f"Общее время выполнения: {total_time:.4f} сек.")
105
+ print(f"Время на одно сравнение: {time_per_example:.6f} сек.")