audiomae / model.py
jhansss's picture
Update model.py
0ce5642 verified
from typing import Tuple
import torch
import torchaudio
import torchaudio.transforms as transforms
from torchaudio.compliance import kaldi
from transformers import PretrainedConfig
from einops import rearrange
from timm.models.vision_transformer import VisionTransformer
from transformers import PreTrainedModel
# it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
class AudioMAEConfig(PretrainedConfig):
model_type = "audiomae"
def __init__(self,
img_size:Tuple[int,int]=(1024,128),
in_chans:int=1,
num_classes:int=0,
**kwargs,):
super().__init__(**kwargs)
self.img_size = img_size
self.in_chans = in_chans
self.num_classes = num_classes
class AudioMAEEncoder(VisionTransformer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
"""
- img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
- AudoMAE accepts a mono-channel (i.e., in_chans=1)
"""
self.MEAN = -4.2677393 # written on the paper
self.STD = 4.5689974 # written on the paper
def load_wav_file(self, file_path:str):
"""
to use this, `torchaudio` and `ffmpeg` must be installed
- `ffmpeg` version must be >=4.4 and <7.
- `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
"""
audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
# Check if the audio has multiple channels
if audio.shape[0] > 1:
# Convert stereo audio to mono by taking the mean across channels
# AudioMAE accepts a mono channel.
audio = torch.mean(audio, dim=0, keepdim=True)
# resample the audio into 16khz
# AudioMAE accepts 16khz
if sample_rate != 16000:
converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
audio = converter(audio)
# length clip
audio_len = audio.shape[-1] / 16000
if audio_len > 10.0:
print(f'{file_path} has audio length {audio_len}s, which is longer than 10s. It will be segmented into 10-second windows with a 5-second stride (50% overlap)')
# current sampling rate is 16000, so 10 seconds is 160000 samples and 5 seconds is 80000 samples
window_size = 160000
stride = 80000
remainder = (audio.shape[-1] - window_size) % stride
if remainder != 0:
padding = (0, stride - remainder)
audio = torch.nn.functional.pad(audio, padding, "constant", 0)
audio = audio.squeeze(0).unfold(0, window_size, stride)
return audio
else:
return audio
def waveform_to_melspec(self, waveform:torch.FloatTensor):
# Compute the Mel spectrogram using Kaldi-compatible features
# the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
mel_spectrogram = kaldi.fbank(
waveform,
num_mel_bins=128,
frame_length=25.0,
frame_shift=10.0,
htk_compat=True,
use_energy=False,
sample_frequency=16000,
window_type='hanning',
dither=0.0
)
# Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
expected_frames = 1024 # as described in the paper
current_frames = mel_spectrogram.shape[0]
if current_frames > expected_frames:
mel_spectrogram = mel_spectrogram[:expected_frames, :]
elif current_frames < expected_frames:
padding = expected_frames - current_frames
mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
0, padding), # (left, right) for the 2nd dim
)
# scale
# as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
return mel_spectrogram
@torch.no_grad()
def encode(self, file_path:str, device, reshape=True):
self.eval()
waveform = self.load_wav_file(file_path)
zs = []
for i in range(waveform.shape[0]):
melspec = self.waveform_to_melspec(waveform[i].unsqueeze(0)) # (length, n_freq_bins) = (1024, 128)
melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
z = z[:,1:,:] # (b n d); remove [CLS], the class token
b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
if reshape:
wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
# reconstruct the temporal and freq dims
z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
else:
# put the patch dim in the back, i.e., (b n d)-> (b d n)
z = z.transpose(1, 2)
# remove the batch dim
z = z[0] # (d h' w') if reshape else (d n)
zs.append(z)
z = torch.stack(zs, dim=0)
z = z.mean(dim=0).unsqueeze(0)
return z # (d h' w') if reshape
class PretrainedAudioMAEEncoder(PreTrainedModel):
config_class = AudioMAEConfig
def __init__(self, config):
super().__init__(config)
self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
def forward(self, file_path:str, reshape=True):
device = self.device
return self.encoder.encode(file_path, device, reshape) # (d h' w')