pluttodk commited on
Commit
d38720b
·
1 Parent(s): 0b6085c
README.md CHANGED
@@ -1,3 +1,434 @@
1
- ---
2
- license: openrail
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: openrail
3
+ language:
4
+ - da
5
+ base_model: Qwen/Qwen3-ASR-1.7B
6
+ tags:
7
+ - automatic-speech-recognition
8
+ - danish
9
+ - qwen
10
+ - asr
11
+ - speech-to-text
12
+ - coral
13
+ - streaming
14
+ datasets:
15
+ - alexandrainst/coral
16
+ - mozilla-foundation/common_voice_17_0
17
+ library_name: transformers
18
+ pipeline_tag: automatic-speech-recognition
19
+ metrics:
20
+ - wer
21
+ - cer
22
+ model-index:
23
+ - name: hvisketiske-v2
24
+ results:
25
+ - task:
26
+ type: automatic-speech-recognition
27
+ name: Speech Recognition
28
+ dataset:
29
+ type: alexandrainst/coral
30
+ name: CoRal v2 Test
31
+ split: test
32
+ metrics:
33
+ - type: wer
34
+ value: 18.47
35
+ name: WER
36
+ - type: cer
37
+ value: 7.86
38
+ name: CER
39
+ ---
40
+
41
+ # hvisketiske-v2: Danish ASR Model
42
+
43
+ **hvisketiske-v2** is a state-of-the-art Danish automatic speech recognition (ASR) model based on [Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B), finetuned on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral) for improved Danish transcription accuracy.
44
+
45
+ ## Key Highlights
46
+
47
+ | Feature | Value |
48
+ |---------|-------|
49
+ | **WER on CoRal v2** | 18.47% (14% better than Whisper v3) |
50
+ | **CER on CoRal v2** | 7.86% (11% better than Whisper v3) |
51
+ | **Real-Time Factor** | 0.086 (45% faster than Whisper v3) |
52
+ | **Model Size** | ~1.7B parameters |
53
+
54
+ ### Inherited Features from Qwen3-ASR
55
+
56
+ - **Streaming/Real-time transcription** via vLLM backend
57
+ - **Singing detection** - can transcribe singing voice and songs with BGM
58
+ - **Word-level timestamps** via forced alignment
59
+ - **30+ language support** (Danish optimized)
60
+ - **Long audio support** - up to 20 minutes per request
61
+
62
+ ---
63
+
64
+ ## Performance Comparison
65
+
66
+ ### CoRal v2 Test Set (9,123 samples, 17.3 hours)
67
+
68
+ | Model | WER | CER | RTF | Throughput | Parameters |
69
+ |-------|-----|-----|-----|------------|------------|
70
+ | **hvisketiske-v2** | **18.47%** | **7.86%** | **0.086** | 1.71 samples/s | ~1.7B |
71
+ | hviske-v3 (Whisper Large v3) | 21.47% | 8.79% | 0.156 | 0.94 samples/s | ~2B |
72
+
73
+ **Improvements over Whisper Large v3:**
74
+ - **14% reduction** in Word Error Rate
75
+ - **11% reduction** in Character Error Rate
76
+ - **45% faster** inference speed
77
+ - **15% fewer** parameters
78
+
79
+ ### Comparison Plots
80
+
81
+ ![WER Comparison](plots/wer_comparison.png)
82
+ ![Speed Comparison](plots/rtf_comparison.png)
83
+ ![Accuracy vs Speed](plots/accuracy_vs_speed.png)
84
+
85
+ ---
86
+
87
+ ## Quick Start
88
+
89
+ ### Installation
90
+
91
+ ```bash
92
+ pip install qwen-asr transformers torch
93
+ ```
94
+
95
+ ### Basic Usage
96
+
97
+ ```python
98
+ from qwen_asr import Qwen3ASRModel
99
+
100
+ # Load the model
101
+ model = Qwen3ASRModel.from_pretrained(
102
+ "pluttodk/hvisketiske-v2",
103
+ dtype="bfloat16",
104
+ device_map="cuda:0",
105
+ )
106
+
107
+ # Transcribe audio file
108
+ results = model.transcribe(
109
+ audio="path/to/danish_audio.wav",
110
+ language="Danish",
111
+ )
112
+
113
+ print(results[0].text)
114
+ ```
115
+
116
+ ---
117
+
118
+ ## Advanced Usage
119
+
120
+ ### Batch Transcription (Fast Processing)
121
+
122
+ Process multiple audio files efficiently in a single call:
123
+
124
+ ```python
125
+ from qwen_asr import Qwen3ASRModel
126
+
127
+ model = Qwen3ASRModel.from_pretrained(
128
+ "pluttodk/hvisketiske-v2",
129
+ dtype="bfloat16",
130
+ device_map="cuda:0",
131
+ max_inference_batch_size=16, # Process up to 16 files at once
132
+ )
133
+
134
+ # Batch transcribe multiple files
135
+ audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"]
136
+ results = model.transcribe(
137
+ audio=audio_files,
138
+ language="Danish",
139
+ )
140
+
141
+ for i, result in enumerate(results):
142
+ print(f"File {i+1}: {result.text}")
143
+ ```
144
+
145
+ ### Transcription with Timestamps
146
+
147
+ Get word-level timestamps using the forced aligner:
148
+
149
+ ```python
150
+ from qwen_asr import Qwen3ASRModel
151
+
152
+ model = Qwen3ASRModel.from_pretrained(
153
+ "pluttodk/hvisketiske-v2",
154
+ forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
155
+ dtype="bfloat16",
156
+ device_map="cuda:0",
157
+ )
158
+
159
+ results = model.transcribe(
160
+ audio="path/to/audio.wav",
161
+ language="Danish",
162
+ return_time_stamps=True,
163
+ )
164
+
165
+ # Access word-level timestamps
166
+ for item in results[0].time_stamps.items:
167
+ print(f"{item.start_time:.2f}s - {item.end_time:.2f}s: {item.text}")
168
+ ```
169
+
170
+ ### Streaming/Real-time Transcription (vLLM Backend)
171
+
172
+ For real-time streaming transcription, use the vLLM backend:
173
+
174
+ ```python
175
+ from qwen_asr import Qwen3ASRModel
176
+
177
+ # Initialize with vLLM backend for streaming
178
+ model = Qwen3ASRModel.LLM(
179
+ model="pluttodk/hvisketiske-v2",
180
+ gpu_memory_utilization=0.8,
181
+ )
182
+
183
+ # Initialize streaming state
184
+ state = model.init_streaming_state(
185
+ language="Danish",
186
+ chunk_size_sec=2.0, # Process audio in 2-second chunks
187
+ )
188
+
189
+ # Simulate streaming audio (16kHz mono float32)
190
+ import numpy as np
191
+
192
+ def audio_stream():
193
+ """Replace with actual audio stream from microphone."""
194
+ for chunk in audio_chunks:
195
+ yield np.array(chunk, dtype=np.float32)
196
+
197
+ # Process streaming audio
198
+ for audio_chunk in audio_stream():
199
+ state = model.streaming_transcribe(audio_chunk, state)
200
+ print(f"Current transcription: {state.text}")
201
+
202
+ # Finalize stream
203
+ state = model.finish_streaming_transcribe(state)
204
+ print(f"Final transcription: {state.text}")
205
+ ```
206
+
207
+ ### Using with Transformers Directly
208
+
209
+ For more control, use the model directly with transformers:
210
+
211
+ ```python
212
+ from transformers import AutoModel, AutoProcessor
213
+ import torch
214
+ import librosa
215
+
216
+ # Load model and processor
217
+ model = AutoModel.from_pretrained(
218
+ "pluttodk/hvisketiske-v2",
219
+ trust_remote_code=True,
220
+ torch_dtype=torch.bfloat16,
221
+ device_map="cuda:0",
222
+ )
223
+ processor = AutoProcessor.from_pretrained(
224
+ "pluttodk/hvisketiske-v2",
225
+ trust_remote_code=True,
226
+ )
227
+
228
+ # Load and preprocess audio
229
+ audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
230
+
231
+ # Build input using chat template
232
+ messages = [
233
+ {"role": "system", "content": ""},
234
+ {"role": "user", "content": [{"type": "audio", "audio": audio}]},
235
+ ]
236
+
237
+ text = processor.apply_chat_template(
238
+ messages,
239
+ add_generation_prompt=True,
240
+ tokenize=False
241
+ )
242
+ text = text + "language Danish<asr_text>"
243
+
244
+ # Process and generate
245
+ inputs = processor(text=[text], audio=[audio], return_tensors="pt", padding=True)
246
+ inputs = inputs.to(model.device).to(model.dtype)
247
+
248
+ output_ids = model.generate(**inputs, max_new_tokens=512)
249
+ transcription = processor.batch_decode(
250
+ output_ids[:, inputs["input_ids"].shape[1]:],
251
+ skip_special_tokens=True,
252
+ )[0]
253
+
254
+ print(transcription)
255
+ ```
256
+
257
+ ### Singing Detection & Multi-Audio Support
258
+
259
+ The model inherits Qwen3-ASR's ability to handle singing and background music:
260
+
261
+ ```python
262
+ from qwen_asr import Qwen3ASRModel
263
+
264
+ model = Qwen3ASRModel.from_pretrained(
265
+ "pluttodk/hvisketiske-v2",
266
+ dtype="bfloat16",
267
+ device_map="cuda:0",
268
+ )
269
+
270
+ # Transcribe audio with singing or background music
271
+ results = model.transcribe(
272
+ audio="path/to/song.wav",
273
+ language="Danish", # or None for auto-detection
274
+ )
275
+
276
+ print(results[0].text)
277
+ ```
278
+
279
+ ---
280
+
281
+ ## Model Details
282
+
283
+ ### Model Description
284
+
285
+ hvisketiske-v2 is a Danish-specialized automatic speech recognition model created by finetuning [Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral). The model achieves state-of-the-art performance on Danish speech recognition while maintaining fast inference speeds.
286
+
287
+ - **Developed by:** Mathias Oliver Valdbjørn Rønnelund
288
+ - **Model type:** Encoder-decoder speech recognition model
289
+ - **Language:** Danish (primary), with inherited multilingual capabilities
290
+ - **License:** Apache 2.0
291
+ - **Finetuned from:** [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
292
+
293
+ ### Architecture
294
+
295
+ The model inherits the Qwen3-ASR architecture:
296
+
297
+ | Component | Specification |
298
+ |-----------|--------------|
299
+ | Audio Encoder | 24-layer transformer (1024 hidden dim, 16 attention heads) |
300
+ | Text Decoder | 28-layer transformer (2048 hidden dim, 16 attention heads) |
301
+ | Total Parameters | ~1.7 billion |
302
+ | Precision | bfloat16 |
303
+ | Audio Input | 16kHz mono WAV |
304
+
305
+ ---
306
+
307
+ ## Training Details
308
+
309
+ ### Training Data
310
+
311
+ The model was finetuned on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral), a comprehensive Danish speech corpus containing:
312
+ - Diverse Danish speakers across demographics
313
+ - Various recording conditions and audio qualities
314
+ - Natural conversational speech
315
+ - Read-aloud speech
316
+
317
+ ### Training Procedure
318
+
319
+ **Training Approach:** Supervised Fine-Tuning (SFT) with chat template formatting
320
+
321
+ **Preprocessing:**
322
+ - Audio resampled to 16kHz mono
323
+ - Chat template applied with system prompt, audio input, and target transcription
324
+ - Prefix masking to train only on transcription tokens
325
+
326
+ **Training Hyperparameters:**
327
+
328
+ | Parameter | Value |
329
+ |-----------|-------|
330
+ | Base model | Qwen/Qwen3-ASR-1.7B |
331
+ | Learning rate | 2e-5 |
332
+ | Batch size (per device) | 8 |
333
+ | Gradient accumulation steps | 4 |
334
+ | Effective batch size | 32 |
335
+ | Epochs | 3 |
336
+ | Warmup ratio | 0.1 |
337
+ | Weight decay | 0.01 |
338
+ | Max gradient norm | 1.0 |
339
+ | Precision | bfloat16 |
340
+ | Optimizer | AdamW |
341
+ | LR scheduler | Linear decay |
342
+ | Total training steps | 23,448 |
343
+
344
+ **Hardware:** Training performed on NVIDIA GPUs (~25GB GPU memory per device)
345
+
346
+ ---
347
+
348
+ ## Evaluation
349
+
350
+ ### Test Data
351
+
352
+ Evaluated on the CoRal v2 test split:
353
+ - **9,123 samples**
354
+ - **17.3 hours** of audio
355
+ - Diverse Danish speakers and recording conditions
356
+
357
+ ### Metrics
358
+
359
+ | Metric | Description |
360
+ |--------|-------------|
361
+ | **WER** | Word Error Rate - percentage of words incorrectly transcribed (lower is better) |
362
+ | **CER** | Character Error Rate - percentage of characters incorrectly transcribed (lower is better) |
363
+ | **RTF** | Real-Time Factor - ratio of processing time to audio duration (< 1.0 = faster than real-time) |
364
+
365
+ ### Results Summary
366
+
367
+ | Model | WER | CER | RTF | Throughput |
368
+ |-------|-----|-----|-----|------------|
369
+ | **hvisketiske-v2** | **18.47%** | **7.86%** | **0.086** | 1.71 samples/sec |
370
+ | hviske-v3 (Whisper v3) | 21.47% | 8.79% | 0.156 | 0.94 samples/sec |
371
+
372
+ ---
373
+
374
+ ## Limitations
375
+
376
+ - **Language:** Optimized for Danish; other languages may have degraded performance compared to base Qwen3-ASR
377
+ - **Audio quality:** Best results with clear speech; noisy environments may affect accuracy
378
+ - **Domain:** Trained on CoRal v2 which is primarily conversational/read-aloud speech; specialized domains (medical, legal, technical) may have higher error rates
379
+ - **Streaming:** Real-time streaming requires vLLM backend installation
380
+
381
+ ## Intended Use
382
+
383
+ ### Primary Use Cases
384
+ - Danish speech-to-text transcription
385
+ - Subtitle generation for Danish content
386
+ - Voice assistant backends
387
+ - Meeting transcription
388
+ - Accessibility applications
389
+
390
+ ### Out-of-Scope Use
391
+ - Non-Danish languages (use base Qwen3-ASR instead)
392
+ - Real-time speaker diarization (not supported)
393
+ - Emotion/sentiment detection from speech
394
+
395
+ ---
396
+
397
+ ## Citation
398
+
399
+ If you use this model, please cite:
400
+
401
+ ```bibtex
402
+ @misc{hvisketiske-v2,
403
+ author = {Rønnelund, Mathias Oliver Valdbjørn},
404
+ title = {hvisketiske-v2: Danish ASR Model based on Qwen3-ASR},
405
+ year = {2025},
406
+ publisher = {HuggingFace},
407
+ url = {https://huggingface.co/pluttodk/hvisketiske-v2}
408
+ }
409
+ ```
410
+
411
+ Also consider citing the base model and dataset:
412
+
413
+ ```bibtex
414
+ @article{qwen3asr,
415
+ title={Qwen3-ASR Technical Report},
416
+ author={Qwen Team},
417
+ journal={arXiv preprint arXiv:2601.21337},
418
+ year={2025}
419
+ }
420
+
421
+ @dataset{coral,
422
+ title={CoRal: A Danish Speech Corpus},
423
+ author={Alexandra Institute},
424
+ year={2024},
425
+ url={https://huggingface.co/datasets/alexandrainst/coral}
426
+ }
427
+ ```
428
+
429
+ ---
430
+
431
+ ## Acknowledgements
432
+
433
+ - [Qwen Team](https://github.com/QwenLM) for the excellent Qwen3-ASR base model
434
+ - [Alexandra Institute](https://alexandra.dk/) for the CoRal v2 Danish speech corpus
__pycache__/evaluate_common_voice.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
__pycache__/generate_plots.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
evaluate_common_voice.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Benchmark ASR models on Common Voice Danish dataset.
4
+
5
+ This script evaluates hvisketiske-v2 (Qwen3-ASR) and hviske-v3 (Whisper)
6
+ on the Mozilla Common Voice Danish test set for comparison.
7
+
8
+ IMPORTANT: Common Voice requires authentication and agreement to terms of use.
9
+ Before running this script:
10
+ 1. Create a HuggingFace account at https://huggingface.co
11
+ 2. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0
12
+ 3. Agree to the dataset terms of use
13
+ 4. Create an access token at https://huggingface.co/settings/tokens
14
+ 5. Login via CLI: `huggingface-cli login`
15
+
16
+ Usage:
17
+ # After logging in:
18
+ python huggingface/evaluate_common_voice.py \
19
+ --hvisketiske-path ./outputs/hvisketiske-v2/checkpoint-23448 \
20
+ --max-samples 1000 \
21
+ --output-file ./results/common_voice_comparison.json
22
+
23
+ # Quick test with fewer samples:
24
+ python huggingface/evaluate_common_voice.py --max-samples 100
25
+
26
+ # Use specific token:
27
+ python huggingface/evaluate_common_voice.py --hf-token YOUR_TOKEN
28
+ """
29
+
30
+ import argparse
31
+ import json
32
+ import sys
33
+ import tempfile
34
+ import time
35
+ from dataclasses import dataclass
36
+ from pathlib import Path
37
+ from typing import List, Optional
38
+
39
+ import soundfile as sf
40
+ from datasets import load_dataset
41
+ from jiwer import cer, wer
42
+ from tqdm import tqdm
43
+
44
+ # Add src to path for imports
45
+ sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
46
+
47
+ from hvisketiske.evaluation.model_adapters import (
48
+ ASRModelAdapter,
49
+ HviskeV3Adapter,
50
+ Qwen3ASRAdapter,
51
+ TranscriptionResult,
52
+ )
53
+ from hvisketiske.evaluation.timing import AggregatedTimingStats
54
+
55
+
56
+ @dataclass
57
+ class CommonVoiceSample:
58
+ """A single Common Voice sample."""
59
+
60
+ audio_path: str
61
+ reference: str
62
+ audio_duration: float
63
+
64
+
65
+ def load_common_voice_danish(
66
+ split: str = "test",
67
+ max_samples: Optional[int] = None,
68
+ cache_dir: Optional[str] = None,
69
+ hf_token: Optional[str] = None,
70
+ ) -> List[CommonVoiceSample]:
71
+ """
72
+ Load Common Voice Danish dataset and prepare samples.
73
+
74
+ Args:
75
+ split: Dataset split to load (test, validation, train).
76
+ max_samples: Maximum number of samples to load.
77
+ cache_dir: Directory to cache audio files.
78
+ hf_token: HuggingFace API token for authentication.
79
+
80
+ Returns:
81
+ List of CommonVoiceSample objects.
82
+ """
83
+ print(f"Loading Common Voice Danish ({split} split)...")
84
+ print("Note: This requires HuggingFace authentication and agreement to dataset terms.")
85
+ print("Visit: https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
86
+ print()
87
+
88
+ try:
89
+ ds = load_dataset(
90
+ "mozilla-foundation/common_voice_17_0",
91
+ "da",
92
+ split=split,
93
+ trust_remote_code=True,
94
+ token=hf_token,
95
+ )
96
+ except Exception as e:
97
+ error_msg = str(e)
98
+ if "EmptyDatasetError" in error_msg or "doesn't contain any data" in error_msg:
99
+ print("\n" + "=" * 70)
100
+ print("ERROR: Cannot access Common Voice dataset.")
101
+ print("=" * 70)
102
+ print("\nThis dataset requires authentication. Please:")
103
+ print("1. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
104
+ print("2. Log in and agree to the terms of use")
105
+ print("3. Run: huggingface-cli login")
106
+ print("4. Or pass --hf-token YOUR_TOKEN to this script")
107
+ print("=" * 70 + "\n")
108
+ raise
109
+
110
+ if max_samples:
111
+ ds = ds.select(range(min(max_samples, len(ds))))
112
+
113
+ print(f"Loaded {len(ds)} samples")
114
+
115
+ # Create temp directory for audio files if not provided
116
+ if cache_dir is None:
117
+ cache_dir = tempfile.mkdtemp(prefix="cv_danish_")
118
+
119
+ cache_path = Path(cache_dir)
120
+ cache_path.mkdir(parents=True, exist_ok=True)
121
+
122
+ samples = []
123
+ print("Preparing audio files...")
124
+ for i, item in enumerate(tqdm(ds, desc="Preparing samples")):
125
+ # Extract audio array and sample rate
126
+ audio_array = item["audio"]["array"]
127
+ sample_rate = item["audio"]["sampling_rate"]
128
+
129
+ # Save to temp file
130
+ audio_path = cache_path / f"sample_{i:06d}.wav"
131
+ sf.write(str(audio_path), audio_array, sample_rate)
132
+
133
+ # Calculate duration
134
+ duration = len(audio_array) / sample_rate
135
+
136
+ samples.append(
137
+ CommonVoiceSample(
138
+ audio_path=str(audio_path),
139
+ reference=item["sentence"],
140
+ audio_duration=duration,
141
+ )
142
+ )
143
+
144
+ return samples
145
+
146
+
147
+ def normalize_text(text: str) -> str:
148
+ """Normalize text for fair comparison."""
149
+ text = text.lower()
150
+ text = " ".join(text.split())
151
+ return text
152
+
153
+
154
+ def evaluate_model(
155
+ model: ASRModelAdapter,
156
+ samples: List[CommonVoiceSample],
157
+ warmup_samples: int = 3,
158
+ ) -> dict:
159
+ """
160
+ Evaluate a model on the Common Voice samples.
161
+
162
+ Args:
163
+ model: Model adapter to evaluate.
164
+ samples: List of samples to evaluate.
165
+ warmup_samples: Number of warmup iterations.
166
+
167
+ Returns:
168
+ Dictionary with evaluation results.
169
+ """
170
+ print(f"\nEvaluating: {model.model_name}")
171
+ print("Loading model...")
172
+ model.load()
173
+
174
+ # Warmup
175
+ if warmup_samples > 0 and samples:
176
+ print(f"Running {warmup_samples} warmup iterations...")
177
+ model.warmup(samples[0].audio_path, num_runs=warmup_samples)
178
+
179
+ # Transcribe all samples
180
+ predictions = []
181
+ individual_times = []
182
+ total_audio_duration = 0.0
183
+ total_inference_time = 0.0
184
+
185
+ print(f"Transcribing {len(samples)} samples...")
186
+ for sample in tqdm(samples, desc=f"Evaluating {model.model_name[:30]}"):
187
+ result = model.transcribe(sample.audio_path)
188
+ predictions.append(result.text)
189
+ individual_times.append(result.inference_time_seconds)
190
+ total_audio_duration += sample.audio_duration
191
+ total_inference_time += result.inference_time_seconds
192
+
193
+ # Normalize text
194
+ predictions_norm = [normalize_text(p) for p in predictions]
195
+ references_norm = [normalize_text(s.reference) for s in samples]
196
+
197
+ # Calculate metrics
198
+ word_error_rate = wer(references_norm, predictions_norm)
199
+ char_error_rate = cer(references_norm, predictions_norm)
200
+
201
+ timing_stats = AggregatedTimingStats(
202
+ total_inference_time_seconds=total_inference_time,
203
+ total_audio_duration_seconds=total_audio_duration,
204
+ num_samples=len(samples),
205
+ individual_times=individual_times,
206
+ )
207
+
208
+ return {
209
+ "model_name": model.model_name,
210
+ "model_size": model.model_size_params,
211
+ "accuracy": {
212
+ "wer": word_error_rate,
213
+ "cer": char_error_rate,
214
+ },
215
+ "performance": {
216
+ "total_inference_time_seconds": timing_stats.total_inference_time_seconds,
217
+ "total_audio_duration_seconds": timing_stats.total_audio_duration_seconds,
218
+ "real_time_factor": timing_stats.real_time_factor,
219
+ "throughput_samples_per_second": timing_stats.throughput_samples_per_second,
220
+ "mean_time_per_sample_seconds": timing_stats.mean_time_per_sample,
221
+ "std_time_per_sample_seconds": timing_stats.std_time_per_sample,
222
+ },
223
+ "num_samples": len(samples),
224
+ }
225
+
226
+
227
+ def print_summary(results: dict) -> None:
228
+ """Print formatted comparison summary."""
229
+ print("\n" + "=" * 80)
230
+ print("COMMON VOICE DANISH - ASR MODEL COMPARISON")
231
+ print("=" * 80)
232
+ print(f"Dataset: mozilla-foundation/common_voice_17_0 (Danish)")
233
+ print(f"Number of models: {len(results['models'])}")
234
+
235
+ sample_count = next(iter(results["models"].values()))["num_samples"]
236
+ print(f"Samples evaluated: {sample_count}")
237
+
238
+ # Accuracy comparison table
239
+ print("\n" + "-" * 80)
240
+ print("ACCURACY METRICS (lower is better)")
241
+ print("-" * 80)
242
+ print(f"{'Model':<45} {'WER':>12} {'CER':>12}")
243
+ print("-" * 80)
244
+ for name, result in sorted(
245
+ results["models"].items(), key=lambda x: x[1]["accuracy"]["wer"]
246
+ ):
247
+ print(
248
+ f"{result['model_name'][:45]:<45} "
249
+ f"{result['accuracy']['wer']:>11.2%} "
250
+ f"{result['accuracy']['cer']:>11.2%}"
251
+ )
252
+
253
+ # Performance comparison table
254
+ print("\n" + "-" * 80)
255
+ print("PERFORMANCE METRICS (RTF < 1.0 = faster than real-time)")
256
+ print("-" * 80)
257
+ print(f"{'Model':<35} {'RTF':>8} {'Throughput':>12} {'Mean Time':>12}")
258
+ print(f"{'':35} {'':>8} {'(samples/s)':>12} {'(s/sample)':>12}")
259
+ print("-" * 80)
260
+ for name, result in sorted(
261
+ results["models"].items(), key=lambda x: x[1]["performance"]["real_time_factor"]
262
+ ):
263
+ perf = result["performance"]
264
+ print(
265
+ f"{result['model_name'][:35]:<35} "
266
+ f"{perf['real_time_factor']:>8.3f} "
267
+ f"{perf['throughput_samples_per_second']:>12.2f} "
268
+ f"{perf['mean_time_per_sample_seconds']:>12.3f}"
269
+ )
270
+
271
+ print("=" * 80)
272
+
273
+
274
+ def parse_args() -> argparse.Namespace:
275
+ """Parse command line arguments."""
276
+ parser = argparse.ArgumentParser(
277
+ description="Benchmark ASR models on Common Voice Danish"
278
+ )
279
+
280
+ parser.add_argument(
281
+ "--output-file",
282
+ type=Path,
283
+ default=Path("results/common_voice_comparison.json"),
284
+ help="Path to save comparison report (JSON)",
285
+ )
286
+ parser.add_argument(
287
+ "--max-samples",
288
+ type=int,
289
+ default=None,
290
+ help="Maximum samples to evaluate (for quick testing)",
291
+ )
292
+ parser.add_argument(
293
+ "--warmup",
294
+ type=int,
295
+ default=3,
296
+ help="Number of warmup iterations per model (default: 3)",
297
+ )
298
+ parser.add_argument(
299
+ "--device",
300
+ type=str,
301
+ default="cuda:0",
302
+ help="Device for inference (default: cuda:0)",
303
+ )
304
+ parser.add_argument(
305
+ "--cache-dir",
306
+ type=str,
307
+ default=None,
308
+ help="Directory to cache audio files",
309
+ )
310
+ parser.add_argument(
311
+ "--hf-token",
312
+ type=str,
313
+ default=None,
314
+ help="HuggingFace API token for authentication (or use huggingface-cli login)",
315
+ )
316
+
317
+ # Model selection
318
+ parser.add_argument(
319
+ "--skip-hviske-v3",
320
+ action="store_true",
321
+ help="Skip hviske-v3-conversation model",
322
+ )
323
+ parser.add_argument(
324
+ "--skip-hvisketiske",
325
+ action="store_true",
326
+ help="Skip hvisketiske-v2 model",
327
+ )
328
+ parser.add_argument(
329
+ "--hvisketiske-path",
330
+ type=str,
331
+ default="./outputs/hvisketiske-v2/checkpoint-23448",
332
+ help="Path to local hvisketiske checkpoint",
333
+ )
334
+
335
+ return parser.parse_args()
336
+
337
+
338
+ def main() -> None:
339
+ """Main entry point for Common Voice evaluation."""
340
+ args = parse_args()
341
+
342
+ # Load dataset
343
+ samples = load_common_voice_danish(
344
+ split="test",
345
+ max_samples=args.max_samples,
346
+ cache_dir=args.cache_dir,
347
+ hf_token=args.hf_token,
348
+ )
349
+
350
+ # Configure models to evaluate
351
+ models = []
352
+
353
+ if not args.skip_hviske_v3:
354
+ models.append(
355
+ HviskeV3Adapter(
356
+ model_id="syvai/hviske-v3-conversation",
357
+ device=args.device,
358
+ )
359
+ )
360
+
361
+ if not args.skip_hvisketiske:
362
+ models.append(
363
+ Qwen3ASRAdapter(
364
+ model_path=args.hvisketiske_path,
365
+ device=args.device,
366
+ )
367
+ )
368
+
369
+ if not models:
370
+ print("Error: No models selected for evaluation")
371
+ sys.exit(1)
372
+
373
+ print("=" * 60)
374
+ print("Common Voice Danish ASR Evaluation")
375
+ print("=" * 60)
376
+ print(f"Dataset: mozilla-foundation/common_voice_17_0")
377
+ print(f"Samples: {len(samples)}")
378
+ print(f"Device: {args.device}")
379
+ print(f"Warmup iterations: {args.warmup}")
380
+ print(f"Models to evaluate: {len(models)}")
381
+ for m in models:
382
+ print(f" - {m.model_name} ({m.model_size_params})")
383
+ print("=" * 60)
384
+
385
+ # Evaluate all models
386
+ results = {"dataset": "mozilla-foundation/common_voice_17_0", "models": {}}
387
+
388
+ for model in models:
389
+ model_results = evaluate_model(model, samples, warmup_samples=args.warmup)
390
+ results["models"][model.model_name] = model_results
391
+
392
+ # Print summary
393
+ print_summary(results)
394
+
395
+ # Save results
396
+ args.output_file.parent.mkdir(parents=True, exist_ok=True)
397
+ with open(args.output_file, "w", encoding="utf-8") as f:
398
+ json.dump(results, f, indent=2, ensure_ascii=False)
399
+
400
+ print(f"\nResults saved to: {args.output_file}")
401
+
402
+
403
+ if __name__ == "__main__":
404
+ main()
generate_plots.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Generate comparison plots for ASR model benchmarks.
4
+
5
+ Creates publication-quality visualizations comparing hvisketiske-v2
6
+ against other Danish ASR models on accuracy and performance metrics.
7
+
8
+ Usage:
9
+ python huggingface/generate_plots.py
10
+
11
+ # Specify custom result files:
12
+ python huggingface/generate_plots.py \
13
+ --coral-results ./results/full_comparison2.json \
14
+ --cv-results ./results/common_voice_comparison.json
15
+
16
+ Output:
17
+ huggingface/plots/
18
+ ├── wer_comparison.png
19
+ ├── cer_comparison.png
20
+ ├── rtf_comparison.png
21
+ └── accuracy_vs_speed.png
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ from pathlib import Path
27
+ from typing import Dict, List, Optional, Tuple
28
+
29
+ import matplotlib.pyplot as plt
30
+ import numpy as np
31
+
32
+ # Use a clean style
33
+ plt.style.use("seaborn-v0_8-whitegrid")
34
+
35
+ # Color palette - distinct colors for models
36
+ COLORS = {
37
+ "hvisketiske": "#2ecc71", # Green for our model (best)
38
+ "qwen3-base": "#27ae60", # Darker green for base Qwen
39
+ "hviske-v2": "#3498db", # Blue for hviske-v2
40
+ "hviske-v3": "#2980b9", # Darker blue for hviske-v3
41
+ "faster": "#e74c3c", # Red for faster-whisper models
42
+ "turbo": "#e67e22", # Orange for turbo
43
+ "default": "#95a5a6", # Gray for others
44
+ }
45
+
46
+ # Model display names mapping
47
+ MODEL_DISPLAY_NAMES = {
48
+ "Qwen3-ASR (checkpoint-23448)": "hvisketiske-v2\n(Qwen3-ASR finetuned)",
49
+ "hviske-v3-conversation (Whisper Large v3)": "hviske-v3\n(Whisper v3)",
50
+ "hviske-v2 (Whisper Large v2)": "hviske-v2\n(Whisper v2)",
51
+ "faster-hviske-v2 (CT2 distilled)": "faster-hviske-v2\n(CT2 distilled)",
52
+ "Whisper Large v3 Turbo": "Whisper v3 Turbo\n(faster-whisper)",
53
+ "Qwen3-ASR-1.7B (base)": "Qwen3-ASR-1.7B\n(base, not finetuned)",
54
+ }
55
+
56
+
57
+ def get_model_color(model_name: str) -> str:
58
+ """Get color for a model based on its name."""
59
+ name_lower = model_name.lower()
60
+
61
+ # Our finetuned model (highest priority)
62
+ if "hvisketiske" in name_lower or "checkpoint" in name_lower:
63
+ return COLORS["hvisketiske"]
64
+ # Base Qwen3-ASR (not finetuned)
65
+ elif "qwen3-asr-1.7b" in name_lower and "base" in name_lower:
66
+ return COLORS["qwen3-base"]
67
+ elif "qwen" in name_lower:
68
+ return COLORS["hvisketiske"]
69
+ # Turbo model
70
+ elif "turbo" in name_lower:
71
+ return COLORS["turbo"]
72
+ # Faster-whisper models
73
+ elif "faster" in name_lower or "ct2" in name_lower:
74
+ return COLORS["faster"]
75
+ # hviske-v3
76
+ elif "hviske-v3" in name_lower or "v3" in name_lower:
77
+ return COLORS["hviske-v3"]
78
+ # hviske-v2
79
+ elif "hviske-v2" in name_lower or "v2" in name_lower:
80
+ return COLORS["hviske-v2"]
81
+ return COLORS["default"]
82
+
83
+
84
+ def get_display_name(model_name: str) -> str:
85
+ """Get display name for a model."""
86
+ return MODEL_DISPLAY_NAMES.get(model_name, model_name)
87
+
88
+
89
+ def load_results(path: Path) -> Optional[dict]:
90
+ """Load benchmark results from JSON file."""
91
+ if not path.exists():
92
+ print(f"Warning: Results file not found: {path}")
93
+ return None
94
+ with open(path, "r", encoding="utf-8") as f:
95
+ return json.load(f)
96
+
97
+
98
+ def extract_metrics(results: dict) -> Tuple[List[str], List[float], List[float], List[float], List[str]]:
99
+ """
100
+ Extract metrics from results dictionary.
101
+
102
+ Returns:
103
+ Tuple of (names, wer_values, cer_values, rtf_values, colors)
104
+ """
105
+ names = []
106
+ wer_values = []
107
+ cer_values = []
108
+ rtf_values = []
109
+ colors = []
110
+
111
+ for model_name, data in results["models"].items():
112
+ display_name = get_display_name(model_name)
113
+ names.append(display_name)
114
+ wer_values.append(data["accuracy"]["wer"] * 100) # Convert to percentage
115
+ cer_values.append(data["accuracy"]["cer"] * 100)
116
+ rtf_values.append(data["performance"]["real_time_factor"])
117
+ colors.append(get_model_color(model_name))
118
+
119
+ return names, wer_values, cer_values, rtf_values, colors
120
+
121
+
122
+ def plot_wer_comparison(
123
+ results: dict,
124
+ output_path: Path,
125
+ dataset_name: str = "CoRal v2",
126
+ ) -> None:
127
+ """Generate WER comparison bar chart."""
128
+ names, wer_values, _, _, colors = extract_metrics(results)
129
+
130
+ fig, ax = plt.subplots(figsize=(8, 5))
131
+
132
+ bars = ax.bar(names, wer_values, color=colors, edgecolor="white", linewidth=1.5)
133
+
134
+ # Add value labels on bars
135
+ for bar, val in zip(bars, wer_values):
136
+ height = bar.get_height()
137
+ ax.annotate(
138
+ f"{val:.1f}%",
139
+ xy=(bar.get_x() + bar.get_width() / 2, height),
140
+ xytext=(0, 5),
141
+ textcoords="offset points",
142
+ ha="center",
143
+ va="bottom",
144
+ fontsize=12,
145
+ fontweight="bold",
146
+ )
147
+
148
+ ax.set_ylabel("Word Error Rate (%)", fontsize=12)
149
+ ax.set_title(f"WER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
150
+ ax.set_ylim(0, max(wer_values) * 1.2)
151
+
152
+ # Add grid
153
+ ax.yaxis.grid(True, linestyle="--", alpha=0.7)
154
+ ax.set_axisbelow(True)
155
+
156
+ plt.tight_layout()
157
+ plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
158
+ plt.close()
159
+ print(f"Saved: {output_path}")
160
+
161
+
162
+ def plot_cer_comparison(
163
+ results: dict,
164
+ output_path: Path,
165
+ dataset_name: str = "CoRal v2",
166
+ ) -> None:
167
+ """Generate CER comparison bar chart."""
168
+ names, _, cer_values, _, colors = extract_metrics(results)
169
+
170
+ fig, ax = plt.subplots(figsize=(8, 5))
171
+
172
+ bars = ax.bar(names, cer_values, color=colors, edgecolor="white", linewidth=1.5)
173
+
174
+ # Add value labels on bars
175
+ for bar, val in zip(bars, cer_values):
176
+ height = bar.get_height()
177
+ ax.annotate(
178
+ f"{val:.1f}%",
179
+ xy=(bar.get_x() + bar.get_width() / 2, height),
180
+ xytext=(0, 5),
181
+ textcoords="offset points",
182
+ ha="center",
183
+ va="bottom",
184
+ fontsize=12,
185
+ fontweight="bold",
186
+ )
187
+
188
+ ax.set_ylabel("Character Error Rate (%)", fontsize=12)
189
+ ax.set_title(f"CER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
190
+ ax.set_ylim(0, max(cer_values) * 1.2)
191
+
192
+ # Add grid
193
+ ax.yaxis.grid(True, linestyle="--", alpha=0.7)
194
+ ax.set_axisbelow(True)
195
+
196
+ plt.tight_layout()
197
+ plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
198
+ plt.close()
199
+ print(f"Saved: {output_path}")
200
+
201
+
202
+ def plot_rtf_comparison(
203
+ results: dict,
204
+ output_path: Path,
205
+ dataset_name: str = "CoRal v2",
206
+ ) -> None:
207
+ """Generate RTF/speed comparison bar chart."""
208
+ names, _, _, rtf_values, colors = extract_metrics(results)
209
+
210
+ fig, ax = plt.subplots(figsize=(8, 5))
211
+
212
+ bars = ax.bar(names, rtf_values, color=colors, edgecolor="white", linewidth=1.5)
213
+
214
+ # Add value labels on bars
215
+ for bar, val in zip(bars, rtf_values):
216
+ height = bar.get_height()
217
+ ax.annotate(
218
+ f"{val:.3f}",
219
+ xy=(bar.get_x() + bar.get_width() / 2, height),
220
+ xytext=(0, 5),
221
+ textcoords="offset points",
222
+ ha="center",
223
+ va="bottom",
224
+ fontsize=12,
225
+ fontweight="bold",
226
+ )
227
+
228
+ # Add reference line at RTF=1.0 (real-time)
229
+ ax.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, label="Real-time (RTF=1.0)")
230
+
231
+ ax.set_ylabel("Real-Time Factor (lower is faster)", fontsize=12)
232
+ ax.set_title(f"Speed Comparison on {dataset_name}", fontsize=14, fontweight="bold")
233
+ ax.set_ylim(0, max(max(rtf_values) * 1.3, 1.1))
234
+ ax.legend(loc="upper right")
235
+
236
+ # Add grid
237
+ ax.yaxis.grid(True, linestyle="--", alpha=0.7)
238
+ ax.set_axisbelow(True)
239
+
240
+ plt.tight_layout()
241
+ plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
242
+ plt.close()
243
+ print(f"Saved: {output_path}")
244
+
245
+
246
+ def plot_accuracy_vs_speed(
247
+ results: dict,
248
+ output_path: Path,
249
+ dataset_name: str = "CoRal v2",
250
+ ) -> None:
251
+ """Generate accuracy vs speed scatter plot."""
252
+ fig, ax = plt.subplots(figsize=(9, 6))
253
+
254
+ for model_name, data in results["models"].items():
255
+ wer = data["accuracy"]["wer"] * 100
256
+ rtf = data["performance"]["real_time_factor"]
257
+ color = get_model_color(model_name)
258
+ display_name = get_display_name(model_name)
259
+
260
+ # Extract parameter count for bubble size
261
+ size_str = data["model_size"]
262
+ if "1.7B" in size_str:
263
+ size = 400
264
+ elif "2B" in size_str:
265
+ size = 500
266
+ else:
267
+ size = 300
268
+
269
+ ax.scatter(
270
+ rtf,
271
+ wer,
272
+ s=size,
273
+ c=color,
274
+ alpha=0.7,
275
+ edgecolors="white",
276
+ linewidth=2,
277
+ label=display_name.replace("\n", " "),
278
+ )
279
+
280
+ # Add label
281
+ ax.annotate(
282
+ display_name.replace("\n", " "),
283
+ xy=(rtf, wer),
284
+ xytext=(10, 10),
285
+ textcoords="offset points",
286
+ fontsize=10,
287
+ ha="left",
288
+ )
289
+
290
+ # Add reference line at RTF=1.0
291
+ ax.axvline(x=1.0, color="red", linestyle="--", linewidth=1, alpha=0.5, label="Real-time")
292
+
293
+ ax.set_xlabel("Real-Time Factor (lower is faster)", fontsize=12)
294
+ ax.set_ylabel("Word Error Rate (%)", fontsize=12)
295
+ ax.set_title(
296
+ f"Accuracy vs Speed Trade-off on {dataset_name}\n(bubble size = model parameters)",
297
+ fontsize=14,
298
+ fontweight="bold",
299
+ )
300
+
301
+ # Set axis limits with padding
302
+ all_wer = [d["accuracy"]["wer"] * 100 for d in results["models"].values()]
303
+ all_rtf = [d["performance"]["real_time_factor"] for d in results["models"].values()]
304
+ ax.set_xlim(0, max(all_rtf) * 1.5)
305
+ ax.set_ylim(min(all_wer) * 0.8, max(all_wer) * 1.2)
306
+
307
+ # Add grid
308
+ ax.grid(True, linestyle="--", alpha=0.7)
309
+
310
+ # Add annotation for best region
311
+ ax.annotate(
312
+ "Better",
313
+ xy=(0.02, min(all_wer) * 0.85),
314
+ fontsize=10,
315
+ color="green",
316
+ fontweight="bold",
317
+ )
318
+ ax.annotate(
319
+ "Faster & More Accurate",
320
+ xy=(0.02, min(all_wer) * 0.9),
321
+ fontsize=8,
322
+ color="gray",
323
+ )
324
+
325
+ plt.tight_layout()
326
+ plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
327
+ plt.close()
328
+ print(f"Saved: {output_path}")
329
+
330
+
331
+ def plot_multi_dataset_comparison(
332
+ coral_results: dict,
333
+ cv_results: Optional[dict],
334
+ output_path: Path,
335
+ ) -> None:
336
+ """Generate multi-dataset WER comparison plot."""
337
+ fig, ax = plt.subplots(figsize=(10, 6))
338
+
339
+ # Prepare data
340
+ datasets = ["CoRal v2"]
341
+ if cv_results:
342
+ datasets.append("Common Voice")
343
+
344
+ # Get model names from coral results
345
+ model_names = list(coral_results["models"].keys())
346
+ x = np.arange(len(datasets))
347
+ width = 0.35
348
+
349
+ for i, model_name in enumerate(model_names):
350
+ display_name = get_display_name(model_name)
351
+ color = get_model_color(model_name)
352
+
353
+ wer_values = [coral_results["models"][model_name]["accuracy"]["wer"] * 100]
354
+ if cv_results and model_name in cv_results["models"]:
355
+ wer_values.append(cv_results["models"][model_name]["accuracy"]["wer"] * 100)
356
+ elif cv_results:
357
+ wer_values.append(0) # Model not evaluated on this dataset
358
+
359
+ offset = (i - len(model_names) / 2 + 0.5) * width
360
+ bars = ax.bar(
361
+ x + offset,
362
+ wer_values,
363
+ width,
364
+ label=display_name.replace("\n", " "),
365
+ color=color,
366
+ edgecolor="white",
367
+ linewidth=1.5,
368
+ )
369
+
370
+ # Add value labels
371
+ for bar, val in zip(bars, wer_values):
372
+ if val > 0:
373
+ height = bar.get_height()
374
+ ax.annotate(
375
+ f"{val:.1f}%",
376
+ xy=(bar.get_x() + bar.get_width() / 2, height),
377
+ xytext=(0, 3),
378
+ textcoords="offset points",
379
+ ha="center",
380
+ va="bottom",
381
+ fontsize=10,
382
+ fontweight="bold",
383
+ )
384
+
385
+ ax.set_ylabel("Word Error Rate (%)", fontsize=12)
386
+ ax.set_title("WER Comparison Across Datasets", fontsize=14, fontweight="bold")
387
+ ax.set_xticks(x)
388
+ ax.set_xticklabels(datasets, fontsize=11)
389
+ ax.legend(loc="upper right")
390
+ ax.yaxis.grid(True, linestyle="--", alpha=0.7)
391
+ ax.set_axisbelow(True)
392
+
393
+ plt.tight_layout()
394
+ plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
395
+ plt.close()
396
+ print(f"Saved: {output_path}")
397
+
398
+
399
+ def parse_args() -> argparse.Namespace:
400
+ """Parse command line arguments."""
401
+ parser = argparse.ArgumentParser(description="Generate ASR comparison plots")
402
+
403
+ parser.add_argument(
404
+ "--coral-results",
405
+ type=Path,
406
+ default=Path("results/full_comparison2.json"),
407
+ help="Path to CoRal benchmark results",
408
+ )
409
+ parser.add_argument(
410
+ "--cv-results",
411
+ type=Path,
412
+ default=Path("results/common_voice_comparison.json"),
413
+ help="Path to Common Voice benchmark results",
414
+ )
415
+ parser.add_argument(
416
+ "--output-dir",
417
+ type=Path,
418
+ default=Path(__file__).parent / "plots",
419
+ help="Output directory for plots",
420
+ )
421
+
422
+ return parser.parse_args()
423
+
424
+
425
+ def main() -> None:
426
+ """Main entry point for plot generation."""
427
+ args = parse_args()
428
+
429
+ # Create output directory
430
+ args.output_dir.mkdir(parents=True, exist_ok=True)
431
+
432
+ # Load results
433
+ coral_results = load_results(args.coral_results)
434
+ cv_results = load_results(args.cv_results)
435
+
436
+ if coral_results is None:
437
+ print("Error: CoRal results file is required")
438
+ return
439
+
440
+ print("=" * 60)
441
+ print("Generating ASR Comparison Plots")
442
+ print("=" * 60)
443
+ print(f"Output directory: {args.output_dir}")
444
+ print()
445
+
446
+ # Generate CoRal plots
447
+ print("Generating CoRal v2 plots...")
448
+ plot_wer_comparison(coral_results, args.output_dir / "wer_comparison.png", "CoRal v2")
449
+ plot_cer_comparison(coral_results, args.output_dir / "cer_comparison.png", "CoRal v2")
450
+ plot_rtf_comparison(coral_results, args.output_dir / "rtf_comparison.png", "CoRal v2")
451
+ plot_accuracy_vs_speed(coral_results, args.output_dir / "accuracy_vs_speed.png", "CoRal v2")
452
+
453
+ # Generate Common Voice plots if available
454
+ if cv_results:
455
+ print("\nGenerating Common Voice plots...")
456
+ plot_wer_comparison(
457
+ cv_results, args.output_dir / "wer_comparison_cv.png", "Common Voice Danish"
458
+ )
459
+ plot_cer_comparison(
460
+ cv_results, args.output_dir / "cer_comparison_cv.png", "Common Voice Danish"
461
+ )
462
+ plot_rtf_comparison(
463
+ cv_results, args.output_dir / "rtf_comparison_cv.png", "Common Voice Danish"
464
+ )
465
+
466
+ # Multi-dataset comparison
467
+ print("\nGenerating multi-dataset comparison...")
468
+ plot_multi_dataset_comparison(
469
+ coral_results, cv_results, args.output_dir / "multi_dataset_wer.png"
470
+ )
471
+
472
+ print("\n" + "=" * 60)
473
+ print("Plot generation complete!")
474
+ print("=" * 60)
475
+
476
+
477
+ if __name__ == "__main__":
478
+ main()