ajkndfjsdfasdf commited on
Commit
8900576
·
verified ·
1 Parent(s): f3f4b7c

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +57 -0
test.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MT5ForConditionalGeneration, MT5Tokenizer
2
+ from datasets import load_dataset
3
+ import torch
4
+
5
+ # Путь к модели и данным
6
+ model_path = "./mt5-finetuned"
7
+ validation_file = "mt5_validation_data-1.jsonl"
8
+
9
+ # Загрузка модели и токенизатора
10
+ tokenizer = MT5Tokenizer.from_pretrained(model_path)
11
+ model = MT5ForConditionalGeneration.from_pretrained(model_path)
12
+ model.eval()
13
+
14
+ # Используем GPU если есть
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+ # Загрузка валидационной выборки
19
+ dataset = load_dataset("json", data_files={"validation": validation_file})
20
+ val_data = dataset["validation"]
21
+
22
+ # Функция предсказания
23
+ def predict(text):
24
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
25
+ outputs = model.generate(
26
+ **inputs,
27
+ max_length=64,
28
+ num_beams=5,
29
+ early_stopping=True
30
+ )
31
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
32
+
33
+ # Подсчёт точности
34
+ correct = 0
35
+ results = []
36
+
37
+ for idx, example in enumerate(val_data):
38
+ text = example["text"]
39
+ target = example["target"].strip()
40
+ pred = predict(text).strip()
41
+
42
+ results.append((text, pred, target))
43
+ if pred == target:
44
+ correct += 1
45
+
46
+ # Примеры
47
+ print("📋 Примеры предсказаний:\n")
48
+ for i, (text, pred, target) in enumerate(results[:80]): # кол-во примеров
49
+ print(f"#{i+1}")
50
+ print(f"📝 Вход: {text}")
51
+ print(f"✅ Target: {target}")
52
+ print(f"🤖 Предсказание: {pred}")
53
+ print("-" * 50)
54
+
55
+ # Accuracy
56
+ accuracy = correct / len(val_data)
57
+ print(f"\n✅ Accuracy: {accuracy:.4f} ({correct}/{len(val_data)})\n")