|
|
import collections.abc |
|
|
import math |
|
|
import sys |
|
|
from itertools import repeat |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
import timm |
|
|
import torch |
|
|
from torch import nn |
|
|
from torchvision.models.vision_transformer import Encoder |
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
from functools import partial |
|
|
from collections.abc import Iterable |
|
|
|
|
|
|
|
|
def plot_fbank(fbank, title=None, save_path=None, **kwargs): |
|
|
fig, axs = plt.subplots(min(4, fbank.shape[0]), 1, sharex=True, sharey=True) |
|
|
if not isinstance(axs, Iterable): |
|
|
axs = np.array([axs]) |
|
|
vmin, vmax = kwargs.get("vmin", None), kwargs.get("vmax", None) |
|
|
|
|
|
for channel in range(0, min(4, fbank.shape[0])): |
|
|
axs[channel].set_title(f"Filter bank channel {channel}, {title}") |
|
|
im = axs[channel].imshow(fbank[channel].T, aspect="auto", vmin=vmin, vmax=vmax) |
|
|
axs[channel].set_ylabel("mel") |
|
|
axs[channel].set_xlabel("time") |
|
|
plt.gca().invert_yaxis() |
|
|
plt.tight_layout() |
|
|
fig.colorbar(im, ax=axs.ravel().tolist()) |
|
|
plt.show() |
|
|
if save_path: |
|
|
fig.savefig(save_path) |
|
|
plt.close() |
|
|
return fig |
|
|
|
|
|
|
|
|
|
|
|
def _ntuple(n): |
|
|
def parse(x): |
|
|
|
|
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
|
return tuple(x) |
|
|
|
|
|
return tuple(repeat(x, n)) |
|
|
|
|
|
return parse |
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
"""Image to Patch Embedding""" |
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
|
|
super().__init__() |
|
|
img_size = _ntuple(2)(img_size) |
|
|
patch_size = _ntuple(2)(patch_size) |
|
|
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
|
|
self.img_size = img_size |
|
|
self.patch_size = patch_size |
|
|
self.num_patches = num_patches |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_channels=in_chans, |
|
|
out_channels=embed_dim, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size, |
|
|
) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
x = self.proj(x).flatten(2).transpose(1, 2) |
|
|
return x |
|
|
|
|
|
|
|
|
def get_sinusoid_encoding(n_position, d_hid): |
|
|
"""Sinusoid position encoding table""" |
|
|
|
|
|
def get_position_angle_vec(position): |
|
|
return [ |
|
|
position / np.power(10000, 2 * (hid_j // 2) / d_hid) |
|
|
for hid_j in range(d_hid) |
|
|
] |
|
|
|
|
|
sinusoid_table = np.array( |
|
|
[get_position_angle_vec(pos_i) for pos_i in range(n_position)] |
|
|
) |
|
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
|
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
|
|
|
|
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0) |
|
|
|
|
|
|
|
|
def create_pretrained_model(model_size, |
|
|
encoder_num_layers = 12, |
|
|
encoder_num_heads = 12, |
|
|
encoder_hidden_dim = 768, |
|
|
encoder_mlp_dim= 3072, |
|
|
encoder_dropout = 0.0, |
|
|
encoder_attention_dropout = 0.0, |
|
|
encoder_norm_layer_eps = 1e-6): |
|
|
if model_size == "tiny": |
|
|
v = timm.create_model("deit_tiny_distilled_patch16_224", pretrained=False) |
|
|
hidden_dim = 182 |
|
|
|
|
|
elif model_size == "small": |
|
|
v = timm.create_model("deit_small_distilled_patch16_224", pretrained=False) |
|
|
hidden_dim = 384 |
|
|
|
|
|
elif model_size == "base": |
|
|
v = Encoder( |
|
|
seq_length = 0, |
|
|
num_layers = encoder_num_layers, |
|
|
num_heads = encoder_num_heads, |
|
|
hidden_dim = encoder_hidden_dim, |
|
|
mlp_dim= encoder_mlp_dim, |
|
|
dropout = encoder_dropout, |
|
|
attention_dropout = encoder_attention_dropout, |
|
|
norm_layer = partial(nn.LayerNorm, eps=encoder_norm_layer_eps)) |
|
|
hidden_dim = encoder_hidden_dim |
|
|
|
|
|
elif model_size == "base_nokd": |
|
|
v = timm.create_model("deit_base_patch16_384", pretrained=False) |
|
|
hidden_dim = 768 |
|
|
|
|
|
else: |
|
|
print("Wrong model size!") |
|
|
sys.exit(0) |
|
|
|
|
|
return v, hidden_dim |
|
|
|
|
|
|
|
|
def _trunc_normal_(tensor, mean, std, a, b): |
|
|
|
|
|
|
|
|
def norm_cdf(x): |
|
|
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
left = norm_cdf((a - mean) / std) |
|
|
up = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
|
|
|
|
tensor.uniform_(2 * left - 1, 2 * up - 1) |
|
|
|
|
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
|
tensor.add_(mean) |
|
|
|
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
|
return tensor |
|
|
|
|
|
|
|
|
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): |
|
|
|
|
|
r"""Fills the input Tensor with values drawn from a truncated |
|
|
normal distribution. The values are effectively drawn from the |
|
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
|
with values outside :math:`[a, b]` redrawn until they are within |
|
|
the bounds. The method used for generating the random values works |
|
|
best when :math:`a \leq \text{mean} \leq b`. |
|
|
|
|
|
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are |
|
|
applied while sampling the normal with mean/std applied, therefore a, b args |
|
|
should be adjusted to match the range of mean, std args. |
|
|
|
|
|
Args: |
|
|
tensor: an n-dimensional `torch.Tensor` |
|
|
mean: the mean of the normal distribution |
|
|
std: the standard deviation of the normal distribution |
|
|
a: the minimum cutoff value |
|
|
b: the maximum cutoff value |
|
|
Examples: |
|
|
>>> w = torch.empty(3, 5) |
|
|
>>> nn.init.trunc_normal_(w) |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
return _trunc_normal_(tensor, mean, std, a, b) |
|
|
|
|
|
|
|
|
def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: |
|
|
"""Expands the index along the last dimension of the input tokens. |
|
|
|
|
|
Args: |
|
|
index: |
|
|
Index tensor with shape (batch_size, idx_length) where each entry is |
|
|
an index in [0, sequence_length). |
|
|
tokens: |
|
|
Tokens tensor with shape (batch_size, sequence_length, dim). |
|
|
|
|
|
Returns: |
|
|
Index tensor with shape (batch_size, idx_length, dim) where the original |
|
|
indices are repeated dim times along the last dimension. |
|
|
|
|
|
""" |
|
|
dim = tokens.shape[-1] |
|
|
index = index.unsqueeze(-1).expand(-1, -1, dim) |
|
|
return index |
|
|
|
|
|
def set_at_index( |
|
|
tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Copies all values into the input tensor at the given indices. |
|
|
|
|
|
Args: |
|
|
tokens: |
|
|
Tokens tensor with shape (batch_size, sequence_length, dim). |
|
|
index: |
|
|
Index tensor with shape (batch_size, index_length). |
|
|
value: |
|
|
Value tensor with shape (batch_size, index_length, dim). |
|
|
|
|
|
Returns: |
|
|
Tokens tensor with shape (batch_size, sequence_length, dim) containing |
|
|
the new values. |
|
|
|
|
|
""" |
|
|
index = expand_index_like(index, tokens) |
|
|
return torch.scatter(tokens, 1, index, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor: |
|
|
"""Repeats a token size times. |
|
|
|
|
|
Args: |
|
|
token: |
|
|
Token tensor with shape (1, 1, dim). |
|
|
size: |
|
|
(batch_size, sequence_length) tuple. |
|
|
|
|
|
Returns: |
|
|
Tensor with shape (batch_size, sequence_length, dim) containing copies |
|
|
of the input token. |
|
|
|
|
|
""" |
|
|
batch_size, sequence_length = size |
|
|
return token.repeat(batch_size, sequence_length, 1) |