File size: 4,866 Bytes
43a4147 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read a trainer_log.jsonl where each JSON line has:
- "current_steps"
- "loss"
Plot and save a loss curve with:
- original loss
- smoothed loss (same EMA logic as in the provided code)
If there are duplicate current_steps (including "more than two identical"),
the LAST occurrence is used.
Usage:
python plot_loss.py \
--input /workspace/v125rc_exp1_Markie/D/trainer_log.jsonl \
--outdir /workspace/v125rc_exp1_Markie/D \
--outfile training_loss.png
"""
import json
import math
import os
import sys
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard (same as provided)."""
if len(scalars) == 0:
return []
last = scalars[0]
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def _import_matplotlib():
try:
import matplotlib.pyplot as plt # type: ignore
except Exception as e:
raise RuntimeError(
"matplotlib is required to plot. Please install/enable matplotlib."
) from e
return plt
def read_steps_and_loss_from_jsonl(jsonl_path: str) -> tuple[list[int], list[float]]:
"""
Reads jsonl and returns (steps, losses) sorted by step.
Keeps ONLY the last loss value for each duplicated current_steps.
"""
if not os.path.isfile(jsonl_path):
raise FileNotFoundError(f"Input file not found: {jsonl_path}")
# step -> loss (last one wins)
step_to_loss: dict[int, float] = {}
with open(jsonl_path, "r", encoding="utf-8") as f:
for lineno, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# skip malformed lines
continue
if not isinstance(obj, dict):
continue
if "current_steps" not in obj or "loss" not in obj:
continue
step = obj.get("current_steps")
loss = obj.get("loss")
# Must be numeric-ish
try:
step_int = int(step)
loss_float = float(loss)
except Exception:
continue
# Keep last occurrence for the step
step_to_loss[step_int] = loss_float
if not step_to_loss:
return [], []
steps_sorted = sorted(step_to_loss.keys())
losses_sorted = [step_to_loss[s] for s in steps_sorted]
return steps_sorted, losses_sorted
def plot_and_save_loss_curve(
jsonl_path: str,
outdir: str,
outfile: str = "training_loss.png",
) -> str:
"""
Plots original and smoothed loss curves and saves PNG to outdir/outfile.
Returns the full path of the saved image.
"""
plt = _import_matplotlib()
plt.close("all")
plt.switch_backend("agg")
steps, losses = read_steps_and_loss_from_jsonl(jsonl_path)
if len(losses) == 0:
raise RuntimeError("No valid (current_steps, loss) records found to plot.")
os.makedirs(outdir, exist_ok=True)
outpath = os.path.join(outdir, outfile)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.set_title(f"training loss of {outpath.replace('/training_loss.png', '')}")
plt.savefig(outpath, format="png", dpi=100)
print("Figure saved at:", outpath)
return outpath
def _parse_args(argv: list[str]) -> dict[str, str]:
"""
Minimal arg parser (no external imports).
Supports:
--input PATH
--outdir DIR
--outfile NAME
"""
args = {"--input": "", "--outdir": "", "--outfile": "training_loss.png"}
i = 0
while i < len(argv):
a = argv[i]
if a in ("-h", "--help"):
print(__doc__.strip())
sys.exit(0)
if a in args:
if i + 1 >= len(argv):
raise ValueError(f"Missing value for {a}")
args[a] = argv[i + 1]
i += 2
else:
i += 1
if not args["--input"] or not args["--outdir"]:
raise ValueError("Required: --input PATH and --outdir DIR")
return args
def main() -> None:
args = _parse_args(sys.argv[1:])
plot_and_save_loss_curve(
jsonl_path=args["--input"],
outdir=args["--outdir"],
outfile=args["--outfile"] or "training_loss.png",
)
if __name__ == "__main__":
main()
|