|
|
from typing import Tuple |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
import torchaudio.transforms as transforms |
|
|
from torchaudio.compliance import kaldi |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
from einops import rearrange |
|
|
|
|
|
from timm.models.vision_transformer import VisionTransformer |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
|
|
|
|
|
|
class AudioMAEConfig(PretrainedConfig): |
|
|
model_type = "audiomae" |
|
|
|
|
|
def __init__(self, |
|
|
img_size:Tuple[int,int]=(1024,128), |
|
|
in_chans:int=1, |
|
|
num_classes:int=0, |
|
|
**kwargs,): |
|
|
super().__init__(**kwargs) |
|
|
self.img_size = img_size |
|
|
self.in_chans = in_chans |
|
|
self.num_classes = num_classes |
|
|
|
|
|
|
|
|
class AudioMAEEncoder(VisionTransformer): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
""" |
|
|
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper |
|
|
- AudoMAE accepts a mono-channel (i.e., in_chans=1) |
|
|
""" |
|
|
self.MEAN = -4.2677393 |
|
|
self.STD = 4.5689974 |
|
|
|
|
|
def load_wav_file(self, file_path:str): |
|
|
""" |
|
|
to use this, `torchaudio` and `ffmpeg` must be installed |
|
|
- `ffmpeg` version must be >=4.4 and <7. |
|
|
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1` |
|
|
""" |
|
|
audio, sample_rate = torchaudio.load(file_path) |
|
|
|
|
|
|
|
|
|
|
|
if audio.shape[0] > 1: |
|
|
|
|
|
|
|
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
if sample_rate != 16000: |
|
|
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
|
audio = converter(audio) |
|
|
|
|
|
|
|
|
audio_len = audio.shape[-1] / 16000 |
|
|
if audio_len > 10.0: |
|
|
print(f'{file_path} has audio length {audio_len}s, which is longer than 10s. It will be segmented into 10-second windows with a 5-second stride (50% overlap)') |
|
|
|
|
|
window_size = 160000 |
|
|
stride = 80000 |
|
|
remainder = (audio.shape[-1] - window_size) % stride |
|
|
if remainder != 0: |
|
|
padding = (0, stride - remainder) |
|
|
audio = torch.nn.functional.pad(audio, padding, "constant", 0) |
|
|
audio = audio.squeeze(0).unfold(0, window_size, stride) |
|
|
return audio |
|
|
else: |
|
|
return audio |
|
|
|
|
|
def waveform_to_melspec(self, waveform:torch.FloatTensor): |
|
|
|
|
|
|
|
|
mel_spectrogram = kaldi.fbank( |
|
|
waveform, |
|
|
num_mel_bins=128, |
|
|
frame_length=25.0, |
|
|
frame_shift=10.0, |
|
|
htk_compat=True, |
|
|
use_energy=False, |
|
|
sample_frequency=16000, |
|
|
window_type='hanning', |
|
|
dither=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
expected_frames = 1024 |
|
|
current_frames = mel_spectrogram.shape[0] |
|
|
if current_frames > expected_frames: |
|
|
mel_spectrogram = mel_spectrogram[:expected_frames, :] |
|
|
elif current_frames < expected_frames: |
|
|
padding = expected_frames - current_frames |
|
|
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, |
|
|
0, padding), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) |
|
|
return mel_spectrogram |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode(self, file_path:str, device, reshape=True): |
|
|
self.eval() |
|
|
|
|
|
waveform = self.load_wav_file(file_path) |
|
|
|
|
|
zs = [] |
|
|
for i in range(waveform.shape[0]): |
|
|
melspec = self.waveform_to_melspec(waveform[i].unsqueeze(0)) |
|
|
melspec = melspec[None,None,:,:] |
|
|
z = self.forward_features(melspec.to(device)).cpu() |
|
|
z = z[:,1:,:] |
|
|
b, c, w, h = melspec.shape |
|
|
if reshape: |
|
|
wprime = round(w / self.patch_embed.patch_size[0]) |
|
|
hprime = round(h / self.patch_embed.patch_size[1]) |
|
|
|
|
|
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) |
|
|
else: |
|
|
|
|
|
z = z.transpose(1, 2) |
|
|
|
|
|
|
|
|
z = z[0] |
|
|
zs.append(z) |
|
|
z = torch.stack(zs, dim=0) |
|
|
z = z.mean(dim=0).unsqueeze(0) |
|
|
return z |
|
|
|
|
|
|
|
|
class PretrainedAudioMAEEncoder(PreTrainedModel): |
|
|
config_class = AudioMAEConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes) |
|
|
|
|
|
def forward(self, file_path:str, reshape=True): |
|
|
device = self.device |
|
|
return self.encoder.encode(file_path, device, reshape) |
|
|
|