AaronWu901225 commited on
Commit
3a0fe79
·
verified ·
1 Parent(s): 2d271a4

Upload LoRA adapter folder

Browse files
plot_loss_from_trainer_state.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Usage:
4
+ python plot_loss_from_trainer_state.py --input trainer_state.json --outdir ./plots \
5
+ --checkpoint_steps 263,526,789,1052
6
+
7
+ 功能:
8
+ - Curve: 黃橘色實線
9
+ - Grid: x,y 虛線
10
+ - Epoch markers: 藍色虛線 + EpochN 標籤(含最後一個 epoch)
11
+ - Checkpoints: 藍色小圓點(線性插值;超出範圍時使用端點值,並自動擴張 x 軸確保能看見)
12
+ """
13
+ import json, argparse
14
+ from pathlib import Path
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
+
18
+ YELLOW_ORANGE = "#d58f00"
19
+ BLUE = "#1f77b4"
20
+
21
+ def find_epoch_boundaries(log_items):
22
+ """找到每個 epoch 邊界 (包含最後一個)"""
23
+ boundaries = []
24
+ prev_epoch_int = None
25
+ seen = set()
26
+ last_step, last_epoch = None, None
27
+ for it in log_items:
28
+ step = it.get("step")
29
+ ep = it.get("epoch")
30
+ if step is None or ep is None:
31
+ continue
32
+ last_step, last_epoch = step, ep
33
+ ep_int = int(ep)
34
+ if prev_epoch_int is None:
35
+ prev_epoch_int = ep_int
36
+ continue
37
+ if ep_int != prev_epoch_int:
38
+ if (step, ep_int) not in seen and ep_int >= 1:
39
+ boundaries.append((step, ep_int))
40
+ seen.add((step, ep_int))
41
+ prev_epoch_int = ep_int
42
+ # 最後一個 epoch 也補上
43
+ if last_step is not None and last_epoch is not None:
44
+ ep_final = int(float(last_epoch)) + 1
45
+ if (last_step, ep_final) not in seen:
46
+ boundaries.append((last_step, ep_final))
47
+ boundaries.sort(key=lambda x: x[0])
48
+ return boundaries
49
+
50
+ def plot_series(x, y, xlabel, ylabel, title, outpath,
51
+ epoch_marks=None, checkpoint_steps=None,
52
+ color=YELLOW_ORANGE, linestyle='-'):
53
+ fig = plt.figure(figsize=(10,6))
54
+ ax = fig.add_subplot(111)
55
+ ax.plot(x, y, color=color, linestyle=linestyle, linewidth=2)
56
+
57
+ # 標記 checkpoint 藍點(線性插值;邊界外使用端點值)
58
+ extra_x = []
59
+ if checkpoint_steps:
60
+ for s in checkpoint_steps:
61
+ y_interp = np.interp(s, x, y, left=y[0], right=y[-1])
62
+ ax.plot(s, y_interp, marker='o', color=BLUE, markersize=6)
63
+ extra_x.append(s)
64
+
65
+ # === 計算 x 範圍時把 epoch 標線也納入,並加右側 padding ===
66
+ xmin = 0
67
+ all_x_candidates = [max(x)]
68
+ if extra_x:
69
+ all_x_candidates.append(max(extra_x))
70
+ if epoch_marks:
71
+ # 把所有 epoch 標線的 step 納入考量
72
+ ep_steps = [s for (s, _) in epoch_marks]
73
+ if ep_steps:
74
+ all_x_candidates.append(max(ep_steps))
75
+
76
+ xmax_base = max(all_x_candidates) if all_x_candidates else x[-1]
77
+
78
+ # 右邊加一點 margin,避免剛好貼齊看不到線
79
+ span = max(xmax_base - xmin, 1.0)
80
+ right_pad = max(1.0, 0.02 * span) # 至少 +1 step 或 2% 寬度
81
+ ax.set_xlim(left=xmin, right=xmax_base + right_pad)
82
+
83
+ # y 仍從 0 起
84
+ ax.set_ylim(bottom=0)
85
+
86
+ # 虛線格線
87
+ ax.grid(True, which='major', axis='both', linestyle='--', linewidth=0.8, alpha=0.6)
88
+
89
+ # epoch 標記 (藍色虛線)
90
+ if epoch_marks:
91
+ for step, ep in epoch_marks:
92
+ ax.axvline(x=step, color=BLUE, linestyle='--', linewidth=1.2)
93
+ ymax = ax.get_ylim()[1]
94
+ ax.text(step, ymax*0.98, f'Epoch{ep}', rotation=90,
95
+ va='top', ha='right', fontsize=8, color=BLUE)
96
+
97
+ # label & look(放到最後避免被 set_xlim/set_ylim 影響)
98
+ ax.set_xlabel(xlabel); ax.set_ylabel(ylabel); ax.set_title(title)
99
+ ax.spines['left'].set_linewidth(2); ax.spines['bottom'].set_linewidth(2)
100
+ ax.spines['right'].set_visible(False); ax.spines['top'].set_visible(False)
101
+
102
+ fig.savefig(outpath, bbox_inches="tight")
103
+ plt.close(fig)
104
+
105
+
106
+ def main():
107
+ ap = argparse.ArgumentParser()
108
+ ap.add_argument("--input", required=True, help="Path to trainer_state.json")
109
+ ap.add_argument("--outdir", default="./plots", help="Directory to save PNGs")
110
+ ap.add_argument("--no_epoch_marks", action="store_true", help="Disable vertical epoch markers")
111
+ ap.add_argument("--checkpoint_steps", default="", help="Comma-separated steps (e.g., 100,200,500)")
112
+ args = ap.parse_args()
113
+
114
+ src = Path(args.input)
115
+ with open(src, "r", encoding="utf-8") as f:
116
+ state = json.load(f)
117
+
118
+ log = state.get("log_history", state.get("logs", []))
119
+
120
+ steps, train_losses = [], []
121
+ eval_steps, eval_losses = [], []
122
+ lr_steps, lrs = [], []
123
+
124
+ for item in log:
125
+ step = item.get("step")
126
+ if step is None:
127
+ continue
128
+ if "loss" in item:
129
+ steps.append(step); train_losses.append(item["loss"])
130
+ if "eval_loss" in item:
131
+ eval_steps.append(step); eval_losses.append(item["eval_loss"])
132
+ if "learning_rate" in item:
133
+ lr_steps.append(step); lrs.append(item["learning_rate"])
134
+
135
+ outdir = Path(args.outdir); outdir.mkdir(parents=True, exist_ok=True)
136
+
137
+ epoch_marks = None if args.no_epoch_marks else find_epoch_boundaries(log)
138
+ # 允許空白與混合格式
139
+ raw = [s.strip() for s in args.checkpoint_steps.replace(",", ",").split(",") if s.strip()]
140
+ checkpoint_steps = []
141
+ for s in raw:
142
+ try:
143
+ checkpoint_steps.append(int(float(s)))
144
+ except:
145
+ pass
146
+
147
+ if steps and train_losses:
148
+ plot_series(steps, train_losses, "Step", "Training Loss", "Training Loss vs Step",
149
+ outdir / "loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
150
+ if eval_steps and eval_losses:
151
+ plot_series(eval_steps, eval_losses, "Step", "Eval Loss", "Eval Loss vs Step",
152
+ outdir / "eval_loss_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
153
+ if lr_steps and lrs:
154
+ plot_series(lr_steps, lrs, "Step", "Learning Rate", "Learning Rate vs Step",
155
+ outdir / "lr_curve.png", epoch_marks=epoch_marks, checkpoint_steps=checkpoint_steps)
156
+
157
+ print(f"Saved plots to: {outdir.resolve()}")
158
+
159
+ if __name__ == "__main__":
160
+ main()
plots/eval_loss_curve.png ADDED
plots/loss_curve.png ADDED
plots/lr_curve.png ADDED