# Copyright (c) (Mddct: Dinghao Zhou) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple import torch from einops import rearrange from s3tokenizer.model import Conv1d, LayerNorm, Linear, MultiHeadAttention from s3tokenizer.utils import make_non_pad_mask, mask_to_bias, onnx2torch, merge_tokenized_segments @dataclass class ModelConfig: n_mels: int = 128 n_audio_ctx: int = 1500 n_audio_state: int = 1280 n_audio_head: int = 20 n_audio_layer: int = 6 n_codebook_size: int = 3**8 use_sdpa: bool = False def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, scaling=None): freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore if scaling is not None: t = t * scaling freqs = torch.outer(t, freqs).float() # type: ignore freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return torch.cat((freqs_cis, freqs_cis), dim=-1) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: real = torch.view_as_real(freqs_cis) cos, sin = real[:, :, 0], real[:, :, 1] cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) D = xq.shape[-1] half_l, half_r = xq[:, :, :, :D // 2], xq[:, :, :, D // 2:] xq_r = torch.cat((-half_r, half_l), dim=-1) D = xk.shape[-1] half_l, half_r = xk[:, :, :, :D // 2], xk[:, :, :, D // 2:] xk_r = torch.cat((-half_r, half_l), dim=-1) return xq * cos + xq_r * sin, xk * cos + xk_r * sin def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [ d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) ] return freqs_cis.view(*shape) class FSQCodebook(torch.nn.Module): def __init__(self, dim: int, level: int = 3): super().__init__() self.project_down = torch.nn.Linear(dim, 8) self.level = level self.embed = None @torch.inference_mode() def preprocess(self, x: torch.Tensor) -> torch.Tensor: x = rearrange(x, "... d -> (...) d") return x @torch.inference_mode() def encode(self, x: torch.Tensor) -> torch.Tensor: x_shape = x.shape # pre-process x = self.preprocess(x) # quantize h = self.project_down(x).float() h = h.tanh() h = h * 0.9990000128746033 h = h.round() + 1 # h = ((self.level - 1) * h).round() # range [-k, k] powers = torch.pow( self.level, torch.arange(2**self.level, device=x.device, dtype=h.dtype)) mu = torch.sum(h * powers.unsqueeze(0), dim=-1) ind = mu.reshape(x_shape[0], x_shape[1]).int() return ind @torch.inference_mode() def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: raise NotImplementedError( 'There is no official up project component provided') class FSQVectorQuantization(torch.nn.Module): """Vector quantization implementation (inference-only). Args: dim (int): Dimension codebook_size (int): Codebook size """ def __init__( self, dim: int, codebook_size: int, ): super().__init__() assert 3**8 == codebook_size self._codebook = FSQCodebook(dim=dim, level=3) self.codebook_size = codebook_size @property def codebook(self): return self._codebook.embed @torch.inference_mode() def encode(self, x: torch.Tensor) -> torch.Tensor: return self._codebook.encode(x) @torch.inference_mode() def decode(self, embed_ind: torch.Tensor) -> torch.Tensor: quantize = self._codebook.decode(embed_ind) quantize = rearrange(quantize, "b n d -> b d n") return quantize class FSMNMultiHeadAttention(MultiHeadAttention): def __init__( self, n_state: int, n_head: int, kernel_size: int = 31, use_sdpa: bool = False, ): super().__init__(n_state, n_head) self.fsmn_block = torch.nn.Conv1d(n_state, n_state, kernel_size, stride=1, padding=0, groups=n_state, bias=False) self.left_padding = (kernel_size - 1) // 2 self.right_padding = kernel_size - 1 - self.left_padding self.pad_fn = torch.nn.ConstantPad1d( (self.left_padding, self.right_padding), 0.0) self.use_sdpa = use_sdpa def forward_fsmn(self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None): b, t, _, _ = inputs.size() inputs = inputs.view(b, t, -1) if mask is not None and mask.size(2) > 0: # time2 > 0 inputs = inputs * mask x = inputs.transpose(1, 2) x = self.pad_fn(x) x = self.fsmn_block(x) x = x.transpose(1, 2) x += inputs return x * mask def qkv_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None, mask_pad: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None): _, _, D = q.shape scale = (D // self.n_head)**-0.25 q = q.view(*q.shape[:2], self.n_head, -1) k = k.view(*k.shape[:2], self.n_head, -1) v = v.view(*v.shape[:2], self.n_head, -1) if freqs_cis is not None: q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) fsm_memory = self.forward_fsmn(v, mask_pad) q = q.permute(0, 2, 1, 3) * scale v = v.permute(0, 2, 1, 3) if not self.use_sdpa: k = k.permute(0, 2, 3, 1) * scale qk = q @ k # (B, n_head, T, T) if mask is not None: qk = qk + mask qk = qk.float() w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype) return (w @ v).permute( 0, 2, 1, 3).flatten(start_dim=2), qk.detach(), fsm_memory else: k = k.permute(0, 2, 1, 3) * scale assert mask is not None output = torch.nn.functional.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0., scale=1., ) output = (output.transpose(1, 2).contiguous().view(q.size(0), -1, D) ) # (batch, time1, d_model) return output, None, fsm_memory def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, mask_pad: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None): q = self.query(x) k = self.key(x) v = self.value(x) wv, qk, fsm_memory = self.qkv_attention(q, k, v, mask, mask_pad, freqs_cis) return self.out(wv) + fsm_memory, qk class ResidualAttentionBlock(torch.nn.Module): def __init__( self, n_state: int, n_head: int, kernel_size: int = 31, use_sdpa: bool = False, ): super().__init__() self.attn = FSMNMultiHeadAttention(n_state, n_head, kernel_size, use_sdpa=use_sdpa) self.attn_ln = LayerNorm(n_state, eps=1e-6) n_mlp = n_state * 4 self.mlp = torch.nn.Sequential(Linear(n_state, n_mlp), torch.nn.GELU(), Linear(n_mlp, n_state)) self.mlp_ln = LayerNorm(n_state) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, mask_pad: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, ): x = x + self.attn( self.attn_ln(x), mask=mask, mask_pad=mask_pad, freqs_cis=freqs_cis)[0] x = x + self.mlp(self.mlp_ln(x)) return x class AudioEncoderV2(torch.nn.Module): def __init__( self, n_mels: int, n_state: int, n_head: int, n_layer: int, stride: int, use_sdpa: bool, ): super().__init__() self.stride = stride self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=stride, padding=1) self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) self.freqs_cis = precompute_freqs_cis(64, 1024 * 2) self.blocks = torch.nn.ModuleList([ ResidualAttentionBlock(n_state, n_head, use_sdpa=use_sdpa) for _ in range(n_layer) ]) def forward(self, x: torch.Tensor, x_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ x : torch.Tensor, shape = (batch_size, n_mels, T) the mel spectrogram of the audio x_len: torch.Tensor, shape = (batch_size,) length of each audio in x """ mask = make_non_pad_mask(x_len).unsqueeze(1) x = torch.nn.functional.gelu(self.conv1(x * mask)) x_len = (x_len + 2 - 1 * (3 - 1) - 1) // self.stride + 1 mask = make_non_pad_mask(x_len).unsqueeze(1) x = torch.nn.functional.gelu(self.conv2(x * mask)) x_len = (x_len + 2 - 1 * (3 - 1) - 1) // 2 + 1 mask = make_non_pad_mask(x_len).unsqueeze(1) x = x.permute(0, 2, 1) # (B, T // 2, n_state) freqs_cis = self.freqs_cis.to(x.device) mask_pad = mask.transpose(1, 2) mask = mask_to_bias(mask, x.dtype) tmp = torch.view_as_real(freqs_cis) cos, sin = tmp[:, :, 0], tmp[:, :, 1] cos = torch.cat((cos, cos), dim=-1) sin = torch.cat((sin, sin), dim=-1) cos = cos.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2) for block in self.blocks: x = block(x, mask.unsqueeze(1), mask_pad, freqs_cis[:x.size(1)]) return x, x_len class S3TokenizerV2(torch.nn.Module): """S3 tokenizer v2 implementation (inference-only). Args: config (ModelConfig): Config """ def __init__(self, name: str, config: ModelConfig = ModelConfig()): super().__init__() self.name = name # Store model name for token_rate determination if 'v1' not in name: assert 'v2' in name # TODO(Mddct): make it configureable config.n_codebook_size = 3**8 self.config = config self.encoder = AudioEncoderV2( self.config.n_mels, self.config.n_audio_state, self.config.n_audio_head, self.config.n_audio_layer, 2, self.config.use_sdpa, ) self.quantizer = FSQVectorQuantization( self.config.n_audio_state, self.config.n_codebook_size, ) def forward(self, mel: torch.Tensor, mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return self.quantize(mel, mel_len) @torch.inference_mode() def quantize(self, mel: torch.Tensor, mel_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize mel spectrogram to tokens, with automatic long audio handling. Args: mel: mel spectrogram tensor, shape (batch_size, n_mels, T) mel_len: mel length tensor, shape (batch_size,) Returns: code: quantized tokens, shape (batch_size, T') code_len: token length, shape (batch_size,) """ # Check if any audio in the batch exceeds 30 seconds # Assuming 16kHz sample rate and hop_length=160, 30s = 30*16000/160 = 3000 frames max_frames = 3000 # Check which samples are long audio long_audio_mask = mel_len > max_frames if long_audio_mask.any(): # Has long audio - need special processing return self._quantize_mixed_batch(mel, mel_len, long_audio_mask, max_frames) else: # All short audio - use original method hidden, code_len = self.encoder(mel, mel_len) code = self.quantizer.encode(hidden) return code, code_len @torch.inference_mode() def _quantize_mixed_batch( self, mel: torch.Tensor, mel_len: torch.Tensor, long_audio_mask: torch.Tensor, max_frames: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Handle mixed batch with both short and long audio using unified batch processing. Args: mel: mel spectrogram tensor, shape (batch_size, n_mels, T) mel_len: mel length tensor, shape (batch_size,) long_audio_mask: boolean mask for long audio, shape (batch_size,) max_frames: maximum frames for short audio Returns: code: quantized tokens, shape (batch_size, T') code_len: token length, shape (batch_size,) """ batch_size = mel.size(0) # Parameters for sliding window sample_rate = 16000 hop_length = 160 # Default hop length for mel spectrogram window_size = 30 # seconds overlap = 4 # seconds # Calculate frame-based parameters frames_per_window = window_size * sample_rate // hop_length # 3000 frames frames_per_overlap = overlap * sample_rate // hop_length # 400 frames frames_per_stride = frames_per_window - frames_per_overlap # 2600 frames # Collect all segments to process (including short and long audio segments) all_segments = [] all_segments_len = [] segment_info = [ ] # Record which audio each segment belongs to and whether it's long audio # Process all audio in the batch for batch_idx in range(batch_size): audio_mel = mel[batch_idx] audio_mel_len = mel_len[batch_idx] is_long_audio = long_audio_mask[batch_idx].item() if not is_long_audio: # Short audio: process directly as a single segment segment = audio_mel[:, :audio_mel_len] seg_len = audio_mel_len.item() # Pad to max_frames if necessary if seg_len < frames_per_window: pad_size = frames_per_window - seg_len segment = torch.nn.functional.pad(segment, (0, pad_size)) all_segments.append(segment) all_segments_len.append( torch.tensor(seg_len, device=mel.device)) segment_info.append({ 'batch_idx': batch_idx, 'is_long_audio': False, 'segment_idx': 0, 'total_segments': 1 }) else: # Long audio: split into multiple segments start = 0 segment_idx = 0 while start < audio_mel_len: end = min(start + frames_per_window, audio_mel_len) segment = audio_mel[:, start:end] seg_len = segment.size(1) # Pad if necessary if seg_len < frames_per_window: pad_size = frames_per_window - seg_len segment = torch.nn.functional.pad( segment, (0, pad_size)) all_segments.append(segment) all_segments_len.append( torch.tensor(seg_len, device=mel.device)) segment_info.append({ 'batch_idx': batch_idx, 'is_long_audio': True, 'segment_idx': segment_idx, 'total_segments': None # Will be filled later }) segment_idx += 1 start += frames_per_stride # Update total_segments info total_segments = segment_idx for info in segment_info: if info['batch_idx'] == batch_idx and info['is_long_audio']: info['total_segments'] = total_segments if not all_segments: # Fallback if no segments return torch.zeros(batch_size, 0, dtype=torch.long, device=mel.device), torch.zeros( batch_size, dtype=torch.long, device=mel.device) # Unified batch processing for all segments unified_batch_mel = torch.stack(all_segments) unified_batch_lens = torch.stack(all_segments_len) # Process all segments at once hidden, code_len = self.encoder(unified_batch_mel, unified_batch_lens) codes = self.quantizer.encode(hidden) # Reorganize results based on segment_info results = {} # batch_idx -> (code_tensor, code_len) for seg_idx, info in enumerate(segment_info): batch_idx = info['batch_idx'] is_long_audio = info['is_long_audio'] segment_idx = info['segment_idx'] # Get codes for current segment segment_code = codes[ seg_idx, :code_len[seg_idx].item()].cpu().numpy().tolist() if not is_long_audio: # Short audio: use directly code_tensor = torch.tensor(segment_code, dtype=torch.long, device=mel.device) results[batch_idx] = (code_tensor, len(segment_code)) else: # Long audio: collect all segments if batch_idx not in results: results[batch_idx] = [] results[batch_idx].append(segment_code) # Process long audio segment merging for batch_idx in range(batch_size): if long_audio_mask[batch_idx].item(): # Merge long audio segments audio_codes = results[batch_idx] # V2 models use 25Hz token rate token_rate = 25 merged_codes = merge_tokenized_segments(audio_codes, overlap=overlap, token_rate=token_rate) # Convert to tensor merged_codes_tensor = torch.tensor(merged_codes, dtype=torch.long, device=mel.device) results[batch_idx] = (merged_codes_tensor, len(merged_codes)) # Construct final output max_code_len = max(code_info[1] for code_info in results.values()) output_codes = torch.zeros(batch_size, max_code_len, dtype=torch.long, device=mel.device) output_codes_len = torch.zeros(batch_size, dtype=torch.long, device=mel.device) for batch_idx, (code_tensor, code_len) in results.items(): output_codes[batch_idx, :code_len] = code_tensor output_codes_len[batch_idx] = code_len return output_codes, output_codes_len @property def device(self): return next(self.parameters()).device def init_from_onnx(self, onnx_path: str): ckpt = onnx2torch(onnx_path, None, False) self.load_state_dict(ckpt, strict=True) def init_from_pt(self, ckpt_path: str): ckpt = torch.load(ckpt_path, map_location="cpu", mmap=True) self.load_state_dict(ckpt, strict=True) def freeze(self): for _, param in self.named_parameters(): param.requires_grad = False