Upload LoRA adapter folder
Browse files- plot_loss_from_trainer_state.py +160 -0
- plots/eval_loss_curve.png +0 -0
- plots/loss_curve.png +0 -0
- plots/lr_curve.png +0 -0
plot_loss_from_trainer_state.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Usage:
|
| 4 |
+
python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \
|
| 5 |
+
--checkpoint_steps 263,526,789,1052
|
| 6 |
+
|
| 7 |
+
功能:
|
| 8 |
+
- Curve: 黃橘色實線
|
| 9 |
+
- Grid: x,y 虛線
|
| 10 |
+
- Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch)
|
| 11 |
+
- Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見)
|
| 12 |
+
"""
|
| 13 |
+
import json, argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import matplotlib.pyplot as plt
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
YELLOW_ORANGE = "#d58f00"
|
| 19 |
+
BLUE = "#1f77b4"
|
| 20 |
+
|
| 21 |
+
def find_epoch_boundaries(log_items):
|
| 22 |
+
"""找到每個 epoch 邊界 (包含最後一個)"""
|
| 23 |
+
boundaries = []
|
| 24 |
+
prev_epoch_int = None
|
| 25 |
+
seen = set()
|
| 26 |
+
last_step, last_epoch = None, None
|
| 27 |
+
for it in log_items:
|
| 28 |
+
step = it.get("step")
|
| 29 |
+
ep = it.get("epoch")
|
| 30 |
+
if step is None or ep is None:
|
| 31 |
+
continue
|
| 32 |
+
last_step, last_epoch = step, ep
|
| 33 |
+
ep_int = int(ep)
|
| 34 |
+
if prev_epoch_int is None:
|
| 35 |
+
prev_epoch_int = ep_int
|
| 36 |
+
continue
|
| 37 |
+
if ep_int != prev_epoch_int:
|
| 38 |
+
if (step, ep_int) not in seen and ep_int >= 1:
|
| 39 |
+
boundaries.append((step, ep_int))
|
| 40 |
+
seen.add((step, ep_int))
|
| 41 |
+
prev_epoch_int = ep_int
|
| 42 |
+
# 最後一個 epoch 也補上
|
| 43 |
+
if last_step is not None and last_epoch is not None:
|
| 44 |
+
ep_final = int(float(last_epoch)) + 1
|
| 45 |
+
if (last_step, ep_final) not in seen:
|
| 46 |
+
boundaries.append((last_step, ep_final))
|
| 47 |
+
boundaries.sort(key=lambda x: x[0])
|
| 48 |
+
return boundaries
|
| 49 |
+
|
| 50 |
+
def plot_series(x, y, xlabel, ylabel, title, outpath,
|
| 51 |
+
epoch_marks=None, checkpoint_steps=None,
|
| 52 |
+
color=YELLOW_ORANGE, linestyle='-'):
|
| 53 |
+
fig = plt.figure(figsize=(10,6))
|
| 54 |
+
ax = fig.add_subplot(111)
|
| 55 |
+
ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2)
|
| 56 |
+
|
| 57 |
+
# 標記 checkpoint 藍點(線性插值;邊界外使用端點值)
|
| 58 |
+
extra_x = []
|
| 59 |
+
if checkpoint_steps:
|
| 60 |
+
for s in checkpoint_steps:
|
| 61 |
+
y_interp = np.interp(s, x, y, left=y[0], right=y[-1])
|
| 62 |
+
ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6)
|
| 63 |
+
extra_x.append(s)
|
| 64 |
+
|
| 65 |
+
# === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding ===
|
| 66 |
+
xmin = 0
|
| 67 |
+
all_x_candidates = [max(x)]
|
| 68 |
+
if extra_x:
|
| 69 |
+
all_x_candidates.append(max(extra_x))
|
| 70 |
+
if epoch_marks:
|
| 71 |
+
# 把所有 epoch 標線的 step 納入考量
|
| 72 |
+
ep_steps = [s for (s, _) in epoch_marks]
|
| 73 |
+
if ep_steps:
|
| 74 |
+
all_x_candidates.append(max(ep_steps))
|
| 75 |
+
|
| 76 |
+
xmax_base = max(all_x_candidates) if all_x_candidates else x[-1]
|
| 77 |
+
|
| 78 |
+
# 右邊加一點 margin,避免剛好貼齊看不到線
|
| 79 |
+
span = max(xmax_base - xmin, 1.0)
|
| 80 |
+
right_pad = max(1.0, 0.02 * span) # 至少 +1 step 或 2% 寬度
|
| 81 |
+
ax.set_xlim(left=xmin, right=xmax_base + right_pad)
|
| 82 |
+
|
| 83 |
+
# y 仍從 0 起
|
| 84 |
+
ax.set_ylim(bottom=0)
|
| 85 |
+
|
| 86 |
+
# 虛線格線
|
| 87 |
+
ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)
|
| 88 |
+
|
| 89 |
+
# epoch 標記 (藍色虛線)
|
| 90 |
+
if epoch_marks:
|
| 91 |
+
for step, ep in epoch_marks:
|
| 92 |
+
ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2)
|
| 93 |
+
ymax = ax.get_ylim()[1]
|
| 94 |
+
ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90,
|
| 95 |
+
va='top', ha='right', fontsize=8, color=BLUE)
|
| 96 |
+
|
| 97 |
+
# label & look(放到最後避免被 set_xlim/set_ylim 影響)
|
| 98 |
+
ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title)
|
| 99 |
+
ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2)
|
| 100 |
+
ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False)
|
| 101 |
+
|
| 102 |
+
fig.savefig(outpath, bbox_inches="tight")
|
| 103 |
+
plt.close(fig)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
ap = argparse.ArgumentParser()
|
| 108 |
+
ap.add_argument("--input", required=True, help="Path to trainer_state.json")
|
| 109 |
+
ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs")
|
| 110 |
+
ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers")
|
| 111 |
+
ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)")
|
| 112 |
+
args = ap.parse_args()
|
| 113 |
+
|
| 114 |
+
src = Path(args.input)
|
| 115 |
+
with open(src, "r", encoding="utf-8") as f:
|
| 116 |
+
state = json.load(f)
|
| 117 |
+
|
| 118 |
+
log = state.get("log_history", state.get("logs", []))
|
| 119 |
+
|
| 120 |
+
steps, train_losses = [], []
|
| 121 |
+
eval_steps, eval_losses = [], []
|
| 122 |
+
lr_steps, lrs = [], []
|
| 123 |
+
|
| 124 |
+
for item in log:
|
| 125 |
+
step = item.get("step")
|
| 126 |
+
if step is None:
|
| 127 |
+
continue
|
| 128 |
+
if "loss" in item:
|
| 129 |
+
steps.append(step); train_losses.append(item["loss"])
|
| 130 |
+
if "eval_loss" in item:
|
| 131 |
+
eval_steps.append(step); eval_losses.append(item["eval_loss"])
|
| 132 |
+
if "learning_rate" in item:
|
| 133 |
+
lr_steps.append(step); lrs.append(item["learning_rate"])
|
| 134 |
+
|
| 135 |
+
outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
|
| 137 |
+
epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log)
|
| 138 |
+
# 允許空白與混合格式
|
| 139 |
+
raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()]
|
| 140 |
+
checkpoint_steps = []
|
| 141 |
+
for s in raw:
|
| 142 |
+
try:
|
| 143 |
+
checkpoint_steps.append(int(float(s)))
|
| 144 |
+
except:
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
if steps and train_losses:
|
| 148 |
+
plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step",
|
| 149 |
+
outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
|
| 150 |
+
if eval_steps and eval_losses:
|
| 151 |
+
plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step",
|
| 152 |
+
outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
|
| 153 |
+
if lr_steps and lrs:
|
| 154 |
+
plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step",
|
| 155 |
+
outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
|
| 156 |
+
|
| 157 |
+
print(f"Saved plots to: {outdir.resolve()}")
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
plots/eval_loss_curve.png
ADDED
|
plots/loss_curve.png
ADDED
|
plots/lr_curve.png
ADDED
|