import collections import collections.abc import torch import torch.nn as nn import torchaudio.functional as F from torch import Tensor from torch.nn.functional import scaled_dot_product_attention from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union, cast from transformers import PreTrainedModel, PretrainedConfig _Tuple2 = Union[int, Tuple[int, int], Sequence[int]] def _resolve_tuple2(x: _Tuple2) -> Tuple[int, int]: if isinstance(x, collections.abc.Sequence): assert len(x) == 2, ( f"Expected a sequence of length 2, got {x} with length {len(x)}" ) return cast(Tuple[int, int], tuple(x)) return (x, x) class DashengConfig(PretrainedConfig): model_type = "midashenglm_dasheng_encoder" def __init__( self, embed_dim: int = 768, outputdim: int = 527, patch_size: Union[int, Tuple[int, int]] = 16, patch_stride: Union[int, Tuple[int, int]] = 16, input_channels: int = 1, target_length: int = 1012, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, qkv_bias: bool = True, init_values: Optional[float] = None, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, f_min: float = 0.0, f_max: float = 8000.0, center: bool = True, win_length: int = 512, hop_length: int = 160, sample_rate: int = 16000, n_fft: int = 512, n_mels: int = 64, **kwargs, ): self.embed_dim = embed_dim self.outputdim = outputdim self.patch_size = patch_size self.patch_stride = patch_stride self.input_channels = input_channels self.target_length = target_length self.depth = depth self.num_heads = num_heads self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.init_values = init_values self.drop_rate = drop_rate self.attn_drop_rate = attn_drop_rate self.f_min = f_min self.f_max = f_max self.center = center self.win_length = win_length self.hop_length = hop_length self.sample_rate = sample_rate self.n_fft = n_fft self.n_mels = n_mels super().__init__(**kwargs) class AudioPatchEmbed(nn.Module): def __init__( self, input_size: _Tuple2 = 64, patch_size: _Tuple2 = 16, patch_stride: _Tuple2 = 16, in_chans: int = 1, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten: bool = False, ): super().__init__() self.input_size = _resolve_tuple2(input_size) self.patch_size = _resolve_tuple2(patch_size) self.patch_stride = _resolve_tuple2(patch_stride) self.grid_size = ( self.input_size[0] // self.patch_stride[0], self.input_size[1] // self.patch_stride[1], ) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d( in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_stride, ) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) if self.flatten: x = torch.permute( torch.flatten(x, 2, 3), (0, 2, 1) ) # rearrange(x, "b c f t -> b (f t) c") x = self.norm(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class DashengMlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, drop: float = 0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = nn.GELU() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class DashengAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, ): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): B, N, C = x.shape q, k, v = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) .unbind(0) ) x = scaled_dot_product_attention( q, k, v, attn_mask=mask[:, None, None, :] if mask is not None else None, ) x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class DashengBlock(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, drop: float = 0.0, attn_drop: float = 0.0, init_values: Optional[float] = None, ): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=1e-6) self.attn = DashengAttention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.norm2 = nn.LayerNorm(dim, eps=1e-6) self.mlp = DashengMlp( in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) # Kwargs usually has a mask parameter that is passed to Attention def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: x = x + self.ls1(self.attn(self.norm1(x), mask)) x = x + self.ls2(self.mlp(self.norm2(x))) return x class DashengFrontend(nn.Module): def __init__(self, config: DashengConfig): super().__init__() self.config = config spectrogram_window = torch.hann_window(self.config.win_length) self.register_buffer( "spectrogram_window", spectrogram_window, persistent=False, ) self.spectrogram_window: torch.Tensor melscale_fbanks = F.melscale_fbanks( n_freqs=self.config.n_fft // 2 + 1, f_min=self.config.f_min, f_max=self.config.f_max, n_mels=self.config.n_mels, sample_rate=self.config.sample_rate, ) self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False) self.melscale_fbanks: torch.Tensor def forward(self, waveform: torch.Tensor) -> torch.Tensor: spectrogram = F.spectrogram( waveform=waveform.to(torch.float32), pad=0, window=self.spectrogram_window, n_fft=self.config.n_fft, hop_length=self.config.hop_length, win_length=self.config.win_length, power=2, normalized=False, center=self.config.center, ) mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT # x has shape [batch, freq, time]. # F.amplitude_to_DB accepts inputs shaped as: # - [freq, time] # - [channel, freq, time] # - [..., channel, freq, time] # Here we insert a channel dimension of size 1 before calling it, # then remove that extra dimension afterward. log_mel_spectrogram = F.amplitude_to_DB( mel_spectrogram.unsqueeze(1), multiplier=10, amin=1e-10, db_multiplier=0, top_db=120, ).squeeze(1) return log_mel_spectrogram.to(waveform.dtype) class DashengAudioTransformer(PreTrainedModel): config_class = DashengConfig supports_gradient_checkpointing = True def __init__(self, config: DashengConfig): super().__init__(config) self.target_length = config.target_length self.embed_dim = config.embed_dim self.hop_length = config.hop_length self.gradient_checkpointing = False self.front_end = DashengFrontend(config) self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) self.patch_embed = AudioPatchEmbed( input_size=(config.n_mels, config.target_length), embed_dim=config.embed_dim, in_chans=config.input_channels, patch_size=config.patch_size, flatten=False, patch_stride=config.patch_stride, ) self.time_pos_embed = nn.Parameter( torch.randn(1, config.embed_dim, 1, self.patch_embed.grid_size[1]) * 0.02 ) self.freq_pos_embed = nn.Parameter( torch.randn(1, config.embed_dim, self.patch_embed.grid_size[0], 1) * 0.02 ) self.pos_drop = nn.Dropout(p=config.drop_rate) self.blocks = nn.ModuleList( DashengBlock( dim=config.embed_dim, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, init_values=config.init_values, drop=config.drop_rate, attn_drop=config.attn_drop_rate, ) for _ in range(config.depth) ) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) self.post_init() def forward_features( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: t = x.shape[-1] x = x + self.time_pos_embed[:, :, :, :t] x = ( x + self.freq_pos_embed[:, :, :, :] ) # Just to support __getitem__ in posembed x = torch.permute( torch.flatten(x, 2, 3), (0, 2, 1) ) # rearrange(x, "b c f t -> b (f t) c") x = self.pos_drop(x) for block in self.blocks: if self.gradient_checkpointing and self.training: x = self._gradient_checkpointing_func(block, x, mask) else: x = block(x, mask) x = self.norm(x) return x def _to_mask(self, lengths: torch.Tensor, max_length: int) -> torch.Tensor: batch_size = len(lengths) idx = torch.arange(max_length, device=lengths.device) idx = idx.repeat(batch_size).view(batch_size, max_length) mask = (idx < lengths.unsqueeze(-1)).bool() return mask def forward( self, x: torch.Tensor, x_length: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: x = self.front_end(x) target_length_in_patches = self.target_length // 4 x = x.unsqueeze(1) x = torch.permute(x, (0, 2, 1, 3)) x = self.init_bn(x) x = torch.permute(x, (0, 2, 1, 3)) x = self.patch_embed(x) t = x.shape[-1] input_splits = x.split(target_length_in_patches, dim=-1) if x_length is not None: assert len(x_length) == len(x), ( "batchsizes of input x and x_length need to be same" ) assert x_length.ndim == 1, "Lengths are of size (B,)" scaled_lengths = (x_length / (self.hop_length * 4)).long() mask = self._to_mask(max_length=t, lengths=scaled_lengths) split_masks = mask.split(target_length_in_patches, dim=-1) else: mask = None split_masks = [None] * len(input_splits) outputs = [] for split_x, split_mask in zip(input_splits, split_masks): forward_kwargs = {} forward_kwargs["mask"] = split_mask split_x = self.forward_features(split_x, **forward_kwargs) outputs.append(split_x) x = torch.cat(outputs, dim=1) return x, mask class AudioProjectorSubsample(nn.Module): def __init__( self, in_dim: int, out_dim: int, downsample_rate=5, dtype: Optional[torch.dtype] = None, ): super().__init__() self.k = downsample_rate self.out_dim = out_dim self.net = nn.Sequential( nn.Linear(in_dim * self.k, out_dim, dtype=dtype), nn.GELU(), nn.Linear(out_dim, out_dim, dtype=dtype), ) def forward(self, x, mask=None): batch_size, seq_len, dim = x.shape num_frames_to_discard = seq_len % self.k if num_frames_to_discard > 0: x = x[:, :-num_frames_to_discard, :] if mask is not None: mask = mask[:, :-num_frames_to_discard] if mask is None: mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device) x = x.reshape( batch_size, -1, self.k * dim ) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k) x = self.net(x) mask = mask.reshape( batch_size, -1, self.k ) # rearrange(mask, "b (s k) -> b s k", k=self.k) mask = mask.any(dim=-1).long() return x, mask config = { "audio_encoder_config": { "attn_drop_rate": 0.0, "center": True, "depth": 32, "drop_rate": 0.0, "embed_dim": 1280, "f_max": 8000.0, "f_min": 0.0, "hop_length": 160, "init_values": None, "input_channels": 1, "mlp_ratio": 4.0, "model_type": "midashenglm_dasheng_encoder", "n_fft": 512, "n_mels": 64, "num_heads": 16, "outputdim": 527, "patch_size": [ 64, 4 ], "patch_stride": [ 64, 4 ], "qkv_bias": True, "sample_rate": 16000, "target_length": 1008, "win_length": 512 }, "audio_projector_config": { "in_dim": 1280, "downsample_rate": 5, "out_dim": 3584, } } def load_dasheng_encoder(ckpt_path, device='cuda'): audio_encoder_config = DashengConfig(**config["audio_encoder_config"]) audio_encoder = DashengAudioTransformer(audio_encoder_config) state_dict = torch.load(ckpt_path, map_location="cpu") audio_encoder.load_state_dict(state_dict, strict=True) audio_encoder.eval() return audio_encoder.to(device) def load_dasheng_proj(ckpt_path, device='cuda'): audio_projector = AudioProjectorSubsample(**config["audio_projector_config"]) state_dict = torch.load(ckpt_path, map_location="cpu") audio_projector.load_state_dict(state_dict, strict=True) audio_projector.eval() return audio_projector.to(device) if __name__ == '__main__': audio_encoder_config = DashengConfig(**config["audio_encoder_config"]) audio_encoder = DashengAudioTransformer(audio_encoder_config) state_dict = torch.load( "/mnt/localssd/dasheng_lm/audio_encoder.pt", map_location="cpu") audio_encoder.load_state_dict(state_dict, strict=True) audio = torch.randn(4, 16000*20) state_dict = torch.load( "/mnt/localssd/dasheng_lm/audio_projector.pt", map_location="cpu")