merge2
Browse files- README.md +434 -3
- __pycache__/evaluate_common_voice.cpython-310.pyc +0 -0
- __pycache__/generate_plots.cpython-310.pyc +0 -0
- evaluate_common_voice.py +404 -0
- generate_plots.py +478 -0
README.md
CHANGED
|
@@ -1,3 +1,434 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: openrail
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: openrail
|
| 3 |
+
language:
|
| 4 |
+
- da
|
| 5 |
+
base_model: Qwen/Qwen3-ASR-1.7B
|
| 6 |
+
tags:
|
| 7 |
+
- automatic-speech-recognition
|
| 8 |
+
- danish
|
| 9 |
+
- qwen
|
| 10 |
+
- asr
|
| 11 |
+
- speech-to-text
|
| 12 |
+
- coral
|
| 13 |
+
- streaming
|
| 14 |
+
datasets:
|
| 15 |
+
- alexandrainst/coral
|
| 16 |
+
- mozilla-foundation/common_voice_17_0
|
| 17 |
+
library_name: transformers
|
| 18 |
+
pipeline_tag: automatic-speech-recognition
|
| 19 |
+
metrics:
|
| 20 |
+
- wer
|
| 21 |
+
- cer
|
| 22 |
+
model-index:
|
| 23 |
+
- name: hvisketiske-v2
|
| 24 |
+
results:
|
| 25 |
+
- task:
|
| 26 |
+
type: automatic-speech-recognition
|
| 27 |
+
name: Speech Recognition
|
| 28 |
+
dataset:
|
| 29 |
+
type: alexandrainst/coral
|
| 30 |
+
name: CoRal v2 Test
|
| 31 |
+
split: test
|
| 32 |
+
metrics:
|
| 33 |
+
- type: wer
|
| 34 |
+
value: 18.47
|
| 35 |
+
name: WER
|
| 36 |
+
- type: cer
|
| 37 |
+
value: 7.86
|
| 38 |
+
name: CER
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
# hvisketiske-v2: Danish ASR Model
|
| 42 |
+
|
| 43 |
+
**hvisketiske-v2** is a state-of-the-art Danish automatic speech recognition (ASR) model based on [Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B), finetuned on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral) for improved Danish transcription accuracy.
|
| 44 |
+
|
| 45 |
+
## Key Highlights
|
| 46 |
+
|
| 47 |
+
| Feature | Value |
|
| 48 |
+
|---------|-------|
|
| 49 |
+
| **WER on CoRal v2** | 18.47% (14% better than Whisper v3) |
|
| 50 |
+
| **CER on CoRal v2** | 7.86% (11% better than Whisper v3) |
|
| 51 |
+
| **Real-Time Factor** | 0.086 (45% faster than Whisper v3) |
|
| 52 |
+
| **Model Size** | ~1.7B parameters |
|
| 53 |
+
|
| 54 |
+
### Inherited Features from Qwen3-ASR
|
| 55 |
+
|
| 56 |
+
- **Streaming/Real-time transcription** via vLLM backend
|
| 57 |
+
- **Singing detection** - can transcribe singing voice and songs with BGM
|
| 58 |
+
- **Word-level timestamps** via forced alignment
|
| 59 |
+
- **30+ language support** (Danish optimized)
|
| 60 |
+
- **Long audio support** - up to 20 minutes per request
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## Performance Comparison
|
| 65 |
+
|
| 66 |
+
### CoRal v2 Test Set (9,123 samples, 17.3 hours)
|
| 67 |
+
|
| 68 |
+
| Model | WER | CER | RTF | Throughput | Parameters |
|
| 69 |
+
|-------|-----|-----|-----|------------|------------|
|
| 70 |
+
| **hvisketiske-v2** | **18.47%** | **7.86%** | **0.086** | 1.71 samples/s | ~1.7B |
|
| 71 |
+
| hviske-v3 (Whisper Large v3) | 21.47% | 8.79% | 0.156 | 0.94 samples/s | ~2B |
|
| 72 |
+
|
| 73 |
+
**Improvements over Whisper Large v3:**
|
| 74 |
+
- **14% reduction** in Word Error Rate
|
| 75 |
+
- **11% reduction** in Character Error Rate
|
| 76 |
+
- **45% faster** inference speed
|
| 77 |
+
- **15% fewer** parameters
|
| 78 |
+
|
| 79 |
+
### Comparison Plots
|
| 80 |
+
|
| 81 |
+

|
| 82 |
+

|
| 83 |
+

