File size: 13,424 Bytes
838f737 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 |
"""TorchCodec integration for TorchAudio."""
import os
from typing import BinaryIO, Optional, Tuple, Union
import torch
def load_with_torchcodec(
uri: Union[BinaryIO, str, os.PathLike],
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
format: Optional[str] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from source using TorchCodec's AudioDecoder.
.. note::
This function supports the same API as :func:`~torchaudio.load`, and
relies on TorchCodec's decoding capabilities under the hood. It is
provided for convenience, but we do recommend that you port your code to
natively use ``torchcodec``'s ``AudioDecoder`` class for better
performance:
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.
As of TorchAudio 2.9, :func:`~torchaudio.load` relies on
:func:`~torchaudio.load_with_torchcodec`. Note that some parameters of
:func:`~torchaudio.load`, like ``normalize``, ``buffer_size``, and
``backend``, are ignored by :func:`~torchaudio.load_with_torchcodec`.
To install torchcodec, follow the instructions at https://github.com/pytorch/torchcodec#installing-torchcodec.
Args:
uri (path-like object or file-like object):
Source of audio data. The following types are accepted:
* ``path-like``: File path or URL.
* ``file-like``: Object with ``read(size: int) -> bytes`` method.
frame_offset (int, optional):
Number of samples to skip before start reading data.
num_frames (int, optional):
Maximum number of samples to read. ``-1`` reads all the remaining samples,
starting from ``frame_offset``.
normalize (bool, optional):
TorchCodec always returns normalized float32 samples. This parameter
is ignored and a warning is issued if set to False.
Default: ``True``.
channels_first (bool, optional):
When True, the returned Tensor has dimension `[channel, time]`.
Otherwise, the returned Tensor's dimension is `[time, channel]`.
format (str or None, optional):
Format hint for the decoder. May not be supported by all TorchCodec
decoders. (Default: ``None``)
buffer_size (int, optional):
Not used by TorchCodec AudioDecoder. Provided for API compatibility.
backend (str or None, optional):
Not used by TorchCodec AudioDecoder. Provided for API compatibility.
Returns:
(torch.Tensor, int): Resulting Tensor and sample rate.
Always returns float32 tensors. If ``channels_first=True``, shape is
`[channel, time]`, otherwise `[time, channel]`.
Raises:
ImportError: If torchcodec is not available.
ValueError: If unsupported parameters are used.
RuntimeError: If TorchCodec fails to decode the audio.
Note:
- TorchCodec always returns normalized float32 samples, so the ``normalize``
parameter has no effect.
- The ``buffer_size`` and ``backend`` parameters are ignored.
- Not all audio formats supported by torchaudio backends may be supported
by TorchCodec.
"""
# Import torchcodec here to provide clear error if not available
try:
from torchcodec.decoders import AudioDecoder
except ImportError as e:
raise ImportError(
"TorchCodec is required for load_with_torchcodec. " "Please install torchcodec to use this function."
) from e
# Parameter validation and warnings
if not normalize:
import warnings
warnings.warn(
"TorchCodec AudioDecoder always returns normalized float32 samples. "
"The 'normalize=False' parameter is ignored.",
UserWarning,
stacklevel=2,
)
if buffer_size != 4096:
import warnings
warnings.warn("The 'buffer_size' parameter is not used by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
if backend is not None:
import warnings
warnings.warn("The 'backend' parameter is not used by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
if format is not None:
import warnings
warnings.warn("The 'format' parameter is not supported by TorchCodec AudioDecoder.", UserWarning, stacklevel=2)
# Create AudioDecoder
try:
decoder = AudioDecoder(uri)
except Exception as e:
raise RuntimeError(f"Failed to create AudioDecoder for {uri}: {e}") from e
# Get sample rate from metadata
sample_rate = decoder.metadata.sample_rate
if sample_rate is None:
raise RuntimeError("Unable to determine sample rate from audio metadata")
# Decode the entire file first, then subsample manually
# This is the simplest approach since torchcodec uses time-based indexing
try:
audio_samples = decoder.get_all_samples()
except Exception as e:
raise RuntimeError(f"Failed to decode audio samples: {e}") from e
data = audio_samples.data
# Apply frame_offset and num_frames (which are actually sample offsets)
if frame_offset > 0:
if frame_offset >= data.shape[1]:
# Return empty tensor if offset is beyond available data
empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
data = data[:, frame_offset:]
if num_frames == 0:
# Return empty tensor if num_frames is 0
empty_shape = (data.shape[0], 0) if channels_first else (0, data.shape[0])
return torch.zeros(empty_shape, dtype=torch.float32), sample_rate
elif num_frames > 0:
data = data[:, :num_frames]
# TorchCodec returns data in [channel, time] format by default
# Handle channels_first parameter
if not channels_first:
data = data.transpose(0, 1) # [channel, time] -> [time, channel]
return data, sample_rate
def save_with_torchcodec(
uri: Union[str, os.PathLike],
src: torch.Tensor,
sample_rate: int,
channels_first: bool = True,
format: Optional[str] = None,
encoding: Optional[str] = None,
bits_per_sample: Optional[int] = None,
buffer_size: int = 4096,
backend: Optional[str] = None,
compression: Optional[Union[float, int]] = None,
) -> None:
"""Save audio data to file using TorchCodec's AudioEncoder.
.. note::
This function supports the same API as :func:`~torchaudio.save`, and
relies on TorchCodec's encoding capabilities under the hood. It is
provided for convenience, but we do recommend that you port your code to
natively use ``torchcodec``'s ``AudioEncoder`` class for better
performance:
https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.encoders.AudioEncoder.
As of TorchAudio 2.9, :func:`~torchaudio.save` relies on
:func:`~torchaudio.save_with_torchcodec`. Note that some parameters of
:func:`~torchaudio.save`, like ``format``, ``encoding``,
``bits_per_sample``, ``buffer_size``, and ``backend``, are ignored by
are ignored by :func:`~torchaudio.save_with_torchcodec`.
To install torchcodec, follow the instructions at https://github.com/pytorch/torchcodec#installing-torchcodec.
This function provides a TorchCodec-based alternative to torchaudio.save
with the same API. TorchCodec's AudioEncoder provides efficient encoding
with FFmpeg under the hood.
Args:
uri (path-like object):
Path to save the audio file. The file extension determines the format.
src (torch.Tensor):
Audio data to save. Must be a 1D or 2D tensor with float32 values
in the range [-1, 1]. If 2D, shape should be [channel, time] when
channels_first=True, or [time, channel] when channels_first=False.
sample_rate (int):
Sample rate of the audio data.
channels_first (bool, optional):
Indicates whether the input tensor has channels as the first dimension.
If True, expects [channel, time]. If False, expects [time, channel].
Default: True.
format (str or None, optional):
Audio format hint. Not used by TorchCodec (format is determined by
file extension). A warning is issued if provided.
Default: None.
encoding (str or None, optional):
Audio encoding. Not fully supported by TorchCodec AudioEncoder.
A warning is issued if provided. Default: None.
bits_per_sample (int or None, optional):
Bits per sample. Not directly supported by TorchCodec AudioEncoder.
A warning is issued if provided. Default: None.
buffer_size (int, optional):
Not used by TorchCodec AudioEncoder. Provided for API compatibility.
A warning is issued if not default value. Default: 4096.
backend (str or None, optional):
Not used by TorchCodec AudioEncoder. Provided for API compatibility.
A warning is issued if provided. Default: None.
compression (float, int or None, optional):
Compression level or bit rate. Maps to bit_rate parameter in
TorchCodec AudioEncoder. Default: None.
Raises:
ImportError: If torchcodec is not available.
ValueError: If input parameters are invalid.
RuntimeError: If TorchCodec fails to encode the audio.
Note:
- TorchCodec AudioEncoder expects float32 samples in [-1, 1] range.
- Some parameters (format, encoding, bits_per_sample, buffer_size, backend)
are not used by TorchCodec but are provided for API compatibility.
- The output format is determined by the file extension in the uri.
- TorchCodec uses FFmpeg under the hood for encoding.
"""
# Import torchcodec here to provide clear error if not available
try:
from torchcodec.encoders import AudioEncoder
except ImportError as e:
raise ImportError(
"TorchCodec is required for save_with_torchcodec. " "Please install torchcodec to use this function."
) from e
# Parameter validation and warnings
if format is not None:
import warnings
warnings.warn(
"The 'format' parameter is not used by TorchCodec AudioEncoder. "
"Format is determined by the file extension.",
UserWarning,
stacklevel=2,
)
if encoding is not None:
import warnings
warnings.warn(
"The 'encoding' parameter is not fully supported by TorchCodec AudioEncoder.", UserWarning, stacklevel=2
)
if bits_per_sample is not None:
import warnings
warnings.warn(
"The 'bits_per_sample' parameter is not directly supported by TorchCodec AudioEncoder.",
UserWarning,
stacklevel=2,
)
if buffer_size != 4096:
import warnings
warnings.warn("The 'buffer_size' parameter is not used by TorchCodec AudioEncoder.", UserWarning, stacklevel=2)
if backend is not None:
import warnings
warnings.warn("The 'backend' parameter is not used by TorchCodec AudioEncoder.", UserWarning, stacklevel=2)
# Input validation
if not isinstance(src, torch.Tensor):
raise ValueError(f"Expected src to be a torch.Tensor, got {type(src)}")
if src.dtype != torch.float32:
src = src.float()
if sample_rate <= 0:
raise ValueError(f"sample_rate must be positive, got {sample_rate}")
# Handle tensor shape and channels_first
if src.ndim == 1:
# Convert to 2D: [1, time] for channels_first=True
if channels_first:
data = src.unsqueeze(0) # [1, time]
else:
# For channels_first=False, input is [time] -> reshape to [time, 1] -> transpose to [1, time]
data = src.unsqueeze(1).transpose(0, 1) # [time, 1] -> [1, time]
elif src.ndim == 2:
if channels_first:
data = src # Already [channel, time]
else:
data = src.transpose(0, 1) # [time, channel] -> [channel, time]
else:
raise ValueError(f"Expected 1D or 2D tensor, got {src.ndim}D tensor")
# Create AudioEncoder
try:
encoder = AudioEncoder(data, sample_rate=sample_rate)
except Exception as e:
raise RuntimeError(f"Failed to create AudioEncoder: {e}") from e
# Determine bit_rate from compression parameter
bit_rate = None
if compression is not None:
if isinstance(compression, (int, float)):
bit_rate = int(compression)
else:
import warnings
warnings.warn(
f"Unsupported compression type {type(compression)}. "
"TorchCodec AudioEncoder expects int or float for bit_rate.",
UserWarning,
stacklevel=2,
)
# Save to file
try:
encoder.to_file(uri, bit_rate=bit_rate)
except Exception as e:
raise RuntimeError(f"Failed to save audio to {uri}: {e}") from e
|