File size: 24,747 Bytes
2cba492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
import math
from dataclasses import dataclass
from typing import Literal

import jsonargparse
import torch
import torch.nn as nn
import torch.nn.functional as F

from .module.fsq import FiniteScalarQuantizer
from .module.global_encoder import GlobalEncoder
from .module.postnet import PostNet
from .module.ssl_extractor import SSLFeatureExtractor
from .module.transformer import Transformer
from .util import freeze_modules, get_logger

logger = get_logger()


@dataclass
class KanadeModelConfig:
    # SSL Feature settings
    local_ssl_layers: tuple[int, ...] = (6, 9)  # Indices of SSL layers for local branch
    global_ssl_layers: tuple[int, ...] = (1, 2)  # Indices of SSL layers for global branch
    normalize_ssl_features: bool = True  # Whether to normalize local SSL features before encoding

    # Down/up-sampling settings
    downsample_factor: int = 2  # Temporal downsampling factor for local features
    mel_upsample_factor: int = 4  # Conv1DTranspose upsampling factor for mel features before interpolation
    use_conv_downsample: bool = True  # Whether to use Conv1D for downsampling instead average pooling
    local_interpolation_mode: str = "linear"  # Interpolation mode for local upsampling ("linear", "nearest")
    mel_interpolation_mode: str = "linear"  # Interpolation mode for mel upsampling ("linear", "nearest")

    # Mel spectrogram settings
    sample_rate: int = 24000
    n_fft: int = 1024
    hop_length: int = 256
    n_mels: int = 100
    padding: str = "center"
    mel_fmin: int = 0  # Minimum frequency for mel spectrograms
    mel_fmax: int | None = None  # Maximum frequency for mel spectrograms
    bigvgan_style_mel: bool = False  # Whether to use BigVGAN-style mel spectrograms

    # Vocoder settings
    vocoder_name: Literal["vocos", "hift"] = "vocos"  # Vocoder to use for waveform synthesis


@dataclass
class KanadeFeatures:
    content_embedding: torch.Tensor | None = None  # (seq_len, dim)
    content_token_indices: torch.Tensor | None = None  # (seq_len,)
    global_embedding: torch.Tensor | None = None  # (dim,)


