File size: 2,591 Bytes
035e51c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8117ca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import matplotlib.pyplot as plt
import os
from pathlib import Path
import re

'''
Find loss values in the latest log file like the below:

2026-02-19 12:18:13 INFO: Finished STEP 6120/50000, loss = 0.491907 (0.263 sec/batch), lr: 2.000000

Extract them and plot over time.

We also find the LAS scores in lines like:

2026-02-19 12:25:34 INFO: Evaluating on dev set...
2026-02-19 12:25:42 INFO: LAS   MLAS    BLEX
2026-02-19 12:25:42 INFO: 94.27 93.26   94.90

We map these as values to the latest step, and plot them over time as well.
'''

LOGS = Path("logs/")
logs = os.listdir(LOGS)
logs = [log for log in logs if log.startswith("log")]

latest_log = sorted(logs, key=lambda x: os.path.getctime(LOGS / x))[-1]
print(f"Latest log: {latest_log}")

step_loss_dict = {}
step_las_dict = {}

latest_step = 0
with open(LOGS / latest_log, "r") as f:
    for line in f:
        if "Training ended" in line:
            break
        if "Finished STEP" in line and "loss =" in line:
            step_match = re.search(r"STEP (\d+)/\d+", line)
            loss_match = re.search(r"loss = ([\d.]+)", line)
            if step_match and loss_match:
                step = int(step_match.group(1))
                loss = float(loss_match.group(1))
                step_loss_dict[step] = loss
                latest_step = max(latest_step, step)
        if "LAS   MLAS    BLEX" in line:
            continue
        if "INFO: " in line and re.search(r"\d+\.\d+\s+\d+\.\d+\s+\d+\.\d+", line):
            las_match = re.search(r"(\d+\.\d+)\s+(\d+\.\d+)\s+(\d+\.\d+)", line)
            if las_match:
                las = float(las_match.group(1))
                mlas = float(las_match.group(2))
                blex = float(las_match.group(3))
                step_las_dict[latest_step] = (las, mlas, blex)


# Sort by step
steps = sorted(step_loss_dict.keys())
losses = [step_loss_dict[step] for step in steps]

las_steps = sorted(step_las_dict.keys())
las_values = [step_las_dict[step][0] for step in las_steps]

# Plot loss and LAS together with shared x-axis and separate y-axes.
fig, ax1 = plt.subplots(figsize=(10, 6))
ax2 = ax1.twinx()

ax1.plot(steps, losses, label="Loss", color="tab:blue")
ax1.set_xlabel("Training Step")
ax1.set_ylabel("Loss", color="tab:blue")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2.plot(las_steps, las_values, label="LAS", color="red")
ax2.set_ylabel("LAS", color="red")
ax2.tick_params(axis="y", labelcolor="red")

ax1.set_title("Loss and LAS over Training Steps")
ax1.grid()
fig.tight_layout()
fig.savefig(f"plot/plot_loss_las_{latest_log}.png", dpi=400)