Spaces:
Running
Running
File size: 3,047 Bytes
36e0dea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import torchaudio
import torch.nn.functional as F
from core.config import SAMPLE_RATE, DEVICE, N_MELS, TARGET_LEN
from pydub import AudioSegment
import numpy as np
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate = SAMPLE_RATE,
n_fft = 400,
hop_length = 256,
n_mels = N_MELS
).to(DEVICE)
amp_to_db = torchaudio.transforms.AmplitudeToDB().to(DEVICE)
# def load_audio(path: str) -> torch.Tensor:
# wav, sr = torchaudio.load(path)
# if sr != SAMPLE_RATE:
# wav = torchaudio.transforms.Resample(wav, sr, SAMPLE_RATE)
# if wav.shape[0] > 1:
# wav = wav.mean(dim = 0)
# return wav.to(DEVICE)
import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
from pydub import AudioSegment
class AudioLoadError(Exception):
pass
def load_audio(path: str) -> torch.Tensor:
waveform = None
sr = None
# --- primary loader ---
try:
waveform, sr = torchaudio.load(path)
except Exception as e1:
# --- fallback loader ---
try:
audio = AudioSegment.from_file(path)
audio = audio.set_channels(1).set_frame_rate(SAMPLE_RATE)
samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
if samples.size == 0:
raise AudioLoadError("Empty audio file")
waveform = torch.from_numpy(samples)
sr = SAMPLE_RATE
except Exception as e2:
raise AudioLoadError(
f"Failed to decode audio file: {str(e2)}"
) from e2
# ---- sanity checks ----
if waveform is None or waveform.numel() == 0:
raise AudioLoadError("Loaded audio is empty")
# mono
if waveform.dim() > 1:
waveform = waveform.mean(dim=0)
# resample
if sr != SAMPLE_RATE:
waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)
# duration control
if waveform.numel() < TARGET_LEN:
raise AudioLoadError("Audio too short for analysis")
if waveform.numel() > TARGET_LEN:
waveform = waveform[:TARGET_LEN]
else:
waveform = F.pad(waveform, (0, TARGET_LEN - waveform.numel()))
return waveform.float()
def waveform_to_mel(waveform: torch.Tensor):
"""
waveform: [T]
returns: [1, T, N_MELS]
"""
mel = mel_transform(waveform.unsqueeze(0)) # [1, n_mels, frames]
mel = amp_to_db(mel)
mel = mel.transpose(1, 2) # [1, frames, n_mels]
return mel
def pad_time_dim(mel):
T = mel.shape[1]
pad_len = (8 - (T % 8)) % 8
if pad_len > 0:
mel = F.pad(mel, (0, 0, 0, pad_len))
return mel
def extract_features(wav: torch.Tensor) -> torch.Tensor:
mel = mel_transform(wav.unsqueeze(0))
mel = amp_to_db(mel)
if mel.dim == 4:
mel = mel.squeeze(1)
mel.transpose(1, 2) # [B, T, N_MELS]
mel = pad_time_dim(mel)
return mel |