stanza-digphil / plot /loss.py
al1808th's picture
finished final model comparison
8117ca9
import matplotlib.pyplot as plt
import os
from pathlib import Path
import re
'''
Find loss values in the latest log file like the below:
2026-02-19 12:18:13 INFO: Finished STEP 6120/50000, loss = 0.491907 (0.263 sec/batch), lr: 2.000000
Extract them and plot over time.
We also find the LAS scores in lines like:
2026-02-19 12:25:34 INFO: Evaluating on dev set...
2026-02-19 12:25:42 INFO: LAS MLAS BLEX
2026-02-19 12:25:42 INFO: 94.27 93.26 94.90
We map these as values to the latest step, and plot them over time as well.
'''
LOGS = Path("logs/")
logs = os.listdir(LOGS)
logs = [log for log in logs if log.startswith("log")]
latest_log = sorted(logs, key=lambda x: os.path.getctime(LOGS / x))[-1]
print(f"Latest log: {latest_log}")
step_loss_dict = {}
step_las_dict = {}
latest_step = 0
with open(LOGS / latest_log, "r") as f:
for line in f:
if "Training ended" in line:
break
if "Finished STEP" in line and "loss =" in line:
step_match = re.search(r"STEP (\d+)/\d+", line)
loss_match = re.search(r"loss = ([\d.]+)", line)
if step_match and loss_match:
step = int(step_match.group(1))
loss = float(loss_match.group(1))
step_loss_dict[step] = loss
latest_step = max(latest_step, step)
if "LAS MLAS BLEX" in line:
continue
if "INFO: " in line and re.search(r"\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+", line):
las_match = re.search(r"(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)", line)
if las_match:
las = float(las_match.group(1))
mlas = float(las_match.group(2))
blex = float(las_match.group(3))
step_las_dict[latest_step] = (las, mlas, blex)
# Sort by step
steps = sorted(step_loss_dict.keys())
losses = [step_loss_dict[step] for step in steps]
las_steps = sorted(step_las_dict.keys())
las_values = [step_las_dict[step][0] for step in las_steps]
# Plot loss and LAS together with shared x-axis and separate y-axes.
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()
ax1.plot(steps, losses, label="Loss", color="tab:blue")
ax1.set_xlabel("Training Step")
ax1.set_ylabel("Loss", color="tab:blue")
ax1.tick_params(axis="y", labelcolor="tab:blue")
ax2.plot(las_steps, las_values, label="LAS", color="red")
ax2.set_ylabel("LAS", color="red")
ax2.tick_params(axis="y", labelcolor="red")
ax1.set_title("Loss and LAS over Training Steps")
ax1.grid()
fig.tight_layout()
fig.savefig(f"plot/plot_loss_las_{latest_log}.png", dpi=400)