|
|
import os |
|
|
from dataclasses import dataclass, field |
|
|
from io import BytesIO |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import pyarrow as pa |
|
|
|
|
|
from .. import config |
|
|
from ..download.download_config import DownloadConfig |
|
|
from ..table import array_cast |
|
|
from ..utils.file_utils import is_local_path, xopen |
|
|
from ..utils.py_utils import no_op_if_value_is_null, string_to_dict |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from torchcodec.decoders import AudioDecoder |
|
|
|
|
|
from .features import FeatureType |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Audio: |
|
|
"""Audio [`Feature`] to extract audio data from an audio file. |
|
|
|
|
|
Input: The Audio feature accepts as input: |
|
|
- A `str`: Absolute path to the audio file (i.e. random access is allowed). |
|
|
- A `pathlib.Path`: path to the audio file (i.e. random access is allowed). |
|
|
- A `dict` with the keys: |
|
|
|
|
|
- `path`: String with relative path of the audio file to the archive file. |
|
|
- `bytes`: Bytes content of the audio file. |
|
|
|
|
|
This is useful for parquet or webdataset files which embed audio files. |
|
|
|
|
|
- A `dict` with the keys: |
|
|
|
|
|
- `array`: Array containing the audio sample |
|
|
- `sampling_rate`: Integer corresponding to the sampling rate of the audio sample. |
|
|
|
|
|
- A `torchcodec.decoders.AudioDecoder`: torchcodec audio decoder object. |
|
|
|
|
|
Output: The Audio features output data as `torchcodec.decoders.AudioDecoder` objects, with additional keys: |
|
|
|
|
|
- `array`: Array containing the audio sample |
|
|
- `sampling_rate`: Integer corresponding to the sampling rate of the audio sample. |
|
|
|
|
|
Args: |
|
|
sampling_rate (`int`, *optional*): |
|
|
Target sampling rate. If `None`, the native sampling rate is used. |
|
|
mono (`bool`, defaults to `True`): |
|
|
Whether to convert the audio signal to mono by averaging samples across |
|
|
channels. |
|
|
decode (`bool`, defaults to `True`): |
|
|
Whether to decode the audio data. If `False`, |
|
|
returns the underlying dictionary in the format `{"path": audio_path, "bytes": audio_bytes}`. |
|
|
stream_index (`int`, *optional*): |
|
|
The streaming index to use from the file. If `None` defaults to the "best" index. |
|
|
|
|
|
Example: |
|
|
|
|
|
```py |
|
|
>>> from datasets import load_dataset, Audio |
|
|
>>> ds = load_dataset("PolyAI/minds14", name="en-US", split="train") |
|
|
>>> ds = ds.cast_column("audio", Audio(sampling_rate=44100)) |
|
|
>>> ds[0]["audio"] |
|
|
<datasets.features._torchcodec.AudioDecoder object at 0x11642b6a0> |
|
|
>>> audio = ds[0]["audio"] |
|
|
>>> audio.get_samples_played_in_range(0, 10) |
|
|
AudioSamples: |
|
|
data (shape): torch.Size([2, 110592]) |
|
|
pts_seconds: 0.0 |
|
|
duration_seconds: 2.507755102040816 |
|
|
sample_rate: 44100 |
|
|
``` |
|
|
""" |
|
|
|
|
|
sampling_rate: Optional[int] = None |
|
|
decode: bool = True |
|
|
stream_index: Optional[int] = None |
|
|
id: Optional[str] = field(default=None, repr=False) |
|
|
|
|
|
dtype: ClassVar[str] = "dict" |
|
|
pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()}) |
|
|
_type: str = field(default="Audio", init=False, repr=False) |
|
|
|
|
|
def __call__(self): |
|
|
return self.pa_type |
|
|
|
|
|
def encode_example(self, value: Union[str, bytes, bytearray, dict, "AudioDecoder"]) -> dict: |
|
|
"""Encode example into a format for Arrow. |
|
|
|
|
|
Args: |
|
|
value (`str`, `bytes`,`bytearray`,`dict`, `AudioDecoder`): |
|
|
Data passed as input to Audio feature. |
|
|
|
|
|
Returns: |
|
|
`dict` |
|
|
""" |
|
|
try: |
|
|
import torch |
|
|
from torchcodec.encoders import AudioEncoder |
|
|
except ImportError as err: |
|
|
raise ImportError("To support encoding audio data, please install 'torchcodec'.") from err |
|
|
|
|
|
if value is None: |
|
|
raise ValueError("value must be provided") |
|
|
|
|
|
if config.TORCHCODEC_AVAILABLE: |
|
|
from torchcodec.decoders import AudioDecoder |
|
|
|
|
|
else: |
|
|
AudioDecoder = None |
|
|
|
|
|
if isinstance(value, str): |
|
|
return {"bytes": None, "path": value} |
|
|
elif isinstance(value, Path): |
|
|
return {"bytes": None, "path": str(value.absolute())} |
|
|
elif isinstance(value, (bytes, bytearray)): |
|
|
return {"bytes": value, "path": None} |
|
|
elif AudioDecoder is not None and isinstance(value, AudioDecoder): |
|
|
return encode_torchcodec_audio(value) |
|
|
elif "array" in value: |
|
|
|
|
|
buffer = BytesIO() |
|
|
AudioEncoder( |
|
|
torch.from_numpy(value["array"].astype(np.float32)), sample_rate=value["sampling_rate"] |
|
|
).to_file_like(buffer, format="wav") |
|
|
return {"bytes": buffer.getvalue(), "path": None} |
|
|
elif value.get("path") is not None and os.path.isfile(value["path"]): |
|
|
|
|
|
if value["path"].endswith("pcm"): |
|
|
|
|
|
if value.get("sampling_rate") is None: |
|
|
|
|
|
raise KeyError("To use PCM files, please specify a 'sampling_rate' in Audio object") |
|
|
if value.get("bytes"): |
|
|
|
|
|
bytes_value = np.frombuffer(value["bytes"], dtype=np.int16).astype(np.float32) / 32767 |
|
|
else: |
|
|
bytes_value = np.memmap(value["path"], dtype="h", mode="r").astype(np.float32) / 32767 |
|
|
|
|
|
buffer = BytesIO() |
|
|
AudioEncoder(torch.from_numpy(bytes_value), sample_rate=value["sampling_rate"]).to_file_like( |
|
|
buffer, format="wav" |
|
|
) |
|
|
return {"bytes": buffer.getvalue(), "path": None} |
|
|
else: |
|
|
return {"bytes": None, "path": value.get("path")} |
|
|
elif value.get("bytes") is not None or value.get("path") is not None: |
|
|
|
|
|
return {"bytes": value.get("bytes"), "path": value.get("path")} |
|
|
else: |
|
|
raise ValueError( |
|
|
f"An audio sample should have one of 'path' or 'bytes' but they are missing or None in {value}." |
|
|
) |
|
|
|
|
|
def decode_example( |
|
|
self, value: dict, token_per_repo_id: Optional[dict[str, Union[str, bool, None]]] = None |
|
|
) -> "AudioDecoder": |
|
|
"""Decode example audio file into audio data. |
|
|
|
|
|
Args: |
|
|
value (`dict`): |
|
|
A dictionary with keys: |
|
|
|
|
|
- `path`: String with relative audio file path. |
|
|
- `bytes`: Bytes of the audio file. |
|
|
token_per_repo_id (`dict`, *optional*): |
|
|
To access and decode |
|
|
audio files from private repositories on the Hub, you can pass |
|
|
a dictionary repo_id (`str`) -> token (`bool` or `str`) |
|
|
|
|
|
Returns: |
|
|
`torchcodec.decoders.AudioDecoder` |
|
|
""" |
|
|
if config.TORCHCODEC_AVAILABLE: |
|
|
from ._torchcodec import AudioDecoder |
|
|
else: |
|
|
raise ImportError("To support decoding audio data, please install 'torchcodec'.") |
|
|
|
|
|
if not self.decode: |
|
|
raise RuntimeError("Decoding is disabled for this feature. Please use Audio(decode=True) instead.") |
|
|
|
|
|
path, bytes = (value["path"], value["bytes"]) if value["bytes"] is not None else (value["path"], None) |
|
|
if path is None and bytes is None: |
|
|
raise ValueError(f"An audio sample should have one of 'path' or 'bytes' but both are None in {value}.") |
|
|
|
|
|
if bytes is None and is_local_path(path): |
|
|
audio = AudioDecoder(path, stream_index=self.stream_index, sample_rate=self.sampling_rate) |
|
|
|
|
|
elif bytes is None: |
|
|
token_per_repo_id = token_per_repo_id or {} |
|
|
source_url = path.split("::")[-1] |
|
|
pattern = ( |
|
|
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL |
|
|
) |
|
|
source_url_fields = string_to_dict(source_url, pattern) |
|
|
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None |
|
|
|
|
|
download_config = DownloadConfig(token=token) |
|
|
f = xopen(path, "rb", download_config=download_config) |
|
|
audio = AudioDecoder(f, stream_index=self.stream_index, sample_rate=self.sampling_rate) |
|
|
|
|
|
else: |
|
|
audio = AudioDecoder(bytes, stream_index=self.stream_index, sample_rate=self.sampling_rate) |
|
|
audio._hf_encoded = {"path": path, "bytes": bytes} |
|
|
audio.metadata.path = path |
|
|
return audio |
|
|
|
|
|
def flatten(self) -> Union["FeatureType", dict[str, "FeatureType"]]: |
|
|
"""If in the decodable state, raise an error, otherwise flatten the feature into a dictionary.""" |
|
|
from .features import Value |
|
|
|
|
|
if self.decode: |
|
|
raise ValueError("Cannot flatten a decoded Audio feature.") |
|
|
return { |
|
|
"bytes": Value("binary"), |
|
|
"path": Value("string"), |
|
|
} |
|
|
|
|
|
def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.StructArray: |
|
|
"""Cast an Arrow array to the Audio arrow storage type. |
|
|
The Arrow types that can be converted to the Audio pyarrow storage type are: |
|
|
|
|
|
- `pa.string()` - it must contain the "path" data |
|
|
- `pa.binary()` - it must contain the audio bytes |
|
|
- `pa.struct({"bytes": pa.binary()})` |
|
|
- `pa.struct({"path": pa.string()})` |
|
|
- `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter |
|
|
|
|
|
Args: |
|
|
storage (`Union[pa.StringArray, pa.StructArray]`): |
|
|
PyArrow array to cast. |
|
|
|
|
|
Returns: |
|
|
`pa.StructArray`: Array in the Audio arrow storage type, that is |
|
|
`pa.struct({"bytes": pa.binary(), "path": pa.string()})` |
|
|
""" |
|
|
if pa.types.is_string(storage.type): |
|
|
bytes_array = pa.array([None] * len(storage), type=pa.binary()) |
|
|
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null()) |
|
|
elif pa.types.is_binary(storage.type): |
|
|
path_array = pa.array([None] * len(storage), type=pa.string()) |
|
|
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null()) |
|
|
elif pa.types.is_struct(storage.type) and storage.type.get_all_field_indices("array"): |
|
|
storage = pa.array( |
|
|
[Audio().encode_example(x) if x is not None else None for x in storage.to_numpy(zero_copy_only=False)] |
|
|
) |
|
|
elif pa.types.is_struct(storage.type): |
|
|
if storage.type.get_field_index("bytes") >= 0: |
|
|
bytes_array = storage.field("bytes") |
|
|
else: |
|
|
bytes_array = pa.array([None] * len(storage), type=pa.binary()) |
|
|
if storage.type.get_field_index("path") >= 0: |
|
|
path_array = storage.field("path") |
|
|
else: |
|
|
path_array = pa.array([None] * len(storage), type=pa.string()) |
|
|
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null()) |
|
|
return array_cast(storage, self.pa_type) |
|
|
|
|
|
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray: |
|
|
"""Embed audio files into the Arrow array. |
|
|
|
|
|
Args: |
|
|
storage (`pa.StructArray`): |
|
|
PyArrow array to embed. |
|
|
|
|
|
Returns: |
|
|
`pa.StructArray`: Array in the Audio arrow storage type, that is |
|
|
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`. |
|
|
""" |
|
|
if token_per_repo_id is None: |
|
|
token_per_repo_id = {} |
|
|
|
|
|
@no_op_if_value_is_null |
|
|
def path_to_bytes(path): |
|
|
source_url = path.split("::")[-1] |
|
|
pattern = ( |
|
|
config.HUB_DATASETS_URL if source_url.startswith(config.HF_ENDPOINT) else config.HUB_DATASETS_HFFS_URL |
|
|
) |
|
|
source_url_fields = string_to_dict(source_url, pattern) |
|
|
token = token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None |
|
|
download_config = DownloadConfig(token=token) |
|
|
with xopen(path, "rb", download_config=download_config) as f: |
|
|
return f.read() |
|
|
|
|
|
bytes_array = pa.array( |
|
|
[ |
|
|
(path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None |
|
|
for x in storage.to_pylist() |
|
|
], |
|
|
type=pa.binary(), |
|
|
) |
|
|
path_array = pa.array( |
|
|
[os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()], |
|
|
type=pa.string(), |
|
|
) |
|
|
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()) |
|
|
return array_cast(storage, self.pa_type) |
|
|
|
|
|
|
|
|
def encode_torchcodec_audio(audio: "AudioDecoder") -> dict: |
|
|
if hasattr(audio, "_hf_encoded"): |
|
|
return audio._hf_encoded |
|
|
else: |
|
|
try: |
|
|
from torchcodec.encoders import AudioEncoder |
|
|
except ImportError as err: |
|
|
raise ImportError("To support encoding audio data, please install 'torchcodec'.") from err |
|
|
|
|
|
samples = audio.get_all_samples() |
|
|
buffer = BytesIO() |
|
|
AudioEncoder(samples.data.cpu(), sample_rate=samples.sample_rate).to_file_like(buffer, format="wav") |
|
|
return {"bytes": buffer.getvalue(), "path": None} |
|
|
|