gr00t1.5_starforce / tests /plot_loss.py
nnh-pbbb's picture
Add files using upload-large-folder tool
11235a4 verified
#!/usr/bin/env python3
"""
plot_loss_curve.py
Usage:
python plot_loss_curve.py --input trainer_log.json --output loss_curve.png
"""
import json
import argparse
import plotly.express as px
def load_log(path):
"""Load and parse the JSON training log."""
with open(path, 'r') as f:
data = json.load(f)
# Expecting a top-level "log_history" list of dicts
history = data.get("log_history", [])
return history
def build_dataframe(history):
"""Convert log history into a DataFrame-like dict."""
# We'll pull step, epoch, loss (you can extend to lr, grad_norm, etc.)
steps = [ entry.get("step") for entry in history ]
epochs = [ entry.get("epoch") for entry in history ]
losses = [ entry.get("loss") for entry in history ]
return {
"step": steps,
"epoch": epochs,
"loss": losses
}
def plot_loss(curve_data, output_png, label):
"""Plot an interactive loss curve and save a static PNG."""
# Create a Plotly Express line chart
fig = px.line(
curve_data,
x="step",
y="loss",
hover_data={"epoch":True, "loss":True},
title=f"Training Loss Curve\n{label}",
labels={"step": "Global Step", "loss": "Loss"}
)
# Show interactive window (in environments that support it)
fig.show()
# Save a static image (requires `pip install -U kaleido`)
fig.write_image(output_png)
print(f"✔️ Saved loss curve PNG to {output_png}")
def main():
parser = argparse.ArgumentParser(description="Plot training loss curve from Trainer JSON logs.")
parser.add_argument(
"input",
help="Path to the JSON file output by the Trainer (with log_history)."
)
parser.add_argument(
"--output", "-o",
default="loss_curve.png",
help="Filename for the saved loss curve PNG."
)
args = parser.parse_args()
history = load_log(args.input)
if not history:
print("⚠️ No entries found under 'log_history'. Exiting.")
return
curve_data = build_dataframe(history)
l = args.input.split('/')[-3:-1]
l = '-'.join(l)
plot_loss(curve_data, args.output, l)
if __name__ == "__main__":
main()