ajkndfjsdfasdf commited on
Commit
5559201
·
verified ·
1 Parent(s): 05c97e7

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +76 -49
test.py CHANGED
@@ -1,9 +1,10 @@
 
 
1
  from transformers import T5ForConditionalGeneration, AutoTokenizer
2
  from datasets import load_dataset
3
- import torch
4
 
5
  # Путь к модели и данным
6
- model_path = "./flan-t5-autobatch"
7
  validation_file = "mt5_validation_data-1.jsonl"
8
 
9
  # Загрузка модели и токенизатора
@@ -11,9 +12,9 @@ tokenizer = AutoTokenizer.from_pretrained(model_path)
11
  model = T5ForConditionalGeneration.from_pretrained(model_path)
12
  model.eval()
13
 
14
- # Используем GPU если есть
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- model = model.to(device)
17
 
18
  # Загрузка валидационной выборки
19
  dataset = load_dataset("json", data_files={"validation": validation_file})
@@ -21,7 +22,14 @@ val_data = dataset["validation"]
21
 
22
  # Функция предсказания
23
  def predict(text):
24
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
 
 
 
 
 
 
 
25
  outputs = model.generate(
26
  **inputs,
27
  max_length=64,
@@ -30,55 +38,74 @@ def predict(text):
30
  )
31
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
32
 
33
- # Подсчёт точности
34
- correct = 0
35
- correct_country = 0
36
- correct_city = 0
37
- results = []
38
 
39
- for idx, example in enumerate(val_data):
40
- text = example["text"]
41
- target = example["target"].strip()
42
- pred = predict(text).strip()
43
 
44
- # Полное совпадение
45
- if pred == target:
46
- correct += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Приведение к нижнему регистру и очистка пробелов
49
- target_parts = [x.strip().lower() for x in target.split(",")]
50
- pred_parts = [x.strip().lower() for x in pred.split(",")]
 
 
 
 
51
 
52
- # Сравнение страны (первый элемент)
53
- if target_parts and pred_parts and target_parts[0] == pred_parts[0]:
54
- correct_country += 1
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- # Сравнение города (только если в target есть второй элемент)
57
- if len(target_parts) > 1:
58
- target_city = target_parts[-1]
59
- pred_city = pred_parts[-1] if len(pred_parts) > 1 else ""
60
- if target_city == pred_city:
61
- correct_city += 1
62
 
63
- results.append({
64
- "text": text,
65
- "pred": pred,
66
- "target": target,
67
- "match": pred == target
68
- })
69
 
70
- # Вывод до 80 первых НЕСОВПАВШИХ
71
- mismatched = [r for r in results if not r["match"]]
72
- print("📋 Несовпавшие предсказания (до 80):\n")
73
- for i, r in enumerate(mismatched[:80]):
74
- print(f"#{i+1}")
75
- print(f"📝 Вход: {r['text']}")
76
- print(f" Target: {r['target']}")
77
- print(f"🤖 Предсказание: {r['pred']}")
78
- print("-" * 50)
79
 
80
- # Accuracy
81
- total = len(val_data)
82
- print(f"\n✅ Accuracy (полное совпадение): {correct / total:.4f} ({correct}/{total})")
83
- print(f"🌍 Accuracy по странам: {correct_country / total:.4f} ({correct_country}/{total})")
84
- print(f"🏙️ Accuracy по городам: {correct_city / total:.4f} ({correct_city}/{total})\n")
 
 
 
 
1
+ import time
2
+ import torch
3
  from transformers import T5ForConditionalGeneration, AutoTokenizer
4
  from datasets import load_dataset
 
5
 
6
  # Путь к модели и данным
7
+ model_path = "./unzipped_model"
8
  validation_file = "mt5_validation_data-1.jsonl"
9
 
10
  # Загрузка модели и токенизатора
 
12
  model = T5ForConditionalGeneration.from_pretrained(model_path)
13
  model.eval()
14
 
15
+ # Используем GPU, если доступно
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model.to(device)
18
 
19
  # Загрузка валидационной выборки
20
  dataset = load_dataset("json", data_files={"validation": validation_file})
 
22
 
23
  # Функция предсказания
24
  def predict(text):
25
+ inputs = tokenizer(
26
+ text,
27
+ return_tensors="pt",
28
+ truncation=True,
29
+ padding=True,
30
+ max_length=256
31
+ ).to(device)
32
+
33
  outputs = model.generate(
34
  **inputs,
35
  max_length=64,
 
38
  )
39
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
40
 
41
+ # Счётчики
42
+ country_correct = 0
43
+ city_correct = 0
44
+ full_correct = 0
 
45
 
46
+ # Массив для хранения ошибочных предсказаний
47
+ # (где хотя бы что-то не совпало между таргетом и предиктом)
48
+ incorrect_samples = []
 
49
 
50
+ start_time = time.time()
51
+
52
+ for example in val_data:
53
+ text = example["text"]
54
+ target = example["target"]
55
+
56
+ # Получаем предсказание
57
+ prediction = predict(text)
58
+
59
+ # Разбиваем на "Страна:Город" - если нет двоеточия, ставим "unknown" для второй части
60
+ pred_parts = prediction.split(":", 1)
61
+ if len(pred_parts) == 2:
62
+ pred_country, pred_city = pred_parts[0].strip(), pred_parts[1].strip()
63
+ else:
64
+ pred_country = pred_parts[0].strip()
65
+ pred_city = "unknown"
66
 
67
+ # Аналогично таргет
68
+ target_parts = target.split(":", 1)
69
+ if len(target_parts) == 2:
70
+ true_country, true_city = target_parts[0].strip(), target_parts[1].strip()
71
+ else:
72
+ true_country = target_parts[0].strip()
73
+ true_city = "unknown"
74
 
75
+ # Сравниваем
76
+ if pred_country == true_country:
77
+ country_correct += 1
78
+ if pred_city == true_city:
79
+ city_correct += 1
80
+ if (pred_country == true_country) and (pred_city == true_city):
81
+ full_correct += 1
82
+ else:
83
+ # Если хотя бы что-то не совпало, сохраним этот пример
84
+ incorrect_samples.append({
85
+ "text": text,
86
+ "target": f"{true_country}:{true_city}",
87
+ "prediction": f"{pred_country}:{pred_city}"
88
+ })
89
 
90
+ end_time = time.time()
91
+ total_time = end_time - start_time
 
 
 
 
92
 
93
+ num_examples = len(val_data)
94
+ time_per_example = total_time / num_examples if num_examples > 0 else 0
 
 
 
 
95
 
96
+ # Выводим первые 80 «ошибочных» примеров (если их меньше, то все)
97
+ print("Приме��ы, где хотя бы что-то не совпало (макс. 80):")
98
+ for i, item in enumerate(incorrect_samples[:80]):
99
+ print(f"\nПример {i+1}:")
100
+ print(f"Текст: {item['text']}")
101
+ print(f"Таргет: {item['target']}")
102
+ print(f"Предсказание: {item['prediction']}")
 
 
103
 
104
+ # После этого выводим статистику
105
+ print("\nРезультаты:")
106
+ print(f"Всего примеров валидации: {num_examples}")
107
+ print(f"Совпало стран: {country_correct}")
108
+ print(f"Совпало городов: {city_correct}")
109
+ print(f"Полных совпадений (страна и город): {full_correct}")
110
+ print(f"Общее время выполнения скрипта: {total_time:.4f} сек.")
111
+ print(f"Время на одно предсказание: {time_per_example:.6f} сек.")