| | |
| | """ |
| | Usage: |
| | python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \ |
| | --checkpoint_steps 263,526,789,1052 |
| | |
| | 功能: |
| | - Curve: 黃橘色實線 |
| | - Grid: x,y 虛線 |
| | - Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch) |
| | - Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見) |
| | """ |
| | import json, argparse |
| | from pathlib import Path |
| | import matplotlib.pyplot as plt |
| | import numpy as np |
| |
|
| | YELLOW_ORANGE = "#d58f00" |
| | BLUE = "#1f77b4" |
| |
|
| | def find_epoch_boundaries(log_items): |
| | """找到每個 epoch 邊界 (包含最後一個)""" |
| | boundaries = [] |
| | prev_epoch_int = None |
| | seen = set() |
| | last_step, last_epoch = None, None |
| | for it in log_items: |
| | step = it.get("step") |
| | ep = it.get("epoch") |
| | if step is None or ep is None: |
| | continue |
| | last_step, last_epoch = step, ep |
| | ep_int = int(ep) |
| | if prev_epoch_int is None: |
| | prev_epoch_int = ep_int |
| | continue |
| | if ep_int != prev_epoch_int: |
| | if (step, ep_int) not in seen and ep_int >= 1: |
| | boundaries.append((step, ep_int)) |
| | seen.add((step, ep_int)) |
| | prev_epoch_int = ep_int |
| | |
| | if last_step is not None and last_epoch is not None: |
| | ep_final = int(float(last_epoch)) + 1 |
| | if (last_step, ep_final) not in seen: |
| | boundaries.append((last_step, ep_final)) |
| | boundaries.sort(key=lambda x: x[0]) |
| | return boundaries |
| |
|
| | def plot_series(x, y, xlabel, ylabel, title, outpath, |
| | epoch_marks=None, checkpoint_steps=None, |
| | color=YELLOW_ORANGE, linestyle='-'): |
| | fig = plt.figure(figsize=(10,6)) |
| | ax = fig.add_subplot(111) |
| | ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2) |
| |
|
| | |
| | extra_x = [] |
| | if checkpoint_steps: |
| | for s in checkpoint_steps: |
| | y_interp = np.interp(s, x, y, left=y[0], right=y[-1]) |
| | ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6) |
| | extra_x.append(s) |
| |
|
| | |
| | xmin = 0 |
| | all_x_candidates = [max(x)] |
| | if extra_x: |
| | all_x_candidates.append(max(extra_x)) |
| | if epoch_marks: |
| | |
| | ep_steps = [s for (s, _) in epoch_marks] |
| | if ep_steps: |
| | all_x_candidates.append(max(ep_steps)) |
| |
|
| | xmax_base = max(all_x_candidates) if all_x_candidates else x[-1] |
| |
|
| | |
| | span = max(xmax_base - xmin, 1.0) |
| | right_pad = max(1.0, 0.02 * span) |
| | ax.set_xlim(left=xmin, right=xmax_base + right_pad) |
| |
|
| | |
| | ax.set_ylim(bottom=0) |
| |
|
| | |
| | ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6) |
| |
|
| | |
| | if epoch_marks: |
| | for step, ep in epoch_marks: |
| | ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2) |
| | ymax = ax.get_ylim()[1] |
| | ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90, |
| | va='top', ha='right', fontsize=8, color=BLUE) |
| |
|
| | |
| | ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title) |
| | ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2) |
| | ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False) |
| |
|
| | fig.savefig(outpath, bbox_inches="tight") |
| | plt.close(fig) |
| |
|
| |
|
| | def main(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--input", required=True, help="Path to trainer_state.json") |
| | ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs") |
| | ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers") |
| | ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)") |
| | args = ap.parse_args() |
| |
|
| | src = Path(args.input) |
| | with open(src, "r", encoding="utf-8") as f: |
| | state = json.load(f) |
| |
|
| | log = state.get("log_history", state.get("logs", [])) |
| |
|
| | steps, train_losses = [], [] |
| | eval_steps, eval_losses = [], [] |
| | lr_steps, lrs = [], [] |
| |
|
| | for item in log: |
| | step = item.get("step") |
| | if step is None: |
| | continue |
| | if "loss" in item: |
| | steps.append(step); train_losses.append(item["loss"]) |
| | if "eval_loss" in item: |
| | eval_steps.append(step); eval_losses.append(item["eval_loss"]) |
| | if "learning_rate" in item: |
| | lr_steps.append(step); lrs.append(item["learning_rate"]) |
| |
|
| | outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True) |
| |
|
| | epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log) |
| | |
| | raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()] |
| | checkpoint_steps = [] |
| | for s in raw: |
| | try: |
| | checkpoint_steps.append(int(float(s))) |
| | except: |
| | pass |
| |
|
| | if steps and train_losses: |
| | plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step", |
| | outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
| | if eval_steps and eval_losses: |
| | plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step", |
| | outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
| | if lr_steps and lrs: |
| | plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step", |
| | outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps) |
| |
|
| | print(f"Saved plots to: {outdir.resolve()}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|