Spaces:
Running
Running
| import math | |
| from dataclasses import dataclass | |
| from typing import Literal | |
| import jsonargparse | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .module.fsq import FiniteScalarQuantizer | |
| from .module.global_encoder import GlobalEncoder | |
| from .module.postnet import PostNet | |
| from .module.ssl_extractor import SSLFeatureExtractor | |
| from .module.transformer import Transformer | |
| from .util import freeze_modules, get_logger | |
| logger = get_logger() | |
| class KanadeModelConfig: | |
| # SSL Feature settings | |
| local_ssl_layers: tuple[int, ...] = (6, 9) # Indices of SSL layers for local branch | |
| global_ssl_layers: tuple[int, ...] = (1, 2) # Indices of SSL layers for global branch | |
| normalize_ssl_features: bool = True # Whether to normalize local SSL features before encoding | |
| # Down/up-sampling settings | |
| downsample_factor: int = 2 # Temporal downsampling factor for local features | |
| mel_upsample_factor: int = 4 # Conv1DTranspose upsampling factor for mel features before interpolation | |
| use_conv_downsample: bool = True # Whether to use Conv1D for downsampling instead average pooling | |
| local_interpolation_mode: str = "linear" # Interpolation mode for local upsampling ("linear", "nearest") | |
| mel_interpolation_mode: str = "linear" # Interpolation mode for mel upsampling ("linear", "nearest") | |
| # Mel spectrogram settings | |
| sample_rate: int = 24000 | |
| n_fft: int = 1024 | |
| hop_length: int = 256 | |
| n_mels: int = 100 | |
| padding: str = "center" | |
| mel_fmin: int = 0 # Minimum frequency for mel spectrograms | |
| mel_fmax: int | None = None # Maximum frequency for mel spectrograms | |
| bigvgan_style_mel: bool = False # Whether to use BigVGAN-style mel spectrograms | |
| # Vocoder settings | |
| vocoder_name: Literal["vocos", "hift"] = "vocos" # Vocoder to use for waveform synthesis | |
| class KanadeFeatures: | |
| content_embedding: torch.Tensor | None = None # (seq_len, dim) | |
| content_token_indices: torch.Tensor | None = None # (seq_len,) | |
| global_embedding: torch.Tensor | None = None # (dim,) | |
| class KanadeModel(nn.Module): | |
| """Model architecture and forward pass logic for Kanade tokenizer.""" | |
| def __init__( | |
| self, | |
| config: KanadeModelConfig, | |
| ssl_feature_extractor: SSLFeatureExtractor, | |
| local_encoder: Transformer, | |
| local_quantizer: FiniteScalarQuantizer, | |
| global_encoder: GlobalEncoder, | |
| mel_prenet: Transformer, | |
| mel_decoder: Transformer, | |
| mel_postnet: PostNet, | |
| feature_decoder: Transformer | None = None, | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self._init_ssl_extractor(config, ssl_feature_extractor) | |
| self._init_local_branch(config, local_encoder, local_quantizer, feature_decoder) | |
| self._init_global_branch(global_encoder) | |
| self._init_mel_decoder(config, mel_prenet, mel_decoder, mel_postnet) | |
| def _init_ssl_extractor(self, config: KanadeModelConfig, ssl_feature_extractor: SSLFeatureExtractor): | |
| """Initialize and configure SSL feature extractor.""" | |
| self.ssl_feature_extractor = ssl_feature_extractor | |
| freeze_modules([self.ssl_feature_extractor]) | |
| logger.debug( | |
| f"SSL feature extractor initialized and frozen, feature dim: {self.ssl_feature_extractor.feature_dim}" | |
| ) | |
| # Configure local SSL layers | |
| self.local_ssl_layers = list(config.local_ssl_layers) | |
| if len(self.local_ssl_layers) > 1: | |
| logger.debug( | |
| f"Using average of {len(self.local_ssl_layers)} SSL layers for local branch: {self.local_ssl_layers}" | |
| ) | |
| else: | |
| logger.debug(f"Using single SSL layer {self.local_ssl_layers[0]} for local branch") | |
| if config.normalize_ssl_features: | |
| logger.debug("Normalizing local SSL features before encoding") | |
| # Configure global SSL layers | |
| self.global_ssl_layers = list(config.global_ssl_layers) | |
| if len(self.global_ssl_layers) > 1: | |
| logger.debug( | |
| f"Using average of {len(self.global_ssl_layers)} SSL layers for global branch: {self.global_ssl_layers}" | |
| ) | |
| else: | |
| logger.debug(f"Using single SSL layer {self.global_ssl_layers[0]} for global branch") | |
| def _init_local_branch( | |
| self, | |
| config: KanadeModelConfig, | |
| local_encoder: Transformer, | |
| local_quantizer: FiniteScalarQuantizer, | |
| feature_decoder: Transformer | None, | |
| ): | |
| """Initialize local branch components (encoder, downsampling, quantizer, decoder).""" | |
| self.local_encoder = local_encoder | |
| self.local_quantizer = local_quantizer | |
| self.feature_decoder = feature_decoder | |
| # Configure downsampling | |
| self.downsample_factor = config.downsample_factor | |
| if self.downsample_factor > 1: | |
| logger.debug(f"Using temporal downsampling with factor {self.downsample_factor}") | |
| if config.use_conv_downsample: | |
| # Create Conv1d layers for downsampling and upsampling local embeddings | |
| feature_dim = local_encoder.output_dim | |
| self.conv_downsample = nn.Conv1d( | |
| feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor | |
| ) | |
| self.conv_upsample = nn.ConvTranspose1d( | |
| feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor | |
| ) # won't be used unless training feature reconstruction | |
| logger.debug(f"Using Conv1d downsampling/upsampling with kernel size {config.downsample_factor}") | |
| else: | |
| self.conv_downsample = None | |
| self.conv_upsample = None | |
| logger.debug("Using average pooling and linear interpolation for downsampling/upsampling") | |
| else: | |
| self.conv_downsample = None | |
| self.conv_upsample = None | |
| def _init_global_branch(self, global_encoder: GlobalEncoder): | |
| """Initialize global branch components.""" | |
| self.global_encoder = global_encoder | |
| def _init_mel_decoder( | |
| self, config: KanadeModelConfig, mel_prenet: Transformer, mel_decoder: Transformer, mel_postnet: PostNet | |
| ): | |
| """Initialize mel decoder components (prenet, upsampling, decoder, postnet).""" | |
| self.mel_prenet = mel_prenet | |
| self.mel_decoder = mel_decoder | |
| self.mel_postnet = mel_postnet | |
| # Configure mel upsampling | |
| self.mel_conv_upsample = None | |
| if config.mel_upsample_factor > 1: | |
| # Create Conv1DTranspose layer for mel upsampling | |
| input_dim = mel_prenet.output_dim | |
| self.mel_conv_upsample = nn.ConvTranspose1d( | |
| input_dim, input_dim, kernel_size=config.mel_upsample_factor, stride=config.mel_upsample_factor | |
| ) | |
| logger.debug(f"Using Conv1DTranspose for mel upsampling with factor {config.mel_upsample_factor}") | |
| def _calculate_waveform_padding(self, audio_length: int, ensure_recon_length: bool = False) -> int: | |
| """Calculate required padding for input waveform to ensure consistent SSL feature lengths.""" | |
| extractor = self.ssl_feature_extractor | |
| sample_rate = self.config.sample_rate | |
| # SSL may resample the input to its own sample rate, so calculate the number of samples after resampling | |
| num_samples_after_resampling = audio_length / sample_rate * extractor.ssl_sample_rate | |
| # We expect the SSL feature extractor to be consistent with its hop size | |
| expected_ssl_output_length = math.ceil(num_samples_after_resampling / extractor.hop_size) | |
| # If ensure_recon_length is True, we want to make sure the output length is exactly divisible by downsample factor | |
| if ensure_recon_length and (remainder := expected_ssl_output_length % self.downsample_factor) != 0: | |
| expected_ssl_output_length += self.downsample_factor - remainder | |
| # But it may require more input samples to produce that output length, so calculate the required input length | |
| num_samples_required_after_resampling = extractor.get_minimum_input_length(expected_ssl_output_length) | |
| # That number of samples is at the SSL sample rate, so convert back to our original sample rate | |
| num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate | |
| # Calculate padding needed on each side | |
| padding = math.ceil((num_samples_required - audio_length) / 2) | |
| return padding | |
| def _calculate_original_audio_length(self, token_length: int) -> int: | |
| """Calculate the original audio length based on token length.""" | |
| extractor = self.ssl_feature_extractor | |
| sample_rate = self.config.sample_rate | |
| # Calculate the feature length before downsampling | |
| feature_length = token_length * self.downsample_factor | |
| num_samples_required_after_resampling = extractor.get_minimum_input_length(feature_length) | |
| num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate | |
| return math.ceil(num_samples_required) | |
| def _calculate_target_mel_length(self, audio_length: int) -> int: | |
| """Calculate the target mel spectrogram length based on audio length.""" | |
| if self.config.padding == "center": | |
| return audio_length // self.config.hop_length + 1 | |
| elif self.config.padding == "same": | |
| return audio_length // self.config.hop_length | |
| else: | |
| return (audio_length - self.config.n_fft) // self.config.hop_length + 1 | |
| def _process_ssl_features(self, features: list[torch.Tensor], layers: list[int]) -> torch.Tensor: | |
| if len(layers) > 1: | |
| # Get features from multiple layers and average them | |
| selected_features = [features[i - 1] for i in layers] | |
| mixed_features = torch.stack(selected_features, dim=0).mean(dim=0) | |
| else: | |
| # Just take the single specified layer | |
| mixed_features = features[layers[0] - 1] | |
| return mixed_features | |
| def _normalize_ssl_features(self, features: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: | |
| if not self.config.normalize_ssl_features: | |
| return features | |
| # Compute mean and std across time steps for each sample and feature dimension | |
| mean = torch.mean(features, dim=1, keepdim=True) # (B, 1, C) | |
| std = torch.std(features, dim=1, keepdim=True) # (B, 1, C) | |
| return (features - mean) / (std + eps) | |
| def forward_ssl_features( | |
| self, waveform: torch.Tensor, padding: int | None = None | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass to extract SSL features. (B, T, C) | |
| Args: | |
| waveform: Input waveform tensor of shape (B, channels, samples) | |
| padding: Optional padding to apply on both sides of the waveform. This is useful to ensure | |
| that the SSL feature extractor produces consistent output lengths. | |
| Returns: | |
| local_ssl_features: Local SSL features for local branch. (B, T, C) | |
| global_ssl_features: Global SSL features for global branch. (B, T, C) | |
| """ | |
| # Prepare input waveform | |
| if waveform.dim() == 3: | |
| waveform = waveform.squeeze(1) | |
| # 1. Extract SSL features | |
| if padding > 0: | |
| waveform = F.pad(waveform, (padding, padding), mode="constant") | |
| with torch.no_grad(): | |
| ssl_features = self.ssl_feature_extractor(waveform) | |
| local_ssl_features = self._process_ssl_features(ssl_features, self.local_ssl_layers) | |
| local_ssl_features = self._normalize_ssl_features(local_ssl_features) | |
| global_ssl_features = self._process_ssl_features(ssl_features, self.global_ssl_layers) | |
| return local_ssl_features, global_ssl_features | |
| def forward_content( | |
| self, local_ssl_features: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None: | |
| """Forward pass to extract content embeddings from the local branch. | |
| Args: | |
| local_ssl_features: Local SSL features tensor of shape (B, T, C) | |
| Returns: | |
| local_quantized: Quantized local embeddings. (B, T/factor, C) | |
| indices: Content token indices. (B, T/factor) | |
| ssl_recon: Reconstructed SSL features (if feature decoder is present). (B, T, C) | |
| perplexity: Quantizer perplexity (if feature decoder is present). Scalar tensor. | |
| """ | |
| local_encoded = self.local_encoder(local_ssl_features) | |
| # Downsample temporally if needed: (B, T, C) -> (B, T/factor, C) | |
| if self.downsample_factor > 1: | |
| if self.config.use_conv_downsample: | |
| local_encoded = self.conv_downsample(local_encoded.transpose(1, 2)).transpose(1, 2) | |
| else: | |
| local_encoded = F.avg_pool1d( | |
| local_encoded.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor | |
| ).transpose(1, 2) | |
| # If training feature reconstruction, decode local embeddings | |
| ssl_recon = None | |
| perplexity = torch.tensor(0.0) | |
| if self.feature_decoder is not None: | |
| local_quantized, local_quantize_info = self.local_quantizer(local_encoded) | |
| indices = local_quantize_info["indices"] | |
| perplexity = torch.mean(local_quantize_info["perplexity"]) | |
| local_latent_for_ssl = local_quantized | |
| # Upsample if needed | |
| if self.downsample_factor > 1: | |
| if self.config.use_conv_downsample: | |
| # Use conv transpose for upsampling: (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T) -> (B, T, C) | |
| local_latent_for_ssl = self.conv_upsample(local_latent_for_ssl.transpose(1, 2)).transpose(1, 2) | |
| else: | |
| # (B, T/factor, C) -> (B, T, C) | |
| local_latent_for_ssl = F.interpolate( | |
| local_latent_for_ssl.transpose(1, 2), | |
| size=local_ssl_features.shape[1], | |
| mode=self.config.local_interpolation_mode, | |
| ).transpose(1, 2) | |
| ssl_recon = self.feature_decoder(local_latent_for_ssl) | |
| else: | |
| # If not training feature reconstruction, just get quantized local embeddings | |
| local_quantized, indices = self.local_quantizer.encode(local_encoded) | |
| return local_quantized, indices, ssl_recon, perplexity | |
| def forward_global(self, global_ssl_features: torch.Tensor) -> torch.Tensor: | |
| """Forward pass to extract global embeddings from the global branch. | |
| Args: | |
| global_ssl_features: Global SSL features tensor of shape (B, T, C) | |
| Returns: | |
| global_encoded: Global embeddings. (B, C) | |
| """ | |
| global_encoded = self.global_encoder(global_ssl_features) | |
| return global_encoded | |
| def forward_mel( | |
| self, content_embeddings: torch.Tensor, global_embeddings: torch.Tensor, mel_length: int | |
| ) -> torch.Tensor: | |
| """Forward pass to generate mel spectrogram from content and global embeddings. | |
| Args: | |
| content_embeddings: Content embeddings tensor of shape (B, T, C) | |
| global_embeddings: Global embeddings tensor of shape (B, C) | |
| mel_length: Target mel spectrogram length (T_mel) | |
| Returns: | |
| mel_recon: Reconstructed mel spectrogram tensor of shape (B, n_mels, T_mel) | |
| """ | |
| local_latent = self.mel_prenet(content_embeddings) | |
| # Upsample local latent to match mel spectrogram length | |
| # First use Conv1DTranspose if configured | |
| if self.mel_conv_upsample is not None: | |
| # (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T*upsample_factor) -> (B, T*upsample_factor, C) | |
| local_latent = self.mel_conv_upsample(local_latent.transpose(1, 2)).transpose(1, 2) | |
| local_latent = F.interpolate( | |
| local_latent.transpose(1, 2), size=mel_length, mode=self.config.mel_interpolation_mode | |
| ).transpose(1, 2) # (B, T_current, C) -> (B, T_mel, C) | |
| # Generate mel spectrogram, conditioned on global embeddings | |
| mel_recon = self.mel_decoder(local_latent, condition=global_embeddings.unsqueeze(1)) | |
| mel_recon = mel_recon.transpose(1, 2) # (B, n_mels, T) | |
| mel_recon = self.mel_postnet(mel_recon) | |
| return mel_recon | |
| # ======== Inference methods ======== | |
| def weights_to_save(self, *, include_modules: list[str]) -> dict[str, torch.Tensor]: | |
| """Get model weights for saving. Excludes certain modules not needed for inference.""" | |
| excluded_modules = [ | |
| m for m in ["ssl_feature_extractor", "feature_decoder", "conv_upsample"] if m not in include_modules | |
| ] | |
| state_dict = { | |
| name: param | |
| for name, param in self.named_parameters() | |
| if not any(name.startswith(excl) for excl in excluded_modules) | |
| } | |
| return state_dict | |
| def from_hparams(cls, config_path: str) -> "KanadeModel": | |
| """Instantiate KanadeModel from config file. | |
| Args: | |
| config_path (str): Path to model configuration file (.yaml). | |
| Returns: | |
| KanadeModel: Instantiated KanadeModel. | |
| """ | |
| parser = jsonargparse.ArgumentParser(exit_on_error=False) | |
| parser.add_argument("--model", type=KanadeModel) | |
| cfg = parser.parse_path(config_path) | |
| cfg = parser.instantiate_classes(cfg) | |
| return cfg.model | |
| def from_pretrained( | |
| cls, | |
| repo_id: str | None = None, | |
| revision: str | None = None, | |
| config_path: str | None = None, | |
| weights_path: str | None = None, | |
| ) -> "KanadeModel": | |
| """Load KanadeModel either from HuggingFace Hub or local config and weights files. | |
| Args: | |
| repo_id (str, optional): HuggingFace Hub repository ID. If provided, loads config and weights from the hub. | |
| revision (str, optional): Revision (branch, tag, commit) for the HuggingFace Hub repo. | |
| config_path (str, optional): Path to model configuration file (.yaml). Required if repo_id is not provided. | |
| weights_path (str, optional): Path to model weights file (.safetensors). Required if repo_id is not provided. | |
| Returns: | |
| KanadeModel: Loaded KanadeModel instance. | |
| """ | |
| if repo_id is not None: | |
| # Load from HuggingFace Hub | |
| from huggingface_hub import hf_hub_download | |
| config_path = hf_hub_download(repo_id, "config.yaml", revision=revision) | |
| weights_path = hf_hub_download(repo_id, "model.safetensors", revision=revision) | |
| else: | |
| # Check local paths | |
| if config_path is None or weights_path is None: | |
| raise ValueError( | |
| "Please provide either HuggingFace Hub repo_id or both config_path and weights_path for model loading." | |
| ) | |
| # Load model from config | |
| model = cls.from_hparams(config_path) | |
| # Load weights | |
| from safetensors.torch import load_file | |
| state_dict = load_file(weights_path, device="cpu") | |
| model.load_state_dict(state_dict, strict=False) | |
| logger.info(f"Loaded weights from safetensors file: {weights_path}") | |
| return model | |
| def encode(self, waveform: torch.Tensor, return_content: bool = True, return_global: bool = True) -> KanadeFeatures: | |
| """Extract content and/or global features from audio using Kanade model. | |
| Args: | |
| waveform (torch.Tensor): Input audio waveform tensor (samples,). The sample rate should match model config. | |
| return_content (bool): Whether to extract content features. | |
| return_global (bool): Whether to extract global features. | |
| Returns: | |
| dict[str, torch.Tensor]: Extracted features. | |
| """ | |
| audio_length = waveform.size(0) | |
| padding = self._calculate_waveform_padding(audio_length) | |
| local_ssl_features, global_ssl_features = self.forward_ssl_features(waveform.unsqueeze(0), padding=padding) | |
| result = KanadeFeatures() | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| if return_content: | |
| content_embedding, token_indices, _, _ = self.forward_content(local_ssl_features) | |
| result.content_embedding = content_embedding.squeeze(0) # (seq_len, dim) | |
| result.content_token_indices = token_indices.squeeze(0) # (seq_len,) | |
| if return_global: | |
| global_embedding = self.forward_global(global_ssl_features) | |
| result.global_embedding = global_embedding.squeeze(0) # (dim,) | |
| return result | |
| def decode_token_indices(self, indices: torch.Tensor) -> torch.Tensor: | |
| """Get content embeddings from content token indices. (..., seq_len) -> (..., seq_len, dim)""" | |
| content_embedding = self.local_quantizer.decode(indices) | |
| return content_embedding | |
| def decode( | |
| self, | |
| global_embedding: torch.Tensor, | |
| content_token_indices: torch.Tensor | None = None, | |
| content_embedding: torch.Tensor | None = None, | |
| target_audio_length: int | None = None, | |
| ) -> torch.Tensor: | |
| """Synthesize audio from content and global features using Kanade model and Vocos. | |
| Args: | |
| global_embedding (torch.Tensor): Global embedding tensor (dim,). | |
| content_token_indices (torch.Tensor, optional): Optional content token indices tensor (seq_len). | |
| content_embedding (torch.Tensor, optional): Optional content embedding tensor (seq_len, dim). | |
| If both content_token_indices and content_embedding are provided, content_embedding takes precedence. | |
| target_audio_length (int, optional): Target length of the output audio in samples. | |
| If None, uses the original audio length estimated from the sequence length of content tokens. | |
| Returns: | |
| torch.Tensor: Generated mel spectrogram tensor (n_mels, T). | |
| """ | |
| # Obtain content embedding if not provided | |
| if content_embedding is None: | |
| if content_token_indices is None: | |
| raise ValueError("Either content_token_indices or content_embedding must be provided.") | |
| content_embedding = self.decode_token_indices(content_token_indices) | |
| if target_audio_length is None: | |
| # Estimate original audio length from content token sequence length | |
| seq_len = content_embedding.size(0) | |
| target_audio_length = self._calculate_original_audio_length(seq_len) | |
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| mel_length = self._calculate_target_mel_length(target_audio_length) | |
| content_embedding = content_embedding.unsqueeze(0) # (1, seq_len, dim) | |
| global_embedding = global_embedding.unsqueeze(0) # (1, dim) | |
| mel_spectrogram = self.forward_mel(content_embedding, global_embedding, mel_length=mel_length) | |
| return mel_spectrogram.squeeze(0) # (n_mels, T) | |
| def voice_conversion(self, source_waveform: torch.Tensor, reference_waveform: torch.Tensor) -> torch.Tensor: | |
| """Convert voice using Kanade model and Vocos, keeping content from source and global characteristics from reference. | |
| Only supports single audio input. Just a convenient wrapper around encode and decode methods. | |
| Args: | |
| source_waveform (torch.Tensor): Source audio waveform tensor (samples,). | |
| reference_waveform (torch.Tensor): Reference audio waveform tensor (samples_ref,). | |
| Returns: | |
| torch.Tensor: Converted mel spectrogram tensor (n_mels, T). | |
| """ | |
| # Extract source content features and reference global features | |
| source_features = self.encode(source_waveform, return_content=True, return_global=False) | |
| reference_features = self.encode(reference_waveform, return_content=False, return_global=True) | |
| # Synthesize mel spectrogram using source content and reference global features | |
| mel_spectrogram = self.decode( | |
| content_embedding=source_features.content_embedding, | |
| global_embedding=reference_features.global_embedding, | |
| target_audio_length=source_waveform.size(0), | |
| ) | |
| return mel_spectrogram | |