|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from .Patcher import PatchStrategy |
|
|
from .mwmae import MWMHABlock |
|
|
from .pos_embed import get_2d_sincos_pos_embed |
|
|
from .utils import PatchEmbed, create_pretrained_model, repeat_token |
|
|
|
|
|
from einops import rearrange |
|
|
|
|
|
from typing import List |
|
|
|
|
|
def conv3x3(in_channels, out_channels, stride=1): |
|
|
"3x3 convolution with padding" |
|
|
return nn.Conv2d( |
|
|
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False |
|
|
) |
|
|
|
|
|
|
|
|
class GRAMT(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
model_size="base", |
|
|
in_channels = 2, |
|
|
decoder_mlp_ratio: float = 4.0, |
|
|
decoder_depth: int = 8, |
|
|
decoder_num_heads: int = 8, |
|
|
decoder_embedding_dim: int = 512, |
|
|
decoder_window_sizes: List[int] = [2, 5, 10, 25, 50, 100, 0, 0], |
|
|
encoder_num_layers = 12, |
|
|
encoder_num_heads = 12, |
|
|
encoder_hidden_dim = 768, |
|
|
encoder_mlp_ratio = 4.0, |
|
|
encoder_dropout = 0.0, |
|
|
encoder_attention_dropout = 0.0, |
|
|
encoder_norm_layer_eps = 1e-6, |
|
|
patch_size = (16,8), |
|
|
frequency_stride = 16, |
|
|
time_stride = 8, |
|
|
input_length = 200, |
|
|
num_mel_bins = 128, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.in_channels = in_channels |
|
|
self.input_length = input_length |
|
|
|
|
|
self.patch_strategy = PatchStrategy(tstride = time_stride, |
|
|
tshape = patch_size[1], |
|
|
fstride = frequency_stride, |
|
|
fshape = patch_size[0], |
|
|
input_fdim = num_mel_bins, |
|
|
input_tdim = self.input_length) |
|
|
self.p_f_dim, self.p_t_dim = self.patch_strategy.get_patch_size() |
|
|
self.num_patches = self.p_f_dim * self.p_t_dim |
|
|
self.grid_size = (self.p_f_dim, self.p_t_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
( |
|
|
self.encoder, |
|
|
self.encoder_embedding_dim, |
|
|
) = create_pretrained_model(model_size, |
|
|
encoder_num_layers = encoder_num_layers, |
|
|
encoder_num_heads = encoder_num_heads, |
|
|
encoder_hidden_dim = encoder_hidden_dim, |
|
|
encoder_mlp_dim = int(encoder_hidden_dim * encoder_mlp_ratio), |
|
|
encoder_dropout = encoder_dropout, |
|
|
encoder_attention_dropout = encoder_attention_dropout, |
|
|
encoder_norm_layer_eps = encoder_norm_layer_eps) |
|
|
self.encoder_cls_token_num = 1 |
|
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed() |
|
|
self._update_patch_embed_layers(self.patch_embed) |
|
|
|
|
|
|
|
|
self.register_buffer("cls_token",nn.Parameter(torch.zeros([1, 1, self.encoder_embedding_dim]), requires_grad = True)) |
|
|
torch.nn.init.normal_(self.cls_token, std=0.02) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.decoder_depth = decoder_depth |
|
|
self.decoder_num_heads = decoder_num_heads |
|
|
self.decoder_embedding_dim = decoder_embedding_dim |
|
|
self.decoder_window_sizes = decoder_window_sizes |
|
|
self.decoder_embed = nn.Linear( |
|
|
self.encoder_embedding_dim, self.decoder_embedding_dim, bias=True |
|
|
) |
|
|
|
|
|
self.register_buffer("mask_token", nn.Parameter(torch.zeros(1, 1, self.decoder_embedding_dim, requires_grad = True))) |
|
|
torch.nn.init.normal_(self.mask_token, std=0.02) |
|
|
self.decoder_blocks = nn.ModuleList( |
|
|
[ |
|
|
MWMHABlock( |
|
|
dim=decoder_embedding_dim, |
|
|
num_heads=decoder_num_heads, |
|
|
window_sizes=decoder_window_sizes, |
|
|
shift_windows=False, |
|
|
mlp_ratio=decoder_mlp_ratio, |
|
|
qkv_bias=True, |
|
|
norm_layer=nn.LayerNorm, |
|
|
) |
|
|
for i in range(self.decoder_depth) |
|
|
] |
|
|
) |
|
|
cls_token_num = 0 |
|
|
self.encoder.pos_embedding = self._get_pos_embed_params() |
|
|
|
|
|
self.register_buffer("decoder_pos_embed", nn.Parameter( |
|
|
torch.zeros(1, self.num_patches, decoder_embedding_dim), |
|
|
requires_grad=False, |
|
|
)) |
|
|
pos_embed = get_2d_sincos_pos_embed( |
|
|
decoder_embedding_dim, self.grid_size, cls_token_num=cls_token_num |
|
|
) |
|
|
self.decoder_pos_embed.data.copy_( |
|
|
torch.from_numpy(pos_embed).float().unsqueeze(0) |
|
|
) |
|
|
|
|
|
self.spec_pred = nn.Sequential( |
|
|
nn.Linear( |
|
|
decoder_embedding_dim, |
|
|
self.patch_strategy.fshape |
|
|
* self.patch_strategy.tshape |
|
|
* self.in_channels, |
|
|
bias=True, |
|
|
), |
|
|
) |
|
|
self.decoder_norm = nn.LayerNorm(decoder_embedding_dim) |
|
|
|
|
|
self.spectrogram_normalize = nn.LayerNorm( |
|
|
[self.in_channels, num_mel_bins, self.input_length], |
|
|
elementwise_affine=False |
|
|
) |
|
|
self.input_shape = [num_mel_bins, self.input_length] |
|
|
compile_modules = kwargs.get("compile_modules", None) |
|
|
if (compile_modules is not None) and (compile_modules): |
|
|
self._compile_operations() |
|
|
|
|
|
|
|
|
def _compile_operations(self): |
|
|
""" |
|
|
Use torch.compile on the extractor, encoder and decoder blocks for faster forward |
|
|
""" |
|
|
try: |
|
|
self.forward = torch.compile(self.get_audio_representation, mode = "reduce-overhead") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not compile operations: {e}") |
|
|
self.use_compiled_forward = False |
|
|
|
|
|
|
|
|
|
|
|
def _get_pos_embed_params(self): |
|
|
"""Calculates the pos embedding embedding parameters and returns them.""" |
|
|
|
|
|
pos_embed = nn.Parameter( |
|
|
torch.zeros( |
|
|
1, |
|
|
self.num_patches + self.encoder_cls_token_num, |
|
|
self.encoder_embedding_dim, |
|
|
), |
|
|
requires_grad=False, |
|
|
) |
|
|
pos_embed_data = get_2d_sincos_pos_embed( |
|
|
self.encoder_embedding_dim, |
|
|
self.grid_size, |
|
|
cls_token_num=self.encoder_cls_token_num, |
|
|
) |
|
|
pos_embed.data.copy_(torch.from_numpy(pos_embed_data).float().unsqueeze(0)) |
|
|
return pos_embed |
|
|
|
|
|
def _update_patch_embed_layers(self, patch_embed): |
|
|
"""Updates the patch embedding embedding layers.""" |
|
|
|
|
|
|
|
|
patch_embed.proj = torch.nn.Conv2d( |
|
|
self.in_channels, |
|
|
self.encoder_embedding_dim, |
|
|
kernel_size=(self.patch_strategy.fshape, self.patch_strategy.tshape), |
|
|
stride=(self.patch_strategy.fstride, self.patch_strategy.tstride), |
|
|
) |
|
|
patch_embed.num_patch = self.num_patches |
|
|
|
|
|
def pass_through_encoder(self, x, non_mask_index, B): |
|
|
"""Passes the input through the Encoder Transformer network.""" |
|
|
|
|
|
x = x + self.encoder.pos_embedding[:, self.encoder_cls_token_num :, :] |
|
|
x = x[non_mask_index, :].reshape((B, -1, x.shape[-1])) |
|
|
cls_token = ( |
|
|
self.cls_token.expand(B, -1, -1) |
|
|
+ self.encoder.pos_embedding[:, :1, :] |
|
|
) |
|
|
|
|
|
try: |
|
|
dist_token = ( |
|
|
self.encoder.dist_token.expand(B, -1, -1) |
|
|
+ self.encoder.pos_embedding[:, 1:2, :] |
|
|
) |
|
|
x = torch.cat((cls_token, dist_token, x), dim=1) |
|
|
|
|
|
except Exception as e: |
|
|
x = torch.cat((cls_token, x), dim=1) |
|
|
|
|
|
|
|
|
x = self.encoder.dropout(x) |
|
|
for block in self.encoder.layers: |
|
|
x = block(x) |
|
|
return self.encoder.ln(x) |
|
|
|
|
|
|
|
|
def pass_through_decoder(self, encoder_output, non_mask_index, B): |
|
|
encoder_output = self.decoder_embed(encoder_output) |
|
|
x_ = repeat_token( |
|
|
self.mask_token, (B, self.num_patches) |
|
|
).type_as(encoder_output) |
|
|
x_[non_mask_index, :] = encoder_output[ |
|
|
:, self.encoder_cls_token_num :, : |
|
|
].reshape((-1, encoder_output.shape[-1])) |
|
|
x_ = x_.reshape((B, -1, encoder_output.shape[-1])) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_mwmae_decoder: |
|
|
x = x_ |
|
|
return_cut = 0 |
|
|
else: |
|
|
x = torch.cat( |
|
|
[encoder_output[:, : self.encoder_cls_token_num, :], x_], dim=1 |
|
|
) |
|
|
return_cut = self.encoder_cls_token_num |
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
|
x = blk(x) |
|
|
x = self.decoder_norm(x) |
|
|
pred = self.spec_pred(x) |
|
|
pred = pred[:, return_cut:, :] |
|
|
return pred |
|
|
|
|
|
|
|
|
|
|
|
def _get_segment_representation(self, x, strategy="mean"): |
|
|
"""Extract audio representation using different strategies.""" |
|
|
|
|
|
assert x.shape[1] == self.in_channels, f"The GRAM has in channels {self.in_channels}, but the feature has shape {x.shape} which the channels are incompatible" |
|
|
B = x.shape[0] |
|
|
x = x.transpose(2, 3) |
|
|
x = self.spectrogram_normalize(x) |
|
|
patches = self.patch_strategy.patch(x) |
|
|
patches = patches.flatten(2) |
|
|
encoded_patches = self.patch_strategy.embed(x, self.patch_embed) |
|
|
mask = torch.zeros((B, self.num_patches), dtype=torch.bool, device=x.device) |
|
|
x = self.pass_through_encoder(encoded_patches, ~mask, B) |
|
|
if strategy == "mean": |
|
|
return x[:, self.encoder_cls_token_num :, :].mean(axis=1) |
|
|
elif strategy == "sum": |
|
|
return x[:, self.encoder_cls_token_num :, :].sum(axis=1) |
|
|
elif strategy == "cls": |
|
|
return x[:, 0, :] |
|
|
elif strategy == "raw": |
|
|
x = x[:, self.encoder_cls_token_num :, :] |
|
|
grid_size = self.grid_size |
|
|
f, t = grid_size |
|
|
|
|
|
outcome = rearrange( |
|
|
x, "b (f t) d -> b t (f d)", f=f, d=self.encoder_embedding_dim |
|
|
) |
|
|
return outcome |
|
|
else: |
|
|
raise ValueError(f"Strategy '{strategy}' is unrecognized.") |
|
|
|
|
|
def get_audio_representation(self, x, strategy = "mean"): |
|
|
unit_frames = self.input_length |
|
|
cur_frames = x.shape[2] |
|
|
pad_frames = unit_frames - (cur_frames % unit_frames) |
|
|
if pad_frames > 0: |
|
|
|
|
|
pad_arg = ( |
|
|
0, |
|
|
0, |
|
|
0, |
|
|
pad_frames, |
|
|
) |
|
|
x = torch.nn.functional.pad(x, pad_arg, mode="constant") |
|
|
|
|
|
embeddings = [] |
|
|
|
|
|
for i in range(x.shape[2] // unit_frames): |
|
|
x_inp = x[:, :, i * unit_frames : (i + 1) * unit_frames, :] |
|
|
with torch.no_grad(): |
|
|
embedding = self._get_segment_representation( |
|
|
x_inp, strategy=strategy |
|
|
) |
|
|
embeddings.append(embedding) |
|
|
|
|
|
if strategy == "raw": |
|
|
x = torch.hstack(embeddings) |
|
|
pad_emb_frames = int(embeddings[0].shape[1] * pad_frames / unit_frames) |
|
|
if pad_emb_frames > 0: |
|
|
x = x[:, :-pad_emb_frames] |
|
|
return x |
|
|
else: |
|
|
x = torch.stack(embeddings, dim=1) |
|
|
return x |
|
|
|