File size: 2,220 Bytes
11235a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | #!/usr/bin/env python3
"""
plot_loss_curve.py
Usage:
python plot_loss_curve.py --input trainer_log.json --output loss_curve.png
"""
import json
import argparse
import plotly.express as px
def load_log(path):
"""Load and parse the JSON training log."""
with open(path, 'r') as f:
data = json.load(f)
# Expecting a top-level "log_history" list of dicts
history = data.get("log_history", [])
return history
def build_dataframe(history):
"""Convert log history into a DataFrame-like dict."""
# We'll pull step, epoch, loss (you can extend to lr, grad_norm, etc.)
steps = [ entry.get("step") for entry in history ]
epochs = [ entry.get("epoch") for entry in history ]
losses = [ entry.get("loss") for entry in history ]
return {
"step": steps,
"epoch": epochs,
"loss": losses
}
def plot_loss(curve_data, output_png, label):
"""Plot an interactive loss curve and save a static PNG."""
# Create a Plotly Express line chart
fig = px.line(
curve_data,
x="step",
y="loss",
hover_data={"epoch":True, "loss":True},
title=f"Training Loss Curve\n{label}",
labels={"step": "Global Step", "loss": "Loss"}
)
# Show interactive window (in environments that support it)
fig.show()
# Save a static image (requires `pip install -U kaleido`)
fig.write_image(output_png)
print(f"✔️ Saved loss curve PNG to {output_png}")
def main():
parser = argparse.ArgumentParser(description="Plot training loss curve from Trainer JSON logs.")
parser.add_argument(
"input",
help="Path to the JSON file output by the Trainer (with log_history)."
)
parser.add_argument(
"--output", "-o",
default="loss_curve.png",
help="Filename for the saved loss curve PNG."
)
args = parser.parse_args()
history = load_log(args.input)
if not history:
print("⚠️ No entries found under 'log_history'. Exiting.")
return
curve_data = build_dataframe(history)
l = args.input.split('/')[-3:-1]
l = '-'.join(l)
plot_loss(curve_data, args.output, l)
if __name__ == "__main__":
main()
|