|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
|
|
|
|
import torch |
|
|
import torchaudio |
|
|
from einops import rearrange |
|
|
from torch import einsum, nn |
|
|
from torch.nn.common_types import _size_2_t |
|
|
from transformers import PreTrainedModel |
|
|
|
|
|
from .configuration_musicfm import MusicFMConfig, MusicFMInferenceConfig |
|
|
|
|
|
|
|
|
class MusicFM25Hz(PreTrainedModel): |
|
|
config_class = MusicFMConfig |
|
|
|
|
|
def __init__(self, config: MusicFMConfig) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.num_codebooks = config.num_codebooks |
|
|
self.codebook_dim = config.codebook_dim |
|
|
self.codebook_size = config.codebook_size |
|
|
self.features = config.features |
|
|
self.hop_length = config.hop_length |
|
|
self.n_mels = config.n_mels |
|
|
self.conv_dim = config.conv_dim |
|
|
self.encoder_dim = config.encoder_dim |
|
|
self.encoder_depth = config.encoder_depth |
|
|
self.mask_hop = config.mask_hop |
|
|
self.mask_prob = config.mask_prob |
|
|
self.is_flash = config.is_flash |
|
|
self.stat = config.stat |
|
|
|
|
|
|
|
|
self.preprocessor_melspec_2048 = MelSTFT( |
|
|
n_fft=2048, hop_length=self.hop_length, is_db=True |
|
|
) |
|
|
|
|
|
|
|
|
seed = 142 |
|
|
for feature in self.features: |
|
|
for i in range(self.num_codebooks): |
|
|
setattr( |
|
|
self, |
|
|
f"quantizer_{feature}_{i}", |
|
|
RandomProjectionQuantizer( |
|
|
self.n_mels * 4, |
|
|
self.codebook_dim, |
|
|
self.codebook_size, |
|
|
seed=seed + i, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
self.conv = Conv2dSubsampling( |
|
|
1, self.conv_dim, self.encoder_dim, strides=[2, 2], n_bands=self.n_mels |
|
|
) |
|
|
|
|
|
|
|
|
if config.is_flash: |
|
|
from .flash_conformer import ( |
|
|
Wav2Vec2ConformerConfig, |
|
|
Wav2Vec2ConformerEncoder, |
|
|
) |
|
|
else: |
|
|
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
|
|
Wav2Vec2ConformerConfig, |
|
|
Wav2Vec2ConformerEncoder, |
|
|
) |
|
|
|
|
|
conformer_config = Wav2Vec2ConformerConfig.from_pretrained( |
|
|
"facebook/wav2vec2-conformer-rope-large-960h-ft" |
|
|
) |
|
|
conformer_config.num_hidden_layers = self.encoder_depth |
|
|
conformer_config.hidden_size = self.encoder_dim |
|
|
self.conformer = Wav2Vec2ConformerEncoder(conformer_config) |
|
|
|
|
|
|
|
|
self.linear = nn.Linear(self.encoder_dim, self.codebook_size) |
|
|
|
|
|
|
|
|
self.loss = nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
random.seed(seed) |
|
|
self.cls_token = nn.Parameter(torch.randn(self.encoder_dim)) |
|
|
|
|
|
def masking(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.LongTensor]: |
|
|
"""random masking of 400ms with given probability""" |
|
|
mx = x.clone() |
|
|
b, t = mx.shape |
|
|
len_masking_raw = int(24000 * self.mask_hop) |
|
|
len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) |
|
|
|
|
|
|
|
|
start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob |
|
|
time_domain_masked_indices = torch.nonzero( |
|
|
start_indices.repeat_interleave(len_masking_raw, dim=1) |
|
|
) |
|
|
token_domain_masked_indices = torch.nonzero( |
|
|
start_indices.repeat_interleave(len_masking_token, dim=1) |
|
|
) |
|
|
|
|
|
|
|
|
masking_noise = ( |
|
|
torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1 |
|
|
) |
|
|
mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device) |
|
|
|
|
|
return mx, token_domain_masked_indices |
|
|
|
|
|
@torch.no_grad() |
|
|
def preprocessing( |
|
|
self, x: torch.Tensor, features: dict[str, torch.Tensor] |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""extract classic audio features""" |
|
|
|
|
|
if x.dtype == torch.float16: |
|
|
precision = 16 |
|
|
else: |
|
|
precision = 32 |
|
|
|
|
|
out = {} |
|
|
for key in features: |
|
|
layer = getattr(self, "preprocessor_%s" % key) |
|
|
out[key] = layer.float()(x.float())[..., :-1] |
|
|
if precision == 16: |
|
|
out[key] = out[key].half() |
|
|
return out |
|
|
|
|
|
def encoder(self, x: torch.Tensor) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
"""2-layer conv + w2v-conformer""" |
|
|
x = self.conv(x) |
|
|
out = self.conformer(x, output_hidden_states=True) |
|
|
hidden_emb = out["hidden_states"] |
|
|
last_emb = out["last_hidden_state"] |
|
|
logits = self.linear(last_emb) |
|
|
logits = { |
|
|
key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size] |
|
|
for i, key in enumerate(self.features) |
|
|
} |
|
|
return logits, hidden_emb |
|
|
|
|
|
@torch.no_grad() |
|
|
def normalize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
|
"""normalize the input audio to have zero mean unit variance""" |
|
|
for key in x.keys(): |
|
|
x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] |
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def rearrange(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
|
"""rearrange the batch to flatten every 4 steps""" |
|
|
for key in x.keys(): |
|
|
if key == "chromagram": |
|
|
x[key] = rearrange(x[key], "b f t -> b t f") |
|
|
else: |
|
|
x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) |
|
|
|
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def tokenize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
|
|
out = {} |
|
|
for key in x.keys(): |
|
|
layer = getattr(self, "quantizer_%s" % key) |
|
|
out[key] = layer(x[key]) |
|
|
return out |
|
|
|
|
|
def get_targets(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
|
|
x = self.preprocessing(x, features=self.features) |
|
|
x = self.normalize(x) |
|
|
x = self.rearrange(x) |
|
|
target_tokens = self.tokenize(x) |
|
|
|
|
|
return target_tokens |
|
|
|
|
|
def get_predictions( |
|
|
self, x: torch.Tensor |
|
|
) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
x = self.preprocessing(x, features=["melspec_2048"]) |
|
|
x = self.normalize(x) |
|
|
|
|
|
|
|
|
logits, hidden_emb = self.encoder(x["melspec_2048"]) |
|
|
|
|
|
return logits, hidden_emb |
|
|
|
|
|
def get_latent(self, x: torch.Tensor, layer_ix: int = 12) -> torch.Tensor: |
|
|
_, hidden_states = self.get_predictions(x) |
|
|
emb = hidden_states[layer_ix] |
|
|
return emb |
|
|
|
|
|
def get_loss( |
|
|
self, |
|
|
logits: dict[str, torch.Tensor], |
|
|
target_tokens: dict[str, torch.Tensor], |
|
|
masked_indices: torch.LongTensor, |
|
|
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: |
|
|
losses = {} |
|
|
accuracies = {} |
|
|
for key in logits.keys(): |
|
|
masked_logits = logits[key][tuple(masked_indices.t())] |
|
|
masked_tokens = target_tokens[key][tuple(masked_indices.t())] |
|
|
losses[key] = self.loss(masked_logits, masked_tokens) |
|
|
accuracies[key] = ( |
|
|
torch.sum(masked_logits.argmax(-1) == masked_tokens) |
|
|
/ masked_tokens.numel() |
|
|
) |
|
|
return losses, accuracies |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor |
|
|
) -> tuple[ |
|
|
dict[str, torch.Tensor], |
|
|
torch.Tensor, |
|
|
dict[str, torch.Tensor], |
|
|
dict[str, torch.Tensor], |
|
|
]: |
|
|
|
|
|
target_tokens = self.get_targets(x) |
|
|
|
|
|
|
|
|
x, masked_indices = self.masking(x) |
|
|
|
|
|
|
|
|
logits, hidden_emb = self.get_predictions(x) |
|
|
|
|
|
|
|
|
losses, accuracies = self.get_loss(logits, target_tokens, masked_indices) |
|
|
|
|
|
return logits, hidden_emb, losses, accuracies |
|
|
|
|
|
|
|
|
class MusicFM25HzInference(MusicFM25Hz): |
|
|
config_class = MusicFMInferenceConfig |
|
|
|
|
|
def __init__(self, config: MusicFMInferenceConfig) -> None: |
|
|
super().__init__(config) |
|
|
|
|
|
self.layer_index = config.layer_index |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
layer_index = self.layer_index |
|
|
|
|
|
_, hidden_emb = self.get_predictions(x) |
|
|
|
|
|
outputs = hidden_emb[layer_index] |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class MelSTFT(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
sample_rate: int = 24000, |
|
|
n_fft: int = 2048, |
|
|
hop_length: int = 240, |
|
|
n_mels: int = 128, |
|
|
is_db: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.mel_stft = torchaudio.transforms.MelSpectrogram( |
|
|
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels |
|
|
) |
|
|
|
|
|
|
|
|
self.is_db = is_db |
|
|
if is_db: |
|
|
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() |
|
|
|
|
|
def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
if self.is_db: |
|
|
return self.amplitude_to_db(self.mel_stft(waveform)) |
|
|
else: |
|
|
return self.mel_stft(waveform) |
|
|
|
|
|
|
|
|
class RandomProjectionQuantizer(nn.Module): |
|
|
""" |
|
|
Random projection and codebook lookup module |
|
|
|
|
|
Some code is borrowed from: |
|
|
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py |
|
|
But I did normalization using pre-computed global mean & variance instead of using layer norm. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
codebook_dim: int, |
|
|
codebook_size: int, |
|
|
seed: int = 142, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
|
|
|
random_projection = torch.empty(input_dim, codebook_dim) |
|
|
nn.init.xavier_normal_(random_projection) |
|
|
self.register_buffer("random_projection", random_projection) |
|
|
|
|
|
|
|
|
codebook = torch.empty(codebook_size, codebook_dim) |
|
|
nn.init.normal_(codebook) |
|
|
self.register_buffer("codebook", codebook) |
|
|
|
|
|
def codebook_lookup(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
b = x.shape[0] |
|
|
x = rearrange(x, "b n e -> (b n) e") |
|
|
|
|
|
|
|
|
normalized_x = nn.functional.normalize(x, dim=1, p=2) |
|
|
normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) |
|
|
|
|
|
|
|
|
distances = torch.cdist(normalized_codebook, normalized_x) |
|
|
|
|
|
|
|
|
nearest_indices = torch.argmin(distances, dim=0) |
|
|
|
|
|
|
|
|
xq = rearrange(nearest_indices, "(b n) -> b n", b=b) |
|
|
|
|
|
return xq |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
self.eval() |
|
|
|
|
|
|
|
|
x = einsum("b n d, d e -> b n e", x, self.random_projection) |
|
|
|
|
|
|
|
|
xq = self.codebook_lookup(x) |
|
|
|
|
|
return xq |
|
|
|
|
|
|
|
|
class Res2dModule(nn.Module): |
|
|
def __init__(self, idim: int, odim: int, stride: _size_2_t = (2, 2)) -> None: |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
|
|
self.bn1 = nn.BatchNorm2d(odim) |
|
|
self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(odim) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
|
|
|
self.diff = False |
|
|
if (idim != odim) or (stride[0] > 1): |
|
|
self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
|
|
self.bn3 = nn.BatchNorm2d(odim) |
|
|
self.diff = True |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) |
|
|
if self.diff: |
|
|
x = self.bn3(self.conv3(x)) |
|
|
out = x + out |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class Conv2dSubsampling(nn.Module): |
|
|
"""Convolutional 2D subsampling (to 1/4 length). |
|
|
|
|
|
Args: |
|
|
idim (int): Input dimension. |
|
|
hdim (int): Hidden dimension. |
|
|
odim (int): Output dimension. |
|
|
strides (list): Sizes of strides. |
|
|
n_bands (int): Number of frequency bands. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
idim: int, |
|
|
hdim: int, |
|
|
odim: int, |
|
|
strides: list[int] = [2, 2], |
|
|
n_bands: int = 64, |
|
|
) -> None: |
|
|
"""Construct an Conv2dSubsampling object.""" |
|
|
super().__init__() |
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
Res2dModule(idim, hdim, (2, strides[0])), |
|
|
Res2dModule(hdim, hdim, (2, strides[1])), |
|
|
) |
|
|
self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Subsample x. |
|
|
|
|
|
Args: |
|
|
x (torch.Tensor): Input tensor (#batch, idim, time). |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Subsampled tensor (#batch, time', odim), |
|
|
where time' = time // 4. |
|
|
""" |
|
|
|
|
|
if x.dim() == 3: |
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
x = self.conv(x) |
|
|
x = rearrange(x, "b c f t -> b t (c f)") |
|
|
x = self.linear(x) |
|
|
|
|
|
return x |
|
|
|