import json import matplotlib.pyplot as plt from pathlib import Path # TODO: 换成你的 checkpoint 路径 ckpt_dir = Path(r"D:\Task_design\Topic\strategy_train\outputs\qwen7b-lora-topic_strategy\checkpoint-26250") state_path = ckpt_dir / "trainer_state.json" with open(state_path, "r", encoding="utf-8") as f: state = json.load(f) log_history = state["log_history"] train_steps, train_loss = [], [] eval_steps, eval_loss = [], [] for log in log_history: # 训练过程中的 loss 记录 if "loss" in log and "step" in log: train_steps.append(log["step"]) train_loss.append(log["loss"]) # 在 dev/validation 上的 eval_loss 记录 if "eval_loss" in log and "step" in log: eval_steps.append(log["step"]) eval_loss.append(log["eval_loss"]) plt.figure() if train_steps: plt.plot(train_steps, train_loss, label="train_loss") if eval_steps: plt.plot(eval_steps, eval_loss, label="eval_loss") plt.xlabel("step") plt.ylabel("loss") plt.title("Training / Eval Loss Curve") plt.legend() plt.grid(True) plt.ylim(0, 10) # 保存成图片 out_path = ckpt_dir / "loss_curve.png" plt.savefig(out_path, dpi=200) print(f"保存训练曲线到: {out_path}")