| |
| """ |
| Usage: |
| python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \ |
| --checkpoint_steps 263,526,789,1052 |
| |
| 功能: |
| - Curve: 黃橘色實線 |
| - Grid: x,y 虛線 |
| - Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch) |
| - Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見) |
| """ |
| import json, argparse |
| from pathlib import Path |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| YELLOW_ORANGE = "#d58f00" |
| BLUE = "#1f77b4" |
|
|
| def find_epoch_boundaries(log_items): |
| """找到每個 epoch 邊界 (包含最後一個)""" |
| boundaries = [] |
| prev_epoch_int = None |
| seen = set() |
| last_step, last_epoch = None, None |
| for it in log_items: |
| step = it.get("step") |
| ep = it.get("epoch") |
| if step is None or ep is None: |
| continue |
| last_step, last_epoch = step, ep |
| ep_int = int(ep) |
| if prev_epoch_int is None: |
| prev_epoch_int = ep_int |
| continue |
| if ep_int != prev_epoch_int: |
| if (step, ep_int) not in seen and ep_int >= 1: |
| boundaries.append((step, ep_int)) |
| seen.add((step, ep_int)) |
| prev_epoch_int = ep_int |
| |
| if last_step is not None and last_epoch is not None: |
| ep_final = int(float(last_epoch)) + 1 |
| if (last_step, ep_final) not in seen: |
| boundaries.append((last_step, ep_final)) |
| boundaries.sort(key=lambda x: x[0]) |
| return boundaries |
|
|
| def plot_series(x, y, xlabel, ylabel, title, outpath, |
| epoch_marks=None, checkpoint_steps=None, |
| color=YELLOW_ORANGE, linestyle='-'): |
| fig = plt.figure(figsize=(10,6)) |
| ax = fig.add_subplot(111) |
| ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2) |
|
|
| |
| extra_x = [] |
| if checkpoint_steps: |
| for s in checkpoint_steps: |
| y_interp = np.interp(s, x, y, left=y[0], right=y[-1]) |
| ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6) |
| extra_x.append(s) |
|
|
| |
| xmin = 0 |
| all_x_candidates = [max(x)] |
| if extra_x: |
| all_x_candidates.append(max(extra_x)) |
| if epoch_marks: |
| |
| ep_steps = [s for (s, _) in epoch_marks] |
| if ep_steps: |
| all_x_candidates.append(max(ep_steps)) |
|
|
| xmax_base = max(all_x_candidates) if all_x_candidates else x[-1] |
|
|
| |
| span = max(xmax_base - xmin, 1.0) |
| right_pad = max(1.0, 0.02 * span) |
| ax.set_xlim(left=xmin, right=xmax_base + right_pad) |
|
|
| |
| ax.set_ylim(bottom=0) |
|
|
| |
| ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6) |
|
|
| |
| if epoch_marks: |
| for step, ep in epoch_marks: |
| ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2) |
| ymax = ax.get_ylim()[1] |
| ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90, |
| va='top', ha='right', fontsize=8, color=BLUE) |
|
|
| |
| ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title) |
| ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2) |
| ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False) |
|
|
| fig.savefig(outpath, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--input", required=True, help="Path to trainer_state.json") |
| ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs") |
| ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers") |
| ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)") |
| args = ap.parse_args() |
|
|
| src = Path(args.input) |
| with open(src, "r", encoding="utf-8") as f: |
| state = json.load(f) |
|
|
| log = state.get("log_history", state.get("logs", [])) |
|
|
| steps, train_losses = [], [] |
| eval_steps, eval_losses = [], [] |
| lr_steps, lrs = [], [] |
|
|
| for item in log: |
| step = item.get("step") |
| if step is None: |
| continue |
| if "loss" in item: |
| steps.append(step); train_losses.append(item["loss"]) |
| if "eval_loss" in item: |
| eval_steps.append(step); eval_losses.append(item["eval_loss"]) |
| if "learning_rate" in item: |
| lr_steps.append(step); lrs.append(item["learning_rate"]) |
|
|
| outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True) |
|
|
| epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log) |
| |
| raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()] |
| checkpoint_steps = [] |
| for s in raw: |
| try: |
| checkpoint_steps.append(int(float(s))) |
| except: |
| pass |
|
|
| if steps and train_losses: |
| plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step", |
| outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
| if eval_steps and eval_losses: |
| plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step", |
| outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
| if lr_steps and lrs: |
| plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step", |
| outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
|
|
| print(f"Saved plots to: {outdir.resolve()}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|