# MIT License # # Copyright 2023 ByteDance Inc. # # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), # to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, # and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS # IN THE SOFTWARE. import random import torch import torchaudio from einops import rearrange from torch import einsum, nn from torch.nn.common_types import _size_2_t from transformers import PreTrainedModel from .configuration_musicfm import MusicFMConfig, MusicFMInferenceConfig class MusicFM25Hz(PreTrainedModel): config_class = MusicFMConfig def __init__(self, config: MusicFMConfig) -> None: super().__init__(config) # global variables self.num_codebooks = config.num_codebooks self.codebook_dim = config.codebook_dim self.codebook_size = config.codebook_size self.features = config.features self.hop_length = config.hop_length self.n_mels = config.n_mels self.conv_dim = config.conv_dim self.encoder_dim = config.encoder_dim self.encoder_depth = config.encoder_depth self.mask_hop = config.mask_hop self.mask_prob = config.mask_prob self.is_flash = config.is_flash self.stat = config.stat # feature extractor self.preprocessor_melspec_2048 = MelSTFT( n_fft=2048, hop_length=self.hop_length, is_db=True ) # random quantizer seed = 142 for feature in self.features: for i in range(self.num_codebooks): setattr( self, f"quantizer_{feature}_{i}", RandomProjectionQuantizer( self.n_mels * 4, self.codebook_dim, self.codebook_size, seed=seed + i, ), ) # two residual convolution layers + one projection layer self.conv = Conv2dSubsampling( 1, self.conv_dim, self.encoder_dim, strides=[2, 2], n_bands=self.n_mels ) # Conformer if config.is_flash: from .flash_conformer import ( Wav2Vec2ConformerConfig, Wav2Vec2ConformerEncoder, ) else: from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( Wav2Vec2ConformerConfig, Wav2Vec2ConformerEncoder, ) conformer_config = Wav2Vec2ConformerConfig.from_pretrained( "facebook/wav2vec2-conformer-rope-large-960h-ft" ) conformer_config.num_hidden_layers = self.encoder_depth conformer_config.hidden_size = self.encoder_dim self.conformer = Wav2Vec2ConformerEncoder(conformer_config) # projection self.linear = nn.Linear(self.encoder_dim, self.codebook_size) # loss function self.loss = nn.CrossEntropyLoss() # cls token (used for sequence classification) random.seed(seed) self.cls_token = nn.Parameter(torch.randn(self.encoder_dim)) def masking(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.LongTensor]: """random masking of 400ms with given probability""" mx = x.clone() b, t = mx.shape len_masking_raw = int(24000 * self.mask_hop) len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # get random mask indices start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob time_domain_masked_indices = torch.nonzero( start_indices.repeat_interleave(len_masking_raw, dim=1) ) token_domain_masked_indices = torch.nonzero( start_indices.repeat_interleave(len_masking_token, dim=1) ) # mask with random values masking_noise = ( torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1 ) # 0 mean 0.1 std mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device) return mx, token_domain_masked_indices @torch.no_grad() def preprocessing( self, x: torch.Tensor, features: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: """extract classic audio features""" # check precision if x.dtype == torch.float16: precision = 16 else: precision = 32 out = {} for key in features: layer = getattr(self, "preprocessor_%s" % key) out[key] = layer.float()(x.float())[..., :-1] if precision == 16: out[key] = out[key].half() return out def encoder(self, x: torch.Tensor) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """2-layer conv + w2v-conformer""" x = self.conv(x) out = self.conformer(x, output_hidden_states=True) hidden_emb = out["hidden_states"] last_emb = out["last_hidden_state"] logits = self.linear(last_emb) logits = { key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size] for i, key in enumerate(self.features) } return logits, hidden_emb @torch.no_grad() def normalize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """normalize the input audio to have zero mean unit variance""" for key in x.keys(): x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] return x @torch.no_grad() def rearrange(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """rearrange the batch to flatten every 4 steps""" for key in x.keys(): if key == "chromagram": x[key] = rearrange(x[key], "b f t -> b t f") else: x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) return x @torch.no_grad() def tokenize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: out = {} for key in x.keys(): layer = getattr(self, "quantizer_%s" % key) out[key] = layer(x[key]) return out def get_targets(self, x: torch.Tensor) -> dict[str, torch.Tensor]: x = self.preprocessing(x, features=self.features) x = self.normalize(x) x = self.rearrange(x) target_tokens = self.tokenize(x) return target_tokens def get_predictions( self, x: torch.Tensor ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: # preprocessing x = self.preprocessing(x, features=["melspec_2048"]) x = self.normalize(x) # encoding logits, hidden_emb = self.encoder(x["melspec_2048"]) return logits, hidden_emb def get_latent(self, x: torch.Tensor, layer_ix: int = 12) -> torch.Tensor: _, hidden_states = self.get_predictions(x) emb = hidden_states[layer_ix] return emb def get_loss( self, logits: dict[str, torch.Tensor], target_tokens: dict[str, torch.Tensor], masked_indices: torch.LongTensor, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: losses = {} accuracies = {} for key in logits.keys(): masked_logits = logits[key][tuple(masked_indices.t())] masked_tokens = target_tokens[key][tuple(masked_indices.t())] losses[key] = self.loss(masked_logits, masked_tokens) accuracies[key] = ( torch.sum(masked_logits.argmax(-1) == masked_tokens) / masked_tokens.numel() ) return losses, accuracies def forward( self, x: torch.Tensor ) -> tuple[ dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor], dict[str, torch.Tensor], ]: # get target feature tokens target_tokens = self.get_targets(x) # masking x, masked_indices = self.masking(x) # forward logits, hidden_emb = self.get_predictions(x) # get loss losses, accuracies = self.get_loss(logits, target_tokens, masked_indices) return logits, hidden_emb, losses, accuracies class MusicFM25HzInference(MusicFM25Hz): config_class = MusicFMInferenceConfig def __init__(self, config: MusicFMInferenceConfig) -> None: super().__init__(config) self.layer_index = config.layer_index def forward(self, x: torch.Tensor) -> torch.Tensor: layer_index = self.layer_index # forward _, hidden_emb = self.get_predictions(x) outputs = hidden_emb[layer_index] return outputs class MelSTFT(nn.Module): def __init__( self, sample_rate: int = 24000, n_fft: int = 2048, hop_length: int = 240, n_mels: int = 128, is_db: bool = False, ) -> None: super().__init__() # spectrogram self.mel_stft = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels ) # amplitude to decibel self.is_db = is_db if is_db: self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() def forward(self, waveform: torch.Tensor) -> torch.Tensor: if self.is_db: return self.amplitude_to_db(self.mel_stft(waveform)) else: return self.mel_stft(waveform) class RandomProjectionQuantizer(nn.Module): """ Random projection and codebook lookup module Some code is borrowed from: https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py But I did normalization using pre-computed global mean & variance instead of using layer norm. """ def __init__( self, input_dim: int, codebook_dim: int, codebook_size: int, seed: int = 142, ) -> None: super().__init__() # random seed torch.manual_seed(seed) # randomly initialized projection random_projection = torch.empty(input_dim, codebook_dim) nn.init.xavier_normal_(random_projection) self.register_buffer("random_projection", random_projection) # randomly initialized codebook codebook = torch.empty(codebook_size, codebook_dim) nn.init.normal_(codebook) self.register_buffer("codebook", codebook) def codebook_lookup(self, x: torch.Tensor) -> torch.Tensor: # reshape b = x.shape[0] x = rearrange(x, "b n e -> (b n) e") # L2 normalization normalized_x = nn.functional.normalize(x, dim=1, p=2) normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) # compute distances distances = torch.cdist(normalized_codebook, normalized_x) # get nearest nearest_indices = torch.argmin(distances, dim=0) # reshape xq = rearrange(nearest_indices, "(b n) -> b n", b=b) return xq @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: # always eval self.eval() # random projection [batch, length, input_dim] -> [batch, length, codebook_dim] x = einsum("b n d, d e -> b n e", x, self.random_projection) # codebook lookup xq = self.codebook_lookup(x) return xq class Res2dModule(nn.Module): def __init__(self, idim: int, odim: int, stride: _size_2_t = (2, 2)) -> None: super().__init__() self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) self.bn1 = nn.BatchNorm2d(odim) self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) self.bn2 = nn.BatchNorm2d(odim) self.relu = nn.ReLU() # residual self.diff = False if (idim != odim) or (stride[0] > 1): self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) self.bn3 = nn.BatchNorm2d(odim) self.diff = True def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) if self.diff: x = self.bn3(self.conv3(x)) out = x + out out = self.relu(out) return out class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). Args: idim (int): Input dimension. hdim (int): Hidden dimension. odim (int): Output dimension. strides (list): Sizes of strides. n_bands (int): Number of frequency bands. """ def __init__( self, idim: int, hdim: int, odim: int, strides: list[int] = [2, 2], n_bands: int = 64, ) -> None: """Construct an Conv2dSubsampling object.""" super().__init__() self.conv = nn.Sequential( Res2dModule(idim, hdim, (2, strides[0])), Res2dModule(hdim, hdim, (2, strides[1])), ) self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. Args: x (torch.Tensor): Input tensor (#batch, idim, time). Returns: torch.Tensor: Subsampled tensor (#batch, time', odim), where time' = time // 4. """ if x.dim() == 3: x = x.unsqueeze(1) # (b, c, f, t) x = self.conv(x) x = rearrange(x, "b c f t -> b t (c f)") x = self.linear(x) return x