| from typing import Tuple |
|
|
| import torch |
| import torchaudio |
| import torchaudio.transforms as transforms |
| from torchaudio.compliance import kaldi |
| from transformers import PretrainedConfig |
|
|
| from einops import rearrange |
|
|
| from timm.models.vision_transformer import VisionTransformer |
| from transformers import PreTrainedModel |
|
|
|
|
| |
| class AudioMAEConfig(PretrainedConfig): |
| model_type = "audiomae" |
|
|
| def __init__(self, |
| img_size:Tuple[int,int]=(1024,128), |
| in_chans:int=1, |
| num_classes:int=0, |
| **kwargs,): |
| super().__init__(**kwargs) |
| self.img_size = img_size |
| self.in_chans = in_chans |
| self.num_classes = num_classes |
|
|
|
|
| class AudioMAEEncoder(VisionTransformer): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| """ |
| - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper |
| - AudoMAE accepts a mono-channel (i.e., in_chans=1) |
| """ |
| self.MEAN = -4.2677393 |
| self.STD = 4.5689974 |
|
|
| def load_wav_file(self, file_path:str): |
| """ |
| to use this, `torchaudio` and `ffmpeg` must be installed |
| - `ffmpeg` version must be >=4.4 and <7. |
| - `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1` |
| """ |
| audio, sample_rate = torchaudio.load(file_path) |
|
|
|
|
| |
| if audio.shape[0] > 1: |
| |
| |
| audio = torch.mean(audio, dim=0, keepdim=True) |
|
|
| |
| |
| if sample_rate != 16000: |
| converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
| audio = converter(audio) |
|
|
| |
| audio_len = audio.shape[-1] / 16000 |
| if audio_len > 10.0: |
| print(f'{file_path} has audio length {audio_len}s, which is longer than 10s. It will be segmented into 10-second windows with a 5-second stride (50% overlap)') |
| |
| window_size = 160000 |
| stride = 80000 |
| remainder = (audio.shape[-1] - window_size) % stride |
| if remainder != 0: |
| padding = (0, stride - remainder) |
| audio = torch.nn.functional.pad(audio, padding, "constant", 0) |
| audio = audio.squeeze(0).unfold(0, window_size, stride) |
| return audio |
| else: |
| return audio |
| |
| def waveform_to_melspec(self, waveform:torch.FloatTensor): |
| |
| |
| mel_spectrogram = kaldi.fbank( |
| waveform, |
| num_mel_bins=128, |
| frame_length=25.0, |
| frame_shift=10.0, |
| htk_compat=True, |
| use_energy=False, |
| sample_frequency=16000, |
| window_type='hanning', |
| dither=0.0 |
| ) |
| |
| |
| expected_frames = 1024 |
| current_frames = mel_spectrogram.shape[0] |
| if current_frames > expected_frames: |
| mel_spectrogram = mel_spectrogram[:expected_frames, :] |
| elif current_frames < expected_frames: |
| padding = expected_frames - current_frames |
| mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, |
| 0, padding), |
| ) |
| |
| |
| |
| mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) |
| return mel_spectrogram |
|
|
| @torch.no_grad() |
| def encode(self, file_path:str, device, reshape=True): |
| self.eval() |
|
|
| waveform = self.load_wav_file(file_path) |
|
|
| zs = [] |
| for i in range(waveform.shape[0]): |
| melspec = self.waveform_to_melspec(waveform[i].unsqueeze(0)) |
| melspec = melspec[None,None,:,:] |
| z = self.forward_features(melspec.to(device)).cpu() |
| z = z[:,1:,:] |
| b, c, w, h = melspec.shape |
| if reshape: |
| wprime = round(w / self.patch_embed.patch_size[0]) |
| hprime = round(h / self.patch_embed.patch_size[1]) |
| |
| z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) |
| else: |
| |
| z = z.transpose(1, 2) |
|
|
| |
| z = z[0] |
| zs.append(z) |
| z = torch.stack(zs, dim=0) |
| z = z.mean(dim=0).unsqueeze(0) |
| return z |
|
|
|
|
| class PretrainedAudioMAEEncoder(PreTrainedModel): |
| config_class = AudioMAEConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes) |
| |
| def forward(self, file_path:str, reshape=True): |
| device = self.device |
| return self.encoder.encode(file_path, device, reshape) |
|
|