File size: 4,866 Bytes
43a4147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Read a trainer_log.jsonl where each JSON line has:
  - "current_steps"
  - "loss"

Plot and save a loss curve with:
  - original loss
  - smoothed loss (same EMA logic as in the provided code)

If there are duplicate current_steps (including "more than two identical"),
the LAST occurrence is used.

Usage:
  python plot_loss.py \
    --input /workspace/v125rc_exp1_Markie/D/trainer_log.jsonl \
    --outdir /workspace/v125rc_exp1_Markie/D \
    --outfile training_loss.png
"""

import json
import math
import os
import sys


def smooth(scalars: list[float]) -> list[float]:
    r"""EMA implementation according to TensorBoard (same as provided)."""
    if len(scalars) == 0:
        return []

    last = scalars[0]
    smoothed = []
    weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5)  # a sigmoid function
    for next_val in scalars:
        smoothed_val = last * weight + (1 - weight) * next_val
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed


def _import_matplotlib():
    try:
        import matplotlib.pyplot as plt  # type: ignore
    except Exception as e:
        raise RuntimeError(
            "matplotlib is required to plot. Please install/enable matplotlib."
        ) from e
    return plt


def read_steps_and_loss_from_jsonl(jsonl_path: str) -> tuple[list[int], list[float]]:
    """
    Reads jsonl and returns (steps, losses) sorted by step.
    Keeps ONLY the last loss value for each duplicated current_steps.
    """
    if not os.path.isfile(jsonl_path):
        raise FileNotFoundError(f"Input file not found: {jsonl_path}")

    # step -> loss (last one wins)
    step_to_loss: dict[int, float] = {}

    with open(jsonl_path, "r", encoding="utf-8") as f:
        for lineno, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                # skip malformed lines
                continue

            if not isinstance(obj, dict):
                continue

            if "current_steps" not in obj or "loss" not in obj:
                continue

            step = obj.get("current_steps")
            loss = obj.get("loss")

            # Must be numeric-ish
            try:
                step_int = int(step)
                loss_float = float(loss)
            except Exception:
                continue

            # Keep last occurrence for the step
            step_to_loss[step_int] = loss_float

    if not step_to_loss:
        return [], []

    steps_sorted = sorted(step_to_loss.keys())
    losses_sorted = [step_to_loss[s] for s in steps_sorted]
    return steps_sorted, losses_sorted


def plot_and_save_loss_curve(
    jsonl_path: str,
    outdir: str,
    outfile: str = "training_loss.png",
) -> str:
    """
    Plots original and smoothed loss curves and saves PNG to outdir/outfile.
    Returns the full path of the saved image.
    """
    plt = _import_matplotlib()
    plt.close("all")
    plt.switch_backend("agg")

    steps, losses = read_steps_and_loss_from_jsonl(jsonl_path)
    if len(losses) == 0:
        raise RuntimeError("No valid (current_steps, loss) records found to plot.")

    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, outfile)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
    ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
    ax.legend()
    ax.set_xlabel("step")
    ax.set_ylabel("loss")
    ax.set_title(f"training loss of {outpath.replace('/training_loss.png', '')}")

    plt.savefig(outpath, format="png", dpi=100)
    print("Figure saved at:", outpath)
    return outpath


def _parse_args(argv: list[str]) -> dict[str, str]:
    """
    Minimal arg parser (no external imports).
    Supports:
      --input PATH
      --outdir DIR
      --outfile NAME
    """
    args = {"--input": "", "--outdir": "", "--outfile": "training_loss.png"}
    i = 0
    while i < len(argv):
        a = argv[i]
        if a in ("-h", "--help"):
            print(__doc__.strip())
            sys.exit(0)
        if a in args:
            if i + 1 >= len(argv):
                raise ValueError(f"Missing value for {a}")
            args[a] = argv[i + 1]
            i += 2
        else:
            i += 1
    if not args["--input"] or not args["--outdir"]:
        raise ValueError("Required: --input PATH and --outdir DIR")
    return args


def main() -> None:
    args = _parse_args(sys.argv[1:])
    plot_and_save_loss_curve(
        jsonl_path=args["--input"],
        outdir=args["--outdir"],
        outfile=args["--outfile"] or "training_loss.png",
    )


if __name__ == "__main__":
    main()