MiaMao's picture
Add LoRA checkpoints (without PNG loss curves)
b843574
import json
import matplotlib.pyplot as plt
from pathlib import Path
# TODO: 换成你的 checkpoint 路径
ckpt_dir = Path(r"D:\Task_design\Topic\strategy_train\outputs\qwen7b-lora-topic_strategy\checkpoint-26250")
state_path = ckpt_dir / "trainer_state.json"
with open(state_path, "r", encoding="utf-8") as f:
state = json.load(f)
log_history = state["log_history"]
train_steps, train_loss = [], []
eval_steps, eval_loss = [], []
for log in log_history:
# 训练过程中的 loss 记录
if "loss" in log and "step" in log:
train_steps.append(log["step"])
train_loss.append(log["loss"])
# 在 dev/validation 上的 eval_loss 记录
if "eval_loss" in log and "step" in log:
eval_steps.append(log["step"])
eval_loss.append(log["eval_loss"])
plt.figure()
if train_steps:
plt.plot(train_steps, train_loss, label="train_loss")
if eval_steps:
plt.plot(eval_steps, eval_loss, label="eval_loss")
plt.xlabel("step")
plt.ylabel("loss")
plt.title("Training / Eval Loss Curve")
plt.legend()
plt.grid(True)
plt.ylim(0, 10)
# 保存成图片
out_path = ckpt_dir / "loss_curve.png"
plt.savefig(out_path, dpi=200)
print(f"保存训练曲线到: {out_path}")