Spaces:
Paused
Paused
| import logging | |
| import librosa | |
| import soundfile as sf | |
| import torch | |
| def load_waveform( | |
| fname: str, | |
| sample_rate: int, | |
| dtype: str = "float32", | |
| device: torch.device = torch.device("cpu"), | |
| return_numpy: bool = False, | |
| max_seconds: float = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Load an audio file, preprocess it, and convert to a PyTorch tensor. | |
| Args: | |
| fname (str): Path to the audio file. | |
| sample_rate (int): Target sample rate for resampling. | |
| dtype (str, optional): Data type to load audio as (default: "float32"). | |
| device (torch.device, optional): Device to place the resulting tensor | |
| on (default: CPU). | |
| return_numpy (bool): If True, returns a NumPy array instead of a | |
| PyTorch tensor. | |
| max_seconds (int): Maximum length (seconds) of the audio tensor. | |
| If the audio is longer than this, it will be truncated. | |
| Returns: | |
| torch.Tensor: Processed audio waveform as a PyTorch tensor, | |
| with shape (num_samples,). | |
| Notes: | |
| - If the audio is stereo, it will be converted to mono by averaging channels. | |
| - If the audio's sample rate differs from the target, it will be resampled. | |
| """ | |
| # Load audio file with specified data type | |
| wav_data, sr = sf.read(fname, dtype=dtype) | |
| # Convert stereo to mono if necessary | |
| if len(wav_data.shape) == 2: | |
| wav_data = wav_data.mean(1) | |
| # Resample to target sample rate if needed | |
| if sr != sample_rate: | |
| wav_data = librosa.resample(wav_data, orig_sr=sr, target_sr=sample_rate) | |
| if max_seconds is not None: | |
| # Trim to max length | |
| max_length = sample_rate * max_seconds | |
| if len(wav_data) > max_length: | |
| wav_data = wav_data[:max_length] | |
| logging.warning( | |
| f"Wav file {fname} is longer than 2 minutes, " | |
| f"truncated to 2 minutes to avoid OOM." | |
| ) | |
| if return_numpy: | |
| return wav_data | |
| else: | |
| wav_data = torch.from_numpy(wav_data) | |
| return wav_data.to(device) | |