MusicFMInference / modeling_musicfm.py
tky823's picture
Upload modeling_musicfm.py with huggingface_hub
e28e259 verified
# MIT License
#
# Copyright 2023 ByteDance Inc.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”),
# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
import random
import torch
import torchaudio
from einops import rearrange
from torch import einsum, nn
from torch.nn.common_types import _size_2_t
from transformers import PreTrainedModel
from .configuration_musicfm import MusicFMConfig, MusicFMInferenceConfig
class MusicFM25Hz(PreTrainedModel):
config_class = MusicFMConfig
def __init__(self, config: MusicFMConfig) -> None:
super().__init__(config)
# global variables
self.num_codebooks = config.num_codebooks
self.codebook_dim = config.codebook_dim
self.codebook_size = config.codebook_size
self.features = config.features
self.hop_length = config.hop_length
self.n_mels = config.n_mels
self.conv_dim = config.conv_dim
self.encoder_dim = config.encoder_dim
self.encoder_depth = config.encoder_depth
self.mask_hop = config.mask_hop
self.mask_prob = config.mask_prob
self.is_flash = config.is_flash
self.stat = config.stat
# feature extractor
self.preprocessor_melspec_2048 = MelSTFT(
n_fft=2048, hop_length=self.hop_length, is_db=True
)
# random quantizer
seed = 142
for feature in self.features:
for i in range(self.num_codebooks):
setattr(
self,
f"quantizer_{feature}_{i}",
RandomProjectionQuantizer(
self.n_mels * 4,
self.codebook_dim,
self.codebook_size,
seed=seed + i,
),
)
# two residual convolution layers + one projection layer
self.conv = Conv2dSubsampling(
1, self.conv_dim, self.encoder_dim, strides=[2, 2], n_bands=self.n_mels
)
# Conformer
if config.is_flash:
from .flash_conformer import (
Wav2Vec2ConformerConfig,
Wav2Vec2ConformerEncoder,
)
else:
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerConfig,
Wav2Vec2ConformerEncoder,
)
conformer_config = Wav2Vec2ConformerConfig.from_pretrained(
"facebook/wav2vec2-conformer-rope-large-960h-ft"
)
conformer_config.num_hidden_layers = self.encoder_depth
conformer_config.hidden_size = self.encoder_dim
self.conformer = Wav2Vec2ConformerEncoder(conformer_config)
# projection
self.linear = nn.Linear(self.encoder_dim, self.codebook_size)
# loss function
self.loss = nn.CrossEntropyLoss()
# cls token (used for sequence classification)
random.seed(seed)
self.cls_token = nn.Parameter(torch.randn(self.encoder_dim))
def masking(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.LongTensor]:
"""random masking of 400ms with given probability"""
mx = x.clone()
b, t = mx.shape
len_masking_raw = int(24000 * self.mask_hop)
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop)
# get random mask indices
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob
time_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_raw, dim=1)
)
token_domain_masked_indices = torch.nonzero(
start_indices.repeat_interleave(len_masking_token, dim=1)
)
# mask with random values
masking_noise = (
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1
) # 0 mean 0.1 std
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device)
return mx, token_domain_masked_indices
@torch.no_grad()
def preprocessing(
self, x: torch.Tensor, features: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""extract classic audio features"""
# check precision
if x.dtype == torch.float16:
precision = 16
else:
precision = 32
out = {}
for key in features:
layer = getattr(self, "preprocessor_%s" % key)
out[key] = layer.float()(x.float())[..., :-1]
if precision == 16:
out[key] = out[key].half()
return out
def encoder(self, x: torch.Tensor) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
"""2-layer conv + w2v-conformer"""
x = self.conv(x)
out = self.conformer(x, output_hidden_states=True)
hidden_emb = out["hidden_states"]
last_emb = out["last_hidden_state"]
logits = self.linear(last_emb)
logits = {
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size]
for i, key in enumerate(self.features)
}
return logits, hidden_emb
@torch.no_grad()
def normalize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""normalize the input audio to have zero mean unit variance"""
for key in x.keys():
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key]
return x
@torch.no_grad()
def rearrange(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""rearrange the batch to flatten every 4 steps"""
for key in x.keys():
if key == "chromagram":
x[key] = rearrange(x[key], "b f t -> b t f")
else:
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4)
return x
@torch.no_grad()
def tokenize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
out = {}
for key in x.keys():
layer = getattr(self, "quantizer_%s" % key)
out[key] = layer(x[key])
return out
def get_targets(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
x = self.preprocessing(x, features=self.features)
x = self.normalize(x)
x = self.rearrange(x)
target_tokens = self.tokenize(x)
return target_tokens
def get_predictions(
self, x: torch.Tensor
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
# preprocessing
x = self.preprocessing(x, features=["melspec_2048"])
x = self.normalize(x)
# encoding
logits, hidden_emb = self.encoder(x["melspec_2048"])
return logits, hidden_emb
def get_latent(self, x: torch.Tensor, layer_ix: int = 12) -> torch.Tensor:
_, hidden_states = self.get_predictions(x)
emb = hidden_states[layer_ix]
return emb
def get_loss(
self,
logits: dict[str, torch.Tensor],
target_tokens: dict[str, torch.Tensor],
masked_indices: torch.LongTensor,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
losses = {}
accuracies = {}
for key in logits.keys():
masked_logits = logits[key][tuple(masked_indices.t())]
masked_tokens = target_tokens[key][tuple(masked_indices.t())]
losses[key] = self.loss(masked_logits, masked_tokens)
accuracies[key] = (
torch.sum(masked_logits.argmax(-1) == masked_tokens)
/ masked_tokens.numel()
)
return losses, accuracies
def forward(
self, x: torch.Tensor
) -> tuple[
dict[str, torch.Tensor],
torch.Tensor,
dict[str, torch.Tensor],
dict[str, torch.Tensor],
]:
# get target feature tokens
target_tokens = self.get_targets(x)
# masking
x, masked_indices = self.masking(x)
# forward
logits, hidden_emb = self.get_predictions(x)
# get loss
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices)
return logits, hidden_emb, losses, accuracies
class MusicFM25HzInference(MusicFM25Hz):
config_class = MusicFMInferenceConfig
def __init__(self, config: MusicFMInferenceConfig) -> None:
super().__init__(config)
self.layer_index = config.layer_index
def forward(self, x: torch.Tensor) -> torch.Tensor:
layer_index = self.layer_index
# forward
_, hidden_emb = self.get_predictions(x)
outputs = hidden_emb[layer_index]
return outputs
class MelSTFT(nn.Module):
def __init__(
self,
sample_rate: int = 24000,
n_fft: int = 2048,
hop_length: int = 240,
n_mels: int = 128,
is_db: bool = False,
) -> None:
super().__init__()
# spectrogram
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
)
# amplitude to decibel
self.is_db = is_db
if is_db:
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
if self.is_db:
return self.amplitude_to_db(self.mel_stft(waveform))
else:
return self.mel_stft(waveform)
class RandomProjectionQuantizer(nn.Module):
"""
Random projection and codebook lookup module
Some code is borrowed from:
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py
But I did normalization using pre-computed global mean & variance instead of using layer norm.
"""
def __init__(
self,
input_dim: int,
codebook_dim: int,
codebook_size: int,
seed: int = 142,
) -> None:
super().__init__()
# random seed
torch.manual_seed(seed)
# randomly initialized projection
random_projection = torch.empty(input_dim, codebook_dim)
nn.init.xavier_normal_(random_projection)
self.register_buffer("random_projection", random_projection)
# randomly initialized codebook
codebook = torch.empty(codebook_size, codebook_dim)
nn.init.normal_(codebook)
self.register_buffer("codebook", codebook)
def codebook_lookup(self, x: torch.Tensor) -> torch.Tensor:
# reshape
b = x.shape[0]
x = rearrange(x, "b n e -> (b n) e")
# L2 normalization
normalized_x = nn.functional.normalize(x, dim=1, p=2)
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2)
# compute distances
distances = torch.cdist(normalized_codebook, normalized_x)
# get nearest
nearest_indices = torch.argmin(distances, dim=0)
# reshape
xq = rearrange(nearest_indices, "(b n) -> b n", b=b)
return xq
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# always eval
self.eval()
# random projection [batch, length, input_dim] -> [batch, length, codebook_dim]
x = einsum("b n d, d e -> b n e", x, self.random_projection)
# codebook lookup
xq = self.codebook_lookup(x)
return xq
class Res2dModule(nn.Module):
def __init__(self, idim: int, odim: int, stride: _size_2_t = (2, 2)) -> None:
super().__init__()
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
self.bn1 = nn.BatchNorm2d(odim)
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1)
self.bn2 = nn.BatchNorm2d(odim)
self.relu = nn.ReLU()
# residual
self.diff = False
if (idim != odim) or (stride[0] > 1):
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride)
self.bn3 = nn.BatchNorm2d(odim)
self.diff = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x)))))
if self.diff:
x = self.bn3(self.conv3(x))
out = x + out
out = self.relu(out)
return out
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
hdim (int): Hidden dimension.
odim (int): Output dimension.
strides (list): Sizes of strides.
n_bands (int): Number of frequency bands.
"""
def __init__(
self,
idim: int,
hdim: int,
odim: int,
strides: list[int] = [2, 2],
n_bands: int = 64,
) -> None:
"""Construct an Conv2dSubsampling object."""
super().__init__()
self.conv = nn.Sequential(
Res2dModule(idim, hdim, (2, strides[0])),
Res2dModule(hdim, hdim, (2, strides[1])),
)
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, idim, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
"""
if x.dim() == 3:
x = x.unsqueeze(1) # (b, c, f, t)
x = self.conv(x)
x = rearrange(x, "b c f t -> b t (c f)")
x = self.linear(x)
return x