| import json |
| import matplotlib.pyplot as plt |
| import os |
| import numpy as np |
| from collections import defaultdict |
|
|
| def calculate_accuracy(label, predict): |
| """计算单个样本的准确率""" |
|
|
| label = label.replace("[PAD]", "").replace("[EOS]", "0") |
| predict = predict.replace("[PAD]", "").replace("[EOS]", "0") |
| total_chars = len(label) |
| correct_chars = sum(1 for l, p in zip(label, predict) if l == p) |
| return correct_chars / total_chars |
| |
|
|
| def evaluate_jsonl(file_path): |
| total_accuracy = 0 |
| sample_count = 0 |
| individual_accuracies = [] |
|
|
| with open(file_path, 'r') as f: |
| for line in f: |
| data = json.loads(line) |
| label = data['label'] |
| predict = data['predict'] |
|
|
| if len(label) != len(predict): |
| print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})") |
| continue |
|
|
| accuracy = calculate_accuracy(label, predict) |
| total_accuracy += accuracy |
| sample_count += 1 |
|
|
| if sample_count == 0: |
| return 0, [] |
| |
| avg_accuracy = total_accuracy / sample_count |
| return avg_accuracy |
|
|
| def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"): |
| t_values = [] |
| accuracies = [] |
| missing_files = [] |
| results = {} |
|
|
| for t in T_list: |
| file_path = os.path.join(exp_path, f"{dataset_name}_T{t}") |
| filename = os.path.join(file_path, "all_results.json") |
| if not os.path.exists(filename): |
| missing_files.append(filename) |
| continue |
| try: |
| with open(filename, 'r') as f: |
| data = json.load(f) |
| if acc_key not in data: |
| print(f"警告: {filename} 中未找到键 '{acc_key}'") |
| src_filename = os.path.join(file_path, "generated_predictions.jsonl") |
| data[acc_key] = evaluate_jsonl(src_filename) |
| with open(filename, 'w', encoding='utf-8') as f: |
| json.dump(data, f, ensure_ascii=False, indent=4) |
|
|
| acc = data[acc_key] |
| if isinstance(acc, str) and acc.endswith('%'): |
| acc = float(acc.strip('%')) / 100.0 |
| t_values.append(t) |
| accuracies.append(acc) |
| results[t] = acc |
| except Exception as e: |
| print(f"处理文件 {filename} 时出错: {str(e)}") |
| |
| return results, missing_files |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| experiments = [ |
| "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900", |
| "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910", |
| "output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232" |
| ] |
| dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"] |
|
|
|
|
| acc_key = "cell_acc" |
|
|
| for experiment in experiments: |
| for dataset_name in dataset_names: |
| file_path = os.path.join(experiment, dataset_name) |
| filename = os.path.join(file_path, "all_results.json") |
| with open(filename, 'r') as f: |
| data = json.load(f) |
| if acc_key not in data: |
| print(f"警告: {filename} 中未找到键 '{acc_key}'") |
| src_filename = os.path.join(file_path, "generated_predictions.jsonl") |
| data[acc_key] = evaluate_jsonl(src_filename) |
| with open(filename, 'w', encoding='utf-8') as f: |
| json.dump(data, f, ensure_ascii=False, indent=4) |
|
|
|
|