#!/usr/bin/env python3 """ plot_loss_curve.py Usage: python plot_loss_curve.py --input trainer_log.json --output loss_curve.png """ import json import argparse import plotly.express as px def load_log(path): """Load and parse the JSON training log.""" with open(path, 'r') as f: data = json.load(f) # Expecting a top-level "log_history" list of dicts history = data.get("log_history", []) return history def build_dataframe(history): """Convert log history into a DataFrame-like dict.""" # We'll pull step, epoch, loss (you can extend to lr, grad_norm, etc.) steps = [ entry.get("step") for entry in history ] epochs = [ entry.get("epoch") for entry in history ] losses = [ entry.get("loss") for entry in history ] return { "step": steps, "epoch": epochs, "loss": losses } def plot_loss(curve_data, output_png, label): """Plot an interactive loss curve and save a static PNG.""" # Create a Plotly Express line chart fig = px.line( curve_data, x="step", y="loss", hover_data={"epoch":True, "loss":True}, title=f"Training Loss Curve\n{label}", labels={"step": "Global Step", "loss": "Loss"} ) # Show interactive window (in environments that support it) fig.show() # Save a static image (requires `pip install -U kaleido`) fig.write_image(output_png) print(f"✔️ Saved loss curve PNG to {output_png}") def main(): parser = argparse.ArgumentParser(description="Plot training loss curve from Trainer JSON logs.") parser.add_argument( "input", help="Path to the JSON file output by the Trainer (with log_history)." ) parser.add_argument( "--output", "-o", default="loss_curve.png", help="Filename for the saved loss curve PNG." ) args = parser.parse_args() history = load_log(args.input) if not history: print("⚠️ No entries found under 'log_history'. Exiting.") return curve_data = build_dataframe(history) l = args.input.split('/')[-3:-1] l = '-'.join(l) plot_loss(curve_data, args.output, l) if __name__ == "__main__": main()