|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Dataset loading utilities for Vietnamese ASR. |
|
|
|
|
|
Supports: |
|
|
- Common Voice 17.0 (Vietnamese) |
|
|
- VIVOS |
|
|
""" |
|
|
|
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from datasets import load_dataset, Audio |
|
|
|
|
|
|
|
|
def load_common_voice( |
|
|
split: str = "train", |
|
|
streaming: bool = False, |
|
|
cache_dir: Optional[str] = None, |
|
|
) -> "Dataset": |
|
|
"""Load Common Voice 17.0 Vietnamese dataset. |
|
|
|
|
|
Args: |
|
|
split: Dataset split ("train", "validation", "test") |
|
|
streaming: Whether to use streaming mode |
|
|
cache_dir: Custom cache directory |
|
|
|
|
|
Returns: |
|
|
HuggingFace Dataset with audio column cast to 16kHz |
|
|
""" |
|
|
ds = load_dataset( |
|
|
"mozilla-foundation/common_voice_17_0", |
|
|
"vi", |
|
|
split=split, |
|
|
streaming=streaming, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
|
ds = ds.remove_columns( |
|
|
[c for c in ds.column_names |
|
|
if c not in ("audio", "sentence")] |
|
|
) |
|
|
return ds |
|
|
|
|
|
|
|
|
def load_vivos( |
|
|
split: str = "train", |
|
|
cache_dir: Optional[str] = None, |
|
|
) -> "Dataset": |
|
|
"""Load VIVOS Vietnamese speech dataset. |
|
|
|
|
|
Args: |
|
|
split: Dataset split ("train", "test") |
|
|
cache_dir: Custom cache directory |
|
|
|
|
|
Returns: |
|
|
HuggingFace Dataset with audio column cast to 16kHz |
|
|
""" |
|
|
ds = load_dataset( |
|
|
"vivos", |
|
|
split=split, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
ds = ds.rename_column("sentence", "sentence") |
|
|
return ds |
|
|
|
|
|
|
|
|
def prepare_dataset(batch, processor): |
|
|
"""Prepare a batch for Whisper training. |
|
|
|
|
|
Args: |
|
|
batch: Dataset batch with 'audio' and 'sentence' columns |
|
|
processor: WhisperProcessor instance |
|
|
|
|
|
Returns: |
|
|
Batch with input_features and labels |
|
|
""" |
|
|
audio = batch["audio"] |
|
|
batch["input_features"] = processor.feature_extractor( |
|
|
audio["array"], |
|
|
sampling_rate=audio["sampling_rate"], |
|
|
).input_features[0] |
|
|
|
|
|
batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids |
|
|
return batch |
|
|
|