Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Bake Context Embedding into MagpieTTS Checkpoint | |
| This script converts a MagpieTTS decoder_ce checkpoint by: | |
| 1. Loading a reference audio file | |
| 2. Running it through the context_encoder to get the embedding | |
| 3. Saving a new checkpoint with: | |
| - The baked context embedding as a buffer | |
| - All original weights EXCEPT context_encoder weights | |
| 4. Saving a modified config without context_encoder settings | |
| Usage: | |
| python scripts/magpietts/bake_context_embedding.py \ | |
| --input_checkpoint /path/to/original.ckpt \ | |
| --config_path /path/to/config.yaml \ | |
| --output_checkpoint /path/to/baked.ckpt \ | |
| --context_audio /path/to/reference.wav | |
| The resulting checkpoint will be smaller (no context_encoder weights) and will | |
| always use the baked reference audio embedding for voice cloning, regardless of | |
| what context audio is provided at inference time. | |
| A modified config file will be saved alongside the output checkpoint with | |
| context_encoder settings removed. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| from copy import deepcopy | |
| from typing import Tuple | |
| from examples.tts.magpietts.utils import update_config_for_inference | |
| import soundfile as sf | |
| import torch | |
| from omegaconf import OmegaConf, open_dict | |
| from nemo.collections.tts.models import MagpieTTSModel | |
| from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths | |
| from nemo.utils import logging | |
| def load_audio(audio_path: str, target_sample_rate: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Load audio file and return tensor with length. | |
| Args: | |
| audio_path: Path to audio file. | |
| target_sample_rate: Expected sample rate. Audio will be resampled if needed. | |
| Returns: | |
| Tuple of (audio_tensor, audio_lens) with shapes (1, T) and (1,). | |
| Raises: | |
| ValueError: If audio file cannot be loaded or has wrong sample rate. | |
| """ | |
| audio, sr = sf.read(audio_path, dtype='float32') | |
| if sr != target_sample_rate: | |
| try: | |
| import librosa | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate) | |
| logging.info(f"Resampled audio from {sr}Hz to {target_sample_rate}Hz") | |
| except ImportError: | |
| raise ValueError( | |
| f"Audio sample rate {sr} does not match target {target_sample_rate}. " | |
| "Install librosa for automatic resampling: pip install librosa" | |
| ) | |
| # Convert to tensor: (T,) -> (1, T) | |
| audio_tensor = torch.tensor(audio).unsqueeze(0) | |
| audio_lens = torch.tensor([audio_tensor.shape[1]]) | |
| return audio_tensor, audio_lens | |
| def bake_model_context_embedding( | |
| model: MagpieTTSModel, | |
| context_audio: torch.Tensor, | |
| context_audio_lens: torch.Tensor, | |
| ) -> None: | |
| """Compute and store the context embedding from reference audio into the model. | |
| This function runs the context audio through the model's context_encoder and stores | |
| the resulting embedding as a buffer. After baking, the context_encoder weights can | |
| be removed from the checkpoint to reduce model size. | |
| Only supported for decoder_ce model type. | |
| Args: | |
| model: MagpieTTSModel instance with context_encoder. | |
| context_audio: Reference audio waveform. Shape: (1, T_samples). | |
| context_audio_lens: Length of audio in samples. Shape: (1,). | |
| Raises: | |
| ValueError: If model type is not decoder_ce. | |
| RuntimeError: If context_encoder is not available. | |
| """ | |
| if model.model_type != 'decoder_ce': | |
| raise ValueError( | |
| f"Baking context embedding is only supported for decoder_ce model type, got {model.model_type}" | |
| ) | |
| if not hasattr(model, 'context_encoder'): | |
| raise RuntimeError("context_encoder not found. Cannot bake embedding.") | |
| with torch.no_grad(): | |
| # Convert audio to codec tokens | |
| context_audio_codes, context_audio_codes_lens = model.audio_to_codes( | |
| context_audio, context_audio_lens, audio_type='context' | |
| ) | |
| context_audio_codes = model.pad_audio_codes( | |
| context_audio_codes, model.frame_stacking_factor, pad_token=0 | |
| ) | |
| context_audio_embedded = model.embed_audio_tokens(context_audio_codes) | |
| # Compute context length after frame stacking | |
| context_input_lens = torch.ceil( | |
| context_audio_codes_lens / model.frame_stacking_factor | |
| ).to(context_audio_codes_lens.dtype) | |
| context_mask = get_mask_from_lengths(context_input_lens) | |
| # Run through context encoder | |
| context_embedding = model.context_encoder( | |
| context_audio_embedded, context_mask, cond=None, cond_mask=None | |
| )['output'] | |
| # Store as buffers (squeeze batch dim since we store single embedding) | |
| model.baked_context_embedding = context_embedding.squeeze(0) # (T, E) | |
| model.baked_context_embedding_len = context_input_lens.squeeze(0) # scalar | |
| logging.info( | |
| f"Baked context embedding with shape {model.baked_context_embedding.shape}, " | |
| f"length {model.baked_context_embedding_len.item()}" | |
| ) | |
| def bake_context_embedding( | |
| input_checkpoint: str, | |
| config_path: str, | |
| output_checkpoint: str, | |
| context_audio: str, | |
| device: str = 'cuda', | |
| ) -> None: | |
| """Bake context embedding into checkpoint. | |
| Args: | |
| input_checkpoint: Path to original MagpieTTS checkpoint (.ckpt). | |
| config_path: Path to model config file (.yaml). | |
| output_checkpoint: Path to save the baked checkpoint. | |
| context_audio: Path to reference audio file for baking. | |
| device: Device to run inference on ('cuda' or 'cpu'). | |
| Raises: | |
| ValueError: If model type is not decoder_ce. | |
| FileNotFoundError: If input files don't exist. | |
| """ | |
| # Validate inputs | |
| if not os.path.exists(input_checkpoint): | |
| raise FileNotFoundError(f"Input checkpoint not found: {input_checkpoint}") | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Config file not found: {config_path}") | |
| if not os.path.exists(context_audio): | |
| raise FileNotFoundError(f"Context audio not found: {context_audio}") | |
| logging.info(f"Loading model from {input_checkpoint}") | |
| logging.info(f"Using config from {config_path}") | |
| # Load config | |
| cfg = OmegaConf.load(config_path) | |
| if "cfg" in cfg: | |
| cfg = cfg.cfg | |
| print(cfg) | |
| with open_dict(cfg): | |
| cfg, cfg_sample_rate = update_config_for_inference( | |
| cfg, | |
| "/nemo_codec_checkpoints/21fps_causal_codecmodel.nemo", | |
| False, | |
| False, | |
| ) | |
| # Load model | |
| model = MagpieTTSModel(cfg) | |
| ckpt = torch.load(input_checkpoint, weights_only=False, map_location=device) | |
| state_dict = ckpt.get('state_dict', ckpt) | |
| model.load_state_dict(state_dict, strict=False) | |
| model = model.to(device) | |
| model.eval() | |
| # Validate model type | |
| if model.model_type != 'decoder_ce': | |
| raise ValueError( | |
| f"Baking context embedding is only supported for decoder_ce model type, " | |
| f"got {model.model_type}" | |
| ) | |
| # Check that context_encoder exists | |
| if not hasattr(model, 'context_encoder'): | |
| raise RuntimeError( | |
| "Model does not have context_encoder. It may already have a baked embedding." | |
| ) | |
| # Load reference audio | |
| logging.info(f"Loading reference audio from {context_audio}") | |
| sample_rate = model.sample_rate | |
| audio_tensor, audio_lens = load_audio(context_audio, target_sample_rate=sample_rate) | |
| audio_tensor = audio_tensor.to(device) | |
| audio_lens = audio_lens.to(device) | |
| logging.info(f"Reference audio duration: {audio_lens[0].item() / sample_rate:.2f}s") | |
| # Bake the embedding | |
| logging.info("Computing context embedding...") | |
| bake_model_context_embedding(model, audio_tensor, audio_lens) | |
| # Verify baking worked | |
| if not model.has_baked_context_embedding: | |
| raise RuntimeError("Failed to bake context embedding") | |
| logging.info( | |
| f"Baked embedding shape: {model.baked_context_embedding.shape}, " | |
| f"length: {model.baked_context_embedding_len.item()}" | |
| ) | |
| # Save the model - state_dict will automatically exclude context_encoder | |
| logging.info(f"Saving baked checkpoint to {output_checkpoint}") | |
| # Get state dict (will exclude context_encoder due to has_baked_context_embedding) | |
| state_dict = model.state_dict() | |
| # Explicitly remove any remaining context_encoder keys | |
| context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k] | |
| for key in context_encoder_keys: | |
| del state_dict[key] | |
| # Count excluded keys for reporting | |
| original_ckpt = torch.load(input_checkpoint, weights_only=False, map_location='cpu') | |
| original_state_dict = original_ckpt.get('state_dict', original_ckpt) | |
| excluded_keys = [k for k in original_state_dict.keys() if 'context_encoder' in k] | |
| logging.info(f"Removed {len(excluded_keys)} context_encoder parameters") | |
| # Calculate size reduction | |
| if excluded_keys: | |
| original_size = sum(original_state_dict[k].numel() for k in excluded_keys) | |
| logging.info(f"Approximate size reduction: {original_size * 4 / 1024 / 1024:.1f} MB (float32)") | |
| # Create modified config without context_encoder | |
| logging.info("Creating modified config without context_encoder...") | |
| modified_cfg = deepcopy(cfg) | |
| with open_dict(modified_cfg): | |
| if 'model' in modified_cfg and 'context_encoder' in modified_cfg.model: | |
| del modified_cfg.model.context_encoder | |
| logging.info("Removed 'context_encoder' from config") | |
| # Add flag to indicate this checkpoint has baked embedding | |
| if 'model' in modified_cfg: | |
| modified_cfg.model.has_baked_context_embedding = True | |
| # Save checkpoint | |
| output_dir = os.path.dirname(output_checkpoint) | |
| if output_dir and not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| torch.save({'state_dict': state_dict}, output_checkpoint) | |
| logging.info(f"Saved baked checkpoint to {output_checkpoint}") | |
| # Save modified config | |
| output_config_path = output_checkpoint.replace('.ckpt', '_config.yaml') | |
| if output_config_path == output_checkpoint: | |
| output_config_path = output_checkpoint + '_config.yaml' | |
| OmegaConf.save(modified_cfg, output_config_path) | |
| logging.info(f"Saved modified config to {output_config_path}") | |
| # Verify the saved checkpoint | |
| logging.info("Verifying saved checkpoint...") | |
| loaded_state = torch.load(output_checkpoint, weights_only=False, map_location='cpu')['state_dict'] | |
| assert 'baked_context_embedding' in loaded_state, "baked_context_embedding not in saved checkpoint" | |
| assert 'baked_context_embedding_len' in loaded_state, "baked_context_embedding_len not in saved checkpoint" | |
| assert not any( | |
| 'context_encoder' in k for k in loaded_state.keys() | |
| ), "context_encoder keys should not be in saved checkpoint" | |
| # Verify the saved config | |
| logging.info("Verifying saved config...") | |
| loaded_cfg = OmegaConf.load(output_config_path) | |
| assert 'context_encoder' not in loaded_cfg.get('model', {}), "context_encoder should not be in saved config" | |
| logging.info("Verification successful!") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Bake context embedding into MagpieTTS checkpoint", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| parser.add_argument( | |
| '--input_checkpoint', | |
| type=str, | |
| required=True, | |
| help='Path to original MagpieTTS checkpoint (.ckpt)', | |
| ) | |
| parser.add_argument( | |
| '--config_path', | |
| type=str, | |
| required=True, | |
| help='Path to model config file (.yaml)', | |
| ) | |
| parser.add_argument( | |
| '--output_checkpoint', | |
| type=str, | |
| required=True, | |
| help='Path to save the baked checkpoint', | |
| ) | |
| parser.add_argument( | |
| '--context_audio', | |
| type=str, | |
| required=True, | |
| help='Path to reference audio file for baking', | |
| ) | |
| parser.add_argument( | |
| '--device', | |
| type=str, | |
| default='cuda', | |
| choices=['cuda', 'cpu'], | |
| help='Device to run inference on (default: cuda)', | |
| ) | |
| args = parser.parse_args() | |
| bake_context_embedding( | |
| input_checkpoint=args.input_checkpoint, | |
| config_path=args.config_path, | |
| output_checkpoint=args.output_checkpoint, | |
| context_audio=args.context_audio, | |
| device=args.device, | |
| ) | |
| if __name__ == '__main__': | |
| main() | |