Spaces:
Running
Running
| import torch | |
| import torchaudio | |
| import torch.nn.functional as F | |
| from core.config import SAMPLE_RATE, DEVICE, N_MELS, TARGET_LEN | |
| from pydub import AudioSegment | |
| import numpy as np | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate = SAMPLE_RATE, | |
| n_fft = 400, | |
| hop_length = 256, | |
| n_mels = N_MELS | |
| ).to(DEVICE) | |
| amp_to_db = torchaudio.transforms.AmplitudeToDB().to(DEVICE) | |
| # def load_audio(path: str) -> torch.Tensor: | |
| # wav, sr = torchaudio.load(path) | |
| # if sr != SAMPLE_RATE: | |
| # wav = torchaudio.transforms.Resample(wav, sr, SAMPLE_RATE) | |
| # if wav.shape[0] > 1: | |
| # wav = wav.mean(dim = 0) | |
| # return wav.to(DEVICE) | |
| import torch | |
| import torchaudio | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from pydub import AudioSegment | |
| class AudioLoadError(Exception): | |
| pass | |
| def load_audio(path: str) -> torch.Tensor: | |
| waveform = None | |
| sr = None | |
| # --- primary loader --- | |
| try: | |
| waveform, sr = torchaudio.load(path) | |
| except Exception as e1: | |
| # --- fallback loader --- | |
| try: | |
| audio = AudioSegment.from_file(path) | |
| audio = audio.set_channels(1).set_frame_rate(SAMPLE_RATE) | |
| samples = np.array(audio.get_array_of_samples(), dtype=np.float32) | |
| if samples.size == 0: | |
| raise AudioLoadError("Empty audio file") | |
| waveform = torch.from_numpy(samples) | |
| sr = SAMPLE_RATE | |
| except Exception as e2: | |
| raise AudioLoadError( | |
| f"Failed to decode audio file: {str(e2)}" | |
| ) from e2 | |
| # ---- sanity checks ---- | |
| if waveform is None or waveform.numel() == 0: | |
| raise AudioLoadError("Loaded audio is empty") | |
| # mono | |
| if waveform.dim() > 1: | |
| waveform = waveform.mean(dim=0) | |
| # resample | |
| if sr != SAMPLE_RATE: | |
| waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform) | |
| # duration control | |
| if waveform.numel() < TARGET_LEN: | |
| raise AudioLoadError("Audio too short for analysis") | |
| if waveform.numel() > TARGET_LEN: | |
| waveform = waveform[:TARGET_LEN] | |
| else: | |
| waveform = F.pad(waveform, (0, TARGET_LEN - waveform.numel())) | |
| return waveform.float() | |
| def waveform_to_mel(waveform: torch.Tensor): | |
| """ | |
| waveform: [T] | |
| returns: [1, T, N_MELS] | |
| """ | |
| mel = mel_transform(waveform.unsqueeze(0)) # [1, n_mels, frames] | |
| mel = amp_to_db(mel) | |
| mel = mel.transpose(1, 2) # [1, frames, n_mels] | |
| return mel | |
| def pad_time_dim(mel): | |
| T = mel.shape[1] | |
| pad_len = (8 - (T % 8)) % 8 | |
| if pad_len > 0: | |
| mel = F.pad(mel, (0, 0, 0, pad_len)) | |
| return mel | |
| def extract_features(wav: torch.Tensor) -> torch.Tensor: | |
| mel = mel_transform(wav.unsqueeze(0)) | |
| mel = amp_to_db(mel) | |
| if mel.dim == 4: | |
| mel = mel.squeeze(1) | |
| mel.transpose(1, 2) # [B, T, N_MELS] | |
| mel = pad_time_dim(mel) | |
| return mel |