| |
| """ |
| Generate comparison plots for ASR model benchmarks. |
| |
| Creates publication-quality visualizations comparing hvisketiske-v2 |
| against other Danish ASR models on accuracy and performance metrics. |
| |
| Usage: |
| python huggingface/generate_plots.py |
| |
| # Specify custom result files: |
| python huggingface/generate_plots.py \ |
| --coral-results ./results/full_comparison2.json \ |
| --cv-results ./results/common_voice_comparison.json |
| |
| Output: |
| huggingface/plots/ |
| ├── wer_comparison.png |
| ├── cer_comparison.png |
| ├── rtf_comparison.png |
| └── accuracy_vs_speed.png |
| """ |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| |
| plt.style.use("seaborn-v0_8-whitegrid") |
|
|
| |
| COLORS = { |
| "hvisketiske": "#2ecc71", |
| "qwen3-base": "#27ae60", |
| "hviske-v2": "#3498db", |
| "hviske-v3": "#2980b9", |
| "faster": "#e74c3c", |
| "turbo": "#e67e22", |
| "default": "#95a5a6", |
| } |
|
|
| |
| MODEL_DISPLAY_NAMES = { |
| "Qwen3-ASR (checkpoint-23448)": "hvisketiske-v2\n(Qwen3-ASR finetuned)", |
| "hviske-v3-conversation (Whisper Large v3)": "hviske-v3\n(Whisper v3)", |
| "hviske-v2 (Whisper Large v2)": "hviske-v2\n(Whisper v2)", |
| "faster-hviske-v2 (CT2 distilled)": "faster-hviske-v2\n(CT2 distilled)", |
| "Whisper Large v3 Turbo": "Whisper v3 Turbo\n(faster-whisper)", |
| "Qwen3-ASR-1.7B (base)": "Qwen3-ASR-1.7B\n(base, not finetuned)", |
| } |
|
|
|
|
| def get_model_color(model_name: str) -> str: |
| """Get color for a model based on its name.""" |
| name_lower = model_name.lower() |
|
|
| |
| if "hvisketiske" in name_lower or "checkpoint" in name_lower: |
| return COLORS["hvisketiske"] |
| |
| elif "qwen3-asr-1.7b" in name_lower and "base" in name_lower: |
| return COLORS["qwen3-base"] |
| elif "qwen" in name_lower: |
| return COLORS["hvisketiske"] |
| |
| elif "turbo" in name_lower: |
| return COLORS["turbo"] |
| |
| elif "faster" in name_lower or "ct2" in name_lower: |
| return COLORS["faster"] |
| |
| elif "hviske-v3" in name_lower or "v3" in name_lower: |
| return COLORS["hviske-v3"] |
| |
| elif "hviske-v2" in name_lower or "v2" in name_lower: |
| return COLORS["hviske-v2"] |
| return COLORS["default"] |
|
|
|
|
| def get_display_name(model_name: str) -> str: |
| """Get display name for a model.""" |
| return MODEL_DISPLAY_NAMES.get(model_name, model_name) |
|
|
|
|
| def load_results(path: Path) -> Optional[dict]: |
| """Load benchmark results from JSON file.""" |
| if not path.exists(): |
| print(f"Warning: Results file not found: {path}") |
| return None |
| with open(path, "r", encoding="utf-8") as f: |
| return json.load(f) |
|
|
|
|
| def extract_metrics(results: dict) -> Tuple[List[str], List[float], List[float], List[float], List[str]]: |
| """ |
| Extract metrics from results dictionary. |
| |
| Returns: |
| Tuple of (names, wer_values, cer_values, rtf_values, colors) |
| """ |
| names = [] |
| wer_values = [] |
| cer_values = [] |
| rtf_values = [] |
| colors = [] |
|
|
| for model_name, data in results["models"].items(): |
| display_name = get_display_name(model_name) |
| names.append(display_name) |
| wer_values.append(data["accuracy"]["wer"] * 100) |
| cer_values.append(data["accuracy"]["cer"] * 100) |
| rtf_values.append(data["performance"]["real_time_factor"]) |
| colors.append(get_model_color(model_name)) |
|
|
| return names, wer_values, cer_values, rtf_values, colors |
|
|
|
|
| def plot_wer_comparison( |
| results: dict, |
| output_path: Path, |
| dataset_name: str = "CoRal v2", |
| ) -> None: |
| """Generate WER comparison bar chart.""" |
| names, wer_values, _, _, colors = extract_metrics(results) |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
| bars = ax.bar(names, wer_values, color=colors, edgecolor="white", linewidth=1.5) |
|
|
| |
| for bar, val in zip(bars, wer_values): |
| height = bar.get_height() |
| ax.annotate( |
| f"{val:.1f}%", |
| xy=(bar.get_x() + bar.get_width() / 2, height), |
| xytext=(0, 5), |
| textcoords="offset points", |
| ha="center", |
| va="bottom", |
| fontsize=12, |
| fontweight="bold", |
| ) |
|
|
| ax.set_ylabel("Word Error Rate (%)", fontsize=12) |
| ax.set_title(f"WER Comparison on {dataset_name}", fontsize=14, fontweight="bold") |
| ax.set_ylim(0, max(wer_values) * 1.2) |
|
|
| |
| ax.yaxis.grid(True, linestyle="--", alpha=0.7) |
| ax.set_axisbelow(True) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved: {output_path}") |
|
|
|
|
| def plot_cer_comparison( |
| results: dict, |
| output_path: Path, |
| dataset_name: str = "CoRal v2", |
| ) -> None: |
| """Generate CER comparison bar chart.""" |
| names, _, cer_values, _, colors = extract_metrics(results) |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
| bars = ax.bar(names, cer_values, color=colors, edgecolor="white", linewidth=1.5) |
|
|
| |
| for bar, val in zip(bars, cer_values): |
| height = bar.get_height() |
| ax.annotate( |
| f"{val:.1f}%", |
| xy=(bar.get_x() + bar.get_width() / 2, height), |
| xytext=(0, 5), |
| textcoords="offset points", |
| ha="center", |
| va="bottom", |
| fontsize=12, |
| fontweight="bold", |
| ) |
|
|
| ax.set_ylabel("Character Error Rate (%)", fontsize=12) |
| ax.set_title(f"CER Comparison on {dataset_name}", fontsize=14, fontweight="bold") |
| ax.set_ylim(0, max(cer_values) * 1.2) |
|
|
| |
| ax.yaxis.grid(True, linestyle="--", alpha=0.7) |
| ax.set_axisbelow(True) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved: {output_path}") |
|
|
|
|
| def plot_rtf_comparison( |
| results: dict, |
| output_path: Path, |
| dataset_name: str = "CoRal v2", |
| ) -> None: |
| """Generate RTF/speed comparison bar chart.""" |
| names, _, _, rtf_values, colors = extract_metrics(results) |
|
|
| fig, ax = plt.subplots(figsize=(8, 5)) |
|
|
| bars = ax.bar(names, rtf_values, color=colors, edgecolor="white", linewidth=1.5) |
|
|
| |
| for bar, val in zip(bars, rtf_values): |
| height = bar.get_height() |
| ax.annotate( |
| f"{val:.3f}", |
| xy=(bar.get_x() + bar.get_width() / 2, height), |
| xytext=(0, 5), |
| textcoords="offset points", |
| ha="center", |
| va="bottom", |
| fontsize=12, |
| fontweight="bold", |
| ) |
|
|
| |
| ax.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, label="Real-time (RTF=1.0)") |
|
|
| ax.set_ylabel("Real-Time Factor (lower is faster)", fontsize=12) |
| ax.set_title(f"Speed Comparison on {dataset_name}", fontsize=14, fontweight="bold") |
| ax.set_ylim(0, max(max(rtf_values) * 1.3, 1.1)) |
| ax.legend(loc="upper right") |
|
|
| |
| ax.yaxis.grid(True, linestyle="--", alpha=0.7) |
| ax.set_axisbelow(True) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved: {output_path}") |
|
|
|
|
| def plot_accuracy_vs_speed( |
| results: dict, |
| output_path: Path, |
| dataset_name: str = "CoRal v2", |
| ) -> None: |
| """Generate accuracy vs speed scatter plot.""" |
| fig, ax = plt.subplots(figsize=(9, 6)) |
|
|
| for model_name, data in results["models"].items(): |
| wer = data["accuracy"]["wer"] * 100 |
| rtf = data["performance"]["real_time_factor"] |
| color = get_model_color(model_name) |
| display_name = get_display_name(model_name) |
|
|
| |
| size_str = data["model_size"] |
| if "1.7B" in size_str: |
| size = 400 |
| elif "2B" in size_str: |
| size = 500 |
| else: |
| size = 300 |
|
|
| ax.scatter( |
| rtf, |
| wer, |
| s=size, |
| c=color, |
| alpha=0.7, |
| edgecolors="white", |
| linewidth=2, |
| label=display_name.replace("\n", " "), |
| ) |
|
|
| |
| ax.annotate( |
| display_name.replace("\n", " "), |
| xy=(rtf, wer), |
| xytext=(10, 10), |
| textcoords="offset points", |
| fontsize=10, |
| ha="left", |
| ) |
|
|
| |
| ax.axvline(x=1.0, color="red", linestyle="--", linewidth=1, alpha=0.5, label="Real-time") |
|
|
| ax.set_xlabel("Real-Time Factor (lower is faster)", fontsize=12) |
| ax.set_ylabel("Word Error Rate (%)", fontsize=12) |
| ax.set_title( |
| f"Accuracy vs Speed Trade-off on {dataset_name}\n(bubble size = model parameters)", |
| fontsize=14, |
| fontweight="bold", |
| ) |
|
|
| |
| all_wer = [d["accuracy"]["wer"] * 100 for d in results["models"].values()] |
| all_rtf = [d["performance"]["real_time_factor"] for d in results["models"].values()] |
| ax.set_xlim(0, max(all_rtf) * 1.5) |
| ax.set_ylim(min(all_wer) * 0.8, max(all_wer) * 1.2) |
|
|
| |
| ax.grid(True, linestyle="--", alpha=0.7) |
|
|
| |
| ax.annotate( |
| "Better", |
| xy=(0.02, min(all_wer) * 0.85), |
| fontsize=10, |
| color="green", |
| fontweight="bold", |
| ) |
| ax.annotate( |
| "Faster & More Accurate", |
| xy=(0.02, min(all_wer) * 0.9), |
| fontsize=8, |
| color="gray", |
| ) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved: {output_path}") |
|
|
|
|
| def plot_multi_dataset_comparison( |
| coral_results: dict, |
| cv_results: Optional[dict], |
| output_path: Path, |
| ) -> None: |
| """Generate multi-dataset WER comparison plot.""" |
| fig, ax = plt.subplots(figsize=(10, 6)) |
|
|
| |
| datasets = ["CoRal v2"] |
| if cv_results: |
| datasets.append("Common Voice") |
|
|
| |
| model_names = list(coral_results["models"].keys()) |
| x = np.arange(len(datasets)) |
| width = 0.35 |
|
|
| for i, model_name in enumerate(model_names): |
| display_name = get_display_name(model_name) |
| color = get_model_color(model_name) |
|
|
| wer_values = [coral_results["models"][model_name]["accuracy"]["wer"] * 100] |
| if cv_results and model_name in cv_results["models"]: |
| wer_values.append(cv_results["models"][model_name]["accuracy"]["wer"] * 100) |
| elif cv_results: |
| wer_values.append(0) |
|
|
| offset = (i - len(model_names) / 2 + 0.5) * width |
| bars = ax.bar( |
| x + offset, |
| wer_values, |
| width, |
| label=display_name.replace("\n", " "), |
| color=color, |
| edgecolor="white", |
| linewidth=1.5, |
| ) |
|
|
| |
| for bar, val in zip(bars, wer_values): |
| if val > 0: |
| height = bar.get_height() |
| ax.annotate( |
| f"{val:.1f}%", |
| xy=(bar.get_x() + bar.get_width() / 2, height), |
| xytext=(0, 3), |
| textcoords="offset points", |
| ha="center", |
| va="bottom", |
| fontsize=10, |
| fontweight="bold", |
| ) |
|
|
| ax.set_ylabel("Word Error Rate (%)", fontsize=12) |
| ax.set_title("WER Comparison Across Datasets", fontsize=14, fontweight="bold") |
| ax.set_xticks(x) |
| ax.set_xticklabels(datasets, fontsize=11) |
| ax.legend(loc="upper right") |
| ax.yaxis.grid(True, linestyle="--", alpha=0.7) |
| ax.set_axisbelow(True) |
|
|
| plt.tight_layout() |
| plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved: {output_path}") |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| """Parse command line arguments.""" |
| parser = argparse.ArgumentParser(description="Generate ASR comparison plots") |
|
|
| parser.add_argument( |
| "--coral-results", |
| type=Path, |
| default=Path("results/full_comparison2.json"), |
| help="Path to CoRal benchmark results", |
| ) |
| parser.add_argument( |
| "--cv-results", |
| type=Path, |
| default=Path("results/common_voice_comparison.json"), |
| help="Path to Common Voice benchmark results", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path(__file__).parent / "plots", |
| help="Output directory for plots", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| """Main entry point for plot generation.""" |
| args = parse_args() |
|
|
| |
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| coral_results = load_results(args.coral_results) |
| cv_results = load_results(args.cv_results) |
|
|
| if coral_results is None: |
| print("Error: CoRal results file is required") |
| return |
|
|
| print("=" * 60) |
| print("Generating ASR Comparison Plots") |
| print("=" * 60) |
| print(f"Output directory: {args.output_dir}") |
| print() |
|
|
| |
| print("Generating CoRal v2 plots...") |
| plot_wer_comparison(coral_results, args.output_dir / "wer_comparison.png", "CoRal v2") |
| plot_cer_comparison(coral_results, args.output_dir / "cer_comparison.png", "CoRal v2") |
| plot_rtf_comparison(coral_results, args.output_dir / "rtf_comparison.png", "CoRal v2") |
| plot_accuracy_vs_speed(coral_results, args.output_dir / "accuracy_vs_speed.png", "CoRal v2") |
|
|
| |
| if cv_results: |
| print("\nGenerating Common Voice plots...") |
| plot_wer_comparison( |
| cv_results, args.output_dir / "wer_comparison_cv.png", "Common Voice Danish" |
| ) |
| plot_cer_comparison( |
| cv_results, args.output_dir / "cer_comparison_cv.png", "Common Voice Danish" |
| ) |
| plot_rtf_comparison( |
| cv_results, args.output_dir / "rtf_comparison_cv.png", "Common Voice Danish" |
| ) |
|
|
| |
| print("\nGenerating multi-dataset comparison...") |
| plot_multi_dataset_comparison( |
| coral_results, cv_results, args.output_dir / "multi_dataset_wer.png" |
| ) |
|
|
| print("\n" + "=" * 60) |
| print("Plot generation complete!") |
| print("=" * 60) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|