milo-asr / generate_plots.py
pluttodk's picture
merge2
d38720b
#!/usr/bin/env python
"""
Generate comparison plots for ASR model benchmarks.
Creates publication-quality visualizations comparing hvisketiske-v2
against other Danish ASR models on accuracy and performance metrics.
Usage:
python huggingface/generate_plots.py
# Specify custom result files:
python huggingface/generate_plots.py \
--coral-results ./results/full_comparison2.json \
--cv-results ./results/common_voice_comparison.json
Output:
huggingface/plots/
├── wer_comparison.png
├── cer_comparison.png
├── rtf_comparison.png
└── accuracy_vs_speed.png
"""
import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
# Use a clean style
plt.style.use("seaborn-v0_8-whitegrid")
# Color palette - distinct colors for models
COLORS = {
"hvisketiske": "#2ecc71", # Green for our model (best)
"qwen3-base": "#27ae60", # Darker green for base Qwen
"hviske-v2": "#3498db", # Blue for hviske-v2
"hviske-v3": "#2980b9", # Darker blue for hviske-v3
"faster": "#e74c3c", # Red for faster-whisper models
"turbo": "#e67e22", # Orange for turbo
"default": "#95a5a6", # Gray for others
}
# Model display names mapping
MODEL_DISPLAY_NAMES = {
"Qwen3-ASR (checkpoint-23448)": "hvisketiske-v2\n(Qwen3-ASR finetuned)",
"hviske-v3-conversation (Whisper Large v3)": "hviske-v3\n(Whisper v3)",
"hviske-v2 (Whisper Large v2)": "hviske-v2\n(Whisper v2)",
"faster-hviske-v2 (CT2 distilled)": "faster-hviske-v2\n(CT2 distilled)",
"Whisper Large v3 Turbo": "Whisper v3 Turbo\n(faster-whisper)",
"Qwen3-ASR-1.7B (base)": "Qwen3-ASR-1.7B\n(base, not finetuned)",
}
def get_model_color(model_name: str) -> str:
"""Get color for a model based on its name."""
name_lower = model_name.lower()
# Our finetuned model (highest priority)
if "hvisketiske" in name_lower or "checkpoint" in name_lower:
return COLORS["hvisketiske"]
# Base Qwen3-ASR (not finetuned)
elif "qwen3-asr-1.7b" in name_lower and "base" in name_lower:
return COLORS["qwen3-base"]
elif "qwen" in name_lower:
return COLORS["hvisketiske"]
# Turbo model
elif "turbo" in name_lower:
return COLORS["turbo"]
# Faster-whisper models
elif "faster" in name_lower or "ct2" in name_lower:
return COLORS["faster"]
# hviske-v3
elif "hviske-v3" in name_lower or "v3" in name_lower:
return COLORS["hviske-v3"]
# hviske-v2
elif "hviske-v2" in name_lower or "v2" in name_lower:
return COLORS["hviske-v2"]
return COLORS["default"]
def get_display_name(model_name: str) -> str:
"""Get display name for a model."""
return MODEL_DISPLAY_NAMES.get(model_name, model_name)
def load_results(path: Path) -> Optional[dict]:
"""Load benchmark results from JSON file."""
if not path.exists():
print(f"Warning: Results file not found: {path}")
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def extract_metrics(results: dict) -> Tuple[List[str], List[float], List[float], List[float], List[str]]:
"""
Extract metrics from results dictionary.
Returns:
Tuple of (names, wer_values, cer_values, rtf_values, colors)
"""
names = []
wer_values = []
cer_values = []
rtf_values = []
colors = []
for model_name, data in results["models"].items():
display_name = get_display_name(model_name)
names.append(display_name)
wer_values.append(data["accuracy"]["wer"] * 100) # Convert to percentage
cer_values.append(data["accuracy"]["cer"] * 100)
rtf_values.append(data["performance"]["real_time_factor"])
colors.append(get_model_color(model_name))
return names, wer_values, cer_values, rtf_values, colors
def plot_wer_comparison(
results: dict,
output_path: Path,
dataset_name: str = "CoRal v2",
) -> None:
"""Generate WER comparison bar chart."""
names, wer_values, _, _, colors = extract_metrics(results)
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(names, wer_values, color=colors, edgecolor="white", linewidth=1.5)
# Add value labels on bars
for bar, val in zip(bars, wer_values):
height = bar.get_height()
ax.annotate(
f"{val:.1f}%",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 5),
textcoords="offset points",
ha="center",
va="bottom",
fontsize=12,
fontweight="bold",
)
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
ax.set_title(f"WER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
ax.set_ylim(0, max(wer_values) * 1.2)
# Add grid
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved: {output_path}")
def plot_cer_comparison(
results: dict,
output_path: Path,
dataset_name: str = "CoRal v2",
) -> None:
"""Generate CER comparison bar chart."""
names, _, cer_values, _, colors = extract_metrics(results)
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(names, cer_values, color=colors, edgecolor="white", linewidth=1.5)
# Add value labels on bars
for bar, val in zip(bars, cer_values):
height = bar.get_height()
ax.annotate(
f"{val:.1f}%",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 5),
textcoords="offset points",
ha="center",
va="bottom",
fontsize=12,
fontweight="bold",
)
ax.set_ylabel("Character Error Rate (%)", fontsize=12)
ax.set_title(f"CER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
ax.set_ylim(0, max(cer_values) * 1.2)
# Add grid
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved: {output_path}")
def plot_rtf_comparison(
results: dict,
output_path: Path,
dataset_name: str = "CoRal v2",
) -> None:
"""Generate RTF/speed comparison bar chart."""
names, _, _, rtf_values, colors = extract_metrics(results)
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(names, rtf_values, color=colors, edgecolor="white", linewidth=1.5)
# Add value labels on bars
for bar, val in zip(bars, rtf_values):
height = bar.get_height()
ax.annotate(
f"{val:.3f}",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 5),
textcoords="offset points",
ha="center",
va="bottom",
fontsize=12,
fontweight="bold",
)
# Add reference line at RTF=1.0 (real-time)
ax.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, label="Real-time (RTF=1.0)")
ax.set_ylabel("Real-Time Factor (lower is faster)", fontsize=12)
ax.set_title(f"Speed Comparison on {dataset_name}", fontsize=14, fontweight="bold")
ax.set_ylim(0, max(max(rtf_values) * 1.3, 1.1))
ax.legend(loc="upper right")
# Add grid
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved: {output_path}")
def plot_accuracy_vs_speed(
results: dict,
output_path: Path,
dataset_name: str = "CoRal v2",
) -> None:
"""Generate accuracy vs speed scatter plot."""
fig, ax = plt.subplots(figsize=(9, 6))
for model_name, data in results["models"].items():
wer = data["accuracy"]["wer"] * 100
rtf = data["performance"]["real_time_factor"]
color = get_model_color(model_name)
display_name = get_display_name(model_name)
# Extract parameter count for bubble size
size_str = data["model_size"]
if "1.7B" in size_str:
size = 400
elif "2B" in size_str:
size = 500
else:
size = 300
ax.scatter(
rtf,
wer,
s=size,
c=color,
alpha=0.7,
edgecolors="white",
linewidth=2,
label=display_name.replace("\n", " "),
)
# Add label
ax.annotate(
display_name.replace("\n", " "),
xy=(rtf, wer),
xytext=(10, 10),
textcoords="offset points",
fontsize=10,
ha="left",
)
# Add reference line at RTF=1.0
ax.axvline(x=1.0, color="red", linestyle="--", linewidth=1, alpha=0.5, label="Real-time")
ax.set_xlabel("Real-Time Factor (lower is faster)", fontsize=12)
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
ax.set_title(
f"Accuracy vs Speed Trade-off on {dataset_name}\n(bubble size = model parameters)",
fontsize=14,
fontweight="bold",
)
# Set axis limits with padding
all_wer = [d["accuracy"]["wer"] * 100 for d in results["models"].values()]
all_rtf = [d["performance"]["real_time_factor"] for d in results["models"].values()]
ax.set_xlim(0, max(all_rtf) * 1.5)
ax.set_ylim(min(all_wer) * 0.8, max(all_wer) * 1.2)
# Add grid
ax.grid(True, linestyle="--", alpha=0.7)
# Add annotation for best region
ax.annotate(
"Better",
xy=(0.02, min(all_wer) * 0.85),
fontsize=10,
color="green",
fontweight="bold",
)
ax.annotate(
"Faster & More Accurate",
xy=(0.02, min(all_wer) * 0.9),
fontsize=8,
color="gray",
)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved: {output_path}")
def plot_multi_dataset_comparison(
coral_results: dict,
cv_results: Optional[dict],
output_path: Path,
) -> None:
"""Generate multi-dataset WER comparison plot."""
fig, ax = plt.subplots(figsize=(10, 6))
# Prepare data
datasets = ["CoRal v2"]
if cv_results:
datasets.append("Common Voice")
# Get model names from coral results
model_names = list(coral_results["models"].keys())
x = np.arange(len(datasets))
width = 0.35
for i, model_name in enumerate(model_names):
display_name = get_display_name(model_name)
color = get_model_color(model_name)
wer_values = [coral_results["models"][model_name]["accuracy"]["wer"] * 100]
if cv_results and model_name in cv_results["models"]:
wer_values.append(cv_results["models"][model_name]["accuracy"]["wer"] * 100)
elif cv_results:
wer_values.append(0) # Model not evaluated on this dataset
offset = (i - len(model_names) / 2 + 0.5) * width
bars = ax.bar(
x + offset,
wer_values,
width,
label=display_name.replace("\n", " "),
color=color,
edgecolor="white",
linewidth=1.5,
)
# Add value labels
for bar, val in zip(bars, wer_values):
if val > 0:
height = bar.get_height()
ax.annotate(
f"{val:.1f}%",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
ax.set_title("WER Comparison Across Datasets", fontsize=14, fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels(datasets, fontsize=11)
ax.legend(loc="upper right")
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
ax.set_axisbelow(True)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved: {output_path}")
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Generate ASR comparison plots")
parser.add_argument(
"--coral-results",
type=Path,
default=Path("results/full_comparison2.json"),
help="Path to CoRal benchmark results",
)
parser.add_argument(
"--cv-results",
type=Path,
default=Path("results/common_voice_comparison.json"),
help="Path to Common Voice benchmark results",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path(__file__).parent / "plots",
help="Output directory for plots",
)
return parser.parse_args()
def main() -> None:
"""Main entry point for plot generation."""
args = parse_args()
# Create output directory
args.output_dir.mkdir(parents=True, exist_ok=True)
# Load results
coral_results = load_results(args.coral_results)
cv_results = load_results(args.cv_results)
if coral_results is None:
print("Error: CoRal results file is required")
return
print("=" * 60)
print("Generating ASR Comparison Plots")
print("=" * 60)
print(f"Output directory: {args.output_dir}")
print()
# Generate CoRal plots
print("Generating CoRal v2 plots...")
plot_wer_comparison(coral_results, args.output_dir / "wer_comparison.png", "CoRal v2")
plot_cer_comparison(coral_results, args.output_dir / "cer_comparison.png", "CoRal v2")
plot_rtf_comparison(coral_results, args.output_dir / "rtf_comparison.png", "CoRal v2")
plot_accuracy_vs_speed(coral_results, args.output_dir / "accuracy_vs_speed.png", "CoRal v2")
# Generate Common Voice plots if available
if cv_results:
print("\nGenerating Common Voice plots...")
plot_wer_comparison(
cv_results, args.output_dir / "wer_comparison_cv.png", "Common Voice Danish"
)
plot_cer_comparison(
cv_results, args.output_dir / "cer_comparison_cv.png", "Common Voice Danish"
)
plot_rtf_comparison(
cv_results, args.output_dir / "rtf_comparison_cv.png", "Common Voice Danish"
)
# Multi-dataset comparison
print("\nGenerating multi-dataset comparison...")
plot_multi_dataset_comparison(
coral_results, cv_results, args.output_dir / "multi_dataset_wer.png"
)
print("\n" + "=" * 60)
print("Plot generation complete!")
print("=" * 60)
if __name__ == "__main__":
main()