|
| 84 |
+
|
| 85 |
+
---
|
| 86 |
+
|
| 87 |
+
## Quick Start
|
| 88 |
+
|
| 89 |
+
### Installation
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pip install qwen-asr transformers torch
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### Basic Usage
|
| 96 |
+
|
| 97 |
+
```python
|
| 98 |
+
from qwen_asr import Qwen3ASRModel
|
| 99 |
+
|
| 100 |
+
# Load the model
|
| 101 |
+
model = Qwen3ASRModel.from_pretrained(
|
| 102 |
+
"pluttodk/hvisketiske-v2",
|
| 103 |
+
dtype="bfloat16",
|
| 104 |
+
device_map="cuda:0",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Transcribe audio file
|
| 108 |
+
results = model.transcribe(
|
| 109 |
+
audio="path/to/danish_audio.wav",
|
| 110 |
+
language="Danish",
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
print(results[0].text)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Advanced Usage
|
| 119 |
+
|
| 120 |
+
### Batch Transcription (Fast Processing)
|
| 121 |
+
|
| 122 |
+
Process multiple audio files efficiently in a single call:
|
| 123 |
+
|
| 124 |
+
```python
|
| 125 |
+
from qwen_asr import Qwen3ASRModel
|
| 126 |
+
|
| 127 |
+
model = Qwen3ASRModel.from_pretrained(
|
| 128 |
+
"pluttodk/hvisketiske-v2",
|
| 129 |
+
dtype="bfloat16",
|
| 130 |
+
device_map="cuda:0",
|
| 131 |
+
max_inference_batch_size=16, # Process up to 16 files at once
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Batch transcribe multiple files
|
| 135 |
+
audio_files = ["audio1.wav", "audio2.wav", "audio3.wav"]
|
| 136 |
+
results = model.transcribe(
|
| 137 |
+
audio=audio_files,
|
| 138 |
+
language="Danish",
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
for i, result in enumerate(results):
|
| 142 |
+
print(f"File {i+1}: {result.text}")
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
### Transcription with Timestamps
|
| 146 |
+
|
| 147 |
+
Get word-level timestamps using the forced aligner:
|
| 148 |
+
|
| 149 |
+
```python
|
| 150 |
+
from qwen_asr import Qwen3ASRModel
|
| 151 |
+
|
| 152 |
+
model = Qwen3ASRModel.from_pretrained(
|
| 153 |
+
"pluttodk/hvisketiske-v2",
|
| 154 |
+
forced_aligner="Qwen/Qwen3-ForcedAligner-0.6B",
|
| 155 |
+
dtype="bfloat16",
|
| 156 |
+
device_map="cuda:0",
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
results = model.transcribe(
|
| 160 |
+
audio="path/to/audio.wav",
|
| 161 |
+
language="Danish",
|
| 162 |
+
return_time_stamps=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Access word-level timestamps
|
| 166 |
+
for item in results[0].time_stamps.items:
|
| 167 |
+
print(f"{item.start_time:.2f}s - {item.end_time:.2f}s: {item.text}")
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
### Streaming/Real-time Transcription (vLLM Backend)
|
| 171 |
+
|
| 172 |
+
For real-time streaming transcription, use the vLLM backend:
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
from qwen_asr import Qwen3ASRModel
|
| 176 |
+
|
| 177 |
+
# Initialize with vLLM backend for streaming
|
| 178 |
+
model = Qwen3ASRModel.LLM(
|
| 179 |
+
model="pluttodk/hvisketiske-v2",
|
| 180 |
+
gpu_memory_utilization=0.8,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
# Initialize streaming state
|
| 184 |
+
state = model.init_streaming_state(
|
| 185 |
+
language="Danish",
|
| 186 |
+
chunk_size_sec=2.0, # Process audio in 2-second chunks
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Simulate streaming audio (16kHz mono float32)
|
| 190 |
+
import numpy as np
|
| 191 |
+
|
| 192 |
+
def audio_stream():
|
| 193 |
+
"""Replace with actual audio stream from microphone."""
|
| 194 |
+
for chunk in audio_chunks:
|
| 195 |
+
yield np.array(chunk, dtype=np.float32)
|
| 196 |
+
|
| 197 |
+
# Process streaming audio
|
| 198 |
+
for audio_chunk in audio_stream():
|
| 199 |
+
state = model.streaming_transcribe(audio_chunk, state)
|
| 200 |
+
print(f"Current transcription: {state.text}")
|
| 201 |
+
|
| 202 |
+
# Finalize stream
|
| 203 |
+
state = model.finish_streaming_transcribe(state)
|
| 204 |
+
print(f"Final transcription: {state.text}")
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
### Using with Transformers Directly
|
| 208 |
+
|
| 209 |
+
For more control, use the model directly with transformers:
|
| 210 |
+
|
| 211 |
+
```python
|
| 212 |
+
from transformers import AutoModel, AutoProcessor
|
| 213 |
+
import torch
|
| 214 |
+
import librosa
|
| 215 |
+
|
| 216 |
+
# Load model and processor
|
| 217 |
+
model = AutoModel.from_pretrained(
|
| 218 |
+
"pluttodk/hvisketiske-v2",
|
| 219 |
+
trust_remote_code=True,
|
| 220 |
+
torch_dtype=torch.bfloat16,
|
| 221 |
+
device_map="cuda:0",
|
| 222 |
+
)
|
| 223 |
+
processor = AutoProcessor.from_pretrained(
|
| 224 |
+
"pluttodk/hvisketiske-v2",
|
| 225 |
+
trust_remote_code=True,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Load and preprocess audio
|
| 229 |
+
audio, sr = librosa.load("path/to/audio.wav", sr=16000, mono=True)
|
| 230 |
+
|
| 231 |
+
# Build input using chat template
|
| 232 |
+
messages = [
|
| 233 |
+
{"role": "system", "content": ""},
|
| 234 |
+
{"role": "user", "content": [{"type": "audio", "audio": audio}]},
|
| 235 |
+
]
|
| 236 |
+
|
| 237 |
+
text = processor.apply_chat_template(
|
| 238 |
+
messages,
|
| 239 |
+
add_generation_prompt=True,
|
| 240 |
+
tokenize=False
|
| 241 |
+
)
|
| 242 |
+
text = text + "language Danish<asr_text>"
|
| 243 |
+
|
| 244 |
+
# Process and generate
|
| 245 |
+
inputs = processor(text=[text], audio=[audio], return_tensors="pt", padding=True)
|
| 246 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
| 247 |
+
|
| 248 |
+
output_ids = model.generate(**inputs, max_new_tokens=512)
|
| 249 |
+
transcription = processor.batch_decode(
|
| 250 |
+
output_ids[:, inputs["input_ids"].shape[1]:],
|
| 251 |
+
skip_special_tokens=True,
|
| 252 |
+
)[0]
|
| 253 |
+
|
| 254 |
+
print(transcription)
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
### Singing Detection & Multi-Audio Support
|
| 258 |
+
|
| 259 |
+
The model inherits Qwen3-ASR's ability to handle singing and background music:
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
from qwen_asr import Qwen3ASRModel
|
| 263 |
+
|
| 264 |
+
model = Qwen3ASRModel.from_pretrained(
|
| 265 |
+
"pluttodk/hvisketiske-v2",
|
| 266 |
+
dtype="bfloat16",
|
| 267 |
+
device_map="cuda:0",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Transcribe audio with singing or background music
|
| 271 |
+
results = model.transcribe(
|
| 272 |
+
audio="path/to/song.wav",
|
| 273 |
+
language="Danish", # or None for auto-detection
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
print(results[0].text)
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
---
|
| 280 |
+
|
| 281 |
+
## Model Details
|
| 282 |
+
|
| 283 |
+
### Model Description
|
| 284 |
+
|
| 285 |
+
hvisketiske-v2 is a Danish-specialized automatic speech recognition model created by finetuning [Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B) on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral). The model achieves state-of-the-art performance on Danish speech recognition while maintaining fast inference speeds.
|
| 286 |
+
|
| 287 |
+
- **Developed by:** Mathias Oliver Valdbjørn Rønnelund
|
| 288 |
+
- **Model type:** Encoder-decoder speech recognition model
|
| 289 |
+
- **Language:** Danish (primary), with inherited multilingual capabilities
|
| 290 |
+
- **License:** Apache 2.0
|
| 291 |
+
- **Finetuned from:** [Qwen/Qwen3-ASR-1.7B](https://huggingface.co/Qwen/Qwen3-ASR-1.7B)
|
| 292 |
+
|
| 293 |
+
### Architecture
|
| 294 |
+
|
| 295 |
+
The model inherits the Qwen3-ASR architecture:
|
| 296 |
+
|
| 297 |
+
| Component | Specification |
|
| 298 |
+
|-----------|--------------|
|
| 299 |
+
| Audio Encoder | 24-layer transformer (1024 hidden dim, 16 attention heads) |
|
| 300 |
+
| Text Decoder | 28-layer transformer (2048 hidden dim, 16 attention heads) |
|
| 301 |
+
| Total Parameters | ~1.7 billion |
|
| 302 |
+
| Precision | bfloat16 |
|
| 303 |
+
| Audio Input | 16kHz mono WAV |
|
| 304 |
+
|
| 305 |
+
---
|
| 306 |
+
|
| 307 |
+
## Training Details
|
| 308 |
+
|
| 309 |
+
### Training Data
|
| 310 |
+
|
| 311 |
+
The model was finetuned on the [CoRal v2 dataset](https://huggingface.co/datasets/alexandrainst/coral), a comprehensive Danish speech corpus containing:
|
| 312 |
+
- Diverse Danish speakers across demographics
|
| 313 |
+
- Various recording conditions and audio qualities
|
| 314 |
+
- Natural conversational speech
|
| 315 |
+
- Read-aloud speech
|
| 316 |
+
|
| 317 |
+
### Training Procedure
|
| 318 |
+
|
| 319 |
+
**Training Approach:** Supervised Fine-Tuning (SFT) with chat template formatting
|
| 320 |
+
|
| 321 |
+
**Preprocessing:**
|
| 322 |
+
- Audio resampled to 16kHz mono
|
| 323 |
+
- Chat template applied with system prompt, audio input, and target transcription
|
| 324 |
+
- Prefix masking to train only on transcription tokens
|
| 325 |
+
|
| 326 |
+
**Training Hyperparameters:**
|
| 327 |
+
|
| 328 |
+
| Parameter | Value |
|
| 329 |
+
|-----------|-------|
|
| 330 |
+
| Base model | Qwen/Qwen3-ASR-1.7B |
|
| 331 |
+
| Learning rate | 2e-5 |
|
| 332 |
+
| Batch size (per device) | 8 |
|
| 333 |
+
| Gradient accumulation steps | 4 |
|
| 334 |
+
| Effective batch size | 32 |
|
| 335 |
+
| Epochs | 3 |
|
| 336 |
+
| Warmup ratio | 0.1 |
|
| 337 |
+
| Weight decay | 0.01 |
|
| 338 |
+
| Max gradient norm | 1.0 |
|
| 339 |
+
| Precision | bfloat16 |
|
| 340 |
+
| Optimizer | AdamW |
|
| 341 |
+
| LR scheduler | Linear decay |
|
| 342 |
+
| Total training steps | 23,448 |
|
| 343 |
+
|
| 344 |
+
**Hardware:** Training performed on NVIDIA GPUs (~25GB GPU memory per device)
|
| 345 |
+
|
| 346 |
+
---
|
| 347 |
+
|
| 348 |
+
## Evaluation
|
| 349 |
+
|
| 350 |
+
### Test Data
|
| 351 |
+
|
| 352 |
+
Evaluated on the CoRal v2 test split:
|
| 353 |
+
- **9,123 samples**
|
| 354 |
+
- **17.3 hours** of audio
|
| 355 |
+
- Diverse Danish speakers and recording conditions
|
| 356 |
+
|
| 357 |
+
### Metrics
|
| 358 |
+
|
| 359 |
+
| Metric | Description |
|
| 360 |
+
|--------|-------------|
|
| 361 |
+
| **WER** | Word Error Rate - percentage of words incorrectly transcribed (lower is better) |
|
| 362 |
+
| **CER** | Character Error Rate - percentage of characters incorrectly transcribed (lower is better) |
|
| 363 |
+
| **RTF** | Real-Time Factor - ratio of processing time to audio duration (< 1.0 = faster than real-time) |
|
| 364 |
+
|
| 365 |
+
### Results Summary
|
| 366 |
+
|
| 367 |
+
| Model | WER | CER | RTF | Throughput |
|
| 368 |
+
|-------|-----|-----|-----|------------|
|
| 369 |
+
| **hvisketiske-v2** | **18.47%** | **7.86%** | **0.086** | 1.71 samples/sec |
|
| 370 |
+
| hviske-v3 (Whisper v3) | 21.47% | 8.79% | 0.156 | 0.94 samples/sec |
|
| 371 |
+
|
| 372 |
+
---
|
| 373 |
+
|
| 374 |
+
## Limitations
|
| 375 |
+
|
| 376 |
+
- **Language:** Optimized for Danish; other languages may have degraded performance compared to base Qwen3-ASR
|
| 377 |
+
- **Audio quality:** Best results with clear speech; noisy environments may affect accuracy
|
| 378 |
+
- **Domain:** Trained on CoRal v2 which is primarily conversational/read-aloud speech; specialized domains (medical, legal, technical) may have higher error rates
|
| 379 |
+
- **Streaming:** Real-time streaming requires vLLM backend installation
|
| 380 |
+
|
| 381 |
+
## Intended Use
|
| 382 |
+
|
| 383 |
+
### Primary Use Cases
|
| 384 |
+
- Danish speech-to-text transcription
|
| 385 |
+
- Subtitle generation for Danish content
|
| 386 |
+
- Voice assistant backends
|
| 387 |
+
- Meeting transcription
|
| 388 |
+
- Accessibility applications
|
| 389 |
+
|
| 390 |
+
### Out-of-Scope Use
|
| 391 |
+
- Non-Danish languages (use base Qwen3-ASR instead)
|
| 392 |
+
- Real-time speaker diarization (not supported)
|
| 393 |
+
- Emotion/sentiment detection from speech
|
| 394 |
+
|
| 395 |
+
---
|
| 396 |
+
|
| 397 |
+
## Citation
|
| 398 |
+
|
| 399 |
+
If you use this model, please cite:
|
| 400 |
+
|
| 401 |
+
```bibtex
|
| 402 |
+
@misc{hvisketiske-v2,
|
| 403 |
+
author = {Rønnelund, Mathias Oliver Valdbjørn},
|
| 404 |
+
title = {hvisketiske-v2: Danish ASR Model based on Qwen3-ASR},
|
| 405 |
+
year = {2025},
|
| 406 |
+
publisher = {HuggingFace},
|
| 407 |
+
url = {https://huggingface.co/pluttodk/hvisketiske-v2}
|
| 408 |
+
}
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
Also consider citing the base model and dataset:
|
| 412 |
+
|
| 413 |
+
```bibtex
|
| 414 |
+
@article{qwen3asr,
|
| 415 |
+
title={Qwen3-ASR Technical Report},
|
| 416 |
+
author={Qwen Team},
|
| 417 |
+
journal={arXiv preprint arXiv:2601.21337},
|
| 418 |
+
year={2025}
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
@dataset{coral,
|
| 422 |
+
title={CoRal: A Danish Speech Corpus},
|
| 423 |
+
author={Alexandra Institute},
|
| 424 |
+
year={2024},
|
| 425 |
+
url={https://huggingface.co/datasets/alexandrainst/coral}
|
| 426 |
+
}
|
| 427 |
+
```
|
| 428 |
+
|
| 429 |
+
---
|
| 430 |
+
|
| 431 |
+
## Acknowledgements
|
| 432 |
+
|
| 433 |
+
- [Qwen Team](https://github.com/QwenLM) for the excellent Qwen3-ASR base model
|
| 434 |
+
- [Alexandra Institute](https://alexandra.dk/) for the CoRal v2 Danish speech corpus
|
__pycache__/evaluate_common_voice.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
__pycache__/generate_plots.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
evaluate_common_voice.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Benchmark ASR models on Common Voice Danish dataset.
|
| 4 |
+
|
| 5 |
+
This script evaluates hvisketiske-v2 (Qwen3-ASR) and hviske-v3 (Whisper)
|
| 6 |
+
on the Mozilla Common Voice Danish test set for comparison.
|
| 7 |
+
|
| 8 |
+
IMPORTANT: Common Voice requires authentication and agreement to terms of use.
|
| 9 |
+
Before running this script:
|
| 10 |
+
1. Create a HuggingFace account at https://huggingface.co
|
| 11 |
+
2. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0
|
| 12 |
+
3. Agree to the dataset terms of use
|
| 13 |
+
4. Create an access token at https://huggingface.co/settings/tokens
|
| 14 |
+
5. Login via CLI: `huggingface-cli login`
|
| 15 |
+
|
| 16 |
+
Usage:
|
| 17 |
+
# After logging in:
|
| 18 |
+
python huggingface/evaluate_common_voice.py \
|
| 19 |
+
--hvisketiske-path ./outputs/hvisketiske-v2/checkpoint-23448 \
|
| 20 |
+
--max-samples 1000 \
|
| 21 |
+
--output-file ./results/common_voice_comparison.json
|
| 22 |
+
|
| 23 |
+
# Quick test with fewer samples:
|
| 24 |
+
python huggingface/evaluate_common_voice.py --max-samples 100
|
| 25 |
+
|
| 26 |
+
# Use specific token:
|
| 27 |
+
python huggingface/evaluate_common_voice.py --hf-token YOUR_TOKEN
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import sys
|
| 33 |
+
import tempfile
|
| 34 |
+
import time
|
| 35 |
+
from dataclasses import dataclass
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import List, Optional
|
| 38 |
+
|
| 39 |
+
import soundfile as sf
|
| 40 |
+
from datasets import load_dataset
|
| 41 |
+
from jiwer import cer, wer
|
| 42 |
+
from tqdm import tqdm
|
| 43 |
+
|
| 44 |
+
# Add src to path for imports
|
| 45 |
+
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
| 46 |
+
|
| 47 |
+
from hvisketiske.evaluation.model_adapters import (
|
| 48 |
+
ASRModelAdapter,
|
| 49 |
+
HviskeV3Adapter,
|
| 50 |
+
Qwen3ASRAdapter,
|
| 51 |
+
TranscriptionResult,
|
| 52 |
+
)
|
| 53 |
+
from hvisketiske.evaluation.timing import AggregatedTimingStats
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class CommonVoiceSample:
|
| 58 |
+
"""A single Common Voice sample."""
|
| 59 |
+
|
| 60 |
+
audio_path: str
|
| 61 |
+
reference: str
|
| 62 |
+
audio_duration: float
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_common_voice_danish(
|
| 66 |
+
split: str = "test",
|
| 67 |
+
max_samples: Optional[int] = None,
|
| 68 |
+
cache_dir: Optional[str] = None,
|
| 69 |
+
hf_token: Optional[str] = None,
|
| 70 |
+
) -> List[CommonVoiceSample]:
|
| 71 |
+
"""
|
| 72 |
+
Load Common Voice Danish dataset and prepare samples.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
split: Dataset split to load (test, validation, train).
|
| 76 |
+
max_samples: Maximum number of samples to load.
|
| 77 |
+
cache_dir: Directory to cache audio files.
|
| 78 |
+
hf_token: HuggingFace API token for authentication.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of CommonVoiceSample objects.
|
| 82 |
+
"""
|
| 83 |
+
print(f"Loading Common Voice Danish ({split} split)...")
|
| 84 |
+
print("Note: This requires HuggingFace authentication and agreement to dataset terms.")
|
| 85 |
+
print("Visit: https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
|
| 86 |
+
print()
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
ds = load_dataset(
|
| 90 |
+
"mozilla-foundation/common_voice_17_0",
|
| 91 |
+
"da",
|
| 92 |
+
split=split,
|
| 93 |
+
trust_remote_code=True,
|
| 94 |
+
token=hf_token,
|
| 95 |
+
)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
error_msg = str(e)
|
| 98 |
+
if "EmptyDatasetError" in error_msg or "doesn't contain any data" in error_msg:
|
| 99 |
+
print("\n" + "=" * 70)
|
| 100 |
+
print("ERROR: Cannot access Common Voice dataset.")
|
| 101 |
+
print("=" * 70)
|
| 102 |
+
print("\nThis dataset requires authentication. Please:")
|
| 103 |
+
print("1. Visit https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0")
|
| 104 |
+
print("2. Log in and agree to the terms of use")
|
| 105 |
+
print("3. Run: huggingface-cli login")
|
| 106 |
+
print("4. Or pass --hf-token YOUR_TOKEN to this script")
|
| 107 |
+
print("=" * 70 + "\n")
|
| 108 |
+
raise
|
| 109 |
+
|
| 110 |
+
if max_samples:
|
| 111 |
+
ds = ds.select(range(min(max_samples, len(ds))))
|
| 112 |
+
|
| 113 |
+
print(f"Loaded {len(ds)} samples")
|
| 114 |
+
|
| 115 |
+
# Create temp directory for audio files if not provided
|
| 116 |
+
if cache_dir is None:
|
| 117 |
+
cache_dir = tempfile.mkdtemp(prefix="cv_danish_")
|
| 118 |
+
|
| 119 |
+
cache_path = Path(cache_dir)
|
| 120 |
+
cache_path.mkdir(parents=True, exist_ok=True)
|
| 121 |
+
|
| 122 |
+
samples = []
|
| 123 |
+
print("Preparing audio files...")
|
| 124 |
+
for i, item in enumerate(tqdm(ds, desc="Preparing samples")):
|
| 125 |
+
# Extract audio array and sample rate
|
| 126 |
+
audio_array = item["audio"]["array"]
|
| 127 |
+
sample_rate = item["audio"]["sampling_rate"]
|
| 128 |
+
|
| 129 |
+
# Save to temp file
|
| 130 |
+
audio_path = cache_path / f"sample_{i:06d}.wav"
|
| 131 |
+
sf.write(str(audio_path), audio_array, sample_rate)
|
| 132 |
+
|
| 133 |
+
# Calculate duration
|
| 134 |
+
duration = len(audio_array) / sample_rate
|
| 135 |
+
|
| 136 |
+
samples.append(
|
| 137 |
+
CommonVoiceSample(
|
| 138 |
+
audio_path=str(audio_path),
|
| 139 |
+
reference=item["sentence"],
|
| 140 |
+
audio_duration=duration,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
return samples
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def normalize_text(text: str) -> str:
|
| 148 |
+
"""Normalize text for fair comparison."""
|
| 149 |
+
text = text.lower()
|
| 150 |
+
text = " ".join(text.split())
|
| 151 |
+
return text
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def evaluate_model(
|
| 155 |
+
model: ASRModelAdapter,
|
| 156 |
+
samples: List[CommonVoiceSample],
|
| 157 |
+
warmup_samples: int = 3,
|
| 158 |
+
) -> dict:
|
| 159 |
+
"""
|
| 160 |
+
Evaluate a model on the Common Voice samples.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
model: Model adapter to evaluate.
|
| 164 |
+
samples: List of samples to evaluate.
|
| 165 |
+
warmup_samples: Number of warmup iterations.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Dictionary with evaluation results.
|
| 169 |
+
"""
|
| 170 |
+
print(f"\nEvaluating: {model.model_name}")
|
| 171 |
+
print("Loading model...")
|
| 172 |
+
model.load()
|
| 173 |
+
|
| 174 |
+
# Warmup
|
| 175 |
+
if warmup_samples > 0 and samples:
|
| 176 |
+
print(f"Running {warmup_samples} warmup iterations...")
|
| 177 |
+
model.warmup(samples[0].audio_path, num_runs=warmup_samples)
|
| 178 |
+
|
| 179 |
+
# Transcribe all samples
|
| 180 |
+
predictions = []
|
| 181 |
+
individual_times = []
|
| 182 |
+
total_audio_duration = 0.0
|
| 183 |
+
total_inference_time = 0.0
|
| 184 |
+
|
| 185 |
+
print(f"Transcribing {len(samples)} samples...")
|
| 186 |
+
for sample in tqdm(samples, desc=f"Evaluating {model.model_name[:30]}"):
|
| 187 |
+
result = model.transcribe(sample.audio_path)
|
| 188 |
+
predictions.append(result.text)
|
| 189 |
+
individual_times.append(result.inference_time_seconds)
|
| 190 |
+
total_audio_duration += sample.audio_duration
|
| 191 |
+
total_inference_time += result.inference_time_seconds
|
| 192 |
+
|
| 193 |
+
# Normalize text
|
| 194 |
+
predictions_norm = [normalize_text(p) for p in predictions]
|
| 195 |
+
references_norm = [normalize_text(s.reference) for s in samples]
|
| 196 |
+
|
| 197 |
+
# Calculate metrics
|
| 198 |
+
word_error_rate = wer(references_norm, predictions_norm)
|
| 199 |
+
char_error_rate = cer(references_norm, predictions_norm)
|
| 200 |
+
|
| 201 |
+
timing_stats = AggregatedTimingStats(
|
| 202 |
+
total_inference_time_seconds=total_inference_time,
|
| 203 |
+
total_audio_duration_seconds=total_audio_duration,
|
| 204 |
+
num_samples=len(samples),
|
| 205 |
+
individual_times=individual_times,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return {
|
| 209 |
+
"model_name": model.model_name,
|
| 210 |
+
"model_size": model.model_size_params,
|
| 211 |
+
"accuracy": {
|
| 212 |
+
"wer": word_error_rate,
|
| 213 |
+
"cer": char_error_rate,
|
| 214 |
+
},
|
| 215 |
+
"performance": {
|
| 216 |
+
"total_inference_time_seconds": timing_stats.total_inference_time_seconds,
|
| 217 |
+
"total_audio_duration_seconds": timing_stats.total_audio_duration_seconds,
|
| 218 |
+
"real_time_factor": timing_stats.real_time_factor,
|
| 219 |
+
"throughput_samples_per_second": timing_stats.throughput_samples_per_second,
|
| 220 |
+
"mean_time_per_sample_seconds": timing_stats.mean_time_per_sample,
|
| 221 |
+
"std_time_per_sample_seconds": timing_stats.std_time_per_sample,
|
| 222 |
+
},
|
| 223 |
+
"num_samples": len(samples),
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def print_summary(results: dict) -> None:
|
| 228 |
+
"""Print formatted comparison summary."""
|
| 229 |
+
print("\n" + "=" * 80)
|
| 230 |
+
print("COMMON VOICE DANISH - ASR MODEL COMPARISON")
|
| 231 |
+
print("=" * 80)
|
| 232 |
+
print(f"Dataset: mozilla-foundation/common_voice_17_0 (Danish)")
|
| 233 |
+
print(f"Number of models: {len(results['models'])}")
|
| 234 |
+
|
| 235 |
+
sample_count = next(iter(results["models"].values()))["num_samples"]
|
| 236 |
+
print(f"Samples evaluated: {sample_count}")
|
| 237 |
+
|
| 238 |
+
# Accuracy comparison table
|
| 239 |
+
print("\n" + "-" * 80)
|
| 240 |
+
print("ACCURACY METRICS (lower is better)")
|
| 241 |
+
print("-" * 80)
|
| 242 |
+
print(f"{'Model':<45} {'WER':>12} {'CER':>12}")
|
| 243 |
+
print("-" * 80)
|
| 244 |
+
for name, result in sorted(
|
| 245 |
+
results["models"].items(), key=lambda x: x[1]["accuracy"]["wer"]
|
| 246 |
+
):
|
| 247 |
+
print(
|
| 248 |
+
f"{result['model_name'][:45]:<45} "
|
| 249 |
+
f"{result['accuracy']['wer']:>11.2%} "
|
| 250 |
+
f"{result['accuracy']['cer']:>11.2%}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Performance comparison table
|
| 254 |
+
print("\n" + "-" * 80)
|
| 255 |
+
print("PERFORMANCE METRICS (RTF < 1.0 = faster than real-time)")
|
| 256 |
+
print("-" * 80)
|
| 257 |
+
print(f"{'Model':<35} {'RTF':>8} {'Throughput':>12} {'Mean Time':>12}")
|
| 258 |
+
print(f"{'':35} {'':>8} {'(samples/s)':>12} {'(s/sample)':>12}")
|
| 259 |
+
print("-" * 80)
|
| 260 |
+
for name, result in sorted(
|
| 261 |
+
results["models"].items(), key=lambda x: x[1]["performance"]["real_time_factor"]
|
| 262 |
+
):
|
| 263 |
+
perf = result["performance"]
|
| 264 |
+
print(
|
| 265 |
+
f"{result['model_name'][:35]:<35} "
|
| 266 |
+
f"{perf['real_time_factor']:>8.3f} "
|
| 267 |
+
f"{perf['throughput_samples_per_second']:>12.2f} "
|
| 268 |
+
f"{perf['mean_time_per_sample_seconds']:>12.3f}"
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
print("=" * 80)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def parse_args() -> argparse.Namespace:
|
| 275 |
+
"""Parse command line arguments."""
|
| 276 |
+
parser = argparse.ArgumentParser(
|
| 277 |
+
description="Benchmark ASR models on Common Voice Danish"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
parser.add_argument(
|
| 281 |
+
"--output-file",
|
| 282 |
+
type=Path,
|
| 283 |
+
default=Path("results/common_voice_comparison.json"),
|
| 284 |
+
help="Path to save comparison report (JSON)",
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--max-samples",
|
| 288 |
+
type=int,
|
| 289 |
+
default=None,
|
| 290 |
+
help="Maximum samples to evaluate (for quick testing)",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--warmup",
|
| 294 |
+
type=int,
|
| 295 |
+
default=3,
|
| 296 |
+
help="Number of warmup iterations per model (default: 3)",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--device",
|
| 300 |
+
type=str,
|
| 301 |
+
default="cuda:0",
|
| 302 |
+
help="Device for inference (default: cuda:0)",
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--cache-dir",
|
| 306 |
+
type=str,
|
| 307 |
+
default=None,
|
| 308 |
+
help="Directory to cache audio files",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--hf-token",
|
| 312 |
+
type=str,
|
| 313 |
+
default=None,
|
| 314 |
+
help="HuggingFace API token for authentication (or use huggingface-cli login)",
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# Model selection
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--skip-hviske-v3",
|
| 320 |
+
action="store_true",
|
| 321 |
+
help="Skip hviske-v3-conversation model",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--skip-hvisketiske",
|
| 325 |
+
action="store_true",
|
| 326 |
+
help="Skip hvisketiske-v2 model",
|
| 327 |
+
)
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--hvisketiske-path",
|
| 330 |
+
type=str,
|
| 331 |
+
default="./outputs/hvisketiske-v2/checkpoint-23448",
|
| 332 |
+
help="Path to local hvisketiske checkpoint",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
return parser.parse_args()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def main() -> None:
|
| 339 |
+
"""Main entry point for Common Voice evaluation."""
|
| 340 |
+
args = parse_args()
|
| 341 |
+
|
| 342 |
+
# Load dataset
|
| 343 |
+
samples = load_common_voice_danish(
|
| 344 |
+
split="test",
|
| 345 |
+
max_samples=args.max_samples,
|
| 346 |
+
cache_dir=args.cache_dir,
|
| 347 |
+
hf_token=args.hf_token,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Configure models to evaluate
|
| 351 |
+
models = []
|
| 352 |
+
|
| 353 |
+
if not args.skip_hviske_v3:
|
| 354 |
+
models.append(
|
| 355 |
+
HviskeV3Adapter(
|
| 356 |
+
model_id="syvai/hviske-v3-conversation",
|
| 357 |
+
device=args.device,
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
if not args.skip_hvisketiske:
|
| 362 |
+
models.append(
|
| 363 |
+
Qwen3ASRAdapter(
|
| 364 |
+
model_path=args.hvisketiske_path,
|
| 365 |
+
device=args.device,
|
| 366 |
+
)
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
if not models:
|
| 370 |
+
print("Error: No models selected for evaluation")
|
| 371 |
+
sys.exit(1)
|
| 372 |
+
|
| 373 |
+
print("=" * 60)
|
| 374 |
+
print("Common Voice Danish ASR Evaluation")
|
| 375 |
+
print("=" * 60)
|
| 376 |
+
print(f"Dataset: mozilla-foundation/common_voice_17_0")
|
| 377 |
+
print(f"Samples: {len(samples)}")
|
| 378 |
+
print(f"Device: {args.device}")
|
| 379 |
+
print(f"Warmup iterations: {args.warmup}")
|
| 380 |
+
print(f"Models to evaluate: {len(models)}")
|
| 381 |
+
for m in models:
|
| 382 |
+
print(f" - {m.model_name} ({m.model_size_params})")
|
| 383 |
+
print("=" * 60)
|
| 384 |
+
|
| 385 |
+
# Evaluate all models
|
| 386 |
+
results = {"dataset": "mozilla-foundation/common_voice_17_0", "models": {}}
|
| 387 |
+
|
| 388 |
+
for model in models:
|
| 389 |
+
model_results = evaluate_model(model, samples, warmup_samples=args.warmup)
|
| 390 |
+
results["models"][model.model_name] = model_results
|
| 391 |
+
|
| 392 |
+
# Print summary
|
| 393 |
+
print_summary(results)
|
| 394 |
+
|
| 395 |
+
# Save results
|
| 396 |
+
args.output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 397 |
+
with open(args.output_file, "w", encoding="utf-8") as f:
|
| 398 |
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
| 399 |
+
|
| 400 |
+
print(f"\nResults saved to: {args.output_file}")
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
if __name__ == "__main__":
|
| 404 |
+
main()
|
generate_plots.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Generate comparison plots for ASR model benchmarks.
|
| 4 |
+
|
| 5 |
+
Creates publication-quality visualizations comparing hvisketiske-v2
|
| 6 |
+
against other Danish ASR models on accuracy and performance metrics.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python huggingface/generate_plots.py
|
| 10 |
+
|
| 11 |
+
# Specify custom result files:
|
| 12 |
+
python huggingface/generate_plots.py \
|
| 13 |
+
--coral-results ./results/full_comparison2.json \
|
| 14 |
+
--cv-results ./results/common_voice_comparison.json
|
| 15 |
+
|
| 16 |
+
Output:
|
| 17 |
+
huggingface/plots/
|
| 18 |
+
├── wer_comparison.png
|
| 19 |
+
├── cer_comparison.png
|
| 20 |
+
├── rtf_comparison.png
|
| 21 |
+
└── accuracy_vs_speed.png
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Dict, List, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
import numpy as np
|
| 31 |
+
|
| 32 |
+
# Use a clean style
|
| 33 |
+
plt.style.use("seaborn-v0_8-whitegrid")
|
| 34 |
+
|
| 35 |
+
# Color palette - distinct colors for models
|
| 36 |
+
COLORS = {
|
| 37 |
+
"hvisketiske": "#2ecc71", # Green for our model (best)
|
| 38 |
+
"qwen3-base": "#27ae60", # Darker green for base Qwen
|
| 39 |
+
"hviske-v2": "#3498db", # Blue for hviske-v2
|
| 40 |
+
"hviske-v3": "#2980b9", # Darker blue for hviske-v3
|
| 41 |
+
"faster": "#e74c3c", # Red for faster-whisper models
|
| 42 |
+
"turbo": "#e67e22", # Orange for turbo
|
| 43 |
+
"default": "#95a5a6", # Gray for others
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Model display names mapping
|
| 47 |
+
MODEL_DISPLAY_NAMES = {
|
| 48 |
+
"Qwen3-ASR (checkpoint-23448)": "hvisketiske-v2\n(Qwen3-ASR finetuned)",
|
| 49 |
+
"hviske-v3-conversation (Whisper Large v3)": "hviske-v3\n(Whisper v3)",
|
| 50 |
+
"hviske-v2 (Whisper Large v2)": "hviske-v2\n(Whisper v2)",
|
| 51 |
+
"faster-hviske-v2 (CT2 distilled)": "faster-hviske-v2\n(CT2 distilled)",
|
| 52 |
+
"Whisper Large v3 Turbo": "Whisper v3 Turbo\n(faster-whisper)",
|
| 53 |
+
"Qwen3-ASR-1.7B (base)": "Qwen3-ASR-1.7B\n(base, not finetuned)",
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_model_color(model_name: str) -> str:
|
| 58 |
+
"""Get color for a model based on its name."""
|
| 59 |
+
name_lower = model_name.lower()
|
| 60 |
+
|
| 61 |
+
# Our finetuned model (highest priority)
|
| 62 |
+
if "hvisketiske" in name_lower or "checkpoint" in name_lower:
|
| 63 |
+
return COLORS["hvisketiske"]
|
| 64 |
+
# Base Qwen3-ASR (not finetuned)
|
| 65 |
+
elif "qwen3-asr-1.7b" in name_lower and "base" in name_lower:
|
| 66 |
+
return COLORS["qwen3-base"]
|
| 67 |
+
elif "qwen" in name_lower:
|
| 68 |
+
return COLORS["hvisketiske"]
|
| 69 |
+
# Turbo model
|
| 70 |
+
elif "turbo" in name_lower:
|
| 71 |
+
return COLORS["turbo"]
|
| 72 |
+
# Faster-whisper models
|
| 73 |
+
elif "faster" in name_lower or "ct2" in name_lower:
|
| 74 |
+
return COLORS["faster"]
|
| 75 |
+
# hviske-v3
|
| 76 |
+
elif "hviske-v3" in name_lower or "v3" in name_lower:
|
| 77 |
+
return COLORS["hviske-v3"]
|
| 78 |
+
# hviske-v2
|
| 79 |
+
elif "hviske-v2" in name_lower or "v2" in name_lower:
|
| 80 |
+
return COLORS["hviske-v2"]
|
| 81 |
+
return COLORS["default"]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_display_name(model_name: str) -> str:
|
| 85 |
+
"""Get display name for a model."""
|
| 86 |
+
return MODEL_DISPLAY_NAMES.get(model_name, model_name)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def load_results(path: Path) -> Optional[dict]:
|
| 90 |
+
"""Load benchmark results from JSON file."""
|
| 91 |
+
if not path.exists():
|
| 92 |
+
print(f"Warning: Results file not found: {path}")
|
| 93 |
+
return None
|
| 94 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 95 |
+
return json.load(f)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def extract_metrics(results: dict) -> Tuple[List[str], List[float], List[float], List[float], List[str]]:
|
| 99 |
+
"""
|
| 100 |
+
Extract metrics from results dictionary.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Tuple of (names, wer_values, cer_values, rtf_values, colors)
|
| 104 |
+
"""
|
| 105 |
+
names = []
|
| 106 |
+
wer_values = []
|
| 107 |
+
cer_values = []
|
| 108 |
+
rtf_values = []
|
| 109 |
+
colors = []
|
| 110 |
+
|
| 111 |
+
for model_name, data in results["models"].items():
|
| 112 |
+
display_name = get_display_name(model_name)
|
| 113 |
+
names.append(display_name)
|
| 114 |
+
wer_values.append(data["accuracy"]["wer"] * 100) # Convert to percentage
|
| 115 |
+
cer_values.append(data["accuracy"]["cer"] * 100)
|
| 116 |
+
rtf_values.append(data["performance"]["real_time_factor"])
|
| 117 |
+
colors.append(get_model_color(model_name))
|
| 118 |
+
|
| 119 |
+
return names, wer_values, cer_values, rtf_values, colors
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def plot_wer_comparison(
|
| 123 |
+
results: dict,
|
| 124 |
+
output_path: Path,
|
| 125 |
+
dataset_name: str = "CoRal v2",
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Generate WER comparison bar chart."""
|
| 128 |
+
names, wer_values, _, _, colors = extract_metrics(results)
|
| 129 |
+
|
| 130 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 131 |
+
|
| 132 |
+
bars = ax.bar(names, wer_values, color=colors, edgecolor="white", linewidth=1.5)
|
| 133 |
+
|
| 134 |
+
# Add value labels on bars
|
| 135 |
+
for bar, val in zip(bars, wer_values):
|
| 136 |
+
height = bar.get_height()
|
| 137 |
+
ax.annotate(
|
| 138 |
+
f"{val:.1f}%",
|
| 139 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 140 |
+
xytext=(0, 5),
|
| 141 |
+
textcoords="offset points",
|
| 142 |
+
ha="center",
|
| 143 |
+
va="bottom",
|
| 144 |
+
fontsize=12,
|
| 145 |
+
fontweight="bold",
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
|
| 149 |
+
ax.set_title(f"WER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
|
| 150 |
+
ax.set_ylim(0, max(wer_values) * 1.2)
|
| 151 |
+
|
| 152 |
+
# Add grid
|
| 153 |
+
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
|
| 154 |
+
ax.set_axisbelow(True)
|
| 155 |
+
|
| 156 |
+
plt.tight_layout()
|
| 157 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
|
| 158 |
+
plt.close()
|
| 159 |
+
print(f"Saved: {output_path}")
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def plot_cer_comparison(
|
| 163 |
+
results: dict,
|
| 164 |
+
output_path: Path,
|
| 165 |
+
dataset_name: str = "CoRal v2",
|
| 166 |
+
) -> None:
|
| 167 |
+
"""Generate CER comparison bar chart."""
|
| 168 |
+
names, _, cer_values, _, colors = extract_metrics(results)
|
| 169 |
+
|
| 170 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 171 |
+
|
| 172 |
+
bars = ax.bar(names, cer_values, color=colors, edgecolor="white", linewidth=1.5)
|
| 173 |
+
|
| 174 |
+
# Add value labels on bars
|
| 175 |
+
for bar, val in zip(bars, cer_values):
|
| 176 |
+
height = bar.get_height()
|
| 177 |
+
ax.annotate(
|
| 178 |
+
f"{val:.1f}%",
|
| 179 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 180 |
+
xytext=(0, 5),
|
| 181 |
+
textcoords="offset points",
|
| 182 |
+
ha="center",
|
| 183 |
+
va="bottom",
|
| 184 |
+
fontsize=12,
|
| 185 |
+
fontweight="bold",
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
ax.set_ylabel("Character Error Rate (%)", fontsize=12)
|
| 189 |
+
ax.set_title(f"CER Comparison on {dataset_name}", fontsize=14, fontweight="bold")
|
| 190 |
+
ax.set_ylim(0, max(cer_values) * 1.2)
|
| 191 |
+
|
| 192 |
+
# Add grid
|
| 193 |
+
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
|
| 194 |
+
ax.set_axisbelow(True)
|
| 195 |
+
|
| 196 |
+
plt.tight_layout()
|
| 197 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
|
| 198 |
+
plt.close()
|
| 199 |
+
print(f"Saved: {output_path}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def plot_rtf_comparison(
|
| 203 |
+
results: dict,
|
| 204 |
+
output_path: Path,
|
| 205 |
+
dataset_name: str = "CoRal v2",
|
| 206 |
+
) -> None:
|
| 207 |
+
"""Generate RTF/speed comparison bar chart."""
|
| 208 |
+
names, _, _, rtf_values, colors = extract_metrics(results)
|
| 209 |
+
|
| 210 |
+
fig, ax = plt.subplots(figsize=(8, 5))
|
| 211 |
+
|
| 212 |
+
bars = ax.bar(names, rtf_values, color=colors, edgecolor="white", linewidth=1.5)
|
| 213 |
+
|
| 214 |
+
# Add value labels on bars
|
| 215 |
+
for bar, val in zip(bars, rtf_values):
|
| 216 |
+
height = bar.get_height()
|
| 217 |
+
ax.annotate(
|
| 218 |
+
f"{val:.3f}",
|
| 219 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 220 |
+
xytext=(0, 5),
|
| 221 |
+
textcoords="offset points",
|
| 222 |
+
ha="center",
|
| 223 |
+
va="bottom",
|
| 224 |
+
fontsize=12,
|
| 225 |
+
fontweight="bold",
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Add reference line at RTF=1.0 (real-time)
|
| 229 |
+
ax.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, label="Real-time (RTF=1.0)")
|
| 230 |
+
|
| 231 |
+
ax.set_ylabel("Real-Time Factor (lower is faster)", fontsize=12)
|
| 232 |
+
ax.set_title(f"Speed Comparison on {dataset_name}", fontsize=14, fontweight="bold")
|
| 233 |
+
ax.set_ylim(0, max(max(rtf_values) * 1.3, 1.1))
|
| 234 |
+
ax.legend(loc="upper right")
|
| 235 |
+
|
| 236 |
+
# Add grid
|
| 237 |
+
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
|
| 238 |
+
ax.set_axisbelow(True)
|
| 239 |
+
|
| 240 |
+
plt.tight_layout()
|
| 241 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
|
| 242 |
+
plt.close()
|
| 243 |
+
print(f"Saved: {output_path}")
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def plot_accuracy_vs_speed(
|
| 247 |
+
results: dict,
|
| 248 |
+
output_path: Path,
|
| 249 |
+
dataset_name: str = "CoRal v2",
|
| 250 |
+
) -> None:
|
| 251 |
+
"""Generate accuracy vs speed scatter plot."""
|
| 252 |
+
fig, ax = plt.subplots(figsize=(9, 6))
|
| 253 |
+
|
| 254 |
+
for model_name, data in results["models"].items():
|
| 255 |
+
wer = data["accuracy"]["wer"] * 100
|
| 256 |
+
rtf = data["performance"]["real_time_factor"]
|
| 257 |
+
color = get_model_color(model_name)
|
| 258 |
+
display_name = get_display_name(model_name)
|
| 259 |
+
|
| 260 |
+
# Extract parameter count for bubble size
|
| 261 |
+
size_str = data["model_size"]
|
| 262 |
+
if "1.7B" in size_str:
|
| 263 |
+
size = 400
|
| 264 |
+
elif "2B" in size_str:
|
| 265 |
+
size = 500
|
| 266 |
+
else:
|
| 267 |
+
size = 300
|
| 268 |
+
|
| 269 |
+
ax.scatter(
|
| 270 |
+
rtf,
|
| 271 |
+
wer,
|
| 272 |
+
s=size,
|
| 273 |
+
c=color,
|
| 274 |
+
alpha=0.7,
|
| 275 |
+
edgecolors="white",
|
| 276 |
+
linewidth=2,
|
| 277 |
+
label=display_name.replace("\n", " "),
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Add label
|
| 281 |
+
ax.annotate(
|
| 282 |
+
display_name.replace("\n", " "),
|
| 283 |
+
xy=(rtf, wer),
|
| 284 |
+
xytext=(10, 10),
|
| 285 |
+
textcoords="offset points",
|
| 286 |
+
fontsize=10,
|
| 287 |
+
ha="left",
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
# Add reference line at RTF=1.0
|
| 291 |
+
ax.axvline(x=1.0, color="red", linestyle="--", linewidth=1, alpha=0.5, label="Real-time")
|
| 292 |
+
|
| 293 |
+
ax.set_xlabel("Real-Time Factor (lower is faster)", fontsize=12)
|
| 294 |
+
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
|
| 295 |
+
ax.set_title(
|
| 296 |
+
f"Accuracy vs Speed Trade-off on {dataset_name}\n(bubble size = model parameters)",
|
| 297 |
+
fontsize=14,
|
| 298 |
+
fontweight="bold",
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Set axis limits with padding
|
| 302 |
+
all_wer = [d["accuracy"]["wer"] * 100 for d in results["models"].values()]
|
| 303 |
+
all_rtf = [d["performance"]["real_time_factor"] for d in results["models"].values()]
|
| 304 |
+
ax.set_xlim(0, max(all_rtf) * 1.5)
|
| 305 |
+
ax.set_ylim(min(all_wer) * 0.8, max(all_wer) * 1.2)
|
| 306 |
+
|
| 307 |
+
# Add grid
|
| 308 |
+
ax.grid(True, linestyle="--", alpha=0.7)
|
| 309 |
+
|
| 310 |
+
# Add annotation for best region
|
| 311 |
+
ax.annotate(
|
| 312 |
+
"Better",
|
| 313 |
+
xy=(0.02, min(all_wer) * 0.85),
|
| 314 |
+
fontsize=10,
|
| 315 |
+
color="green",
|
| 316 |
+
fontweight="bold",
|
| 317 |
+
)
|
| 318 |
+
ax.annotate(
|
| 319 |
+
"Faster & More Accurate",
|
| 320 |
+
xy=(0.02, min(all_wer) * 0.9),
|
| 321 |
+
fontsize=8,
|
| 322 |
+
color="gray",
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
plt.tight_layout()
|
| 326 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
|
| 327 |
+
plt.close()
|
| 328 |
+
print(f"Saved: {output_path}")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def plot_multi_dataset_comparison(
|
| 332 |
+
coral_results: dict,
|
| 333 |
+
cv_results: Optional[dict],
|
| 334 |
+
output_path: Path,
|
| 335 |
+
) -> None:
|
| 336 |
+
"""Generate multi-dataset WER comparison plot."""
|
| 337 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 338 |
+
|
| 339 |
+
# Prepare data
|
| 340 |
+
datasets = ["CoRal v2"]
|
| 341 |
+
if cv_results:
|
| 342 |
+
datasets.append("Common Voice")
|
| 343 |
+
|
| 344 |
+
# Get model names from coral results
|
| 345 |
+
model_names = list(coral_results["models"].keys())
|
| 346 |
+
x = np.arange(len(datasets))
|
| 347 |
+
width = 0.35
|
| 348 |
+
|
| 349 |
+
for i, model_name in enumerate(model_names):
|
| 350 |
+
display_name = get_display_name(model_name)
|
| 351 |
+
color = get_model_color(model_name)
|
| 352 |
+
|
| 353 |
+
wer_values = [coral_results["models"][model_name]["accuracy"]["wer"] * 100]
|
| 354 |
+
if cv_results and model_name in cv_results["models"]:
|
| 355 |
+
wer_values.append(cv_results["models"][model_name]["accuracy"]["wer"] * 100)
|
| 356 |
+
elif cv_results:
|
| 357 |
+
wer_values.append(0) # Model not evaluated on this dataset
|
| 358 |
+
|
| 359 |
+
offset = (i - len(model_names) / 2 + 0.5) * width
|
| 360 |
+
bars = ax.bar(
|
| 361 |
+
x + offset,
|
| 362 |
+
wer_values,
|
| 363 |
+
width,
|
| 364 |
+
label=display_name.replace("\n", " "),
|
| 365 |
+
color=color,
|
| 366 |
+
edgecolor="white",
|
| 367 |
+
linewidth=1.5,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Add value labels
|
| 371 |
+
for bar, val in zip(bars, wer_values):
|
| 372 |
+
if val > 0:
|
| 373 |
+
height = bar.get_height()
|
| 374 |
+
ax.annotate(
|
| 375 |
+
f"{val:.1f}%",
|
| 376 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 377 |
+
xytext=(0, 3),
|
| 378 |
+
textcoords="offset points",
|
| 379 |
+
ha="center",
|
| 380 |
+
va="bottom",
|
| 381 |
+
fontsize=10,
|
| 382 |
+
fontweight="bold",
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
ax.set_ylabel("Word Error Rate (%)", fontsize=12)
|
| 386 |
+
ax.set_title("WER Comparison Across Datasets", fontsize=14, fontweight="bold")
|
| 387 |
+
ax.set_xticks(x)
|
| 388 |
+
ax.set_xticklabels(datasets, fontsize=11)
|
| 389 |
+
ax.legend(loc="upper right")
|
| 390 |
+
ax.yaxis.grid(True, linestyle="--", alpha=0.7)
|
| 391 |
+
ax.set_axisbelow(True)
|
| 392 |
+
|
| 393 |
+
plt.tight_layout()
|
| 394 |
+
plt.savefig(output_path, dpi=300, bbox_inches="tight", facecolor="white")
|
| 395 |
+
plt.close()
|
| 396 |
+
print(f"Saved: {output_path}")
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def parse_args() -> argparse.Namespace:
|
| 400 |
+
"""Parse command line arguments."""
|
| 401 |
+
parser = argparse.ArgumentParser(description="Generate ASR comparison plots")
|
| 402 |
+
|
| 403 |
+
parser.add_argument(
|
| 404 |
+
"--coral-results",
|
| 405 |
+
type=Path,
|
| 406 |
+
default=Path("results/full_comparison2.json"),
|
| 407 |
+
help="Path to CoRal benchmark results",
|
| 408 |
+
)
|
| 409 |
+
parser.add_argument(
|
| 410 |
+
"--cv-results",
|
| 411 |
+
type=Path,
|
| 412 |
+
default=Path("results/common_voice_comparison.json"),
|
| 413 |
+
help="Path to Common Voice benchmark results",
|
| 414 |
+
)
|
| 415 |
+
parser.add_argument(
|
| 416 |
+
"--output-dir",
|
| 417 |
+
type=Path,
|
| 418 |
+
default=Path(__file__).parent / "plots",
|
| 419 |
+
help="Output directory for plots",
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
return parser.parse_args()
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def main() -> None:
|
| 426 |
+
"""Main entry point for plot generation."""
|
| 427 |
+
args = parse_args()
|
| 428 |
+
|
| 429 |
+
# Create output directory
|
| 430 |
+
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 431 |
+
|
| 432 |
+
# Load results
|
| 433 |
+
coral_results = load_results(args.coral_results)
|
| 434 |
+
cv_results = load_results(args.cv_results)
|
| 435 |
+
|
| 436 |
+
if coral_results is None:
|
| 437 |
+
print("Error: CoRal results file is required")
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
print("=" * 60)
|
| 441 |
+
print("Generating ASR Comparison Plots")
|
| 442 |
+
print("=" * 60)
|
| 443 |
+
print(f"Output directory: {args.output_dir}")
|
| 444 |
+
print()
|
| 445 |
+
|
| 446 |
+
# Generate CoRal plots
|
| 447 |
+
print("Generating CoRal v2 plots...")
|
| 448 |
+
plot_wer_comparison(coral_results, args.output_dir / "wer_comparison.png", "CoRal v2")
|
| 449 |
+
plot_cer_comparison(coral_results, args.output_dir / "cer_comparison.png", "CoRal v2")
|
| 450 |
+
plot_rtf_comparison(coral_results, args.output_dir / "rtf_comparison.png", "CoRal v2")
|
| 451 |
+
plot_accuracy_vs_speed(coral_results, args.output_dir / "accuracy_vs_speed.png", "CoRal v2")
|
| 452 |
+
|
| 453 |
+
# Generate Common Voice plots if available
|
| 454 |
+
if cv_results:
|
| 455 |
+
print("\nGenerating Common Voice plots...")
|
| 456 |
+
plot_wer_comparison(
|
| 457 |
+
cv_results, args.output_dir / "wer_comparison_cv.png", "Common Voice Danish"
|
| 458 |
+
)
|
| 459 |
+
plot_cer_comparison(
|
| 460 |
+
cv_results, args.output_dir / "cer_comparison_cv.png", "Common Voice Danish"
|
| 461 |
+
)
|
| 462 |
+
plot_rtf_comparison(
|
| 463 |
+
cv_results, args.output_dir / "rtf_comparison_cv.png", "Common Voice Danish"
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Multi-dataset comparison
|
| 467 |
+
print("\nGenerating multi-dataset comparison...")
|
| 468 |
+
plot_multi_dataset_comparison(
|
| 469 |
+
coral_results, cv_results, args.output_dir / "multi_dataset_wer.png"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
print("\n" + "=" * 60)
|
| 473 |
+
print("Plot generation complete!")
|
| 474 |
+
print("=" * 60)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
if __name__ == "__main__":
|
| 478 |
+
main()
|