Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
File size: 2,328 Bytes
5763d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "datasets>=2.14.0",
#     "torchaudio>=2.0.0",
#     "transformers>=4.36.0",
# ]
# ///
"""
Dataset loading utilities for Vietnamese ASR.

Supports:
- Common Voice 17.0 (Vietnamese)
- VIVOS
"""

from pathlib import Path
from typing import Optional

from datasets import load_dataset, Audio


def load_common_voice(
    split: str = "train",
    streaming: bool = False,
    cache_dir: Optional[str] = None,
) -> "Dataset":
    """Load Common Voice 17.0 Vietnamese dataset.

    Args:
        split: Dataset split ("train", "validation", "test")
        streaming: Whether to use streaming mode
        cache_dir: Custom cache directory

    Returns:
        HuggingFace Dataset with audio column cast to 16kHz
    """
    ds = load_dataset(
        "mozilla-foundation/common_voice_17_0",
        "vi",
        split=split,
        streaming=streaming,
        cache_dir=cache_dir,
        trust_remote_code=True,
    )
    ds = ds.cast_column("audio", Audio(sampling_rate=16000))
    # Keep only relevant columns
    ds = ds.remove_columns(
        [c for c in ds.column_names
         if c not in ("audio", "sentence")]
    )
    return ds


def load_vivos(
    split: str = "train",
    cache_dir: Optional[str] = None,
) -> "Dataset":
    """Load VIVOS Vietnamese speech dataset.

    Args:
        split: Dataset split ("train", "test")
        cache_dir: Custom cache directory

    Returns:
        HuggingFace Dataset with audio column cast to 16kHz
    """
    ds = load_dataset(
        "vivos",
        split=split,
        cache_dir=cache_dir,
        trust_remote_code=True,
    )
    ds = ds.cast_column("audio", Audio(sampling_rate=16000))
    ds = ds.rename_column("sentence", "sentence")
    return ds


def prepare_dataset(batch, processor):
    """Prepare a batch for Whisper training.

    Args:
        batch: Dataset batch with 'audio' and 'sentence' columns
        processor: WhisperProcessor instance

    Returns:
        Batch with input_features and labels
    """
    audio = batch["audio"]
    batch["input_features"] = processor.feature_extractor(
        audio["array"],
        sampling_rate=audio["sampling_rate"],
    ).input_features[0]

    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch