#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import glob import json import re import matplotlib.pyplot as plt # ===== 配置根目录 ===== BASE_DIR = "/pfs/lichenyi/work/evaluation" def collect_accuracies(base_dir: str): """ 从 base_dir 下面的 valid_score_in_*.json 和 valid_score_ood_*.json 中 读取 summary.accuracy,返回两个 dict: in_acc[step] = accuracy ood_acc[step] = accuracy """ pattern = os.path.join(base_dir, "valid_score_*.json") files = glob.glob(pattern) in_acc = {} ood_acc = {} # 匹配文件名:valid_score_in_100.json / valid_score_ood_100.json regex = re.compile(r"valid_score_(in|ood)_(\d+)\.json") for path in sorted(files): fname = os.path.basename(path) m = regex.match(fname) if not m: continue split = m.group(1) # 'in' or 'ood' step = int(m.group(2)) with open(path, "r", encoding="utf-8") as f: data = json.load(f) acc = data.get("summary", {}).get("accuracy", None) if acc is None: continue if split == "in": in_acc[step] = acc else: ood_acc[step] = acc return in_acc, ood_acc def plot_accuracies(in_acc, ood_acc, out_path="valid_accuracy.png"): """ 根据 in_acc 和 ood_acc 画图并保存为 out_path。 in_acc / ood_acc: dict[int, float] """ plt.figure(figsize=(8, 5)) # in-domain if in_acc: steps_in = sorted(in_acc.keys()) vals_in = [in_acc[s] for s in steps_in] plt.plot(steps_in, vals_in, marker="o", label="in (ID)") # out-of-domain if ood_acc: steps_ood = sorted(ood_acc.keys()) vals_ood = [ood_acc[s] for s in steps_ood] plt.plot(steps_ood, vals_ood, marker="s", linestyle="--", label="ood (OOD)") plt.xlabel("checkpoint / step") plt.ylabel("accuracy") plt.title("Validation Accuracy (in vs ood)") plt.grid(True, linestyle=":") plt.legend() plt.tight_layout() plt.savefig(out_path, dpi=300) # 如需在终端弹出窗口查看,可取消下一行注释 # plt.show() def main(): in_acc, ood_acc = collect_accuracies(BASE_DIR) print("in-domain checkpoints and accuracies:", in_acc) print("ood checkpoints and accuracies:", ood_acc) plot_accuracies(in_acc, ood_acc, out_path=os.path.join(BASE_DIR, "valid_accuracy.png")) if __name__ == "__main__": main()