import time import torch from transformers import T5ForConditionalGeneration, AutoTokenizer from datasets import load_dataset # Путь к модели и данным model_path = "./unzipped_model" validation_file = "mt5_validation_data-1.jsonl" # Загрузка модели и токенизатора tokenizer = AutoTokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path) model.eval() # Используем GPU, если доступно device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) # Загрузка валидационной выборки dataset = load_dataset("json", data_files={"validation": validation_file}) val_data = dataset["validation"] # Функция предсказания def predict(text): inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=256 ).to(device) outputs = model.generate( **inputs, max_length=64, num_beams=5, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) # Счётчики country_correct = 0 city_correct = 0 full_correct = 0 # Массив для хранения ошибочных предсказаний # (где хотя бы что-то не совпало между таргетом и предиктом) incorrect_samples = [] start_time = time.time() for example in val_data: text = example["text"] target = example["target"] # Получаем предсказание prediction = predict(text) # Разбиваем на "Страна:Город" - если нет двоеточия, ставим "unknown" для второй части pred_parts = prediction.split(":", 1) if len(pred_parts) == 2: pred_country, pred_city = pred_parts[0].strip(), pred_parts[1].strip() else: pred_country = pred_parts[0].strip() pred_city = "unknown" # Аналогично таргет target_parts = target.split(":", 1) if len(target_parts) == 2: true_country, true_city = target_parts[0].strip(), target_parts[1].strip() else: true_country = target_parts[0].strip() true_city = "unknown" # Сравниваем if pred_country == true_country: country_correct += 1 if pred_city == true_city: city_correct += 1 if (pred_country == true_country) and (pred_city == true_city): full_correct += 1 else: # Если хотя бы что-то не совпало, сохраним этот пример incorrect_samples.append({ "text": text, "target": f"{true_country}:{true_city}", "prediction": f"{pred_country}:{pred_city}" }) end_time = time.time() total_time = end_time - start_time num_examples = len(val_data) time_per_example = total_time / num_examples if num_examples > 0 else 0 # Выводим первые 80 «ошибочных» примеров (если их меньше, то все) print("Примеры, где хотя бы что-то не совпало (макс. 80):") for i, item in enumerate(incorrect_samples[:80]): print(f"\nПример {i+1}:") print(f"Текст: {item['text']}") print(f"Таргет: {item['target']}") print(f"Предсказание: {item['prediction']}") # После этого выводим статистику print("\nРезультаты:") print(f"Всего примеров валидации: {num_examples}") print(f"Совпало стран: {country_correct}") print(f"Совпало городов: {city_correct}") print(f"Полных совпадений (страна и город): {full_correct}") print(f"Общее время выполнения скрипта: {total_time:.4f} сек.") print(f"Время на одно предсказание: {time_per_example:.6f} сек.")