Dalzymodderever
Intial Commit
2cba492
import math
from dataclasses import dataclass
from typing import Literal
import jsonargparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from .module.fsq import FiniteScalarQuantizer
from .module.global_encoder import GlobalEncoder
from .module.postnet import PostNet
from .module.ssl_extractor import SSLFeatureExtractor
from .module.transformer import Transformer
from .util import freeze_modules, get_logger
logger = get_logger()
@dataclass
class KanadeModelConfig:
# SSL Feature settings
local_ssl_layers: tuple[int, ...] = (6, 9) # Indices of SSL layers for local branch
global_ssl_layers: tuple[int, ...] = (1, 2) # Indices of SSL layers for global branch
normalize_ssl_features: bool = True # Whether to normalize local SSL features before encoding
# Down/up-sampling settings
downsample_factor: int = 2 # Temporal downsampling factor for local features
mel_upsample_factor: int = 4 # Conv1DTranspose upsampling factor for mel features before interpolation
use_conv_downsample: bool = True # Whether to use Conv1D for downsampling instead average pooling
local_interpolation_mode: str = "linear" # Interpolation mode for local upsampling ("linear", "nearest")
mel_interpolation_mode: str = "linear" # Interpolation mode for mel upsampling ("linear", "nearest")
# Mel spectrogram settings
sample_rate: int = 24000
n_fft: int = 1024
hop_length: int = 256
n_mels: int = 100
padding: str = "center"
mel_fmin: int = 0 # Minimum frequency for mel spectrograms
mel_fmax: int | None = None # Maximum frequency for mel spectrograms
bigvgan_style_mel: bool = False # Whether to use BigVGAN-style mel spectrograms
# Vocoder settings
vocoder_name: Literal["vocos", "hift"] = "vocos" # Vocoder to use for waveform synthesis
@dataclass
class KanadeFeatures:
content_embedding: torch.Tensor | None = None # (seq_len, dim)
content_token_indices: torch.Tensor | None = None # (seq_len,)
global_embedding: torch.Tensor | None = None # (dim,)
class KanadeModel(nn.Module):
"""Model architecture and forward pass logic for Kanade tokenizer."""
def __init__(
self,
config: KanadeModelConfig,
ssl_feature_extractor: SSLFeatureExtractor,
local_encoder: Transformer,
local_quantizer: FiniteScalarQuantizer,
global_encoder: GlobalEncoder,
mel_prenet: Transformer,
mel_decoder: Transformer,
mel_postnet: PostNet,
feature_decoder: Transformer | None = None,
):
super().__init__()
self.config = config
self._init_ssl_extractor(config, ssl_feature_extractor)
self._init_local_branch(config, local_encoder, local_quantizer, feature_decoder)
self._init_global_branch(global_encoder)
self._init_mel_decoder(config, mel_prenet, mel_decoder, mel_postnet)
def _init_ssl_extractor(self, config: KanadeModelConfig, ssl_feature_extractor: SSLFeatureExtractor):
"""Initialize and configure SSL feature extractor."""
self.ssl_feature_extractor = ssl_feature_extractor
freeze_modules([self.ssl_feature_extractor])
logger.debug(
f"SSL feature extractor initialized and frozen, feature dim: {self.ssl_feature_extractor.feature_dim}"
)
# Configure local SSL layers
self.local_ssl_layers = list(config.local_ssl_layers)
if len(self.local_ssl_layers) > 1:
logger.debug(
f"Using average of {len(self.local_ssl_layers)} SSL layers for local branch: {self.local_ssl_layers}"
)
else:
logger.debug(f"Using single SSL layer {self.local_ssl_layers[0]} for local branch")
if config.normalize_ssl_features:
logger.debug("Normalizing local SSL features before encoding")
# Configure global SSL layers
self.global_ssl_layers = list(config.global_ssl_layers)
if len(self.global_ssl_layers) > 1:
logger.debug(
f"Using average of {len(self.global_ssl_layers)} SSL layers for global branch: {self.global_ssl_layers}"
)
else:
logger.debug(f"Using single SSL layer {self.global_ssl_layers[0]} for global branch")
def _init_local_branch(
self,
config: KanadeModelConfig,
local_encoder: Transformer,
local_quantizer: FiniteScalarQuantizer,
feature_decoder: Transformer | None,
):
"""Initialize local branch components (encoder, downsampling, quantizer, decoder)."""
self.local_encoder = local_encoder
self.local_quantizer = local_quantizer
self.feature_decoder = feature_decoder
# Configure downsampling
self.downsample_factor = config.downsample_factor
if self.downsample_factor > 1:
logger.debug(f"Using temporal downsampling with factor {self.downsample_factor}")
if config.use_conv_downsample:
# Create Conv1d layers for downsampling and upsampling local embeddings
feature_dim = local_encoder.output_dim
self.conv_downsample = nn.Conv1d(
feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
)
self.conv_upsample = nn.ConvTranspose1d(
feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
) # won't be used unless training feature reconstruction
logger.debug(f"Using Conv1d downsampling/upsampling with kernel size {config.downsample_factor}")
else:
self.conv_downsample = None
self.conv_upsample = None
logger.debug("Using average pooling and linear interpolation for downsampling/upsampling")
else:
self.conv_downsample = None
self.conv_upsample = None
def _init_global_branch(self, global_encoder: GlobalEncoder):
"""Initialize global branch components."""
self.global_encoder = global_encoder
def _init_mel_decoder(
self, config: KanadeModelConfig, mel_prenet: Transformer, mel_decoder: Transformer, mel_postnet: PostNet
):
"""Initialize mel decoder components (prenet, upsampling, decoder, postnet)."""
self.mel_prenet = mel_prenet
self.mel_decoder = mel_decoder
self.mel_postnet = mel_postnet
# Configure mel upsampling
self.mel_conv_upsample = None
if config.mel_upsample_factor > 1:
# Create Conv1DTranspose layer for mel upsampling
input_dim = mel_prenet.output_dim
self.mel_conv_upsample = nn.ConvTranspose1d(
input_dim, input_dim, kernel_size=config.mel_upsample_factor, stride=config.mel_upsample_factor
)
logger.debug(f"Using Conv1DTranspose for mel upsampling with factor {config.mel_upsample_factor}")
def _calculate_waveform_padding(self, audio_length: int, ensure_recon_length: bool = False) -> int:
"""Calculate required padding for input waveform to ensure consistent SSL feature lengths."""
extractor = self.ssl_feature_extractor
sample_rate = self.config.sample_rate
# SSL may resample the input to its own sample rate, so calculate the number of samples after resampling
num_samples_after_resampling = audio_length / sample_rate * extractor.ssl_sample_rate
# We expect the SSL feature extractor to be consistent with its hop size
expected_ssl_output_length = math.ceil(num_samples_after_resampling / extractor.hop_size)
# If ensure_recon_length is True, we want to make sure the output length is exactly divisible by downsample factor
if ensure_recon_length and (remainder := expected_ssl_output_length % self.downsample_factor) != 0:
expected_ssl_output_length += self.downsample_factor - remainder
# But it may require more input samples to produce that output length, so calculate the required input length
num_samples_required_after_resampling = extractor.get_minimum_input_length(expected_ssl_output_length)
# That number of samples is at the SSL sample rate, so convert back to our original sample rate
num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
# Calculate padding needed on each side
padding = math.ceil((num_samples_required - audio_length) / 2)
return padding
def _calculate_original_audio_length(self, token_length: int) -> int:
"""Calculate the original audio length based on token length."""
extractor = self.ssl_feature_extractor
sample_rate = self.config.sample_rate
# Calculate the feature length before downsampling
feature_length = token_length * self.downsample_factor
num_samples_required_after_resampling = extractor.get_minimum_input_length(feature_length)
num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
return math.ceil(num_samples_required)
def _calculate_target_mel_length(self, audio_length: int) -> int:
"""Calculate the target mel spectrogram length based on audio length."""
if self.config.padding == "center":
return audio_length // self.config.hop_length + 1
elif self.config.padding == "same":
return audio_length // self.config.hop_length
else:
return (audio_length - self.config.n_fft) // self.config.hop_length + 1
def _process_ssl_features(self, features: list[torch.Tensor], layers: list[int]) -> torch.Tensor:
if len(layers) > 1:
# Get features from multiple layers and average them
selected_features = [features[i - 1] for i in layers]
mixed_features = torch.stack(selected_features, dim=0).mean(dim=0)
else:
# Just take the single specified layer
mixed_features = features[layers[0] - 1]
return mixed_features
def _normalize_ssl_features(self, features: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
if not self.config.normalize_ssl_features:
return features
# Compute mean and std across time steps for each sample and feature dimension
mean = torch.mean(features, dim=1, keepdim=True) # (B, 1, C)
std = torch.std(features, dim=1, keepdim=True) # (B, 1, C)
return (features - mean) / (std + eps)
def forward_ssl_features(
self, waveform: torch.Tensor, padding: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass to extract SSL features. (B, T, C)
Args:
waveform: Input waveform tensor of shape (B, channels, samples)
padding: Optional padding to apply on both sides of the waveform. This is useful to ensure
that the SSL feature extractor produces consistent output lengths.
Returns:
local_ssl_features: Local SSL features for local branch. (B, T, C)
global_ssl_features: Global SSL features for global branch. (B, T, C)
"""
# Prepare input waveform
if waveform.dim() == 3:
waveform = waveform.squeeze(1)
# 1. Extract SSL features
if padding > 0:
waveform = F.pad(waveform, (padding, padding), mode="constant")
with torch.no_grad():
ssl_features = self.ssl_feature_extractor(waveform)
local_ssl_features = self._process_ssl_features(ssl_features, self.local_ssl_layers)
local_ssl_features = self._normalize_ssl_features(local_ssl_features)
global_ssl_features = self._process_ssl_features(ssl_features, self.global_ssl_layers)
return local_ssl_features, global_ssl_features
def forward_content(
self, local_ssl_features: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
"""Forward pass to extract content embeddings from the local branch.
Args:
local_ssl_features: Local SSL features tensor of shape (B, T, C)
Returns:
local_quantized: Quantized local embeddings. (B, T/factor, C)
indices: Content token indices. (B, T/factor)
ssl_recon: Reconstructed SSL features (if feature decoder is present). (B, T, C)
perplexity: Quantizer perplexity (if feature decoder is present). Scalar tensor.
"""
local_encoded = self.local_encoder(local_ssl_features)
# Downsample temporally if needed: (B, T, C) -> (B, T/factor, C)
if self.downsample_factor > 1:
if self.config.use_conv_downsample:
local_encoded = self.conv_downsample(local_encoded.transpose(1, 2)).transpose(1, 2)
else:
local_encoded = F.avg_pool1d(
local_encoded.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor
).transpose(1, 2)
# If training feature reconstruction, decode local embeddings
ssl_recon = None
perplexity = torch.tensor(0.0)
if self.feature_decoder is not None:
local_quantized, local_quantize_info = self.local_quantizer(local_encoded)
indices = local_quantize_info["indices"]
perplexity = torch.mean(local_quantize_info["perplexity"])
local_latent_for_ssl = local_quantized
# Upsample if needed
if self.downsample_factor > 1:
if self.config.use_conv_downsample:
# Use conv transpose for upsampling: (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T) -> (B, T, C)
local_latent_for_ssl = self.conv_upsample(local_latent_for_ssl.transpose(1, 2)).transpose(1, 2)
else:
# (B, T/factor, C) -> (B, T, C)
local_latent_for_ssl = F.interpolate(
local_latent_for_ssl.transpose(1, 2),
size=local_ssl_features.shape[1],
mode=self.config.local_interpolation_mode,
).transpose(1, 2)
ssl_recon = self.feature_decoder(local_latent_for_ssl)
else:
# If not training feature reconstruction, just get quantized local embeddings
local_quantized, indices = self.local_quantizer.encode(local_encoded)
return local_quantized, indices, ssl_recon, perplexity
def forward_global(self, global_ssl_features: torch.Tensor) -> torch.Tensor:
"""Forward pass to extract global embeddings from the global branch.
Args:
global_ssl_features: Global SSL features tensor of shape (B, T, C)
Returns:
global_encoded: Global embeddings. (B, C)
"""
global_encoded = self.global_encoder(global_ssl_features)
return global_encoded
def forward_mel(
self, content_embeddings: torch.Tensor, global_embeddings: torch.Tensor, mel_length: int
) -> torch.Tensor:
"""Forward pass to generate mel spectrogram from content and global embeddings.
Args:
content_embeddings: Content embeddings tensor of shape (B, T, C)
global_embeddings: Global embeddings tensor of shape (B, C)
mel_length: Target mel spectrogram length (T_mel)
Returns:
mel_recon: Reconstructed mel spectrogram tensor of shape (B, n_mels, T_mel)
"""
local_latent = self.mel_prenet(content_embeddings)
# Upsample local latent to match mel spectrogram length
# First use Conv1DTranspose if configured
if self.mel_conv_upsample is not None:
# (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T*upsample_factor) -> (B, T*upsample_factor, C)
local_latent = self.mel_conv_upsample(local_latent.transpose(1, 2)).transpose(1, 2)
local_latent = F.interpolate(
local_latent.transpose(1, 2), size=mel_length, mode=self.config.mel_interpolation_mode
).transpose(1, 2) # (B, T_current, C) -> (B, T_mel, C)
# Generate mel spectrogram, conditioned on global embeddings
mel_recon = self.mel_decoder(local_latent, condition=global_embeddings.unsqueeze(1))
mel_recon = mel_recon.transpose(1, 2) # (B, n_mels, T)
mel_recon = self.mel_postnet(mel_recon)
return mel_recon
# ======== Inference methods ========
def weights_to_save(self, *, include_modules: list[str]) -> dict[str, torch.Tensor]:
"""Get model weights for saving. Excludes certain modules not needed for inference."""
excluded_modules = [
m for m in ["ssl_feature_extractor", "feature_decoder", "conv_upsample"] if m not in include_modules
]
state_dict = {
name: param
for name, param in self.named_parameters()
if not any(name.startswith(excl) for excl in excluded_modules)
}
return state_dict
@classmethod
def from_hparams(cls, config_path: str) -> "KanadeModel":
"""Instantiate KanadeModel from config file.
Args:
config_path (str): Path to model configuration file (.yaml).
Returns:
KanadeModel: Instantiated KanadeModel.
"""
parser = jsonargparse.ArgumentParser(exit_on_error=False)
parser.add_argument("--model", type=KanadeModel)
cfg = parser.parse_path(config_path)
cfg = parser.instantiate_classes(cfg)
return cfg.model
@classmethod
def from_pretrained(
cls,
repo_id: str | None = None,
revision: str | None = None,
config_path: str | None = None,
weights_path: str | None = None,
) -> "KanadeModel":
"""Load KanadeModel either from HuggingFace Hub or local config and weights files.
Args:
repo_id (str, optional): HuggingFace Hub repository ID. If provided, loads config and weights from the hub.
revision (str, optional): Revision (branch, tag, commit) for the HuggingFace Hub repo.
config_path (str, optional): Path to model configuration file (.yaml). Required if repo_id is not provided.
weights_path (str, optional): Path to model weights file (.safetensors). Required if repo_id is not provided.
Returns:
KanadeModel: Loaded KanadeModel instance.
"""
if repo_id is not None:
# Load from HuggingFace Hub
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(repo_id, "config.yaml", revision=revision)
weights_path = hf_hub_download(repo_id, "model.safetensors", revision=revision)
else:
# Check local paths
if config_path is None or weights_path is None:
raise ValueError(
"Please provide either HuggingFace Hub repo_id or both config_path and weights_path for model loading."
)
# Load model from config
model = cls.from_hparams(config_path)
# Load weights
from safetensors.torch import load_file
state_dict = load_file(weights_path, device="cpu")
model.load_state_dict(state_dict, strict=False)
logger.info(f"Loaded weights from safetensors file: {weights_path}")
return model
@torch.inference_mode()
def encode(self, waveform: torch.Tensor, return_content: bool = True, return_global: bool = True) -> KanadeFeatures:
"""Extract content and/or global features from audio using Kanade model.
Args:
waveform (torch.Tensor): Input audio waveform tensor (samples,). The sample rate should match model config.
return_content (bool): Whether to extract content features.
return_global (bool): Whether to extract global features.
Returns:
dict[str, torch.Tensor]: Extracted features.
"""
audio_length = waveform.size(0)
padding = self._calculate_waveform_padding(audio_length)
local_ssl_features, global_ssl_features = self.forward_ssl_features(waveform.unsqueeze(0), padding=padding)
result = KanadeFeatures()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
if return_content:
content_embedding, token_indices, _, _ = self.forward_content(local_ssl_features)
result.content_embedding = content_embedding.squeeze(0) # (seq_len, dim)
result.content_token_indices = token_indices.squeeze(0) # (seq_len,)
if return_global:
global_embedding = self.forward_global(global_ssl_features)
result.global_embedding = global_embedding.squeeze(0) # (dim,)
return result
def decode_token_indices(self, indices: torch.Tensor) -> torch.Tensor:
"""Get content embeddings from content token indices. (..., seq_len) -> (..., seq_len, dim)"""
content_embedding = self.local_quantizer.decode(indices)
return content_embedding
@torch.inference_mode()
def decode(
self,
global_embedding: torch.Tensor,
content_token_indices: torch.Tensor | None = None,
content_embedding: torch.Tensor | None = None,
target_audio_length: int | None = None,
) -> torch.Tensor:
"""Synthesize audio from content and global features using Kanade model and Vocos.
Args:
global_embedding (torch.Tensor): Global embedding tensor (dim,).
content_token_indices (torch.Tensor, optional): Optional content token indices tensor (seq_len).
content_embedding (torch.Tensor, optional): Optional content embedding tensor (seq_len, dim).
If both content_token_indices and content_embedding are provided, content_embedding takes precedence.
target_audio_length (int, optional): Target length of the output audio in samples.
If None, uses the original audio length estimated from the sequence length of content tokens.
Returns:
torch.Tensor: Generated mel spectrogram tensor (n_mels, T).
"""
# Obtain content embedding if not provided
if content_embedding is None:
if content_token_indices is None:
raise ValueError("Either content_token_indices or content_embedding must be provided.")
content_embedding = self.decode_token_indices(content_token_indices)
if target_audio_length is None:
# Estimate original audio length from content token sequence length
seq_len = content_embedding.size(0)
target_audio_length = self._calculate_original_audio_length(seq_len)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
mel_length = self._calculate_target_mel_length(target_audio_length)
content_embedding = content_embedding.unsqueeze(0) # (1, seq_len, dim)
global_embedding = global_embedding.unsqueeze(0) # (1, dim)
mel_spectrogram = self.forward_mel(content_embedding, global_embedding, mel_length=mel_length)
return mel_spectrogram.squeeze(0) # (n_mels, T)
@torch.inference_mode()
def voice_conversion(self, source_waveform: torch.Tensor, reference_waveform: torch.Tensor) -> torch.Tensor:
"""Convert voice using Kanade model and Vocos, keeping content from source and global characteristics from reference.
Only supports single audio input. Just a convenient wrapper around encode and decode methods.
Args:
source_waveform (torch.Tensor): Source audio waveform tensor (samples,).
reference_waveform (torch.Tensor): Reference audio waveform tensor (samples_ref,).
Returns:
torch.Tensor: Converted mel spectrogram tensor (n_mels, T).
"""
# Extract source content features and reference global features
source_features = self.encode(source_waveform, return_content=True, return_global=False)
reference_features = self.encode(reference_waveform, return_content=False, return_global=True)
# Synthesize mel spectrogram using source content and reference global features
mel_spectrogram = self.decode(
content_embedding=source_features.content_embedding,
global_embedding=reference_features.global_embedding,
target_audio_length=source_waveform.size(0),
)
return mel_spectrogram