docker-default / test.py
ajkndfjsdfasdf's picture
Update test.py
5559201 verified
import time
import torch
from transformers import T5ForConditionalGeneration, AutoTokenizer
from datasets import load_dataset
# Путь к модели и данным
model_path = "./unzipped_model"
validation_file = "mt5_validation_data-1.jsonl"
# Загрузка модели и токенизатора
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
model.eval()
# Используем GPU, если доступно
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Загрузка валидационной выборки
dataset = load_dataset("json", data_files={"validation": validation_file})
val_data = dataset["validation"]
# Функция предсказания
def predict(text):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=256
).to(device)
outputs = model.generate(
**inputs,
max_length=64,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Счётчики
country_correct = 0
city_correct = 0
full_correct = 0
# Массив для хранения ошибочных предсказаний
# (где хотя бы что-то не совпало между таргетом и предиктом)
incorrect_samples = []
start_time = time.time()
for example in val_data:
text = example["text"]
target = example["target"]
# Получаем предсказание
prediction = predict(text)
# Разбиваем на "Страна:Город" - если нет двоеточия, ставим "unknown" для второй части
pred_parts = prediction.split(":", 1)
if len(pred_parts) == 2:
pred_country, pred_city = pred_parts[0].strip(), pred_parts[1].strip()
else:
pred_country = pred_parts[0].strip()
pred_city = "unknown"
# Аналогично таргет
target_parts = target.split(":", 1)
if len(target_parts) == 2:
true_country, true_city = target_parts[0].strip(), target_parts[1].strip()
else:
true_country = target_parts[0].strip()
true_city = "unknown"
# Сравниваем
if pred_country == true_country:
country_correct += 1
if pred_city == true_city:
city_correct += 1
if (pred_country == true_country) and (pred_city == true_city):
full_correct += 1
else:
# Если хотя бы что-то не совпало, сохраним этот пример
incorrect_samples.append({
"text": text,
"target": f"{true_country}:{true_city}",
"prediction": f"{pred_country}:{pred_city}"
})
end_time = time.time()
total_time = end_time - start_time
num_examples = len(val_data)
time_per_example = total_time / num_examples if num_examples > 0 else 0
# Выводим первые 80 «ошибочных» примеров (если их меньше, то все)
print("Примеры, где хотя бы что-то не совпало (макс. 80):")
for i, item in enumerate(incorrect_samples[:80]):
print(f"\nПример {i+1}:")
print(f"Текст: {item['text']}")
print(f"Таргет: {item['target']}")
print(f"Предсказание: {item['prediction']}")
# После этого выводим статистику
print("\nРезультаты:")
print(f"Всего примеров валидации: {num_examples}")
print(f"Совпало стран: {country_correct}")
print(f"Совпало городов: {city_correct}")
print(f"Полных совпадений (страна и город): {full_correct}")
print(f"Общее время выполнения скрипта: {total_time:.4f} сек.")
print(f"Время на одно предсказание: {time_per_example:.6f} сек.")