VoiceFocus / hf_dataset_utils.py
mariesig
Add Gradio info message for processing indication to beginning
2f73944
raw
history blame contribute delete
870 Bytes
from __future__ import annotations
import numpy as np
from datasets import load_dataset, Audio
from constants import DATASET_NAME, DEFAULT_SPLIT
# Load once (HF datasets handles caching; HF_TOKEN / login is used automatically if needed)
ds = load_dataset(DATASET_NAME, split=DEFAULT_SPLIT)
ds = ds.cast_column("mix", Audio(sampling_rate=16000, decode=True))
ds = ds.cast_column("speech", Audio(sampling_rate=16000, decode=True))
ALL_FILES = ds["id"]
def get_audio(sample_id: str, prefix: str) -> tuple[np.ndarray, int]:
row = ds.filter(lambda x: x["id"] == sample_id)[0]
array = row[f"{prefix}"]["array"]
sr = row[f"{prefix}"]["sampling_rate"]
np_array = np.array(array, dtype=np.float32)
return np_array, sr
def get_transcript(sample_id: str) -> str:
row = ds.filter(lambda x: x["id"] == sample_id)[0]
return row.get("transcript", "")