File size: 6,435 Bytes
b0489bd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | from .configuration_dasheng_tokenizer import DashengTokenizerConfig
from .modeling_dasheng_encoder import DashengEncoder
from .vocos import VocosModel
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange
import torchaudio
from transformers import PreTrainedModel
class VocosMelSpec(torch.nn.Module):
"""MelSpectrogram frontend for Vocos."""
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
with torch.device("cpu"):
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
center=self.padding == "center",
power=1,)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
return torch.log(torch.clip(mel, min=1e-7))
class DashengTokenizerEncoder(torch.nn.Module):
def __init__(
self,
embed_dim: int = 1280,
depth:int = 32,
num_heads: int = 16,
n_mels_patch: int = 128,
hop_length: int = 160,
**kwargs,
):
super().__init__()
self.model = DashengEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads)
self.embed_dim = int(self.model.embed_dim)
self.model.outputlayer = torch.nn.Identity()
self.front_end = VocosMelSpec(hop_length=hop_length, n_mels=n_mels_patch)
self.patch_embed = torch.nn.Conv2d(
1, self.model.embed_dim, (n_mels_patch, 4), (n_mels_patch, 4)
)
self.norm = torch.nn.LayerNorm(self.model.embed_dim)
# Store parameters for reference
self.n_fft = self.model.front_end.n_fft
self.hop_size = self.model.front_end.hop_size
@torch.no_grad()
def forward(
self,
input: torch.Tensor,
input_attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass of the encoder.
Args:
input: Audio tensor of shape (batch_size, num_samples)
input_attn_mask: Optional attention mask
Returns:
Combined embeddings of shape (batch_size, num_tokens, embed_dim)
"""
with torch.no_grad():
semantic_emb = self.model(input, input_attn_mask)
# acoustic part
mel = self.front_end(input).unsqueeze(1)
mel_emb = self.patch_embed(mel)
acoustic_emb = rearrange(mel_emb, "b c f t -> b (f t) c")
acoustic_emb = self.norm(acoustic_emb)
semantic_emb = semantic_emb[:, : acoustic_emb.shape[1], :]
emb = semantic_emb + acoustic_emb
return emb
class DashengTokenizerPreTrainedModel(PreTrainedModel):
config_class = DashengTokenizerConfig
supports_gradient_checkpointing = True
class DashengTokenizerModel(DashengTokenizerPreTrainedModel):
"""
HuggingFace-compatible DashEng Tokenizer Model (Encoder + Decoder).
This model includes both the encoder and decoder for end-to-end audio processing.
"""
def __init__(self, config: DashengTokenizerConfig):
super().__init__(config)
self.config = config
self.encoder = DashengTokenizerEncoder(
embed_dim=config.embed_dim,
depth = config.depth,
num_heads=config.num_heads,
n_mels_patch=config.n_mels_patch,
hop_length=config.hop_length,
)
self.embed_dim = self.encoder.embed_dim
# Upsampler (if needed)
self.upsampler = None
if config.upsample_tokens > 1:
self.upsampler = torch.nn.ConvTranspose1d(
self.embed_dim, self.embed_dim,
kernel_size=config.upsample_tokens,
stride=config.upsample_tokens
)
# Decoder
self.decoder = VocosModel(
input_channels=self.embed_dim,
hidden_dim=config.decoder_embed_dim,
intermediate_dim=config.decoder_intermediate_size,
vocos_istft_hop=config.istft_hop,
vocos_n_fft=config.istft_n_fft,
num_layers=config.decoder_depth,
)
self.post_init()
def encode(
self,
audio: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode audio into embeddings."""
return self.encoder(audio, attention_mask)
def decode(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Decode embeddings back to audio."""
if self.upsampler is not None:
embeddings = self.upsampler(embeddings.transpose(-2, -1)).transpose(-2, -1)
output = self.decoder(embeddings.transpose(-2, -1))
return output
def forward(
self,
audio: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], dict]:
"""
Forward pass of the DashEng tokenizer.
Args:
audio: Audio tensor of shape (batch_size, num_samples)
attention_mask: Optional attention mask
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return hidden states
return_dict: Whether to return a dict
Returns:
Reconstructed audio of shape (batch_size, num_samples)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode
embeddings = self.encoder(audio, attention_mask)
# Decode
audio_reconstructed = self.decode(embeddings)
if not return_dict:
return (audio_reconstructed,)
return {
"audio": audio_reconstructed,
"embeddings": embeddings,
}
|