import json import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np with open("summary.json") as f: summary = json.load(f) history = summary["history"] epochs = [h["epoch"] for h in history] train_loss = [h["train_loss"] for h in history] val_loss = [h["val_loss"] for h in history] dev_eer = [h["dev_eer"] * 100 for h in history] # fraction → % dev_tdcf = [h["dev_tdcf"] for h in history] best_eer_epoch = summary["best_by_dev_eer"]["epoch"] best_eer_val = summary["best_by_dev_eer"]["dev_eer"] * 100 best_tdcf_epoch = summary["best_by_dev_eer"]["epoch"] best_tdcf_val = summary["best_by_dev_eer"]["dev_tdcf"] BASELINE_EER = 8.09 fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle("Wav2Vec2 Training Results", fontsize=14, fontweight="bold", y=1.01) # — top-left: loss linear — ax = axes[0, 0] ax.plot(epochs, train_loss, color="steelblue", label="Train") ax.plot(epochs, val_loss, color="tomato", label="Validation") ax.set_title("Training and Validation Loss") ax.set_xlabel("Epoch") ax.set_ylabel("Loss") ax.legend() ax.grid(True, alpha=0.3) # — top-right: loss log scale — ax = axes[0, 1] ax.plot(epochs, train_loss, color="steelblue", label="Train") ax.plot(epochs, val_loss, color="tomato", label="Validation") ax.set_yscale("log") ax.set_title("Training and Validation Loss (Log Scale)") ax.set_xlabel("Epoch") ax.set_ylabel("Loss (log scale)") ax.legend() ax.grid(True, alpha=0.3) # — bottom-left: EER — ax = axes[1, 0] ax.plot(epochs, dev_eer, color="green", marker="o", markersize=3) ax.axhline(BASELINE_EER, color="red", linestyle="--", label=f"Baseline ({BASELINE_EER}%)") ax.plot(best_eer_epoch, best_eer_val, "o", color="darkred", markersize=10, label=f"Best: {best_eer_val:.4f}%") ax.set_title("Dev Set Equal Error Rate") ax.set_xlabel("Epoch") ax.set_ylabel("EER (%)") ax.legend(loc="upper right") ax.grid(True, alpha=0.3) # — bottom-right: t-DCF — ax = axes[1, 1] ax.plot(epochs, dev_tdcf, color="purple", marker="o", markersize=3) ax.plot(best_tdcf_epoch, best_tdcf_val, "o", color="darkred", markersize=10, label=f"Best: {best_tdcf_val:.4f}") ax.set_title("Dev Set Tandem Detection Cost Function") ax.set_xlabel("Epoch") ax.set_ylabel("t-DCF") ax.legend(loc="upper right") ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig("training_results.png", dpi=150, bbox_inches="tight") print("Saved training_results.png")