|
|
import collections |
|
|
import collections.abc |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio.functional as F |
|
|
from torch import Tensor |
|
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union, cast |
|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
_Tuple2 = Union[int, Tuple[int, int], Sequence[int]] |
|
|
|
|
|
|
|
|
def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]: |
|
|
if isinstance(x, collections.abc.Sequence): |
|
|
assert len(x) == 2, ( |
|
|
f"Expected a sequence of length 2, got {x} with length {len(x)}" |
|
|
) |
|
|
return cast(Tuple[int, int], tuple(x)) |
|
|
return (x, x) |
|
|
|
|
|
|
|
|
|
|
|
class DashengConfig(PretrainedConfig): |
|
|
model_type = "midashenglm_dasheng_encoder" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim: int = 768, |
|
|
outputdim: int = 527, |
|
|
patch_size: Union[int, Tuple[int, int]] = 16, |
|
|
patch_stride: Union[int, Tuple[int, int]] = 16, |
|
|
input_channels: int = 1, |
|
|
target_length: int = 1012, |
|
|
depth: int = 12, |
|
|
num_heads: int = 12, |
|
|
mlp_ratio: float = 4.0, |
|
|
qkv_bias: bool = True, |
|
|
init_values: Optional[float] = None, |
|
|
drop_rate: float = 0.0, |
|
|
attn_drop_rate: float = 0.0, |
|
|
f_min: float = 0.0, |
|
|
f_max: float = 8000.0, |
|
|
center: bool = True, |
|
|
win_length: int = 512, |
|
|
hop_length: int = 160, |
|
|
sample_rate: int = 16000, |
|
|
n_fft: int = 512, |
|
|
n_mels: int = 64, |
|
|
**kwargs, |
|
|
): |
|
|
self.embed_dim = embed_dim |
|
|
self.outputdim = outputdim |
|
|
self.patch_size = patch_size |
|
|
self.patch_stride = patch_stride |
|
|
self.input_channels = input_channels |
|
|
self.target_length = target_length |
|
|
self.depth = depth |
|
|
self.num_heads = num_heads |
|
|
self.mlp_ratio = mlp_ratio |
|
|
self.qkv_bias = qkv_bias |
|
|
self.init_values = init_values |
|
|
self.drop_rate = drop_rate |
|
|
self.attn_drop_rate = attn_drop_rate |
|
|
self.f_min = f_min |
|
|
self.f_max = f_max |
|
|
self.center = center |
|
|
self.win_length = win_length |
|
|
self.hop_length = hop_length |
|
|
self.sample_rate = sample_rate |
|
|
self.n_fft = n_fft |
|
|
self.n_mels = n_mels |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class AudioPatchEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size: _Tuple2 = 64, |
|
|
patch_size: _Tuple2 = 16, |
|
|
patch_stride: _Tuple2 = 16, |
|
|
in_chans: int = 1, |
|
|
embed_dim: int = 768, |
|
|
norm_layer: Optional[Callable] = None, |
|
|
flatten: bool = False, |
|
|
): |
|
|
super().__init__() |
|
|
self.input_size = _resolve_tuple2(input_size) |
|
|
self.patch_size = _resolve_tuple2(patch_size) |
|
|
self.patch_stride = _resolve_tuple2(patch_stride) |
|
|
self.grid_size = ( |
|
|
self.input_size[0] // self.patch_stride[0], |
|
|
self.input_size[1] // self.patch_stride[1], |
|
|
) |
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1] |
|
|
self.flatten = flatten |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_chans, |
|
|
embed_dim, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_stride, |
|
|
) |
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.proj(x) |
|
|
if self.flatten: |
|
|
x = torch.permute( |
|
|
torch.flatten(x, 2, 3), (0, 2, 1) |
|
|
) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__(self, dim, init_values=1e-5, inplace=False): |
|
|
super().__init__() |
|
|
self.inplace = inplace |
|
|
self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
|
|
|
|
class DashengMlp(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
hidden_features: Optional[int] = None, |
|
|
out_features: Optional[int] = None, |
|
|
drop: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = nn.GELU() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DashengAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int = 8, |
|
|
qkv_bias: bool = False, |
|
|
attn_drop: float = 0.0, |
|
|
proj_drop: float = 0.0, |
|
|
): |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, "dim should be divisible by num_heads" |
|
|
self.num_heads = num_heads |
|
|
head_dim = dim // num_heads |
|
|
self.scale = head_dim**-0.5 |
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): |
|
|
B, N, C = x.shape |
|
|
q, k, v = ( |
|
|
self.qkv(x) |
|
|
.reshape(B, N, 3, self.num_heads, C // self.num_heads) |
|
|
.permute(2, 0, 3, 1, 4) |
|
|
.unbind(0) |
|
|
) |
|
|
x = scaled_dot_product_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
attn_mask=mask[:, None, None, :] if mask is not None else None, |
|
|
) |
|
|
x = x.transpose(1, 2).reshape(B, N, C) |
|
|
x = self.proj(x) |
|
|
x = self.proj_drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DashengBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_heads: int, |
|
|
mlp_ratio: float = 4.0, |
|
|
qkv_bias: bool = False, |
|
|
drop: float = 0.0, |
|
|
attn_drop: float = 0.0, |
|
|
init_values: Optional[float] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.attn = DashengAttention( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
self.ls1 = ( |
|
|
LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
) |
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
|
|
self.mlp = DashengMlp( |
|
|
in_features=dim, |
|
|
hidden_features=int(dim * mlp_ratio), |
|
|
drop=drop, |
|
|
) |
|
|
self.ls2 = ( |
|
|
LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
|
) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
x = x + self.ls1(self.attn(self.norm1(x), mask)) |
|
|
x = x + self.ls2(self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|
|
|
|
|
|
class DashengFrontend(nn.Module): |
|
|
def __init__(self, config: DashengConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
spectrogram_window = torch.hann_window(self.config.win_length) |
|
|
self.register_buffer( |
|
|
"spectrogram_window", |
|
|
spectrogram_window, |
|
|
persistent=False, |
|
|
) |
|
|
self.spectrogram_window: torch.Tensor |
|
|
|
|
|
melscale_fbanks = F.melscale_fbanks( |
|
|
n_freqs=self.config.n_fft // 2 + 1, |
|
|
f_min=self.config.f_min, |
|
|
f_max=self.config.f_max, |
|
|
n_mels=self.config.n_mels, |
|
|
sample_rate=self.config.sample_rate, |
|
|
) |
|
|
self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False) |
|
|
self.melscale_fbanks: torch.Tensor |
|
|
|
|
|
def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
|
|
spectrogram = F.spectrogram( |
|
|
waveform=waveform.to(torch.float32), |
|
|
pad=0, |
|
|
window=self.spectrogram_window, |
|
|
n_fft=self.config.n_fft, |
|
|
hop_length=self.config.hop_length, |
|
|
win_length=self.config.win_length, |
|
|
power=2, |
|
|
normalized=False, |
|
|
center=self.config.center, |
|
|
) |
|
|
mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_mel_spectrogram = F.amplitude_to_DB( |
|
|
mel_spectrogram.unsqueeze(1), |
|
|
multiplier=10, |
|
|
amin=1e-10, |
|
|
db_multiplier=0, |
|
|
top_db=120, |
|
|
).squeeze(1) |
|
|
return log_mel_spectrogram.to(waveform.dtype) |
|
|
|
|
|
|
|
|
class DashengAudioTransformer(PreTrainedModel): |
|
|
config_class = DashengConfig |
|
|
supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__(self, config: DashengConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.target_length = config.target_length |
|
|
self.embed_dim = config.embed_dim |
|
|
self.hop_length = config.hop_length |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
self.front_end = DashengFrontend(config) |
|
|
|
|
|
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) |
|
|
|
|
|
self.patch_embed = AudioPatchEmbed( |
|
|
input_size=(config.n_mels, config.target_length), |
|
|
embed_dim=config.embed_dim, |
|
|
in_chans=config.input_channels, |
|
|
patch_size=config.patch_size, |
|
|
flatten=False, |
|
|
patch_stride=config.patch_stride, |
|
|
) |
|
|
|
|
|
self.time_pos_embed = nn.Parameter( |
|
|
torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02 |
|
|
) |
|
|
self.freq_pos_embed = nn.Parameter( |
|
|
torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02 |
|
|
) |
|
|
|
|
|
self.pos_drop = nn.Dropout(p=config.drop_rate) |
|
|
self.blocks = nn.ModuleList( |
|
|
DashengBlock( |
|
|
dim=config.embed_dim, |
|
|
num_heads=config.num_heads, |
|
|
mlp_ratio=config.mlp_ratio, |
|
|
qkv_bias=config.qkv_bias, |
|
|
init_values=config.init_values, |
|
|
drop=config.drop_rate, |
|
|
attn_drop=config.attn_drop_rate, |
|
|
) |
|
|
for _ in range(config.depth) |
|
|
) |
|
|
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward_features( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
mask: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
t = x.shape[-1] |
|
|
x = x + self.time_pos_embed[:, :, :, :t] |
|
|
x = ( |
|
|
x + self.freq_pos_embed[:, :, :, :] |
|
|
) |
|
|
x = torch.permute( |
|
|
torch.flatten(x, 2, 3), (0, 2, 1) |
|
|
) |
|
|
x = self.pos_drop(x) |
|
|
for block in self.blocks: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
x = self._gradient_checkpointing_func(block, x, mask) |
|
|
else: |
|
|
x = block(x, mask) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: |
|
|
batch_size = len(lengths) |
|
|
idx = torch.arange(max_length, device=lengths.device) |
|
|
idx = idx.repeat(batch_size).view(batch_size, max_length) |
|
|
mask = (idx < lengths.unsqueeze(-1)).bool() |
|
|
return mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
x_length: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
x = self.front_end(x) |
|
|
target_length_in_patches = self.target_length // 4 |
|
|
x = x.unsqueeze(1) |
|
|
x = torch.permute(x, (0, 2, 1, 3)) |
|
|
x = self.init_bn(x) |
|
|
x = torch.permute(x, (0, 2, 1, 3)) |
|
|
|
|
|
x = self.patch_embed(x) |
|
|
t = x.shape[-1] |
|
|
|
|
|
input_splits = x.split(target_length_in_patches, dim=-1) |
|
|
|
|
|
if x_length is not None: |
|
|
assert len(x_length) == len(x), ( |
|
|
"batchsizes of input x and x_length need to be same" |
|
|
) |
|
|
assert x_length.ndim == 1, "Lengths are of size (B,)" |
|
|
scaled_lengths = (x_length / (self.hop_length * 4)).long() |
|
|
mask = self._to_mask(max_length=t, lengths=scaled_lengths) |
|
|
split_masks = mask.split(target_length_in_patches, dim=-1) |
|
|
else: |
|
|
mask = None |
|
|
split_masks = [None] * len(input_splits) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
for split_x, split_mask in zip(input_splits, split_masks): |
|
|
forward_kwargs = {} |
|
|
forward_kwargs["mask"] = split_mask |
|
|
split_x = self.forward_features(split_x, **forward_kwargs) |
|
|
outputs.append(split_x) |
|
|
x = torch.cat(outputs, dim=1) |
|
|
return x, mask |
|
|
|
|
|
|
|
|
class AudioProjectorSubsample(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_dim: int, |
|
|
out_dim: int, |
|
|
downsample_rate=5, |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.k = downsample_rate |
|
|
self.out_dim = out_dim |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(in_dim * self.k, out_dim, dtype=dtype), |
|
|
nn.GELU(), |
|
|
nn.Linear(out_dim, out_dim, dtype=dtype), |
|
|
) |
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
batch_size, seq_len, dim = x.shape |
|
|
num_frames_to_discard = seq_len % self.k |
|
|
if num_frames_to_discard > 0: |
|
|
x = x[:, :-num_frames_to_discard, :] |
|
|
if mask is not None: |
|
|
mask = mask[:, :-num_frames_to_discard] |
|
|
if mask is None: |
|
|
mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) |
|
|
x = x.reshape( |
|
|
batch_size, -1, self.k * dim |
|
|
) |
|
|
x = self.net(x) |
|
|
mask = mask.reshape( |
|
|
batch_size, -1, self.k |
|
|
) |
|
|
mask = mask.any(dim=-1).long() |
|
|
return x, mask |
|
|
|
|
|
|
|
|
config = { |
|
|
"audio_encoder_config": { |
|
|
"attn_drop_rate": 0.0, |
|
|
"center": True, |
|
|
"depth": 32, |
|
|
"drop_rate": 0.0, |
|
|
"embed_dim": 1280, |
|
|
"f_max": 8000.0, |
|
|
"f_min": 0.0, |
|
|
"hop_length": 160, |
|
|
"init_values": None, |
|
|
"input_channels": 1, |
|
|
"mlp_ratio": 4.0, |
|
|
"model_type": "midashenglm_dasheng_encoder", |
|
|
"n_fft": 512, |
|
|
"n_mels": 64, |
|
|
"num_heads": 16, |
|
|
"outputdim": 527, |
|
|
"patch_size": [ |
|
|
64, |
|
|
4 |
|
|
], |
|
|
"patch_stride": [ |
|
|
64, |
|
|
4 |
|
|
], |
|
|
"qkv_bias": True, |
|
|
"sample_rate": 16000, |
|
|
"target_length": 1008, |
|
|
"win_length": 512 |
|
|
}, |
|
|
|
|
|
"audio_projector_config": { |
|
|
"in_dim": 1280, |
|
|
"downsample_rate": 5, |
|
|
"out_dim": 3584, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
def load_dasheng_encoder(ckpt_path, device='cuda'): |
|
|
audio_encoder_config = DashengConfig(**config["audio_encoder_config"]) |
|
|
audio_encoder = DashengAudioTransformer(audio_encoder_config) |
|
|
|
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
audio_encoder.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
audio_encoder.eval() |
|
|
return audio_encoder.to(device) |
|
|
|
|
|
|
|
|
def load_dasheng_proj(ckpt_path, device='cuda'): |
|
|
audio_projector = AudioProjectorSubsample(**config["audio_projector_config"]) |
|
|
|
|
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
|
|
|
|
audio_projector.load_state_dict(state_dict, strict=True) |
|
|
audio_projector.eval() |
|
|
return audio_projector.to(device) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
audio_encoder_config = DashengConfig(**config["audio_encoder_config"]) |
|
|
audio_encoder = DashengAudioTransformer(audio_encoder_config) |
|
|
|
|
|
state_dict = torch.load( |
|
|
"/mnt/localssd/dasheng_lm/audio_encoder.pt", |
|
|
map_location="cpu") |
|
|
|
|
|
audio_encoder.load_state_dict(state_dict, strict=True) |
|
|
|
|
|
audio = torch.randn(4, 16000*20) |
|
|
|
|
|
|
|
|
state_dict = torch.load( |
|
|
"/mnt/localssd/dasheng_lm/audio_projector.pt", |
|
|
map_location="cpu") |
|
|
|
|
|
|
|
|
|