ajkndfjsdfasdf commited on
Commit
05c97e7
·
verified ·
1 Parent(s): d47a7fe

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +40 -13
test.py CHANGED
@@ -1,14 +1,14 @@
1
- from transformers import MT5ForConditionalGeneration, MT5Tokenizer
2
  from datasets import load_dataset
3
  import torch
4
 
5
  # Путь к модели и данным
6
- model_path = "./mt5-finetuned"
7
  validation_file = "mt5_validation_data-1.jsonl"
8
 
9
  # Загрузка модели и токенизатора
10
- tokenizer = MT5Tokenizer.from_pretrained(model_path)
11
- model = MT5ForConditionalGeneration.from_pretrained(model_path)
12
  model.eval()
13
 
14
  # Используем GPU если есть
@@ -32,6 +32,8 @@ def predict(text):
32
 
33
  # Подсчёт точности
34
  correct = 0
 
 
35
  results = []
36
 
37
  for idx, example in enumerate(val_data):
@@ -39,19 +41,44 @@ for idx, example in enumerate(val_data):
39
  target = example["target"].strip()
40
  pred = predict(text).strip()
41
 
42
- results.append((text, pred, target))
43
  if pred == target:
44
  correct += 1
45
 
46
- # Примеры
47
- print("📋 Примеры предсказаний:\n")
48
- for i, (text, pred, target) in enumerate(results[:80]): # кол-во примеров
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  print(f"#{i+1}")
50
- print(f"📝 Вход: {text}")
51
- print(f"✅ Target: {target}")
52
- print(f"🤖 Предсказание: {pred}")
53
  print("-" * 50)
54
 
55
  # Accuracy
56
- accuracy = correct / len(val_data)
57
- print(f"\n✅ Accuracy: {accuracy:.4f} ({correct}/{len(val_data)})\n")
 
 
 
1
+ from transformers import T5ForConditionalGeneration, AutoTokenizer
2
  from datasets import load_dataset
3
  import torch
4
 
5
  # Путь к модели и данным
6
+ model_path = "./flan-t5-autobatch"
7
  validation_file = "mt5_validation_data-1.jsonl"
8
 
9
  # Загрузка модели и токенизатора
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
11
+ model = T5ForConditionalGeneration.from_pretrained(model_path)
12
  model.eval()
13
 
14
  # Используем GPU если есть
 
32
 
33
  # Подсчёт точности
34
  correct = 0
35
+ correct_country = 0
36
+ correct_city = 0
37
  results = []
38
 
39
  for idx, example in enumerate(val_data):
 
41
  target = example["target"].strip()
42
  pred = predict(text).strip()
43
 
44
+ # Полное совпадение
45
  if pred == target:
46
  correct += 1
47
 
48
+ # Приведение к нижнему регистру и очистка пробелов
49
+ target_parts = [x.strip().lower() for x in target.split(",")]
50
+ pred_parts = [x.strip().lower() for x in pred.split(",")]
51
+
52
+ # Сравнение страны (первый элемент)
53
+ if target_parts and pred_parts and target_parts[0] == pred_parts[0]:
54
+ correct_country += 1
55
+
56
+ # Сравнение города (только если в target есть второй элемент)
57
+ if len(target_parts) > 1:
58
+ target_city = target_parts[-1]
59
+ pred_city = pred_parts[-1] if len(pred_parts) > 1 else ""
60
+ if target_city == pred_city:
61
+ correct_city += 1
62
+
63
+ results.append({
64
+ "text": text,
65
+ "pred": pred,
66
+ "target": target,
67
+ "match": pred == target
68
+ })
69
+
70
+ # Вывод до 80 первых НЕСОВПАВШИХ
71
+ mismatched = [r for r in results if not r["match"]]
72
+ print("📋 Несовпавшие предсказания (до 80):\n")
73
+ for i, r in enumerate(mismatched[:80]):
74
  print(f"#{i+1}")
75
+ print(f"📝 Вход: {r['text']}")
76
+ print(f"✅ Target: {r['target']}")
77
+ print(f"🤖 Предсказание: {r['pred']}")
78
  print("-" * 50)
79
 
80
  # Accuracy
81
+ total = len(val_data)
82
+ print(f"\n✅ Accuracy (полное совпадение): {correct / total:.4f} ({correct}/{total})")
83
+ print(f"🌍 Accuracy по странам: {correct_country / total:.4f} ({correct_country}/{total})")
84
+ print(f"🏙️ Accuracy по городам: {correct_city / total:.4f} ({correct_city}/{total})\n")