| | import json |
| | import matplotlib.pyplot as plt |
| | from pathlib import Path |
| |
|
| | |
| | ckpt_dir = Path(r"D:\Task_design\Topic\strategy_train\outputs\qwen7b-lora-topic_strategy\checkpoint-26250") |
| | state_path = ckpt_dir / "trainer_state.json" |
| |
|
| | with open(state_path, "r", encoding="utf-8") as f: |
| | state = json.load(f) |
| |
|
| | log_history = state["log_history"] |
| |
|
| | train_steps, train_loss = [], [] |
| | eval_steps, eval_loss = [], [] |
| |
|
| | for log in log_history: |
| | |
| | if "loss" in log and "step" in log: |
| | train_steps.append(log["step"]) |
| | train_loss.append(log["loss"]) |
| | |
| | if "eval_loss" in log and "step" in log: |
| | eval_steps.append(log["step"]) |
| | eval_loss.append(log["eval_loss"]) |
| |
|
| | plt.figure() |
| | if train_steps: |
| | plt.plot(train_steps, train_loss, label="train_loss") |
| | if eval_steps: |
| | plt.plot(eval_steps, eval_loss, label="eval_loss") |
| |
|
| | plt.xlabel("step") |
| | plt.ylabel("loss") |
| | plt.title("Training / Eval Loss Curve") |
| | plt.legend() |
| | plt.grid(True) |
| |
|
| |
|
| | plt.ylim(0, 10) |
| | |
| | out_path = ckpt_dir / "loss_curve.png" |
| | plt.savefig(out_path, dpi=200) |
| | print(f"保存训练曲线到: {out_path}") |
| |
|