class KanadeModel(nn.Module):
    """Model architecture and forward pass logic for Kanade tokenizer."""

    def __init__(
        self,
        config: KanadeModelConfig,
        ssl_feature_extractor: SSLFeatureExtractor,
        local_encoder: Transformer,
        local_quantizer: FiniteScalarQuantizer,
        global_encoder: GlobalEncoder,
        mel_prenet: Transformer,
        mel_decoder: Transformer,
        mel_postnet: PostNet,
        feature_decoder: Transformer | None = None,
    ):
        super().__init__()
        self.config = config
        self._init_ssl_extractor(config, ssl_feature_extractor)
        self._init_local_branch(config, local_encoder, local_quantizer, feature_decoder)
        self._init_global_branch(global_encoder)
        self._init_mel_decoder(config, mel_prenet, mel_decoder, mel_postnet)

    def _init_ssl_extractor(self, config: KanadeModelConfig, ssl_feature_extractor: SSLFeatureExtractor):
        """Initialize and configure SSL feature extractor."""
        self.ssl_feature_extractor = ssl_feature_extractor
        freeze_modules([self.ssl_feature_extractor])
        logger.debug(
            f"SSL feature extractor initialized and frozen, feature dim: {self.ssl_feature_extractor.feature_dim}"
        )

        # Configure local SSL layers
        self.local_ssl_layers = list(config.local_ssl_layers)
        if len(self.local_ssl_layers) > 1:
            logger.debug(
                f"Using average of {len(self.local_ssl_layers)} SSL layers for local branch: {self.local_ssl_layers}"
            )
        else:
            logger.debug(f"Using single SSL layer {self.local_ssl_layers[0]} for local branch")

        if config.normalize_ssl_features:
            logger.debug("Normalizing local SSL features before encoding")

        # Configure global SSL layers
        self.global_ssl_layers = list(config.global_ssl_layers)
        if len(self.global_ssl_layers) > 1:
            logger.debug(
                f"Using average of {len(self.global_ssl_layers)} SSL layers for global branch: {self.global_ssl_layers}"
            )
        else:
            logger.debug(f"Using single SSL layer {self.global_ssl_layers[0]} for global branch")

    def _init_local_branch(
        self,
        config: KanadeModelConfig,
        local_encoder: Transformer,
        local_quantizer: FiniteScalarQuantizer,
        feature_decoder: Transformer | None,
    ):
        """Initialize local branch components (encoder, downsampling, quantizer, decoder)."""
        self.local_encoder = local_encoder
        self.local_quantizer = local_quantizer
        self.feature_decoder = feature_decoder

        # Configure downsampling
        self.downsample_factor = config.downsample_factor
        if self.downsample_factor > 1:
            logger.debug(f"Using temporal downsampling with factor {self.downsample_factor}")
            if config.use_conv_downsample:
                # Create Conv1d layers for downsampling and upsampling local embeddings
                feature_dim = local_encoder.output_dim
                self.conv_downsample = nn.Conv1d(
                    feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
                )
                self.conv_upsample = nn.ConvTranspose1d(
                    feature_dim, feature_dim, kernel_size=config.downsample_factor, stride=config.downsample_factor
                )  # won't be used unless training feature reconstruction
                logger.debug(f"Using Conv1d downsampling/upsampling with kernel size {config.downsample_factor}")
            else:
                self.conv_downsample = None
                self.conv_upsample = None
                logger.debug("Using average pooling and linear interpolation for downsampling/upsampling")
        else:
            self.conv_downsample = None
            self.conv_upsample = None

    def _init_global_branch(self, global_encoder: GlobalEncoder):
        """Initialize global branch components."""
        self.global_encoder = global_encoder

    def _init_mel_decoder(
        self, config: KanadeModelConfig, mel_prenet: Transformer, mel_decoder: Transformer, mel_postnet: PostNet
    ):
        """Initialize mel decoder components (prenet, upsampling, decoder, postnet)."""
        self.mel_prenet = mel_prenet
        self.mel_decoder = mel_decoder
        self.mel_postnet = mel_postnet

        # Configure mel upsampling
        self.mel_conv_upsample = None
        if config.mel_upsample_factor > 1:
            # Create Conv1DTranspose layer for mel upsampling
            input_dim = mel_prenet.output_dim
            self.mel_conv_upsample = nn.ConvTranspose1d(
                input_dim, input_dim, kernel_size=config.mel_upsample_factor, stride=config.mel_upsample_factor
            )
            logger.debug(f"Using Conv1DTranspose for mel upsampling with factor {config.mel_upsample_factor}")

    def _calculate_waveform_padding(self, audio_length: int, ensure_recon_length: bool = False) -> int:
        """Calculate required padding for input waveform to ensure consistent SSL feature lengths."""
        extractor = self.ssl_feature_extractor
        sample_rate = self.config.sample_rate
        # SSL may resample the input to its own sample rate, so calculate the number of samples after resampling
        num_samples_after_resampling = audio_length / sample_rate * extractor.ssl_sample_rate
        # We expect the SSL feature extractor to be consistent with its hop size
        expected_ssl_output_length = math.ceil(num_samples_after_resampling / extractor.hop_size)
        # If ensure_recon_length is True, we want to make sure the output length is exactly divisible by downsample factor
        if ensure_recon_length and (remainder := expected_ssl_output_length % self.downsample_factor) != 0:
            expected_ssl_output_length += self.downsample_factor - remainder
        # But it may require more input samples to produce that output length, so calculate the required input length
        num_samples_required_after_resampling = extractor.get_minimum_input_length(expected_ssl_output_length)
        # That number of samples is at the SSL sample rate, so convert back to our original sample rate
        num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
        # Calculate padding needed on each side
        padding = math.ceil((num_samples_required - audio_length) / 2)
        return padding

    def _calculate_original_audio_length(self, token_length: int) -> int:
        """Calculate the original audio length based on token length."""
        extractor = self.ssl_feature_extractor
        sample_rate = self.config.sample_rate
        # Calculate the feature length before downsampling
        feature_length = token_length * self.downsample_factor
        num_samples_required_after_resampling = extractor.get_minimum_input_length(feature_length)
        num_samples_required = num_samples_required_after_resampling / extractor.ssl_sample_rate * sample_rate
        return math.ceil(num_samples_required)

    def _calculate_target_mel_length(self, audio_length: int) -> int:
        """Calculate the target mel spectrogram length based on audio length."""
        if self.config.padding == "center":
            return audio_length // self.config.hop_length + 1
        elif self.config.padding == "same":
            return audio_length // self.config.hop_length
        else:
            return (audio_length - self.config.n_fft) // self.config.hop_length + 1

    def _process_ssl_features(self, features: list[torch.Tensor], layers: list[int]) -> torch.Tensor:
        if len(layers) > 1:
            # Get features from multiple layers and average them
            selected_features = [features[i - 1] for i in layers]
            mixed_features = torch.stack(selected_features, dim=0).mean(dim=0)
        else:
            # Just take the single specified layer
            mixed_features = features[layers[0] - 1]
        return mixed_features

    def _normalize_ssl_features(self, features: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        if not self.config.normalize_ssl_features:
            return features

        # Compute mean and std across time steps for each sample and feature dimension
        mean = torch.mean(features, dim=1, keepdim=True)  # (B, 1, C)
        std = torch.std(features, dim=1, keepdim=True)  # (B, 1, C)
        return (features - mean) / (std + eps)

    def forward_ssl_features(
        self, waveform: torch.Tensor, padding: int | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass to extract SSL features. (B, T, C)
        Args:
            waveform: Input waveform tensor of shape (B, channels, samples)
            padding: Optional padding to apply on both sides of the waveform. This is useful to ensure
                     that the SSL feature extractor produces consistent output lengths.
        Returns:
            local_ssl_features: Local SSL features for local branch. (B, T, C)
            global_ssl_features: Global SSL features for global branch. (B, T, C)
        """
        # Prepare input waveform
        if waveform.dim() == 3:
            waveform = waveform.squeeze(1)

        # 1. Extract SSL features
        if padding > 0:
            waveform = F.pad(waveform, (padding, padding), mode="constant")

        with torch.no_grad():
            ssl_features = self.ssl_feature_extractor(waveform)

        local_ssl_features = self._process_ssl_features(ssl_features, self.local_ssl_layers)
        local_ssl_features = self._normalize_ssl_features(local_ssl_features)

        global_ssl_features = self._process_ssl_features(ssl_features, self.global_ssl_layers)

        return local_ssl_features, global_ssl_features

    def forward_content(
        self, local_ssl_features: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] | None:
        """Forward pass to extract content embeddings from the local branch.
        Args:
            local_ssl_features: Local SSL features tensor of shape (B, T, C)
        Returns:
            local_quantized: Quantized local embeddings. (B, T/factor, C)
            indices: Content token indices. (B, T/factor)
            ssl_recon: Reconstructed SSL features (if feature decoder is present). (B, T, C)
            perplexity: Quantizer perplexity (if feature decoder is present). Scalar tensor.
        """
        local_encoded = self.local_encoder(local_ssl_features)

        # Downsample temporally if needed: (B, T, C) -> (B, T/factor, C)
        if self.downsample_factor > 1:
            if self.config.use_conv_downsample:
                local_encoded = self.conv_downsample(local_encoded.transpose(1, 2)).transpose(1, 2)
            else:
                local_encoded = F.avg_pool1d(
                    local_encoded.transpose(1, 2), kernel_size=self.downsample_factor, stride=self.downsample_factor
                ).transpose(1, 2)

        # If training feature reconstruction, decode local embeddings
        ssl_recon = None
        perplexity = torch.tensor(0.0)
        if self.feature_decoder is not None:
            local_quantized, local_quantize_info = self.local_quantizer(local_encoded)
            indices = local_quantize_info["indices"]
            perplexity = torch.mean(local_quantize_info["perplexity"])

            local_latent_for_ssl = local_quantized
            # Upsample if needed
            if self.downsample_factor > 1:
                if self.config.use_conv_downsample:
                    # Use conv transpose for upsampling: (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T) -> (B, T, C)
                    local_latent_for_ssl = self.conv_upsample(local_latent_for_ssl.transpose(1, 2)).transpose(1, 2)
                else:
                    # (B, T/factor, C) -> (B, T, C)
                    local_latent_for_ssl = F.interpolate(
                        local_latent_for_ssl.transpose(1, 2),
                        size=local_ssl_features.shape[1],
                        mode=self.config.local_interpolation_mode,
                    ).transpose(1, 2)

            ssl_recon = self.feature_decoder(local_latent_for_ssl)
        else:
            # If not training feature reconstruction, just get quantized local embeddings
            local_quantized, indices = self.local_quantizer.encode(local_encoded)

        return local_quantized, indices, ssl_recon, perplexity

    def forward_global(self, global_ssl_features: torch.Tensor) -> torch.Tensor:
        """Forward pass to extract global embeddings from the global branch.
        Args:
            global_ssl_features: Global SSL features tensor of shape (B, T, C)
        Returns:
            global_encoded: Global embeddings. (B, C)
        """
        global_encoded = self.global_encoder(global_ssl_features)
        return global_encoded

    def forward_mel(
        self, content_embeddings: torch.Tensor, global_embeddings: torch.Tensor, mel_length: int
    ) -> torch.Tensor:
        """Forward pass to generate mel spectrogram from content and global embeddings.
        Args:
            content_embeddings: Content embeddings tensor of shape (B, T, C)
            global_embeddings: Global embeddings tensor of shape (B, C)
            mel_length: Target mel spectrogram length (T_mel)
        Returns:
            mel_recon: Reconstructed mel spectrogram tensor of shape (B, n_mels, T_mel)
        """
        local_latent = self.mel_prenet(content_embeddings)

        # Upsample local latent to match mel spectrogram length
        # First use Conv1DTranspose if configured
        if self.mel_conv_upsample is not None:
            # (B, T/factor, C) -> (B, C, T/factor) -> conv -> (B, C, T*upsample_factor) -> (B, T*upsample_factor, C)
            local_latent = self.mel_conv_upsample(local_latent.transpose(1, 2)).transpose(1, 2)
        local_latent = F.interpolate(
            local_latent.transpose(1, 2), size=mel_length, mode=self.config.mel_interpolation_mode
        ).transpose(1, 2)  # (B, T_current, C) -> (B, T_mel, C)

        # Generate mel spectrogram, conditioned on global embeddings
        mel_recon = self.mel_decoder(local_latent, condition=global_embeddings.unsqueeze(1))
        mel_recon = mel_recon.transpose(1, 2)  # (B, n_mels, T)

        mel_recon = self.mel_postnet(mel_recon)
        return mel_recon

    # ======== Inference methods ========

    def weights_to_save(self, *, include_modules: list[str]) -> dict[str, torch.Tensor]:
        """Get model weights for saving. Excludes certain modules not needed for inference."""
        excluded_modules = [
            m for m in ["ssl_feature_extractor", "feature_decoder", "conv_upsample"] if m not in include_modules
        ]
        state_dict = {
            name: param
            for name, param in self.named_parameters()
            if not any(name.startswith(excl) for excl in excluded_modules)
        }
        return state_dict

    @classmethod
    def from_hparams(cls, config_path: str) -> "KanadeModel":
        """Instantiate KanadeModel from config file.
        Args:
            config_path (str): Path to model configuration file (.yaml).
        Returns:
            KanadeModel: Instantiated KanadeModel.
        """
        parser = jsonargparse.ArgumentParser(exit_on_error=False)
        parser.add_argument("--model", type=KanadeModel)
        cfg = parser.parse_path(config_path)
        cfg = parser.instantiate_classes(cfg)
        return cfg.model

    @classmethod
    def from_pretrained(
        cls,
        repo_id: str | None = None,
        revision: str | None = None,
        config_path: str | None = None,
        weights_path: str | None = None,
    ) -> "KanadeModel":
        """Load KanadeModel either from HuggingFace Hub or local config and weights files.
        Args:
            repo_id (str, optional): HuggingFace Hub repository ID. If provided, loads config and weights from the hub.
            revision (str, optional): Revision (branch, tag, commit) for the HuggingFace Hub repo.
            config_path (str, optional): Path to model configuration file (.yaml). Required if repo_id is not provided.
            weights_path (str, optional): Path to model weights file (.safetensors). Required if repo_id is not provided.
        Returns:
            KanadeModel: Loaded KanadeModel instance.
        """
        if repo_id is not None:
            # Load from HuggingFace Hub
            from huggingface_hub import hf_hub_download

            config_path = hf_hub_download(repo_id, "config.yaml", revision=revision)
            weights_path = hf_hub_download(repo_id, "model.safetensors", revision=revision)
        else:
            # Check local paths
            if config_path is None or weights_path is None:
                raise ValueError(
                    "Please provide either HuggingFace Hub repo_id or both config_path and weights_path for model loading."
                )

        # Load model from config
        model = cls.from_hparams(config_path)

        # Load weights
        from safetensors.torch import load_file

        state_dict = load_file(weights_path, device="cpu")
        model.load_state_dict(state_dict, strict=False)
        logger.info(f"Loaded weights from safetensors file: {weights_path}")

        return model

    @torch.inference_mode()
    def encode(self, waveform: torch.Tensor, return_content: bool = True, return_global: bool = True) -> KanadeFeatures:
        """Extract content and/or global features from audio using Kanade model.
        Args:
            waveform (torch.Tensor): Input audio waveform tensor (samples,). The sample rate should match model config.
            return_content (bool): Whether to extract content features.
            return_global (bool): Whether to extract global features.
        Returns:
            dict[str, torch.Tensor]: Extracted features.
        """
        audio_length = waveform.size(0)
        padding = self._calculate_waveform_padding(audio_length)
        local_ssl_features, global_ssl_features = self.forward_ssl_features(waveform.unsqueeze(0), padding=padding)

        result = KanadeFeatures()
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
            if return_content:
                content_embedding, token_indices, _, _ = self.forward_content(local_ssl_features)
                result.content_embedding = content_embedding.squeeze(0)  # (seq_len, dim)
                result.content_token_indices = token_indices.squeeze(0)  # (seq_len,)

            if return_global:
                global_embedding = self.forward_global(global_ssl_features)
                result.global_embedding = global_embedding.squeeze(0)  # (dim,)

        return result

    def decode_token_indices(self, indices: torch.Tensor) -> torch.Tensor:
        """Get content embeddings from content token indices. (..., seq_len) -> (..., seq_len, dim)"""
        content_embedding = self.local_quantizer.decode(indices)
        return content_embedding

    @torch.inference_mode()
    def decode(
        self,
        global_embedding: torch.Tensor,
        content_token_indices: torch.Tensor | None = None,
        content_embedding: torch.Tensor | None = None,
        target_audio_length: int | None = None,
    ) -> torch.Tensor:
        """Synthesize audio from content and global features using Kanade model and Vocos.
        Args:
            global_embedding (torch.Tensor): Global embedding tensor (dim,).
            content_token_indices (torch.Tensor, optional): Optional content token indices tensor (seq_len).
            content_embedding (torch.Tensor, optional): Optional content embedding tensor (seq_len, dim).
                If both content_token_indices and content_embedding are provided, content_embedding takes precedence.
            target_audio_length (int, optional): Target length of the output audio in samples.
                If None, uses the original audio length estimated from the sequence length of content tokens.
        Returns:
            torch.Tensor: Generated mel spectrogram tensor (n_mels, T).
        """
        # Obtain content embedding if not provided
        if content_embedding is None:
            if content_token_indices is None:
                raise ValueError("Either content_token_indices or content_embedding must be provided.")
            content_embedding = self.decode_token_indices(content_token_indices)

        if target_audio_length is None:
            # Estimate original audio length from content token sequence length
            seq_len = content_embedding.size(0)
            target_audio_length = self._calculate_original_audio_length(seq_len)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
            mel_length = self._calculate_target_mel_length(target_audio_length)
            content_embedding = content_embedding.unsqueeze(0)  # (1, seq_len, dim)
            global_embedding = global_embedding.unsqueeze(0)  # (1, dim)
            mel_spectrogram = self.forward_mel(content_embedding, global_embedding, mel_length=mel_length)

        return mel_spectrogram.squeeze(0)  # (n_mels, T)

    @torch.inference_mode()
    def voice_conversion(self, source_waveform: torch.Tensor, reference_waveform: torch.Tensor) -> torch.Tensor:
        """Convert voice using Kanade model and Vocos, keeping content from source and global characteristics from reference.
        Only supports single audio input. Just a convenient wrapper around encode and decode methods.
        Args:
            source_waveform (torch.Tensor): Source audio waveform tensor (samples,).
            reference_waveform (torch.Tensor): Reference audio waveform tensor (samples_ref,).
        Returns:
            torch.Tensor: Converted mel spectrogram tensor (n_mels, T).
        """
        # Extract source content features and reference global features
        source_features = self.encode(source_waveform, return_content=True, return_global=False)
        reference_features = self.encode(reference_waveform, return_content=False, return_global=True)

        # Synthesize mel spectrogram using source content and reference global features
        mel_spectrogram = self.decode(
            content_embedding=source_features.content_embedding,
            global_embedding=reference_features.global_embedding,
            target_audio_length=source_waveform.size(0),
        )
        return mel_spectrogram