dasheng-denoiser / modeling_dasheng_encoder.py
richermans's picture
Upload folder using huggingface_hub
e23a1f0 verified
from einops import rearrange
from einops.layers.torch import Rearrange
import torchaudio.transforms as audio_transforms
import torch
import torch.nn as nn
from typing import Optional, Type
class FrontEnd(nn.Sequential):
def __init__(
self,
f_min: int = 0,
sample_rate: int = 16000,
win_size: int = 512,
center: bool = True,
n_fft: int = 512,
f_max: Optional[int] = 8000,
hop_size: int = 160,
n_mels: int = 64,
):
self.f_min = f_min
self.sample_rate = sample_rate
self.win_size = win_size
self.center = center
self.n_fft = n_fft
self.f_max = f_max
self.hop_size = hop_size
self.n_mels = n_mels
with torch.device("cpu"):
super().__init__(
audio_transforms.MelSpectrogram(
f_min=self.f_min,
sample_rate=self.sample_rate,
win_length=self.win_size,
center=self.center,
n_fft=self.n_fft,
f_max=self.f_max,
hop_length=self.hop_size,
n_mels=self.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
@torch.autocast(enabled=False, device_type="cuda")
def forward(self, x, attention_mask=None):
"""
Forward pass of the frontend.
Args:
x: Audio tensor of shape (batch_size, num_samples)
attention_mask: Optional attention mask of shape (batch_size, num_samples)
Returns:
features: Mel spectrogram features of shape (batch_size, n_mels, num_frames)
attention_mask: Downsampled attention mask of shape (batch_size, num_frames)
"""
features = super().forward(x)
if attention_mask is not None:
lengths = attention_mask.float().sum(-1) // self.hop_size
attention_mask = (torch.arange(features.shape[-1], device=features.device) < lengths.unsqueeze(-1)).int()
return features, attention_mask
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Type[torch.nn.Module] = nn.GELU,
drop: float = 0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
causal: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.causal = causal
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape
# qkv: [3, B, heads, N, head_dim]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
# Apply Causal Mask
if self.causal:
c_mask = torch.ones(N, N, device=x.device, dtype=torch.bool).triu(1)
attn = attn.masked_fill(c_mask, float("-inf"))
# Apply Padding Mask (B, N) -> (B, 1, 1, N)
if mask is not None:
if mask.dtype != torch.bool:
padding_mask = (mask == 0)
else:
padding_mask = mask
padding_mask = padding_mask.view(B, 1, 1, N)
attn = attn.masked_fill(padding_mask, float("-inf"))
attn = attn.softmax(dim=-1).nan_to_num()
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(x))
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-6)
self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, drop)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.GELU, drop=drop)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask=mask)
x = x + self.mlp(self.norm2(x))
return x
class AudioPatchEmbed(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__()
self.stride = kwargs.get('stride', [None, 4])[-1]
self.proj = nn.Conv2d(*args, **kwargs)
def forward(self, x:torch.Tensor, attention_mask:torch.Tensor | None =None):
x = self.proj(x)
if attention_mask is not None:
lengths = attention_mask.float().sum(-1) // self.stride
attention_mask = (torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1)).int()
return x, attention_mask
class DashengBlocks(torch.nn.Module):
"""
Args:
embedding_dims (int): Hidden dimension of the model.
num_heads (int): Number of heads in Attention layers.
num_layers (int): Number of Attention layers in Block.
"""
def __init__(
self,
embedding_dims: int,
num_heads: int,
num_layers: int,
):
super().__init__()
self.blocks = nn.Sequential(*[
Block(
dim=embedding_dims,
num_heads=num_heads,
mlp_ratio=2.0,
qkv_bias=True,
drop=0.0,
attn_drop=0.0,
) for i in range(num_layers)
])
self.norm = nn.LayerNorm(embedding_dims)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
x: (B, D, T)
return: (B, D, T)
"""
x = self.blocks(x)
x = self.norm(x)
return x.transpose(-2,-1)
class DashengEncoder(nn.Module):
def __init__(
self,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
patch_size=[64, 4],
patch_stride=[64, 4],
target_length=1008,
):
super().__init__()
self.embed_dim = embed_dim
self.time_patches = patch_stride[-1]
self.front_end = FrontEnd()
self.target_length = target_length
self.max_t_tokens = target_length // patch_stride[-1]
self.patch_embed = AudioPatchEmbed(1, embed_dim, kernel_size=patch_size, stride=patch_stride)
self.init_bn = nn.Sequential(
Rearrange("b c f t -> b f c t"),
torch.nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01),
Rearrange("b f c t -> b c f t"),
)
self.time_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02)
self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02)
self.blocks = nn.ModuleList([Block(embed_dim, num_heads) for _ in range(depth)])
self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
def _forward_main(self, x, attention_mask, mask_to_zero:bool = False):
x, attention_mask = self.patch_embed(x, attention_mask)
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed
x = rearrange(x, "b c f t -> b (f t) c")
for block in self.blocks:
x = block(x, mask=attention_mask)
x = self.norm(x)
if attention_mask is not None and mask_to_zero:
x = x * attention_mask.unsqueeze(-1) # Zero out all samples that were masked, but only after first chunk
return x
def forward(self, x: torch.Tensor, attention_mask=None):
"""
Forward pass of the AudioTransformer.
Args:
x: Audio tensor of shape (batch_size, num_samples)
attention_mask: Optional attention mask of shape (batch_size, num_samples)
where True indicates valid samples and False indicates padding
Returns:
embeddings: Token embeddings of shape (batch_size, num_tokens, embed_dim)
"""
# Process through frontend - returns features and downsampled mask
x, attention_mask = self.front_end(x, attention_mask)
# Rearrange features for patch embedding: (b f t) -> (b 1 f t)
x = rearrange(x, "b f t -> b 1 f t")
x = self.init_bn(x)
input_splits = x.split(self.target_length, dim=-1)
masks = [None for _ in range(len(input_splits))]
if attention_mask is not None:
masks = attention_mask.split(self.target_length, dim=-1)
outputs = []
for i, (input_split_x, mask) in enumerate(zip(input_splits, masks)):
output = self._forward_main(input_split_x, attention_mask=mask, mask_to_zero=i != 0)
outputs.append(output)
x = torch.cat(outputs, dim=1)
return x