dashengtokenizer / modeling_dasheng_tokenizer.py
richermans's picture
Upload folder using huggingface_hub
b0489bd verified
from .configuration_dasheng_tokenizer import DashengTokenizerConfig
from .modeling_dasheng_encoder import DashengEncoder
from .vocos import VocosModel
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange
import torchaudio
from transformers import PreTrainedModel
class VocosMelSpec(torch.nn.Module):
"""MelSpectrogram frontend for Vocos."""
def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.sample_rate = sample_rate
self.n_fft = n_fft
self.hop_length = hop_length
self.n_mels = n_mels
with torch.device("cpu"):
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
center=self.padding == "center",
power=1,)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
return torch.log(torch.clip(mel, min=1e-7))
class DashengTokenizerEncoder(torch.nn.Module):
def __init__(
self,
embed_dim: int = 1280,
depth:int = 32,
num_heads: int = 16,
n_mels_patch: int = 128,
hop_length: int = 160,
**kwargs,
):
super().__init__()
self.model = DashengEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads)
self.embed_dim = int(self.model.embed_dim)
self.model.outputlayer = torch.nn.Identity()
self.front_end = VocosMelSpec(hop_length=hop_length, n_mels=n_mels_patch)
self.patch_embed = torch.nn.Conv2d(
1, self.model.embed_dim, (n_mels_patch, 4), (n_mels_patch, 4)
)
self.norm = torch.nn.LayerNorm(self.model.embed_dim)
# Store parameters for reference
self.n_fft = self.model.front_end.n_fft
self.hop_size = self.model.front_end.hop_size
@torch.no_grad()
def forward(
self,
input: torch.Tensor,
input_attn_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass of the encoder.
Args:
input: Audio tensor of shape (batch_size, num_samples)
input_attn_mask: Optional attention mask
Returns:
Combined embeddings of shape (batch_size, num_tokens, embed_dim)
"""
with torch.no_grad():
semantic_emb = self.model(input, input_attn_mask)
# acoustic part
mel = self.front_end(input).unsqueeze(1)
mel_emb = self.patch_embed(mel)
acoustic_emb = rearrange(mel_emb, "b c f t -> b (f t) c")
acoustic_emb = self.norm(acoustic_emb)
semantic_emb = semantic_emb[:, : acoustic_emb.shape[1], :]
emb = semantic_emb + acoustic_emb
return emb
class DashengTokenizerPreTrainedModel(PreTrainedModel):
config_class = DashengTokenizerConfig
supports_gradient_checkpointing = True
class DashengTokenizerModel(DashengTokenizerPreTrainedModel):
"""
HuggingFace-compatible DashEng Tokenizer Model (Encoder + Decoder).
This model includes both the encoder and decoder for end-to-end audio processing.
"""
def __init__(self, config: DashengTokenizerConfig):
super().__init__(config)
self.config = config
self.encoder = DashengTokenizerEncoder(
embed_dim=config.embed_dim,
depth = config.depth,
num_heads=config.num_heads,
n_mels_patch=config.n_mels_patch,
hop_length=config.hop_length,
)
self.embed_dim = self.encoder.embed_dim
# Upsampler (if needed)
self.upsampler = None
if config.upsample_tokens > 1:
self.upsampler = torch.nn.ConvTranspose1d(
self.embed_dim, self.embed_dim,
kernel_size=config.upsample_tokens,
stride=config.upsample_tokens
)
# Decoder
self.decoder = VocosModel(
input_channels=self.embed_dim,
hidden_dim=config.decoder_embed_dim,
intermediate_dim=config.decoder_intermediate_size,
vocos_istft_hop=config.istft_hop,
vocos_n_fft=config.istft_n_fft,
num_layers=config.decoder_depth,
)
self.post_init()
def encode(
self,
audio: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode audio into embeddings."""
return self.encoder(audio, attention_mask)
def decode(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Decode embeddings back to audio."""
if self.upsampler is not None:
embeddings = self.upsampler(embeddings.transpose(-2, -1)).transpose(-2, -1)
output = self.decoder(embeddings.transpose(-2, -1))
return output
def forward(
self,
audio: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], dict]:
"""
Forward pass of the DashEng tokenizer.
Args:
audio: Audio tensor of shape (batch_size, num_samples)
attention_mask: Optional attention mask
output_attentions: Whether to return attention weights
output_hidden_states: Whether to return hidden states
return_dict: Whether to return a dict
Returns:
Reconstructed audio of shape (batch_size, num_samples)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode
embeddings = self.encoder(audio, attention_mask)
# Decode
audio_reconstructed = self.decode(embeddings)
if not return_dict:
return (audio_reconstructed,)
return {
"audio": audio_reconstructed,
"embeddings": embeddings,
}