Automatic Speech Recognition
Transformers
Vietnamese
vietnamese
whisper
speech-to-text
asr-1 / src /data.py
rain1024's picture
Initial commit: ASR-1 Vietnamese speech recognition model
5763d9e
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "datasets>=2.14.0",
# "torchaudio>=2.0.0",
# "transformers>=4.36.0",
# ]
# ///
"""
Dataset loading utilities for Vietnamese ASR.
Supports:
- Common Voice 17.0 (Vietnamese)
- VIVOS
"""
from pathlib import Path
from typing import Optional
from datasets import load_dataset, Audio
def load_common_voice(
split: str = "train",
streaming: bool = False,
cache_dir: Optional[str] = None,
) -> "Dataset":
"""Load Common Voice 17.0 Vietnamese dataset.
Args:
split: Dataset split ("train", "validation", "test")
streaming: Whether to use streaming mode
cache_dir: Custom cache directory
Returns:
HuggingFace Dataset with audio column cast to 16kHz
"""
ds = load_dataset(
"mozilla-foundation/common_voice_17_0",
"vi",
split=split,
streaming=streaming,
cache_dir=cache_dir,
trust_remote_code=True,
)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
# Keep only relevant columns
ds = ds.remove_columns(
[c for c in ds.column_names
if c not in ("audio", "sentence")]
)
return ds
def load_vivos(
split: str = "train",
cache_dir: Optional[str] = None,
) -> "Dataset":
"""Load VIVOS Vietnamese speech dataset.
Args:
split: Dataset split ("train", "test")
cache_dir: Custom cache directory
Returns:
HuggingFace Dataset with audio column cast to 16kHz
"""
ds = load_dataset(
"vivos",
split=split,
cache_dir=cache_dir,
trust_remote_code=True,
)
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
ds = ds.rename_column("sentence", "sentence")
return ds
def prepare_dataset(batch, processor):
"""Prepare a batch for Whisper training.
Args:
batch: Dataset batch with 'audio' and 'sentence' columns
processor: WhisperProcessor instance
Returns:
Batch with input_features and labels
"""
audio = batch["audio"]
batch["input_features"] = processor.feature_extractor(
audio["array"],
sampling_rate=audio["sampling_rate"],
).input_features[0]
batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
return batch