| import math
|
| import typing as tp
|
| from functools import partial
|
| from dataclasses import dataclass, field
|
| from typing import Dict, List, Optional, Tuple, Union
|
| import copy
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
|
|
| from transformers.models.auto import AutoModel
|
|
|
| from transformers.configuration_utils import PretrainedConfig
|
| from transformers.utils import logging
|
| from transformers.modeling_utils import PreTrainedModel
|
| from transformers.activations import ACT2FN
|
|
|
| from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
| import os
|
|
|
| try:
|
| from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
| APEX_AVAILABLE = True
|
| logger.info("APEX FusedRMSNorm is available and will be used for optimization")
|
| if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0:
|
| APEX_AVAILABLE = False
|
| logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0")
|
| except ImportError:
|
| APEX_AVAILABLE = False
|
| logger.warning("APEX FusedRMSNorm not available, using native implementation")
|
|
|
|
|
|
|
| class ConvLayerNorm(nn.LayerNorm):
|
| """
|
| Convolution-friendly LayerNorm that moves channels to last dimensions
|
| before running the normalization and moves them back to original position right after.
|
| """
|
| def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs):
|
| super().__init__(normalized_shape, **kwargs)
|
|
|
| def forward(self, x):
|
| x = x.transpose(1, 2)
|
| x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x)
|
| x = x.transpose(1, 2)
|
| return x
|
|
|
| class RMSNorm(nn.Module):
|
| def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| super().__init__()
|
| self.dim = dim
|
| self.eps = eps
|
| self.elementwise_affine = elementwise_affine
|
| if self.elementwise_affine:
|
| weight_shape = (dim,) if weight_shape is None else weight_shape
|
| self.weight = nn.Parameter(torch.ones(weight_shape))
|
| else:
|
| self.register_parameter('weight', None)
|
|
|
| def _norm(self, x):
|
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
| def forward(self, x):
|
| output = self._norm(x.float()).type_as(x)
|
| if self.weight is not None:
|
| output = output * self.weight
|
| return output
|
|
|
| def extra_repr(self) -> str:
|
| return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}'
|
|
|
| class ConvRMSNorm(RMSNorm):
|
| def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None):
|
| super().__init__(dim, eps, elementwise_affine, weight_shape)
|
|
|
| def forward(self, x):
|
| x = x.transpose(1, 2)
|
| if (not APEX_AVAILABLE) or (not self.elementwise_affine):
|
|
|
| output = self._norm(x.float()).type_as(x)
|
| if self.weight is not None:
|
| output = output * self.weight
|
| else:
|
| output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps)
|
| output = output.transpose(1, 2)
|
| return output
|
|
|
|
|
| CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
| 'time_layer_norm', 'layer_norm', 'time_group_norm'])
|
|
|
|
|
| def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module:
|
| assert norm in CONV_NORMALIZATIONS
|
| if norm == 'weight_norm':
|
| return nn.utils.weight_norm(module)
|
| elif norm == 'spectral_norm':
|
| return nn.utils.spectral_norm(module)
|
| else:
|
|
|
|
|
| return module
|
|
|
|
|
| def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module:
|
| """Return the proper normalization module. If causal is True, this will ensure the returned
|
| module is causal, or return an error if the normalization doesn't support causal evaluation.
|
| """
|
| assert norm in CONV_NORMALIZATIONS
|
| if norm == 'layer_norm':
|
| assert isinstance(module, nn.modules.conv._ConvNd)
|
| return ConvLayerNorm(module.out_channels, **norm_kwargs)
|
| elif norm == 'time_group_norm':
|
| if causal:
|
| raise ValueError("GroupNorm doesn't support causal evaluation.")
|
| assert isinstance(module, nn.modules.conv._ConvNd)
|
| return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
| else:
|
| return nn.Identity()
|
|
|
|
|
| def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
| padding_total: int = 0) -> int:
|
| """Calculate extra padding needed for convolution to have the same output length"""
|
| length = x.shape[-1]
|
| n_frames = (length - kernel_size + padding_total) / stride + 1
|
| ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
| return ideal_length - length
|
|
|
|
|
| def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.):
|
| """Pad 1D input with handling for small inputs in reflect mode"""
|
| length = x.shape[-1]
|
| padding_left, padding_right = paddings
|
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| if mode == 'reflect':
|
| max_pad = max(padding_left, padding_right)
|
| extra_pad = 0
|
| if length <= max_pad:
|
| extra_pad = max_pad - length + 1
|
| x = F.pad(x, (0, extra_pad))
|
| padded = F.pad(x, paddings, mode, value)
|
| end = padded.shape[-1] - extra_pad
|
| return padded[..., :end]
|
| else:
|
| return F.pad(x, paddings, mode, value)
|
|
|
|
|
| def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
| """Remove padding from x, handling properly zero padding. Only for 1d!"""
|
| padding_left, padding_right = paddings
|
| assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
| assert (padding_left + padding_right) <= x.shape[-1]
|
| end = x.shape[-1] - padding_right
|
| return x[..., padding_left: end]
|
|
|
|
|
| class NormConv1d(nn.Module):
|
| """Wrapper around Conv1d and normalization applied to this conv"""
|
| def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| super().__init__()
|
| self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
| self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
| self.norm_type = norm
|
|
|
| def forward(self, x):
|
| x = self.conv(x)
|
| x = self.norm(x)
|
| return x
|
|
|
|
|
| class NormConvTranspose1d(nn.Module):
|
| """Wrapper around ConvTranspose1d and normalization applied to this conv"""
|
| def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
| norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
| super().__init__()
|
| self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
| self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
| self.norm_type = norm
|
|
|
| def forward(self, x):
|
| x = self.convtr(x)
|
| x = self.norm(x)
|
| return x
|
|
|
|
|
| class VibeVoiceTokenizerStreamingCache:
|
| """Cache for streaming convolution, similar to KV cache in attention"""
|
| def __init__(self):
|
| self.cache = {}
|
|
|
| def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]:
|
| """Get cached states for given layer and sample indices"""
|
| states = []
|
| max_length = 0
|
|
|
|
|
| for idx in sample_indices.tolist():
|
| key = (layer_id, idx)
|
| if key not in self.cache:
|
| return None
|
| state = self.cache[key]
|
| states.append(state)
|
| max_length = max(max_length, state.shape[-1])
|
|
|
|
|
| if len(states) > 0 and states[0].dim() >= 2:
|
| padded_states = []
|
| for state in states:
|
| if state.shape[-1] < max_length:
|
|
|
| pad_size = max_length - state.shape[-1]
|
|
|
| padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0)
|
| padded_states.append(padded_state)
|
| else:
|
| padded_states.append(state)
|
| return torch.stack(padded_states, dim=0)
|
| else:
|
| return torch.stack(states, dim=0)
|
|
|
| def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor):
|
| """Set cached states for given layer and sample indices"""
|
| for i, idx in enumerate(sample_indices.tolist()):
|
| key = (layer_id, idx)
|
| self.cache[key] = states[i].detach()
|
|
|
| def set_to_zero(self, sample_indices: torch.Tensor):
|
| """Set all cached states to zero for given sample indices"""
|
| for key in list(self.cache.keys()):
|
| layer_id, sample_idx = key
|
| if sample_idx in sample_indices.tolist():
|
|
|
| cached_tensor = self.cache[key]
|
| self.cache[key] = torch.zeros_like(cached_tensor)
|
|
|
| def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None):
|
| """Clear cache for specific layer/samples or everything"""
|
| if layer_id is None and sample_indices is None:
|
| self.cache.clear()
|
| elif layer_id is not None and sample_indices is None:
|
|
|
| keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id]
|
| for k in keys_to_remove:
|
| del self.cache[k]
|
| elif layer_id is not None and sample_indices is not None:
|
|
|
| for idx in sample_indices.tolist():
|
| key = (layer_id, idx)
|
| self.cache.pop(key, None)
|
|
|
| class SConv1d(nn.Module):
|
| """Conv1d with built-in handling of asymmetric or causal padding and normalization."""
|
| def __init__(self, in_channels: int, out_channels: int,
|
| kernel_size: int, stride: int = 1, dilation: int = 1,
|
| groups: int = 1, bias: bool = True, causal: bool = False,
|
| norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
| pad_mode: str = 'reflect'):
|
| super().__init__()
|
| self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
| dilation=dilation, groups=groups, bias=bias, causal=causal,
|
| norm=norm, norm_kwargs=norm_kwargs)
|
| self.causal = causal
|
| self.pad_mode = pad_mode
|
|
|
|
|
| self.kernel_size = kernel_size
|
| self.dilation = dilation
|
| self.stride = stride
|
| self.in_channels = in_channels
|
| self.out_channels = out_channels
|
|
|
|
|
|
|
|
|
| self.context_size = (kernel_size - 1) * dilation - (stride - 1)
|
|
|
|
|
| self.padding_total = (kernel_size - 1) * dilation - (stride - 1)
|
|
|
|
|
| self._layer_id = None
|
|
|
| @property
|
| def layer_id(self):
|
| if self._layer_id is None:
|
| self._layer_id = f"sconv1d_{id(self)}"
|
| return self._layer_id
|
|
|
| def forward(self, x: torch.Tensor,
|
| cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
| sample_indices: Optional[torch.Tensor] = None,
|
| use_cache: bool = False,
|
| debug: bool = False) -> torch.Tensor:
|
| """
|
| Forward pass with optional streaming support via cache.
|
|
|
| Args:
|
| x: Input tensor [batch_size, channels, time]
|
| cache: VibeVoiceTokenizerStreamingCache object for maintaining states
|
| sample_indices: Indices identifying each sample for cache management
|
| use_cache: Whether to use cached states for streaming
|
| debug: Whether to print debug information
|
|
|
| Returns:
|
| Output tensor
|
| """
|
| B, C, T = x.shape
|
|
|
|
|
| if not use_cache or cache is None:
|
| return self._forward_non_streaming(x, debug=debug)
|
|
|
|
|
| assert self.causal, "Streaming mode is only supported for causal convolutions"
|
| assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| assert len(sample_indices) == B, "sample_indices must match batch size"
|
|
|
| return self._forward_streaming(x, cache, sample_indices, debug)
|
|
|
| def _forward_streaming(self, x: torch.Tensor,
|
| cache: VibeVoiceTokenizerStreamingCache,
|
| sample_indices: torch.Tensor,
|
| debug: bool = False) -> torch.Tensor:
|
| """Streaming forward pass with cache operations kept separate from compiled code"""
|
| B, C, T = x.shape
|
|
|
|
|
| cached_states = cache.get(self.layer_id, sample_indices)
|
|
|
| if cached_states is None:
|
|
|
| if self.context_size > 0:
|
| cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype)
|
| if debug:
|
| print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}")
|
| else:
|
| cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| if debug:
|
| print(f"[DEBUG] No context needed (kernel_size=stride)")
|
|
|
|
|
| if cached_states.shape[2] > 0:
|
| input_with_context = torch.cat([cached_states, x], dim=2)
|
| else:
|
| input_with_context = x
|
|
|
| if debug:
|
| print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}")
|
|
|
|
|
|
|
| output = self.conv(input_with_context)
|
|
|
| if debug:
|
| print(f"[DEBUG] Output shape: {output.shape}")
|
|
|
|
|
| if self.context_size > 0:
|
|
|
| total_input_length = input_with_context.shape[2]
|
|
|
|
|
| if total_input_length >= self.context_size:
|
| new_cache_start = total_input_length - self.context_size
|
| new_cache = input_with_context[:, :, new_cache_start:]
|
| else:
|
|
|
| new_cache = input_with_context
|
|
|
| if debug:
|
| print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
|
|
| cache.set(self.layer_id, sample_indices, new_cache)
|
|
|
| return output
|
|
|
| def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| """Standard forward pass without streaming"""
|
| B, C, T = x.shape
|
| kernel_size = self.kernel_size
|
| stride = self.stride
|
| dilation = self.dilation
|
| padding_total = self.padding_total
|
|
|
|
|
| extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
|
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}")
|
|
|
| if self.causal:
|
|
|
| if self.pad_mode == 'constant':
|
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0)
|
| else:
|
| x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
| else:
|
|
|
| padding_right = padding_total // 2
|
| padding_left = padding_total - padding_right
|
| x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
|
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] After padding: {x.shape}")
|
|
|
| output = self.conv(x)
|
|
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}")
|
|
|
| return output
|
|
|
|
|
| class SConvTranspose1d(nn.Module):
|
| """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization."""
|
| def __init__(self, in_channels: int, out_channels: int,
|
| kernel_size: int, stride: int = 1, causal: bool = False,
|
| norm: str = 'none', trim_right_ratio: float = 1.,
|
| norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True):
|
| super().__init__()
|
| self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
| causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias)
|
| self.causal = causal
|
| self.trim_right_ratio = trim_right_ratio
|
| assert self.causal or self.trim_right_ratio == 1., \
|
| "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
| assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
|
|
|
|
| self.kernel_size = kernel_size
|
| self.stride = stride
|
| self.in_channels = in_channels
|
| self.out_channels = out_channels
|
|
|
|
|
| self.padding_total = kernel_size - stride
|
|
|
|
|
|
|
| self.context_size = kernel_size - 1
|
|
|
|
|
| self._layer_id = None
|
|
|
| @property
|
| def layer_id(self):
|
| if self._layer_id is None:
|
| self._layer_id = f"sconvtr1d_{id(self)}"
|
| return self._layer_id
|
|
|
| def forward(self, x: torch.Tensor,
|
| cache: Optional[VibeVoiceTokenizerStreamingCache] = None,
|
| sample_indices: Optional[torch.Tensor] = None,
|
| use_cache: bool = False,
|
| debug: bool = False) -> torch.Tensor:
|
| """
|
| Forward pass with optional streaming support via cache.
|
| """
|
| B, C, T = x.shape
|
|
|
|
|
| if not use_cache or cache is None:
|
| return self._forward_non_streaming(x, debug=debug)
|
|
|
|
|
| assert sample_indices is not None, "sample_indices must be provided for streaming mode"
|
| assert len(sample_indices) == B, "sample_indices must match batch size"
|
|
|
| return self._forward_streaming(x, cache, sample_indices, debug)
|
|
|
| def _forward_streaming(self, x: torch.Tensor,
|
| cache: VibeVoiceTokenizerStreamingCache,
|
| sample_indices: torch.Tensor,
|
| debug: bool = False) -> torch.Tensor:
|
| """Streaming forward pass with cache operations kept separate from compiled code"""
|
| B, C, T = x.shape
|
|
|
|
|
| cached_input = cache.get(self.layer_id, sample_indices)
|
|
|
| if cached_input is None:
|
|
|
| cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype)
|
| if debug:
|
| print(f"[DEBUG] Initialized empty cache for transposed conv")
|
|
|
|
|
| full_input = torch.cat([cached_input, x], dim=2)
|
|
|
| if debug:
|
| print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}")
|
|
|
|
|
| full_output = self.convtr(full_input)
|
|
|
| if debug:
|
| print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}")
|
|
|
|
|
| if self.causal:
|
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| padding_left = self.padding_total - padding_right
|
| else:
|
| padding_right = self.padding_total // 2
|
| padding_left = self.padding_total - padding_right
|
|
|
|
|
| if padding_left + padding_right > 0:
|
| full_output = unpad1d(full_output, (padding_left, padding_right))
|
|
|
| if debug:
|
| print(f"[DEBUG] After unpadding: {full_output.shape}")
|
|
|
|
|
| if cached_input.shape[2] == 0:
|
|
|
| output = full_output
|
| else:
|
|
|
| expected_new_output = T * self.stride
|
|
|
|
|
| if full_output.shape[2] >= expected_new_output:
|
| output = full_output[:, :, -expected_new_output:]
|
| else:
|
| output = full_output
|
|
|
| if debug:
|
| print(f"[DEBUG] Final streaming output shape: {output.shape}")
|
|
|
|
|
| if full_input.shape[2] > self.context_size:
|
| new_cache = full_input[:, :, -self.context_size:]
|
| else:
|
| new_cache = full_input
|
|
|
| if debug:
|
| print(f"[DEBUG] New cache shape: {new_cache.shape}")
|
|
|
| cache.set(self.layer_id, sample_indices, new_cache)
|
|
|
| return output
|
|
|
| def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor:
|
| """Standard forward pass without streaming"""
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}")
|
|
|
|
|
| y = self.convtr(x)
|
|
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}")
|
|
|
|
|
| if self.causal:
|
| padding_right = math.ceil(self.padding_total * self.trim_right_ratio)
|
| padding_left = self.padding_total - padding_right
|
| else:
|
| padding_right = self.padding_total // 2
|
| padding_left = self.padding_total - padding_right
|
|
|
| if padding_left + padding_right > 0:
|
| y = unpad1d(y, (padding_left, padding_right))
|
|
|
| if debug:
|
| print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}")
|
|
|
| return y
|
|
|
|
|
| class FFN(nn.Module):
|
| def __init__(
|
| self,
|
| embed_dim,
|
| ffn_dim,
|
| bias=False,
|
| ):
|
| super().__init__()
|
| self.embed_dim = embed_dim
|
| self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias)
|
| self.gelu = ACT2FN["gelu"]
|
| self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias)
|
|
|
| def forward(self, x):
|
| x = self.linear1(x)
|
| x = self.gelu(x)
|
| x = self.linear2(x)
|
| return x
|
|
|
|
|
| class Convlayer(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels,
|
| out_channels,
|
| kernel_size,
|
| stride=1,
|
| dilation=1,
|
| groups=1,
|
| bias=True,
|
| pad_mode='zeros',
|
| norm='weight_norm',
|
| causal=True,
|
| ):
|
| super().__init__()
|
| self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation,
|
| groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal)
|
|
|
| def forward(self, x):
|
| return self.conv(x)
|
|
|
| class Block1D(nn.Module):
|
| def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv',
|
| layer_scale_init_value=1e-6, **kwargs):
|
| super().__init__()
|
|
|
| if kwargs.get('layernorm', 'LN') == 'LN':
|
| self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm':
|
| self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
| self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6))
|
|
|
| if mixer_layer == 'conv':
|
| self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1),
|
| kernel_size=kernel_size,
|
| pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| norm=kwargs.get('norm', 'none'),
|
| causal=kwargs.get('causal', True),
|
| bias=kwargs.get('bias', True),
|
| )
|
| elif mixer_layer == 'depthwise_conv':
|
| self.mixer = Convlayer(dim, dim, groups=dim,
|
| kernel_size=kernel_size,
|
| pad_mode=kwargs.get('pad_mode', 'reflect'),
|
| norm=kwargs.get('norm', 'none'),
|
| causal=kwargs.get('causal', True),
|
| bias=kwargs.get('bias', True),
|
| )
|
| else:
|
| raise ValueError(f"Unsupported mixer layer: {mixer_layer}")
|
|
|
| self.ffn = FFN(
|
| dim,
|
| kwargs.get('ffn_expansion', 4) * dim,
|
| bias=kwargs.get('bias', False),
|
| )
|
| self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path)
|
|
|
| if layer_scale_init_value > 0:
|
| self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
|
| else:
|
| self.gamma = None
|
| self.ffn_gamma = None
|
|
|
| def forward(self, x):
|
|
|
| residual = x
|
| x = self.norm(x)
|
| x = self.mixer(x)
|
| if self.gamma is not None:
|
| x = x * self.gamma.unsqueeze(-1)
|
| x = residual + self.drop_path(x)
|
|
|
|
|
| residual = x
|
| x = self.ffn_norm(x)
|
| x = x.permute(0, 2, 1)
|
| x = self.ffn(x)
|
| x = x.permute(0, 2, 1)
|
| if self.ffn_gamma is not None:
|
| x = x * self.ffn_gamma.unsqueeze(-1)
|
| x = residual + self.drop_path(x)
|
|
|
| return x
|
|
|
|
|
| class TokenizerDecoder(nn.Module):
|
| """
|
| Decoder component for the VibeVoice tokenizer that converts latent representations back to audio.
|
|
|
| Args:
|
| config: Configuration object with model parameters
|
| """
|
| def __init__(self, config):
|
| super().__init__()
|
|
|
|
|
| self.dimension = config.dimension
|
| self.channels = config.channels
|
| self.n_filters = config.n_filters
|
| self.ratios = config.ratios
|
|
|
|
|
| self.depths = config.depths
|
|
|
| self.n_residual_layers = getattr(config, "n_residual_layers", 1)
|
| self.hop_length = np.prod(self.ratios)
|
| self.causal = config.causal
|
|
|
|
|
| kernel_size = getattr(config, "kernel_size", 7)
|
| last_kernel_size = getattr(config, "last_kernel_size", 7)
|
| norm = getattr(config, "norm", "none")
|
| norm_params = getattr(config, "norm_params", {})
|
| pad_mode = getattr(config, "pad_mode", "reflect")
|
| bias = getattr(config, "bias", True)
|
| layernorm = getattr(config, "layernorm", "LN")
|
| layernorm_eps = getattr(config, "layernorm_eps", 1e-6)
|
| trim_right_ratio = getattr(config, "trim_right_ratio", 1.0)
|
| layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True)
|
| drop_path_rate = getattr(config, "drop_path_rate", 0.0)
|
| mixer_layer = getattr(config, "mixer_layer", "conv")
|
| layer_scale_init_value = getattr(config, "layer_scale_init_value", 0)
|
| disable_last_norm = getattr(config, "disable_last_norm", False)
|
|
|
|
|
| if layernorm == 'LN':
|
| norm_type = ConvLayerNorm
|
| elif layernorm == 'RMSNorm':
|
| norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine)
|
| else:
|
| raise ValueError(f"Unsupported norm type: {layernorm}")
|
|
|
|
|
| stem = nn.Sequential(
|
| SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm,
|
| norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias),
|
| )
|
|
|
| self.upsample_layers = nn.ModuleList()
|
| self.upsample_layers.append(stem)
|
| for i in range(len(self.ratios)):
|
| in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1))
|
| upsample_layer = nn.Sequential(
|
| SConvTranspose1d(in_ch, out_ch,
|
| kernel_size=self.ratios[i] * 2, stride=self.ratios[i],
|
| norm=norm, norm_kwargs=norm_params, bias=bias,
|
| causal=self.causal, trim_right_ratio=trim_right_ratio),
|
| )
|
| self.upsample_layers.append(upsample_layer)
|
|
|
|
|
| layer_type = partial(
|
| Block1D,
|
| mixer_layer=mixer_layer,
|
| layernorm=layernorm,
|
| eps=layernorm_eps,
|
| causal=self.causal,
|
| pad_mode=pad_mode,
|
| norm=norm,
|
| bias=bias,
|
| layer_scale_init_value=layer_scale_init_value,
|
| )
|
|
|
| self.stages = nn.ModuleList()
|
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| cur = 0
|
|
|
|
|
| for i in range(len(self.depths)):
|
| in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i))
|
| stage = nn.Sequential(
|
| *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])]
|
| )
|
| self.stages.append(stage)
|
| cur += self.depths[i]
|
|
|
| if not disable_last_norm:
|
| self.norm = norm_type(in_ch, eps=layernorm_eps)
|
| else:
|
| self.norm = nn.Identity()
|
| self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias)
|
|
|
| def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| for i in range(len(self.depths)):
|
|
|
| for layer in self.upsample_layers[i]:
|
| if isinstance(layer, (SConv1d, SConvTranspose1d)):
|
| x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| else:
|
| x = layer(x)
|
|
|
|
|
| for block in self.stages[i]:
|
| if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d):
|
|
|
| residual = x
|
| x = block.norm(x)
|
| x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| if block.gamma is not None:
|
| x = x * block.gamma.unsqueeze(-1)
|
| x = residual + x
|
|
|
|
|
| residual = x
|
| x = block.ffn_norm(x)
|
| x = x.permute(0, 2, 1)
|
| x = block.ffn(x)
|
| x = x.permute(0, 2, 1)
|
| if block.ffn_gamma is not None:
|
| x = x * block.ffn_gamma.unsqueeze(-1)
|
| x = residual + x
|
| else:
|
| x = block(x)
|
|
|
| return self.norm(x)
|
|
|
| def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| return x
|
|
|
|
|
| class VibeVoiceAcousticTokenizerModel(PreTrainedModel):
|
| """VibeVoice speech tokenizer model (only decoder) for acoustic tokens"""
|
|
|
| config_class = VibeVoiceAcousticTokenizerConfig
|
| base_model_prefix = "vibevoice_acoustic_tokenizer"
|
| _supports_flash_attn_2 = True
|
| _supports_sdpa = True
|
| _no_split_modules = ["TokenizerDecoder"]
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
|
|
| self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False)
|
| self.std_dist_type = getattr(config, "std_dist_type", "fix")
|
|
|
|
|
| if isinstance(config.encoder_depths, str):
|
| encoder_depths = [int(d) for d in config.encoder_depths.split('-')]
|
| else:
|
| encoder_depths = config.encoder_depths
|
|
|
|
|
| if config.decoder_depths is not None and isinstance(config.decoder_depths, str):
|
| decoder_depths = [int(d) for d in config.decoder_depths.split('-')]
|
| else:
|
|
|
| decoder_depths = list(reversed(encoder_depths))
|
|
|
|
|
| decoder_config = copy.deepcopy(config)
|
| decoder_config.dimension = config.vae_dim
|
| decoder_config.n_filters = config.decoder_n_filters
|
| decoder_config.ratios = config.decoder_ratios
|
| decoder_config.depths = decoder_depths
|
| decoder_config.norm = config.conv_norm
|
| decoder_config.pad_mode = config.pad_mode
|
| decoder_config.bias = config.conv_bias
|
| decoder_config.layernorm_eps = config.layernorm_eps
|
| decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine
|
| decoder_config.mixer_layer = config.mixer_layer
|
| decoder_config.layer_scale_init_value = config.layer_scale_init_value
|
| decoder_config.disable_last_norm = config.disable_last_norm
|
|
|
| self.decoder = TokenizerDecoder(decoder_config)
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
| def _init_weights(self, module):
|
| """Initialize weights for the model"""
|
| if isinstance(module, nn.Linear):
|
| nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.LayerNorm):
|
| nn.init.ones_(module.weight)
|
| nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Conv1d):
|
| nn.init.normal_(module.weight, std=self.config.weight_init_value)
|
| if module.bias is not None:
|
| nn.init.zeros_(module.bias)
|
|
|
| @torch.no_grad()
|
| def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False):
|
| """Convert latent representations back to audio"""
|
| if latents.shape[1] == self.config.vae_dim:
|
| pass
|
| else:
|
| latents = latents.permute(0, 2, 1)
|
|
|
| audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug)
|
| return audio
|
|
|
|
|
| AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel)
|
|
|
| __all__ = [
|
| "VibeVoiceTokenizerStreamingCache",
|
| "VibeVoiceAcousticTokenizerModel",
|
| ] |