File size: 2,488 Bytes
661c54a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import glob
import json
import re
import matplotlib.pyplot as plt
# ===== 配置根目录 =====
BASE_DIR = "/pfs/lichenyi/work/evaluation"
def collect_accuracies(base_dir: str):
"""
从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中
读取 summary.accuracy,返回两个 dict:
in_acc[step] = accuracy
ood_acc[step] = accuracy
"""
pattern = os.path.join(base_dir, "valid_score_*.json")
files = glob.glob(pattern)
in_acc = {}
ood_acc = {}
# 匹配文件名:valid_score_in_100.json / valid_score_ood_100.json
regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json")
for path in sorted(files):
fname = os.path.basename(path)
m = regex.match(fname)
if not m:
continue
split = m.group(1) # 'in' or 'ood'
step = int(m.group(2))
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
acc = data.get("summary", {}).get("accuracy", None)
if acc is None:
continue
if split == "in":
in_acc[step] = acc
else:
ood_acc[step] = acc
return in_acc, ood_acc
def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"):
"""
根据 in_acc 和 ood_acc 画图并保存为 out_path。
in_acc / ood_acc: dict[int, float]
"""
plt.figure(figsize=(8, 5))
# in-domain
if in_acc:
steps_in = sorted(in_acc.keys())
vals_in = [in_acc[s] for s in steps_in]
plt.plot(steps_in, vals_in, marker="o", label="in (ID)")
# out-of-domain
if ood_acc:
steps_ood = sorted(ood_acc.keys())
vals_ood = [ood_acc[s] for s in steps_ood]
plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)")
plt.xlabel("checkpoint / step")
plt.ylabel("accuracy")
plt.title("Validation Accuracy (in vs ood)")
plt.grid(True, linestyle=":")
plt.legend()
plt.tight_layout()
plt.savefig(out_path, dpi=300)
# 如需在终端弹出窗口查看,可取消下一行注释
# plt.show()
def main():
in_acc, ood_acc = collect_accuracies(BASE_DIR)
print("in-domain checkpoints and accuracies:", in_acc)
print("ood checkpoints and accuracies:", ood_acc)
plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png"))
if __name__ == "__main__":
main()
|