| | from einops import rearrange |
| | from einops.layers.torch import Rearrange |
| | import torchaudio.transforms as audio_transforms |
| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Type |
| |
|
| | class FrontEnd(nn.Sequential): |
| | def __init__( |
| | self, |
| | f_min: int = 0, |
| | sample_rate: int = 16000, |
| | win_size: int = 512, |
| | center: bool = True, |
| | n_fft: int = 512, |
| | f_max: Optional[int] = 8000, |
| | hop_size: int = 160, |
| | n_mels: int = 64, |
| | ): |
| | self.f_min = f_min |
| | self.sample_rate = sample_rate |
| | self.win_size = win_size |
| | self.center = center |
| | self.n_fft = n_fft |
| | self.f_max = f_max |
| | self.hop_size = hop_size |
| | self.n_mels = n_mels |
| |
|
| | with torch.device("cpu"): |
| | super().__init__( |
| | audio_transforms.MelSpectrogram( |
| | f_min=self.f_min, |
| | sample_rate=self.sample_rate, |
| | win_length=self.win_size, |
| | center=self.center, |
| | n_fft=self.n_fft, |
| | f_max=self.f_max, |
| | hop_length=self.hop_size, |
| | n_mels=self.n_mels, |
| | ), |
| | audio_transforms.AmplitudeToDB(top_db=120), |
| | ) |
| |
|
| | @torch.autocast(enabled=False, device_type="cuda") |
| | def forward(self, x, attention_mask=None): |
| | """ |
| | Forward pass of the frontend. |
| | |
| | Args: |
| | x: Audio tensor of shape (batch_size, num_samples) |
| | attention_mask: Optional attention mask of shape (batch_size, num_samples) |
| | |
| | Returns: |
| | features: Mel spectrogram features of shape (batch_size, n_mels, num_frames) |
| | attention_mask: Downsampled attention mask of shape (batch_size, num_frames) |
| | """ |
| | features = super().forward(x) |
| | if attention_mask is not None: |
| | lengths = attention_mask.float().sum(-1) // self.hop_size |
| | attention_mask = (torch.arange(features.shape[-1], device=features.device) < lengths.unsqueeze(-1)).int() |
| | return features, attention_mask |
| |
|
| |
|
| | class Mlp(nn.Module): |
| | def __init__( |
| | self, |
| | in_features: int, |
| | hidden_features: Optional[int] = None, |
| | out_features: Optional[int] = None, |
| | act_layer: Type[torch.nn.Module] = nn.GELU, |
| | drop: float = 0.0, |
| | ): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int = 8, |
| | qkv_bias: bool = True, |
| | attn_drop: float = 0.0, |
| | proj_drop: float = 0.0, |
| | causal: bool = False, |
| | ): |
| | super().__init__() |
| | self.num_heads = num_heads |
| | self.scale = (dim // num_heads) ** -0.5 |
| | self.causal = causal |
| |
|
| | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| | self.attn_drop = nn.Dropout(attn_drop) |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x, mask: Optional[torch.Tensor] = None): |
| | B, N, C = x.shape |
| | |
| | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| | q, k, v = qkv[0], qkv[1], qkv[2] |
| |
|
| | attn = (q @ k.transpose(-2, -1)) * self.scale |
| |
|
| | |
| | if self.causal: |
| | c_mask = torch.ones(N, N, device=x.device, dtype=torch.bool).triu(1) |
| | attn = attn.masked_fill(c_mask, float("-inf")) |
| |
|
| | |
| | if mask is not None: |
| | if mask.dtype != torch.bool: |
| | padding_mask = (mask == 0) |
| | else: |
| | padding_mask = mask |
| | padding_mask = padding_mask.view(B, 1, 1, N) |
| | attn = attn.masked_fill(padding_mask, float("-inf")) |
| | attn = attn.softmax(dim=-1).nan_to_num() |
| | attn = self.attn_drop(attn) |
| |
|
| | x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| | return self.proj_drop(self.proj(x)) |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | mlp_ratio: float = 4.0, |
| | qkv_bias: bool = True, |
| | drop: float = 0.0, |
| | attn_drop: float = 0.0, |
| | ): |
| | super().__init__() |
| | self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| | self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, drop) |
| | self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| | self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.GELU, drop=drop) |
| |
|
| | def forward(self, x, mask=None): |
| | x = x + self.attn(self.norm1(x), mask=mask) |
| | x = x + self.mlp(self.norm2(x)) |
| | return x |
| |
|
| |
|
| | class AudioPatchEmbed(torch.nn.Module): |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | super().__init__() |
| | self.stride = kwargs.get('stride', [None, 4])[-1] |
| | self.proj = nn.Conv2d(*args, **kwargs) |
| |
|
| | def forward(self, x:torch.Tensor, attention_mask:torch.Tensor | None =None): |
| | x = self.proj(x) |
| | if attention_mask is not None: |
| | lengths = attention_mask.float().sum(-1) // self.stride |
| | attention_mask = (torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1)).int() |
| | return x, attention_mask |
| |
|
| |
|
| |
|
| |
|
| | class DashengEncoder(nn.Module): |
| | def __init__( |
| | self, |
| | embed_dim: int = 1280, |
| | depth: int = 32, |
| | num_heads: int = 20, |
| | patch_size=[64, 4], |
| | patch_stride=[64, 4], |
| | target_length=1008, |
| | ): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.time_patches = patch_stride[-1] |
| | self.front_end = FrontEnd() |
| | self.target_length = target_length |
| | self.max_t_tokens = target_length // patch_stride[-1] |
| | self.patch_embed = AudioPatchEmbed(1, embed_dim, kernel_size=patch_size, stride=patch_stride) |
| | self.init_bn = nn.Sequential( |
| | Rearrange("b c f t -> b f c t"), |
| | torch.nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01), |
| | Rearrange("b f c t -> b c f t"), |
| | ) |
| |
|
| | self.time_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02) |
| | self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02) |
| |
|
| | self.blocks = nn.ModuleList([Block(embed_dim, num_heads) for _ in range(depth)]) |
| | self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
| |
|
| | def _forward_main(self, x, attention_mask, mask_to_zero:bool = False): |
| | x, attention_mask = self.patch_embed(x, attention_mask) |
| | t = x.shape[-1] |
| | x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed |
| | x = rearrange(x, "b c f t -> b (f t) c") |
| | for block in self.blocks: |
| | x = block(x, mask=attention_mask) |
| | x = self.norm(x) |
| | if attention_mask is not None and mask_to_zero: |
| | x = x * attention_mask.unsqueeze(-1) |
| | return x |
| |
|
| |
|
| | def forward(self, x: torch.Tensor, attention_mask=None): |
| | """ |
| | Forward pass of the AudioTransformer. |
| | |
| | Args: |
| | x: Audio tensor of shape (batch_size, num_samples) |
| | attention_mask: Optional attention mask of shape (batch_size, num_samples) |
| | where True indicates valid samples and False indicates padding |
| | |
| | Returns: |
| | embeddings: Token embeddings of shape (batch_size, num_tokens, embed_dim) |
| | """ |
| | |
| | x, attention_mask = self.front_end(x, attention_mask) |
| |
|
| | |
| | x = rearrange(x, "b f t -> b 1 f t") |
| | x = self.init_bn(x) |
| |
|
| | input_splits = x.split(self.target_length, dim=-1) |
| | masks = [None for _ in range(len(input_splits))] |
| | if attention_mask is not None: |
| | masks = attention_mask.split(self.target_length, dim=-1) |
| |
|
| | outputs = [] |
| | for i, (input_split_x, mask) in enumerate(zip(input_splits, masks)): |
| | output = self._forward_main(input_split_x, attention_mask=mask, mask_to_zero=i != 0) |
| | outputs.append(output) |
| | x = torch.cat(outputs, dim=1) |
| | return x |
| |
|