Update test.py
Browse files
test.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from transformers import
|
| 2 |
from datasets import load_dataset
|
| 3 |
import torch
|
| 4 |
|
| 5 |
# Путь к модели и данным
|
| 6 |
-
model_path = "./
|
| 7 |
validation_file = "mt5_validation_data-1.jsonl"
|
| 8 |
|
| 9 |
# Загрузка модели и токенизатора
|
| 10 |
-
tokenizer =
|
| 11 |
-
model =
|
| 12 |
model.eval()
|
| 13 |
|
| 14 |
# Используем GPU если есть
|
|
@@ -32,6 +32,8 @@ def predict(text):
|
|
| 32 |
|
| 33 |
# Подсчёт точности
|
| 34 |
correct = 0
|
|
|
|
|
|
|
| 35 |
results = []
|
| 36 |
|
| 37 |
for idx, example in enumerate(val_data):
|
|
@@ -39,19 +41,44 @@ for idx, example in enumerate(val_data):
|
|
| 39 |
target = example["target"].strip()
|
| 40 |
pred = predict(text).strip()
|
| 41 |
|
| 42 |
-
|
| 43 |
if pred == target:
|
| 44 |
correct += 1
|
| 45 |
|
| 46 |
-
#
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
print(f"#{i+1}")
|
| 50 |
-
print(f"📝 Вход: {text}")
|
| 51 |
-
print(f"✅ Target: {target}")
|
| 52 |
-
print(f"🤖 Предсказание: {pred}")
|
| 53 |
print("-" * 50)
|
| 54 |
|
| 55 |
# Accuracy
|
| 56 |
-
|
| 57 |
-
print(f"\n✅ Accuracy: {
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import T5ForConditionalGeneration, AutoTokenizer
|
| 2 |
from datasets import load_dataset
|
| 3 |
import torch
|
| 4 |
|
| 5 |
# Путь к модели и данным
|
| 6 |
+
model_path = "./flan-t5-autobatch"
|
| 7 |
validation_file = "mt5_validation_data-1.jsonl"
|
| 8 |
|
| 9 |
# Загрузка модели и токенизатора
|
| 10 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 11 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
| 12 |
model.eval()
|
| 13 |
|
| 14 |
# Используем GPU если есть
|
|
|
|
| 32 |
|
| 33 |
# Подсчёт точности
|
| 34 |
correct = 0
|
| 35 |
+
correct_country = 0
|
| 36 |
+
correct_city = 0
|
| 37 |
results = []
|
| 38 |
|
| 39 |
for idx, example in enumerate(val_data):
|
|
|
|
| 41 |
target = example["target"].strip()
|
| 42 |
pred = predict(text).strip()
|
| 43 |
|
| 44 |
+
# Полное совпадение
|
| 45 |
if pred == target:
|
| 46 |
correct += 1
|
| 47 |
|
| 48 |
+
# Приведение к нижнему регистру и очистка пробелов
|
| 49 |
+
target_parts = [x.strip().lower() for x in target.split(",")]
|
| 50 |
+
pred_parts = [x.strip().lower() for x in pred.split(",")]
|
| 51 |
+
|
| 52 |
+
# Сравнение страны (первый элемент)
|
| 53 |
+
if target_parts and pred_parts and target_parts[0] == pred_parts[0]:
|
| 54 |
+
correct_country += 1
|
| 55 |
+
|
| 56 |
+
# Сравнение города (только если в target есть второй элемент)
|
| 57 |
+
if len(target_parts) > 1:
|
| 58 |
+
target_city = target_parts[-1]
|
| 59 |
+
pred_city = pred_parts[-1] if len(pred_parts) > 1 else ""
|
| 60 |
+
if target_city == pred_city:
|
| 61 |
+
correct_city += 1
|
| 62 |
+
|
| 63 |
+
results.append({
|
| 64 |
+
"text": text,
|
| 65 |
+
"pred": pred,
|
| 66 |
+
"target": target,
|
| 67 |
+
"match": pred == target
|
| 68 |
+
})
|
| 69 |
+
|
| 70 |
+
# Вывод до 80 первых НЕСОВПАВШИХ
|
| 71 |
+
mismatched = [r for r in results if not r["match"]]
|
| 72 |
+
print("📋 Несовпавшие предсказания (до 80):\n")
|
| 73 |
+
for i, r in enumerate(mismatched[:80]):
|
| 74 |
print(f"#{i+1}")
|
| 75 |
+
print(f"📝 Вход: {r['text']}")
|
| 76 |
+
print(f"✅ Target: {r['target']}")
|
| 77 |
+
print(f"🤖 Предсказание: {r['pred']}")
|
| 78 |
print("-" * 50)
|
| 79 |
|
| 80 |
# Accuracy
|
| 81 |
+
total = len(val_data)
|
| 82 |
+
print(f"\n✅ Accuracy (полное совпадение): {correct / total:.4f} ({correct}/{total})")
|
| 83 |
+
print(f"🌍 Accuracy по странам: {correct_country / total:.4f} ({correct_country}/{total})")
|
| 84 |
+
print(f"🏙️ Accuracy по городам: {correct_city / total:.4f} ({correct_city}/{total})\n")
|