Luigi's picture
Restore complete zipvoice package with all source files
2967cdb
import logging
import librosa
import soundfile as sf
import torch
def load_waveform(
fname: str,
sample_rate: int,
dtype: str = "float32",
device: torch.device = torch.device("cpu"),
return_numpy: bool = False,
max_seconds: float = None,
) -> torch.Tensor:
"""
Load an audio file, preprocess it, and convert to a PyTorch tensor.
Args:
fname (str): Path to the audio file.
sample_rate (int): Target sample rate for resampling.
dtype (str, optional): Data type to load audio as (default: "float32").
device (torch.device, optional): Device to place the resulting tensor
on (default: CPU).
return_numpy (bool): If True, returns a NumPy array instead of a
PyTorch tensor.
max_seconds (int): Maximum length (seconds) of the audio tensor.
If the audio is longer than this, it will be truncated.
Returns:
torch.Tensor: Processed audio waveform as a PyTorch tensor,
with shape (num_samples,).
Notes:
- If the audio is stereo, it will be converted to mono by averaging channels.
- If the audio's sample rate differs from the target, it will be resampled.
"""
# Load audio file with specified data type
wav_data, sr = sf.read(fname, dtype=dtype)
# Convert stereo to mono if necessary
if len(wav_data.shape) == 2:
wav_data = wav_data.mean(1)
# Resample to target sample rate if needed
if sr != sample_rate:
wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate)
if max_seconds is not None:
# Trim to max length
max_length = sample_rate * max_seconds
if len(wav_data) > max_length:
wav_data = wav_data[:max_length]
logging.warning(
f"Wav file {fname} is longer than 2 minutes, "
f"truncated to 2 minutes to avoid OOM."
)
if return_numpy:
return wav_data
else:
wav_data = torch.from_numpy(wav_data)
return wav_data.to(device)