import json import matplotlib.pyplot as plt import os import numpy as np from collections import defaultdict def calculate_accuracy(label, predict): """计算单个样本的准确率""" label = label.replace("[PAD]", "").replace("[EOS]", "0") predict = predict.replace("[PAD]", "").replace("[EOS]", "0") total_chars = len(label) correct_chars = sum(1 for l, p in zip(label, predict) if l == p) return correct_chars / total_chars # return label == predict def evaluate_jsonl(file_path): total_accuracy = 0 sample_count = 0 individual_accuracies = [] with open(file_path, 'r') as f: for line in f: data = json.loads(line) label = data['label'] predict = data['predict'] if len(label) != len(predict): print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})") continue accuracy = calculate_accuracy(label, predict) total_accuracy += accuracy sample_count += 1 if sample_count == 0: return 0, [] avg_accuracy = total_accuracy / sample_count return avg_accuracy def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"): t_values = [] accuracies = [] missing_files = [] results = {} for t in T_list: file_path = os.path.join(exp_path, f"{dataset_name}_T{t}") filename = os.path.join(file_path, "all_results.json") if not os.path.exists(filename): missing_files.append(filename) continue try: with open(filename, 'r') as f: data = json.load(f) if acc_key not in data: print(f"警告: {filename} 中未找到键 '{acc_key}'") src_filename = os.path.join(file_path, "generated_predictions.jsonl") data[acc_key] = evaluate_jsonl(src_filename) with open(filename, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=4) acc = data[acc_key] if isinstance(acc, str) and acc.endswith('%'): acc = float(acc.strip('%')) / 100.0 t_values.append(t) accuracies.append(acc) results[t] = acc except Exception as e: print(f"处理文件 {filename} 时出错: {str(e)}") return results, missing_files # experiments = [ # "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250702-123750", # "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-021344" # ] # dataset_names = ["hard_test", "sudoku_test"] experiments = [ "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900", "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910", "output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232" ] dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"] acc_key = "cell_acc" for experiment in experiments: for dataset_name in dataset_names: file_path = os.path.join(experiment, dataset_name) filename = os.path.join(file_path, "all_results.json") with open(filename, 'r') as f: data = json.load(f) if acc_key not in data: print(f"警告: {filename} 中未找到键 '{acc_key}'") src_filename = os.path.join(file_path, "generated_predictions.jsonl") data[acc_key] = evaluate_jsonl(src_filename) with open(filename, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=4)