File size: 3,047 Bytes
36e0dea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torchaudio
import torch.nn.functional as F
from core.config import SAMPLE_RATE, DEVICE, N_MELS, TARGET_LEN
from pydub import AudioSegment
import numpy as np

mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate = SAMPLE_RATE,
    n_fft = 400,
    hop_length = 256,
    n_mels = N_MELS    
).to(DEVICE)

amp_to_db = torchaudio.transforms.AmplitudeToDB().to(DEVICE)

# def load_audio(path: str) -> torch.Tensor:
#     wav, sr = torchaudio.load(path)
#     if sr != SAMPLE_RATE:
#         wav = torchaudio.transforms.Resample(wav, sr, SAMPLE_RATE)
#     if wav.shape[0] > 1:
#         wav = wav.mean(dim = 0)
#     return wav.to(DEVICE)  

import torch
import torchaudio
import torch.nn.functional as F
import numpy as np
from pydub import AudioSegment

class AudioLoadError(Exception):
    pass

def load_audio(path: str) -> torch.Tensor:
    waveform = None
    sr = None

    # --- primary loader ---
    try:
        waveform, sr = torchaudio.load(path)
    except Exception as e1:
        # --- fallback loader ---
        try:
            audio = AudioSegment.from_file(path)
            audio = audio.set_channels(1).set_frame_rate(SAMPLE_RATE)
            samples = np.array(audio.get_array_of_samples(), dtype=np.float32)
            if samples.size == 0:
                raise AudioLoadError("Empty audio file")

            waveform = torch.from_numpy(samples)
            sr = SAMPLE_RATE

        except Exception as e2:
            raise AudioLoadError(
                f"Failed to decode audio file: {str(e2)}"
            ) from e2

    # ---- sanity checks ----
    if waveform is None or waveform.numel() == 0:
        raise AudioLoadError("Loaded audio is empty")

    # mono
    if waveform.dim() > 1:
        waveform = waveform.mean(dim=0)

    # resample
    if sr != SAMPLE_RATE:
        waveform = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(waveform)

    # duration control
    if waveform.numel() < TARGET_LEN:
        raise AudioLoadError("Audio too short for analysis")

    if waveform.numel() > TARGET_LEN:
        waveform = waveform[:TARGET_LEN]
    else:
        waveform = F.pad(waveform, (0, TARGET_LEN - waveform.numel()))

    return waveform.float()



def waveform_to_mel(waveform: torch.Tensor):
    """

    waveform: [T]

    returns: [1, T, N_MELS]

    """
    mel = mel_transform(waveform.unsqueeze(0))   # [1, n_mels, frames]
    mel = amp_to_db(mel)
    mel = mel.transpose(1, 2)                     # [1, frames, n_mels]
    return mel

def pad_time_dim(mel):
    T = mel.shape[1]
    pad_len = (8 - (T % 8)) % 8
    if pad_len > 0:
        mel = F.pad(mel, (0, 0, 0, pad_len))
    return mel


def extract_features(wav: torch.Tensor) -> torch.Tensor:
    mel = mel_transform(wav.unsqueeze(0))
    mel = amp_to_db(mel)
    if mel.dim == 4:
        mel = mel.squeeze(1)

    mel.transpose(1, 2)  # [B, T, N_MELS]
    mel = pad_time_dim(mel)
    return mel