Spaces:
Running
Running
| from dataclasses import dataclass | |
| from typing import Literal | |
| import jsonargparse | |
| import lightning as L | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import yaml | |
| from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from .data.datamodule import AudioBatch | |
| from .model import KanadeModel, KanadeModelConfig | |
| from .module.audio_feature import MelSpectrogramFeature | |
| from .module.discriminator import SpectrogramDiscriminator | |
| from .module.fsq import FiniteScalarQuantizer | |
| from .module.global_encoder import GlobalEncoder | |
| from .module.postnet import PostNet | |
| from .module.ssl_extractor import SSLFeatureExtractor | |
| from .module.transformer import Transformer | |
| from .util import freeze_modules, get_logger, load_vocoder, vocode | |
| logger = get_logger() | |
| class KanadePipelineConfig: | |
| # Training control | |
| train_feature: bool = True # Whether to train the feature reconstruction branch | |
| train_mel: bool = True # Whether to train the mel spectrogram generation branch | |
| # Audio settings | |
| audio_length: int = 138240 # Length of audio input in samples | |
| # Optimization settings | |
| lr: float = 2e-4 | |
| weight_decay: float = 1e-4 | |
| betas: tuple[float, float] = (0.9, 0.99) | |
| gradient_clip_val: float | None = 1.0 | |
| # LR scheduling parameters | |
| warmup_percent: float = 0.1 | |
| lr_div_factor: float = 10.0 | |
| lr_final_div_factor: float = 1.0 | |
| anneal_mode: str = "cos" | |
| # Loss weights | |
| feature_l1_weight: float = 30.0 | |
| feature_l2_weight: float = 0.0 | |
| mel_l1_weight: float = 30.0 | |
| mel_l2_weight: float = 0.0 | |
| adv_weight: float = 1.0 | |
| feature_matching_weight: float = 10.0 | |
| # GAN settings | |
| use_discriminator: bool = False | |
| adv_loss_type: Literal["hinge", "least_square"] = "hinge" # Type of adversarial loss | |
| discriminator_lr: float | None = None # Learning rate for discriminator | |
| discriminator_start_step: int = 0 # Step to start training discriminator | |
| discriminator_update_prob: float = 1.0 # Probability of updating discriminator at each step | |
| # Checkpoint loading | |
| ckpt_path: str | None = None # Path to checkpoint to load from | |
| skip_loading_modules: tuple[str, ...] = () # Modules to skip when loading checkpoint | |
| # Other settings | |
| log_mel_samples: int = 10 | |
| use_torch_compile: bool = True | |
| class KanadePipeline(L.LightningModule): | |
| """LightningModule wrapper for KanadeModel, handling training (including GAN).""" | |
| def __init__( | |
| self, | |
| model_config: KanadeModelConfig, | |
| pipeline_config: KanadePipelineConfig, | |
| ssl_feature_extractor: SSLFeatureExtractor, | |
| local_encoder: Transformer, | |
| local_quantizer: FiniteScalarQuantizer, | |
| feature_decoder: Transformer | None, | |
| global_encoder: GlobalEncoder, | |
| mel_prenet: Transformer, | |
| mel_decoder: Transformer, | |
| mel_postnet: PostNet, | |
| discriminator: SpectrogramDiscriminator | None = None, | |
| ): | |
| super().__init__() | |
| self.config = pipeline_config | |
| self.save_hyperparameters("model_config", "pipeline_config") | |
| self.strict_loading = False | |
| self.automatic_optimization = False | |
| self.torch_compiled = False | |
| # Validate components required for training | |
| assert not pipeline_config.train_feature or feature_decoder is not None, ( | |
| "Feature decoder must be provided if training feature reconstruction" | |
| ) | |
| logger.info( | |
| f"Training configuration: train_feature={pipeline_config.train_feature}, train_mel={pipeline_config.train_mel}" | |
| ) | |
| # 1. Kanade model | |
| self.model = KanadeModel( | |
| config=model_config, | |
| ssl_feature_extractor=ssl_feature_extractor, | |
| local_encoder=local_encoder, | |
| local_quantizer=local_quantizer, | |
| feature_decoder=feature_decoder, | |
| global_encoder=global_encoder, | |
| mel_decoder=mel_decoder, | |
| mel_prenet=mel_prenet, | |
| mel_postnet=mel_postnet, | |
| ) | |
| self._freeze_unused_modules(pipeline_config.train_feature, pipeline_config.train_mel) | |
| # Calculate padding for expected SSL output length | |
| self.padding = self.model._calculate_waveform_padding(pipeline_config.audio_length) | |
| logger.info(f"Input waveform padding for SSL feature extractor: {self.padding} samples") | |
| # Calculate target mel spectrogram length | |
| self.target_mel_length = self.model._calculate_target_mel_length(pipeline_config.audio_length) | |
| logger.info(f"Target mel spectrogram length: {self.target_mel_length} frames") | |
| # 2. Discriminator | |
| self._init_discriminator(pipeline_config, discriminator) | |
| # 3. Mel spectrogram feature extractor for loss computation | |
| if pipeline_config.train_mel: | |
| self.mel_spec = MelSpectrogramFeature( | |
| sample_rate=model_config.sample_rate, | |
| n_fft=model_config.n_fft, | |
| hop_length=model_config.hop_length, | |
| n_mels=model_config.n_mels, | |
| padding=model_config.padding, | |
| fmin=model_config.mel_fmin, | |
| fmax=model_config.mel_fmax, | |
| bigvgan_style_mel=model_config.bigvgan_style_mel, | |
| ) | |
| # Mel sample storage for logging | |
| self.vocoder = None | |
| self.validation_examples = [] | |
| self.log_mel_samples = pipeline_config.log_mel_samples | |
| def _freeze_unused_modules(self, train_feature: bool, train_mel: bool): | |
| model = self.model | |
| if not train_feature: | |
| # Freeze local branch components if not training feature reconstruction | |
| freeze_modules([model.local_encoder, model.local_quantizer, model.feature_decoder]) | |
| if model.conv_downsample is not None: | |
| freeze_modules([model.conv_downsample, model.conv_upsample]) | |
| logger.info("Feature reconstruction branch frozen: local_encoder, local_quantizer, feature_decoder") | |
| if not train_mel: | |
| # Freeze global branch and mel generation components if not training mel generation | |
| freeze_modules( | |
| [model.global_encoder, model.mel_prenet, model.mel_conv_upsample, model.mel_decoder, model.mel_postnet] | |
| ) | |
| logger.info( | |
| "Mel generation branch frozen: global_encoder, mel_prenet, mel_conv_upsample, mel_decoder, mel_postnet" | |
| ) | |
| def _init_discriminator(self, config: KanadePipelineConfig, discriminator: SpectrogramDiscriminator | None): | |
| # Setup discriminator if provided | |
| self.discriminator = discriminator | |
| self.use_discriminator = config.use_discriminator and discriminator is not None and config.train_mel | |
| if config.use_discriminator and discriminator is None: | |
| logger.error( | |
| "Discriminator is enabled in config but no discriminator model provided. Disabling GAN training." | |
| ) | |
| if config.use_discriminator and discriminator is not None and not config.train_mel: | |
| logger.warning( | |
| "Discriminator is enabled but train_mel=False. Discriminator will not be effective without mel training." | |
| ) | |
| self.discriminator_start_step = config.discriminator_start_step | |
| self.discriminator_update_prob = config.discriminator_update_prob | |
| if self.use_discriminator: | |
| logger.info("Discriminator initialized for GAN training") | |
| logger.info(f"Discriminator start step: {self.discriminator_start_step}") | |
| logger.info(f"Discriminator update probability: {self.discriminator_update_prob}") | |
| def setup(self, stage: str): | |
| # Torch compile model if enabled | |
| if torch.__version__ >= "2.0" and self.config.use_torch_compile: | |
| self.model = torch.compile(self.model) | |
| if self.discriminator is not None: | |
| self.discriminator = torch.compile(self.discriminator) | |
| self.torch_compiled = True | |
| # Load checkpoint if provided | |
| if self.config.ckpt_path: | |
| ckpt_path = self.config.ckpt_path | |
| # Download weights from HuggingFace Hub if needed | |
| if ckpt_path.startswith("hf:"): | |
| from huggingface_hub import hf_hub_download | |
| repo_id = ckpt_path[len("hf:") :] | |
| # Separate out revision if specified | |
| revision = None | |
| if "@" in repo_id: | |
| repo_id, revision = repo_id.split("@", 1) | |
| ckpt_path = hf_hub_download(repo_id, filename="model.safetensors", revision=revision) | |
| self._load_weights(ckpt_path) | |
| def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: | |
| """ | |
| Returns: | |
| ssl_real: Extracted SSL features for local branch (B, T, C) | |
| ssl_recon: Reconstructed SSL features (B, T, C) - only if train_feature=True | |
| mel_recon: Generated mel spectrogram (B, n_mels, T) - only if train_mel=True | |
| loss_dict: Dictionary with auxiliary information (codes, losses, etc.) | |
| """ | |
| loss_dict = {} | |
| # 1. Extract SSL features | |
| local_ssl_features, global_ssl_features = self.model.forward_ssl_features(waveform, padding=self.padding) | |
| # 2. Content branch processing | |
| content_embeddings, _, ssl_recon, perplexity = self.model.forward_content(local_ssl_features) | |
| loss_dict["local/perplexity"] = perplexity | |
| # 3. Global branch processing and mel reconstruction | |
| mel_recon = None | |
| if self.config.train_mel: | |
| global_embeddings = self.model.forward_global(global_ssl_features) | |
| mel_recon = self.model.forward_mel(content_embeddings, global_embeddings, mel_length=self.target_mel_length) | |
| return local_ssl_features, ssl_recon, mel_recon, loss_dict | |
| def _get_reconstruction_loss( | |
| self, audio_real: torch.Tensor, ssl_real: torch.Tensor, ssl_recon: torch.Tensor, mel_recon: torch.Tensor | |
| ) -> tuple[torch.Tensor, dict, torch.Tensor]: | |
| """Compute L1 + L2 loss for SSL feature and mel spectrogram reconstruction. | |
| Returns: | |
| total_loss: Combined reconstruction loss | |
| loss_dict: Dictionary with individual loss components | |
| mel_real: Real mel spectrogram for reference | |
| """ | |
| if audio_real.dim() == 3: | |
| audio_real = audio_real.squeeze(1) | |
| loss_dict = {} | |
| feature_loss, mel_loss = 0, 0 | |
| # Compute SSL feature reconstruction losses if training features | |
| if self.config.train_feature and self.model.feature_decoder is not None: | |
| assert ssl_real is not None and ssl_recon is not None, ( | |
| "SSL features must be provided for training feature reconstruction" | |
| ) | |
| ssl_l1 = F.l1_loss(ssl_recon, ssl_real) | |
| ssl_l2 = F.mse_loss(ssl_recon, ssl_real) | |
| feature_loss = self.config.feature_l1_weight * ssl_l1 + self.config.feature_l2_weight * ssl_l2 | |
| loss_dict.update({"ssl_l1": ssl_l1, "ssl_l2": ssl_l2, "feature_loss": feature_loss}) | |
| # Compute mel spectrogram reconstruction losses if training mel | |
| mel_real = None | |
| if self.config.train_mel: | |
| assert mel_recon is not None, "Mel reconstruction must be provided for training mel generation" | |
| # Extract reference mel spectrogram from audio | |
| mel_real = self.mel_spec(audio_real) | |
| mel_l1 = F.l1_loss(mel_recon, mel_real) | |
| mel_l2 = F.mse_loss(mel_recon, mel_real) | |
| mel_loss = self.config.mel_l1_weight * mel_l1 + self.config.mel_l2_weight * mel_l2 | |
| loss_dict.update({"mel_l1": mel_l1, "mel_l2": mel_l2, "mel_loss": mel_loss}) | |
| total_loss = feature_loss + mel_loss | |
| return total_loss, loss_dict, mel_real | |
| def _get_discriminator_loss( | |
| self, real_outputs: torch.Tensor, fake_outputs: torch.Tensor | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Compute the adversarial loss for discriminator. | |
| Returns: | |
| disc_loss: Total discriminator loss | |
| real_loss: Loss component from real samples | |
| fake_loss: Loss component from fake samples | |
| """ | |
| if self.config.adv_loss_type == "hinge": | |
| real_loss = torch.mean(torch.clamp(1 - real_outputs, min=0)) | |
| fake_loss = torch.mean(torch.clamp(1 + fake_outputs, min=0)) | |
| elif self.config.adv_loss_type == "least_square": | |
| real_loss = torch.mean((real_outputs - 1) ** 2) | |
| fake_loss = torch.mean(fake_outputs**2) | |
| else: | |
| raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}") | |
| disc_loss = real_loss + fake_loss | |
| return disc_loss, real_loss, fake_loss | |
| def _get_generator_loss(self, fake_outputs: torch.Tensor) -> torch.Tensor: | |
| """Compute the adversarial loss for generator.""" | |
| if self.config.adv_loss_type == "hinge": | |
| return torch.mean(torch.clamp(1 - fake_outputs, min=0)) | |
| elif self.config.adv_loss_type == "least_square": | |
| return torch.mean((fake_outputs - 1) ** 2) | |
| else: | |
| raise ValueError(f"Unknown adversarial loss type: {self.config.adv_loss_type}") | |
| def _get_feature_matching_loss( | |
| self, real_intermediates: list[torch.Tensor], fake_intermediates: list[torch.Tensor] | |
| ) -> torch.Tensor: | |
| losses = [] | |
| for real_feat, fake_feat in zip(real_intermediates, fake_intermediates): | |
| losses.append(torch.mean(torch.abs(real_feat.detach() - fake_feat))) | |
| fm_loss = torch.mean(torch.stack(losses)) | |
| return fm_loss | |
| def _discriminator_step( | |
| self, batch: AudioBatch, optimizer_disc: torch.optim.Optimizer | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor], list[torch.Tensor]]: | |
| """ | |
| Returns: | |
| ssl_real: Real SSL features | |
| ssl_recon: Reconstructed SSL features from generator | |
| mel_recon: Generated mel spectrogram | |
| loss_dict: Dictionary with auxiliary information | |
| real_intermediates: Intermediate feature maps from discriminator for real mel | |
| """ | |
| assert self.use_discriminator, "Discriminator step called but discriminator is not enabled" | |
| ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform) | |
| assert mel_recon is not None, "Mel reconstruction must be available for discriminator step" | |
| # Get true mel spectrogram (always use original waveform) | |
| mel_real = self.mel_spec(batch.waveform) | |
| # Get discriminator outputs and intermediates for real mel | |
| real_outputs, real_intermediates = self.discriminator(mel_real) | |
| fake_outputs, _ = self.discriminator(mel_recon.detach()) | |
| # Compute discriminator loss | |
| disc_loss, real_loss, fake_loss = self._get_discriminator_loss(real_outputs, fake_outputs) | |
| # Log discriminator losses | |
| batch_size = batch.waveform.size(0) | |
| self.log("train/disc/real", real_loss, batch_size=batch_size) | |
| self.log("train/disc/fake", fake_loss, batch_size=batch_size) | |
| self.log("train/disc/loss", disc_loss, batch_size=batch_size, prog_bar=True) | |
| for name, value in loss_dict.items(): | |
| self.log(f"train/{name}", value, batch_size=batch_size) | |
| # Optimize discriminator | |
| optimizer_disc.zero_grad() | |
| self.manual_backward(disc_loss) | |
| # Log gradient norm | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.discriminator.parameters(), max_norm=self.config.gradient_clip_val or torch.inf | |
| ) | |
| self.log("train/disc/grad_norm", grad_norm, batch_size=batch_size) | |
| optimizer_disc.step() | |
| return ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates | |
| def _generator_step( | |
| self, | |
| batch: AudioBatch, | |
| optimizer_gen: torch.optim.Optimizer, | |
| ssl_real: torch.Tensor | None = None, | |
| ssl_recon: torch.Tensor | None = None, | |
| mel_recon: torch.Tensor | None = None, | |
| loss_dict: dict | None = None, | |
| real_intermediates: list[torch.Tensor] | None = None, | |
| training_disc: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Args: | |
| batch: Audio batch with waveform and augmented_waveform | |
| optimizer_gen: Generator optimizer | |
| ssl_real: Real SSL features (optional) | |
| ssl_recon: Reconstructed SSL features (optional) | |
| mel_recon: Generated mel spectrogram (optional) | |
| loss_dict: Dictionary with auxiliary information (optional) | |
| real_intermediates: Intermediate feature maps from discriminator for real mel (optional) | |
| training_disc: Whether discriminator is being trained in this step | |
| Returns: | |
| gen_loss: Total generator loss | |
| """ | |
| # Forward pass through the model if not already done in discriminator step | |
| if loss_dict is None: | |
| ssl_real, ssl_recon, mel_recon, loss_dict = self(batch.waveform) | |
| # Compute reconstruction loss (always use original waveform for mel target) | |
| recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(batch.waveform, ssl_real, ssl_recon, mel_recon) | |
| gen_loss = recon_loss | |
| # Compute adversarial and feature matching losses if using discriminator | |
| batch_size = batch.waveform.size(0) | |
| if training_disc: | |
| assert mel_real is not None and mel_recon is not None, "Mel spectrograms must be provided for GAN training" | |
| if real_intermediates is None: | |
| _, real_intermediates = self.discriminator(mel_real) | |
| fake_outputs, fake_intermediates = self.discriminator(mel_recon) | |
| # Compute adversarial loss | |
| adv_loss = self._get_generator_loss(fake_outputs) | |
| gen_loss += self.config.adv_weight * adv_loss | |
| self.log("train/gen/adv_loss", adv_loss, batch_size=batch_size) | |
| # Compute feature matching loss | |
| feature_matching_loss = self._get_feature_matching_loss(real_intermediates, fake_intermediates) | |
| gen_loss += self.config.feature_matching_weight * feature_matching_loss | |
| self.log("train/gen/feature_matching_loss", feature_matching_loss, batch_size=batch_size) | |
| # Log reconstruction losses | |
| for name, value in loss_dict.items(): | |
| self.log(f"train/{name}", value, batch_size=batch_size) | |
| for name, value in recon_dict.items(): | |
| self.log(f"train/gen/{name}", value, batch_size=batch_size) | |
| self.log("train/loss", gen_loss, batch_size=batch_size, prog_bar=True) | |
| # Optimize generator | |
| optimizer_gen.zero_grad() | |
| self.manual_backward(gen_loss) | |
| # Log gradient norm | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), max_norm=self.config.gradient_clip_val or torch.inf | |
| ) | |
| self.log("train/gen/grad_norm", grad_norm, batch_size=batch_size) | |
| optimizer_gen.step() | |
| return gen_loss | |
| def training_step(self, batch: AudioBatch, batch_idx: int): | |
| if self.use_discriminator: | |
| optimizer_disc, optimizer_gen = self.optimizers() | |
| scheduler_disc, scheduler_gen = self.lr_schedulers() | |
| else: | |
| optimizer_gen = self.optimizers() | |
| scheduler_gen = self.lr_schedulers() | |
| # Determine if discriminator should be trained in this step | |
| training_disc = ( | |
| self.use_discriminator | |
| and self.global_step >= self.discriminator_start_step | |
| and torch.rand(1).item() < self.discriminator_update_prob | |
| ) | |
| if self.global_step == self.discriminator_start_step and self.use_discriminator: | |
| logger.info(f"Discriminator training starts at step {self.global_step}") | |
| ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = None, None, None, None, None | |
| # Train discriminator if conditions are met | |
| if training_disc: | |
| ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates = self._discriminator_step( | |
| batch, optimizer_disc | |
| ) | |
| scheduler_disc.step() | |
| elif self.use_discriminator: | |
| # Step the discriminator scheduler even when not training discriminator | |
| scheduler_disc.step() | |
| # Train generator | |
| self._generator_step( | |
| batch, optimizer_gen, ssl_real, ssl_recon, mel_recon, loss_dict, real_intermediates, training_disc | |
| ) | |
| scheduler_gen.step() | |
| def validation_step(self, batch: AudioBatch, batch_idx: int): | |
| audio_real = batch.waveform | |
| ssl_real, ssl_recon, mel_recon, loss_dict = self(audio_real) | |
| # Convert to waveform using vocoder for logging | |
| batch_size = audio_real.size(0) | |
| # Compute reconstruction loss | |
| recon_loss, recon_dict, mel_real = self._get_reconstruction_loss(audio_real, ssl_real, ssl_recon, mel_recon) | |
| gen_loss = recon_loss | |
| # Log reconstruction losses | |
| for name, value in loss_dict.items(): | |
| self.log(f"val/{name}", value, batch_size=batch_size) | |
| for name, value in recon_dict.items(): | |
| self.log(f"val/gen/{name}", value, batch_size=batch_size) | |
| self.log("val/loss", gen_loss, batch_size=batch_size) | |
| # Save first few samples for visualization at end of epoch if training mel generation | |
| if self.config.train_mel and len(self.validation_examples) < self.log_mel_samples: | |
| assert mel_real is not None and mel_recon is not None, ( | |
| "Mel spectrograms must be provided for validation logging" | |
| ) | |
| audio_real = audio_real[0].cpu() | |
| audio_gen = None | |
| if self.vocoder is not None: | |
| audio_gen = self.vocode(mel_recon[0:1])[0].cpu() | |
| self.validation_examples.append((mel_real[0].cpu(), mel_recon[0].detach().cpu(), audio_real, audio_gen)) | |
| def predict_step(self, batch: AudioBatch, batch_idx: int): | |
| audio_real = batch.waveform | |
| _, _, mel_gen, _ = self(audio_real) | |
| audio_gen = self.vocode(mel_gen) | |
| if audio_gen.dim() == 2: | |
| audio_gen = audio_gen.unsqueeze(1) | |
| return {"audio_ids": batch.audio_ids, "audio_real": audio_real, "audio_gen": audio_gen} | |
| def configure_optimizers(self): | |
| # Generator optimizer | |
| optimizer_gen = AdamW( | |
| self.model.parameters(), lr=self.config.lr, betas=self.config.betas, weight_decay=self.config.weight_decay | |
| ) | |
| # Generator LR scheduler | |
| scheduler_gen = OneCycleLR( | |
| optimizer_gen, | |
| max_lr=self.config.lr, | |
| div_factor=self.config.lr_div_factor, | |
| final_div_factor=self.config.lr_final_div_factor, | |
| pct_start=self.config.warmup_percent, | |
| anneal_strategy=self.config.anneal_mode, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| ) | |
| if not self.use_discriminator: | |
| return ([optimizer_gen], [{"scheduler": scheduler_gen, "interval": "step"}]) | |
| # If using discriminator, also configure discriminator optimizer and scheduler | |
| optimizer_disc = AdamW( | |
| self.discriminator.parameters(), | |
| lr=self.config.discriminator_lr or self.config.lr, | |
| betas=self.config.betas, | |
| weight_decay=self.config.weight_decay, | |
| ) | |
| # Discriminator LR scheduler | |
| scheduler_disc = OneCycleLR( | |
| optimizer_disc, | |
| max_lr=self.config.discriminator_lr or self.config.lr, | |
| div_factor=self.config.lr_div_factor, | |
| final_div_factor=self.config.lr_final_div_factor, | |
| pct_start=self.config.warmup_percent, | |
| anneal_strategy=self.config.anneal_mode, | |
| total_steps=self.trainer.estimated_stepping_batches, | |
| ) | |
| # Load optimizer state | |
| if self.config.ckpt_path: | |
| if self.config.ckpt_path.endswith(".ckpt"): | |
| checkpoint = torch.load(self.config.ckpt_path) | |
| optimizer_states = checkpoint["optimizer_states"] | |
| if len(optimizer_states) > 1 and self.use_discriminator: | |
| optimizer_disc.load_state_dict(optimizer_states[0]) | |
| optimizer_gen.load_state_dict(optimizer_states[1]) | |
| logger.info("Loaded discriminator and generator's optimizer states from checkpoint") | |
| elif len(optimizer_states) == 1 and not self.use_discriminator: | |
| # Load generator optimizer state only | |
| optimizer_gen.load_state_dict(optimizer_states[0]) | |
| logger.info("Loaded generator's optimizer state from checkpoint") | |
| else: | |
| logger.info("No optimizer state loaded since checkpoint is not a .ckpt file") | |
| return ( | |
| [optimizer_disc, optimizer_gen], | |
| [{"scheduler": scheduler_disc, "interval": "step"}, {"scheduler": scheduler_gen, "interval": "step"}], | |
| ) | |
| def _setup_vocoder(self): | |
| try: | |
| return load_vocoder(name=self.model.config.vocoder_name) | |
| except ImportError: | |
| logger.error("Vocoder could not be loaded. Please install the required dependencies.") | |
| return None | |
| def vocode(self, mel: torch.Tensor) -> torch.Tensor: | |
| self.vocoder = self.vocoder.to(mel.device) | |
| waveform = vocode(self.vocoder, mel) | |
| return waveform.cpu().float() | |
| def on_validation_start(self): | |
| self.vocoder = self._setup_vocoder() | |
| def on_predict_start(self): | |
| self.vocoder = self._setup_vocoder() | |
| def on_validation_end(self): | |
| if len(self.validation_examples) > 0: | |
| for i, (mel_real, mel_recon, audio_real, audio_gen) in enumerate(self.validation_examples): | |
| # Log spectrograms | |
| fig_real = self._get_spectrogram_plot(mel_real) | |
| fig_gen = self._get_spectrogram_plot(mel_recon) | |
| self._log_figure(f"val/{i}_mel_real", fig_real) | |
| self._log_figure(f"val/{i}_mel_gen", fig_gen) | |
| # Log audio samples | |
| if audio_gen is not None: | |
| audio_real = audio_real.cpu().numpy() | |
| audio_gen = audio_gen.cpu().numpy() | |
| self._log_audio(f"val/{i}_audio_real", audio_real) | |
| self._log_audio(f"val/{i}_audio_gen", audio_gen) | |
| self.validation_examples = [] | |
| # Clear vocoder to free memory | |
| self.vocoder = None | |
| def _log_figure(self, tag: str, fig): | |
| """Log a matplotlib figure to the logger.""" | |
| if isinstance(self.logger, TensorBoardLogger): | |
| self.logger.experiment.add_figure(tag, fig, self.global_step) | |
| elif isinstance(self.logger, WandbLogger): | |
| import PIL.Image as Image | |
| fig.canvas.draw() | |
| image = Image.frombytes("RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba()) | |
| image = image.convert("RGB") | |
| self.logger.log_image(tag, [image], step=self.global_step) | |
| def _log_audio(self, tag: str, audio: np.ndarray): | |
| """Log an audio sample to the logger.""" | |
| if isinstance(self.logger, TensorBoardLogger): | |
| self.logger.experiment.add_audio(tag, audio, self.global_step, sample_rate=self.model.config.sample_rate) | |
| elif isinstance(self.logger, WandbLogger): | |
| self.logger.log_audio( | |
| tag, [audio.flatten()], sample_rate=[self.model.config.sample_rate], step=self.global_step | |
| ) | |
| def _get_spectrogram_plot(self, mel: torch.Tensor): | |
| from matplotlib import pyplot as plt | |
| mel = mel.detach().cpu().numpy() | |
| fig, ax = plt.subplots(figsize=(10, 4)) | |
| im = ax.imshow(mel, aspect="auto", origin="lower", cmap="magma", vmin=-8.0, vmax=5.0) | |
| fig.colorbar(im, ax=ax) | |
| ax.set_ylabel("Mel bins") | |
| ax.set_xlabel("Time steps") | |
| fig.tight_layout() | |
| return fig | |
| def _load_weights(self, ckpt_path: str | None, model_state_dict: dict[str, torch.Tensor] | None = None): | |
| """Load model and discriminator weights from checkpoint. Supports .ckpt (Lightning), .safetensors, .pt/.pth formats. | |
| If model_state_dict is provided, load weights from it instead of ckpt_path.""" | |
| def select_keys(state_dict: dict, prefix: str) -> dict: | |
| """Select keys from state_dict that start with the given prefix. Remove the prefix from keys.""" | |
| return {k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)} | |
| def remove_prefix(state_dict: dict, prefix: str) -> dict: | |
| """Remove a prefix from keys that start with that prefix.""" | |
| return {k[len(prefix) :] if k.startswith(prefix) else k: v for k, v in state_dict.items()} | |
| def add_prefix(state_dict: dict, prefix: str) -> dict: | |
| """Add a prefix to keys that do not start with that prefix.""" | |
| return {f"{prefix}{k}" if not k.startswith(prefix) else k: v for k, v in state_dict.items()} | |
| # Load state dict | |
| if model_state_dict is not None: | |
| # Load from provided state dict | |
| disc_state_dict = {} | |
| elif ckpt_path.endswith(".ckpt"): | |
| # Lightning checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| model_state_dict = select_keys(checkpoint["state_dict"], "model.") | |
| disc_state_dict = select_keys(checkpoint["state_dict"], "discriminator.") | |
| elif ckpt_path.endswith(".safetensors"): | |
| # Safetensors checkpoint | |
| from safetensors.torch import load_file | |
| checkpoint = load_file(ckpt_path, device="cpu") | |
| model_state_dict = checkpoint | |
| disc_state_dict = {} | |
| elif ckpt_path.endswith(".pt") or ckpt_path.endswith(".pth"): | |
| # Standard PyTorch checkpoint | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| model_state_dict = checkpoint | |
| disc_state_dict = {} | |
| else: | |
| raise ValueError(f"Unsupported checkpoint format: {ckpt_path}") | |
| # Load model weights | |
| model_state_dict = remove_prefix(model_state_dict, "_orig_mod.") | |
| model_state_dict = { | |
| k: v | |
| for k, v in model_state_dict.items() | |
| if not any(k.startswith(module) for module in self.config.skip_loading_modules) | |
| } | |
| if self.torch_compiled: | |
| model_state_dict = add_prefix(model_state_dict, "_orig_mod.") | |
| if len(model_state_dict) > 0: | |
| result = self.model.load_state_dict(model_state_dict, strict=False) | |
| logger.info(f"Loaded model weights from {ckpt_path or 'provided state_dict'}.") | |
| if result.missing_keys: | |
| logger.debug(f"Missing keys in model state_dict: {result.missing_keys}") | |
| if result.unexpected_keys: | |
| logger.debug(f"Unexpected keys in model state_dict: {result.unexpected_keys}") | |
| # Load discriminator weights if available | |
| if self.use_discriminator: | |
| disc_state_dict = remove_prefix(disc_state_dict, "_orig_mod.") | |
| if self.torch_compiled: | |
| disc_state_dict = add_prefix(disc_state_dict, "_orig_mod.") | |
| if len(disc_state_dict) > 0: | |
| result = self.discriminator.load_state_dict(disc_state_dict, strict=False) | |
| logger.info(f"Loaded discriminator weights from {ckpt_path}.") | |
| if result.missing_keys: | |
| logger.debug(f"Missing keys in discriminator state_dict: {result.missing_keys}") | |
| if result.unexpected_keys: | |
| logger.debug(f"Unexpected keys in discriminator state_dict: {result.unexpected_keys}") | |
| def from_hparams(cls, config_path: str) -> "KanadePipeline": | |
| """Instantiate KanadePipeline from config file. | |
| Args: | |
| config_path (str): Path to model configuration file (.yaml). | |
| Returns: | |
| KanadePipeline: Instantiated KanadePipeline. | |
| """ | |
| # Load config | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| # Remove related fields to prevent loading actual weights here | |
| new_config = {"model": config["model"]} | |
| pipeline_config = new_config["model"]["init_args"]["pipeline_config"] | |
| if "ckpt_path" in pipeline_config: | |
| del pipeline_config["ckpt_path"] | |
| if "skip_loading_modules" in pipeline_config: | |
| del pipeline_config["skip_loading_modules"] | |
| # Instantiate model using jsonargparse | |
| parser = jsonargparse.ArgumentParser(exit_on_error=False) | |
| parser.add_argument("--model", type=KanadePipeline) | |
| cfg = parser.parse_object(new_config) | |
| cfg = parser.instantiate_classes(cfg) | |
| return cfg.model | |
| def from_pretrained(config_path: str, ckpt_path: str) -> "KanadePipeline": | |
| """Load KanadePipeline from training configuration and checkpoint files. | |
| Args: | |
| config_path: Path to pipeline configuration file (YAML). | |
| ckpt_path: Path to checkpoint file (.ckpt) or model weights (.safetensors). | |
| Returns: | |
| KanadePipeline: Instantied KanadePipeline with loaded weights. | |
| """ | |
| # Load pipeline from config | |
| model = KanadePipeline.from_hparams(config_path) | |
| # Load the weights | |
| model._load_weights(ckpt_path) | |
| return model | |