| import json |
| import matplotlib.pyplot as plt |
| import matplotlib.ticker as ticker |
| import numpy as np |
|
|
| with open("summary.json") as f: |
| summary = json.load(f) |
|
|
| history = summary["history"] |
| epochs = [h["epoch"] for h in history] |
| train_loss = [h["train_loss"] for h in history] |
| val_loss = [h["val_loss"] for h in history] |
| dev_eer = [h["dev_eer"] * 100 for h in history] |
| dev_tdcf = [h["dev_tdcf"] for h in history] |
|
|
| best_eer_epoch = summary["best_by_dev_eer"]["epoch"] |
| best_eer_val = summary["best_by_dev_eer"]["dev_eer"] * 100 |
| best_tdcf_epoch = summary["best_by_dev_eer"]["epoch"] |
| best_tdcf_val = summary["best_by_dev_eer"]["dev_tdcf"] |
|
|
| BASELINE_EER = 8.09 |
|
|
| fig, axes = plt.subplots(2, 2, figsize=(14, 10)) |
| fig.suptitle("Wav2Vec2 Training Results", fontsize=14, fontweight="bold", y=1.01) |
|
|
| |
| ax = axes[0, 0] |
| ax.plot(epochs, train_loss, color="steelblue", label="Train") |
| ax.plot(epochs, val_loss, color="tomato", label="Validation") |
| ax.set_title("Training and Validation Loss") |
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("Loss") |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[0, 1] |
| ax.plot(epochs, train_loss, color="steelblue", label="Train") |
| ax.plot(epochs, val_loss, color="tomato", label="Validation") |
| ax.set_yscale("log") |
| ax.set_title("Training and Validation Loss (Log Scale)") |
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("Loss (log scale)") |
| ax.legend() |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 0] |
| ax.plot(epochs, dev_eer, color="green", marker="o", markersize=3) |
| ax.axhline(BASELINE_EER, color="red", linestyle="--", label=f"Baseline ({BASELINE_EER}%)") |
| ax.plot(best_eer_epoch, best_eer_val, "o", color="darkred", markersize=10, |
| label=f"Best: {best_eer_val:.4f}%") |
| ax.set_title("Dev Set Equal Error Rate") |
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("EER (%)") |
| ax.legend(loc="upper right") |
| ax.grid(True, alpha=0.3) |
|
|
| |
| ax = axes[1, 1] |
| ax.plot(epochs, dev_tdcf, color="purple", marker="o", markersize=3) |
| ax.plot(best_tdcf_epoch, best_tdcf_val, "o", color="darkred", markersize=10, |
| label=f"Best: {best_tdcf_val:.4f}") |
| ax.set_title("Dev Set Tandem Detection Cost Function") |
| ax.set_xlabel("Epoch") |
| ax.set_ylabel("t-DCF") |
| ax.legend(loc="upper right") |
| ax.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig("training_results.png", dpi=150, bbox_inches="tight") |
| print("Saved training_results.png") |
|
|