File size: 3,648 Bytes
077b816
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import matplotlib.pyplot as plt
import os
import numpy as np
from collections import defaultdict

def calculate_accuracy(label, predict):
    """计算单个样本的准确率"""

    label = label.replace("[PAD]", "").replace("[EOS]", "0")
    predict = predict.replace("[PAD]", "").replace("[EOS]", "0")
    total_chars = len(label)
    correct_chars = sum(1 for l, p in zip(label, predict) if l == p)
    return correct_chars / total_chars
    # return label == predict

def evaluate_jsonl(file_path):
    total_accuracy = 0
    sample_count = 0
    individual_accuracies = []

    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            label = data['label']
            predict = data['predict']

            if len(label) != len(predict):
                print(f"警告: 第{sample_count+1}行长度不一致 (label:{len(label)} vs predict:{len(predict)})")
                continue

            accuracy = calculate_accuracy(label, predict)
            total_accuracy += accuracy
            sample_count += 1

    if sample_count == 0:
        return 0, []
    
    avg_accuracy = total_accuracy / sample_count
    return avg_accuracy

def load_experiment_data(exp_path, dataset_name, T_list, acc_key="predict_acc"):
    t_values = []
    accuracies = []
    missing_files = []
    results = {}

    for t in T_list:
        file_path = os.path.join(exp_path, f"{dataset_name}_T{t}")
        filename = os.path.join(file_path, "all_results.json")
        if not os.path.exists(filename):
            missing_files.append(filename)
            continue
        try:
            with open(filename, 'r') as f:
                data = json.load(f)
            if acc_key not in data:
                print(f"警告: {filename} 中未找到键 '{acc_key}'")
                src_filename = os.path.join(file_path, "generated_predictions.jsonl")
                data[acc_key] = evaluate_jsonl(src_filename)
                with open(filename, 'w', encoding='utf-8') as f:
                    json.dump(data, f, ensure_ascii=False, indent=4)

            acc = data[acc_key]
            if isinstance(acc, str) and acc.endswith('%'):
                acc = float(acc.strip('%')) / 100.0
            t_values.append(t)
            accuracies.append(acc)
            results[t] = acc
        except Exception as e:
            print(f"处理文件 {filename} 时出错: {str(e)}")
    
    return results, missing_files


# experiments = [
#     "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250702-123750",
#     "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-021344"
# ]
# dataset_names = ["hard_test", "sudoku_test"]

experiments = [
    "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-073900",
    "output/sudoku/gpt2-model-bs1024-lr1e-3-ep100-20250703-075910",
    "output/sudoku/gpt2-model-bs1024-lr1e-3-ep300-20250618-082232"
]
dataset_names = ["sudoku_test", "sudoku_test", "sudoku_test"]


acc_key = "cell_acc"

for experiment in experiments:
    for dataset_name in dataset_names:
        file_path = os.path.join(experiment, dataset_name)
        filename = os.path.join(file_path, "all_results.json")
        with open(filename, 'r') as f:
            data = json.load(f)
            if acc_key not in data:
                print(f"警告: {filename} 中未找到键 '{acc_key}'")
                src_filename = os.path.join(file_path, "generated_predictions.jsonl")
                data[acc_key] = evaluate_jsonl(src_filename)
                with open(filename, 'w', encoding='utf-8') as f:
                    json.dump(data, f, ensure_ascii=False, indent=4)