Linksome's picture
Add files using upload-large-folder tool
43a4147 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Read a trainer_log.jsonl where each JSON line has:
- "current_steps"
- "loss"
Plot and save a loss curve with:
- original loss
- smoothed loss (same EMA logic as in the provided code)
If there are duplicate current_steps (including "more than two identical"),
the LAST occurrence is used.
Usage:
python plot_loss.py \
--input /workspace/v125rc_exp1_Markie/D/trainer_log.jsonl \
--outdir /workspace/v125rc_exp1_Markie/D \
--outfile training_loss.png
"""
import json
import math
import os
import sys
def smooth(scalars: list[float]) -> list[float]:
r"""EMA implementation according to TensorBoard (same as provided)."""
if len(scalars) == 0:
return []
last = scalars[0]
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
for next_val in scalars:
smoothed_val = last * weight + (1 - weight) * next_val
smoothed.append(smoothed_val)
last = smoothed_val
return smoothed
def _import_matplotlib():
try:
import matplotlib.pyplot as plt # type: ignore
except Exception as e:
raise RuntimeError(
"matplotlib is required to plot. Please install/enable matplotlib."
) from e
return plt
def read_steps_and_loss_from_jsonl(jsonl_path: str) -> tuple[list[int], list[float]]:
"""
Reads jsonl and returns (steps, losses) sorted by step.
Keeps ONLY the last loss value for each duplicated current_steps.
"""
if not os.path.isfile(jsonl_path):
raise FileNotFoundError(f"Input file not found: {jsonl_path}")
# step -> loss (last one wins)
step_to_loss: dict[int, float] = {}
with open(jsonl_path, "r", encoding="utf-8") as f:
for lineno, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
except json.JSONDecodeError:
# skip malformed lines
continue
if not isinstance(obj, dict):
continue
if "current_steps" not in obj or "loss" not in obj:
continue
step = obj.get("current_steps")
loss = obj.get("loss")
# Must be numeric-ish
try:
step_int = int(step)
loss_float = float(loss)
except Exception:
continue
# Keep last occurrence for the step
step_to_loss[step_int] = loss_float
if not step_to_loss:
return [], []
steps_sorted = sorted(step_to_loss.keys())
losses_sorted = [step_to_loss[s] for s in steps_sorted]
return steps_sorted, losses_sorted
def plot_and_save_loss_curve(
jsonl_path: str,
outdir: str,
outfile: str = "training_loss.png",
) -> str:
"""
Plots original and smoothed loss curves and saves PNG to outdir/outfile.
Returns the full path of the saved image.
"""
plt = _import_matplotlib()
plt.close("all")
plt.switch_backend("agg")
steps, losses = read_steps_and_loss_from_jsonl(jsonl_path)
if len(losses) == 0:
raise RuntimeError("No valid (current_steps, loss) records found to plot.")
os.makedirs(outdir, exist_ok=True)
outpath = os.path.join(outdir, outfile)
fig = plt.figure()
ax = fig.add_subplot(111)
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
ax.legend()
ax.set_xlabel("step")
ax.set_ylabel("loss")
ax.set_title(f"training loss of {outpath.replace('/training_loss.png', '')}")
plt.savefig(outpath, format="png", dpi=100)
print("Figure saved at:", outpath)
return outpath
def _parse_args(argv: list[str]) -> dict[str, str]:
"""
Minimal arg parser (no external imports).
Supports:
--input PATH
--outdir DIR
--outfile NAME
"""
args = {"--input": "", "--outdir": "", "--outfile": "training_loss.png"}
i = 0
while i < len(argv):
a = argv[i]
if a in ("-h", "--help"):
print(__doc__.strip())
sys.exit(0)
if a in args:
if i + 1 >= len(argv):
raise ValueError(f"Missing value for {a}")
args[a] = argv[i + 1]
i += 2
else:
i += 1
if not args["--input"] or not args["--outdir"]:
raise ValueError("Required: --input PATH and --outdir DIR")
return args
def main() -> None:
args = _parse_args(sys.argv[1:])
plot_and_save_loss_curve(
jsonl_path=args["--input"],
outdir=args["--outdir"],
outfile=args["--outfile"] or "training_loss.png",
)
if __name__ == "__main__":
main()