File size: 2,488 Bytes
661c54a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import glob
import json
import re
import matplotlib.pyplot as plt

# ===== 配置根目录 =====
BASE_DIR = "/pfs/lichenyi/work/evaluation"

def collect_accuracies(base_dir: str):
    """
    从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中
    读取 summary.accuracy,返回两个 dict:
        in_acc[step] = accuracy
        ood_acc[step] = accuracy
    """
    pattern = os.path.join(base_dir, "valid_score_*.json")
    files = glob.glob(pattern)

    in_acc = {}
    ood_acc = {}

    # 匹配文件名:valid_score_in_100.json / valid_score_ood_100.json
    regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json")

    for path in sorted(files):
        fname = os.path.basename(path)
        m = regex.match(fname)
        if not m:
            continue

        split = m.group(1)   # 'in' or 'ood'
        step = int(m.group(2))

        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)

        acc = data.get("summary", {}).get("accuracy", None)
        if acc is None:
            continue

        if split == "in":
            in_acc[step] = acc
        else:
            ood_acc[step] = acc

    return in_acc, ood_acc


def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"):
    """
    根据 in_acc 和 ood_acc 画图并保存为 out_path。
    in_acc / ood_acc: dict[int, float]
    """
    plt.figure(figsize=(8, 5))

    # in-domain
    if in_acc:
        steps_in = sorted(in_acc.keys())
        vals_in = [in_acc[s] for s in steps_in]
        plt.plot(steps_in, vals_in, marker="o", label="in (ID)")

    # out-of-domain
    if ood_acc:
        steps_ood = sorted(ood_acc.keys())
        vals_ood = [ood_acc[s] for s in steps_ood]
        plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)")

    plt.xlabel("checkpoint / step")
    plt.ylabel("accuracy")
    plt.title("Validation Accuracy (in vs ood)")
    plt.grid(True, linestyle=":")
    plt.legend()
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    # 如需在终端弹出窗口查看,可取消下一行注释
    # plt.show()


def main():
    in_acc, ood_acc = collect_accuracies(BASE_DIR)
    print("in-domain checkpoints and accuracies:", in_acc)
    print("ood checkpoints and accuracies:", ood_acc)
    plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png"))


if __name__ == "__main__":
    main()