| import matplotlib.pyplot as plt |
| import os |
| from pathlib import Path |
| import re |
|
|
| ''' |
| Find loss values in the latest log file like the below: |
| |
| 2026-02-19 12:18:13 INFO: Finished STEP 6120/50000, loss = 0.491907 (0.263 sec/batch), lr: 2.000000 |
| |
| Extract them and plot over time. |
| |
| We also find the LAS scores in lines like: |
| |
| 2026-02-19 12:25:34 INFO: Evaluating on dev set... |
| 2026-02-19 12:25:42 INFO: LAS MLAS BLEX |
| 2026-02-19 12:25:42 INFO: 94.27 93.26 94.90 |
| |
| We map these as values to the latest step, and plot them over time as well. |
| ''' |
|
|
| LOGS = Path("logs/") |
| logs = os.listdir(LOGS) |
| logs = [log for log in logs if log.startswith("log")] |
|
|
| latest_log = sorted(logs, key=lambda x: os.path.getctime(LOGS / x))[-1] |
| print(f"Latest log: {latest_log}") |
|
|
| step_loss_dict = {} |
| step_las_dict = {} |
|
|
| latest_step = 0 |
| with open(LOGS / latest_log, "r") as f: |
| for line in f: |
| if "Training ended" in line: |
| break |
| if "Finished STEP" in line and "loss =" in line: |
| step_match = re.search(r"STEP (\d+)/\d+", line) |
| loss_match = re.search(r"loss = ([\d.]+)", line) |
| if step_match and loss_match: |
| step = int(step_match.group(1)) |
| loss = float(loss_match.group(1)) |
| step_loss_dict[step] = loss |
| latest_step = max(latest_step, step) |
| if "LAS MLAS BLEX" in line: |
| continue |
| if "INFO: " in line and re.search(r"\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+", line): |
| las_match = re.search(r"(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)", line) |
| if las_match: |
| las = float(las_match.group(1)) |
| mlas = float(las_match.group(2)) |
| blex = float(las_match.group(3)) |
| step_las_dict[latest_step] = (las, mlas, blex) |
|
|
|
|
| |
| steps = sorted(step_loss_dict.keys()) |
| losses = [step_loss_dict[step] for step in steps] |
|
|
| las_steps = sorted(step_las_dict.keys()) |
| las_values = [step_las_dict[step][0] for step in las_steps] |
|
|
| |
| fig, ax1 = plt.subplots(figsize=(10, 6)) |
| ax2 = ax1.twinx() |
|
|
| ax1.plot(steps, losses, label="Loss", color="tab:blue") |
| ax1.set_xlabel("Training Step") |
| ax1.set_ylabel("Loss", color="tab:blue") |
| ax1.tick_params(axis="y", labelcolor="tab:blue") |
|
|
| ax2.plot(las_steps, las_values, label="LAS", color="red") |
| ax2.set_ylabel("LAS", color="red") |
| ax2.tick_params(axis="y", labelcolor="red") |
|
|
| ax1.set_title("Loss and LAS over Training Steps") |
| ax1.grid() |
| fig.tight_layout() |
| fig.savefig(f"plot/plot_loss_las_{latest_log}.png", dpi=400) |