wav2vec2 / plot_results.py
rmachado23's picture
Upload folder using huggingface_hub
8760971 verified
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
with open("summary.json") as f:
summary = json.load(f)
history = summary["history"]
epochs = [h["epoch"] for h in history]
train_loss = [h["train_loss"] for h in history]
val_loss = [h["val_loss"] for h in history]
dev_eer = [h["dev_eer"] * 100 for h in history] # fraction β†’ %
dev_tdcf = [h["dev_tdcf"] for h in history]
best_eer_epoch = summary["best_by_dev_eer"]["epoch"]
best_eer_val = summary["best_by_dev_eer"]["dev_eer"] * 100
best_tdcf_epoch = summary["best_by_dev_eer"]["epoch"]
best_tdcf_val = summary["best_by_dev_eer"]["dev_tdcf"]
BASELINE_EER = 8.09
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle("Wav2Vec2 Training Results", fontsize=14, fontweight="bold", y=1.01)
# β€” top-left: loss linear β€”
ax = axes[0, 0]
ax.plot(epochs, train_loss, color="steelblue", label="Train")
ax.plot(epochs, val_loss, color="tomato", label="Validation")
ax.set_title("Training and Validation Loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()
ax.grid(True, alpha=0.3)
# β€” top-right: loss log scale β€”
ax = axes[0, 1]
ax.plot(epochs, train_loss, color="steelblue", label="Train")
ax.plot(epochs, val_loss, color="tomato", label="Validation")
ax.set_yscale("log")
ax.set_title("Training and Validation Loss (Log Scale)")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss (log scale)")
ax.legend()
ax.grid(True, alpha=0.3)
# β€” bottom-left: EER β€”
ax = axes[1, 0]
ax.plot(epochs, dev_eer, color="green", marker="o", markersize=3)
ax.axhline(BASELINE_EER, color="red", linestyle="--", label=f"Baseline ({BASELINE_EER}%)")
ax.plot(best_eer_epoch, best_eer_val, "o", color="darkred", markersize=10,
label=f"Best: {best_eer_val:.4f}%")
ax.set_title("Dev Set Equal Error Rate")
ax.set_xlabel("Epoch")
ax.set_ylabel("EER (%)")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
# β€” bottom-right: t-DCF β€”
ax = axes[1, 1]
ax.plot(epochs, dev_tdcf, color="purple", marker="o", markersize=3)
ax.plot(best_tdcf_epoch, best_tdcf_val, "o", color="darkred", markersize=10,
label=f"Best: {best_tdcf_val:.4f}")
ax.set_title("Dev Set Tandem Detection Cost Function")
ax.set_xlabel("Epoch")
ax.set_ylabel("t-DCF")
ax.legend(loc="upper right")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("training_results.png", dpi=150, bbox_inches="tight")
print("Saved training_results.png")