|
|
|
|
|
""" |
|
|
plot_loss_curve.py |
|
|
|
|
|
Usage: |
|
|
python plot_loss_curve.py --input trainer_log.json --output loss_curve.png |
|
|
""" |
|
|
|
|
|
import json |
|
|
import argparse |
|
|
import plotly.express as px |
|
|
|
|
|
def load_log(path): |
|
|
"""Load and parse the JSON training log.""" |
|
|
with open(path, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
history = data.get("log_history", []) |
|
|
return history |
|
|
|
|
|
def build_dataframe(history): |
|
|
"""Convert log history into a DataFrame-like dict.""" |
|
|
|
|
|
steps = [ entry.get("step") for entry in history ] |
|
|
epochs = [ entry.get("epoch") for entry in history ] |
|
|
losses = [ entry.get("loss") for entry in history ] |
|
|
return { |
|
|
"step": steps, |
|
|
"epoch": epochs, |
|
|
"loss": losses |
|
|
} |
|
|
|
|
|
def plot_loss(curve_data, output_png, label): |
|
|
"""Plot an interactive loss curve and save a static PNG.""" |
|
|
|
|
|
fig = px.line( |
|
|
curve_data, |
|
|
x="step", |
|
|
y="loss", |
|
|
hover_data={"epoch":True, "loss":True}, |
|
|
title=f"Training Loss Curve\n{label}", |
|
|
labels={"step": "Global Step", "loss": "Loss"} |
|
|
) |
|
|
|
|
|
fig.show() |
|
|
|
|
|
fig.write_image(output_png) |
|
|
print(f"✔️ Saved loss curve PNG to {output_png}") |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Plot training loss curve from Trainer JSON logs.") |
|
|
parser.add_argument( |
|
|
"input", |
|
|
help="Path to the JSON file output by the Trainer (with log_history)." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", "-o", |
|
|
default="loss_curve.png", |
|
|
help="Filename for the saved loss curve PNG." |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
history = load_log(args.input) |
|
|
if not history: |
|
|
print("⚠️ No entries found under 'log_history'. Exiting.") |
|
|
return |
|
|
|
|
|
curve_data = build_dataframe(history) |
|
|
l = args.input.split('/')[-3:-1] |
|
|
l = '-'.join(l) |
|
|
plot_loss(curve_data, args.output, l) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|