MagpieTTS_Internal_Demo / scripts /magpietts /bake_context_embedding.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
#!/usr/bin/env python
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bake Context Embedding into MagpieTTS Checkpoint
This script converts a MagpieTTS decoder_ce checkpoint by:
1. Loading a reference audio file
2. Running it through the context_encoder to get the embedding
3. Saving a new checkpoint with:
- The baked context embedding as a buffer
- All original weights EXCEPT context_encoder weights
4. Saving a modified config without context_encoder settings
Usage:
python scripts/magpietts/bake_context_embedding.py \
--input_checkpoint /path/to/original.ckpt \
--config_path /path/to/config.yaml \
--output_checkpoint /path/to/baked.ckpt \
--context_audio /path/to/reference.wav
The resulting checkpoint will be smaller (no context_encoder weights) and will
always use the baked reference audio embedding for voice cloning, regardless of
what context audio is provided at inference time.
A modified config file will be saved alongside the output checkpoint with
context_encoder settings removed.
"""
from __future__ import annotations
import argparse
import os
from copy import deepcopy
from typing import Tuple
from examples.tts.magpietts.utils import update_config_for_inference
import soundfile as sf
import torch
from omegaconf import OmegaConf, open_dict
from nemo.collections.tts.models import MagpieTTSModel
from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths
from nemo.utils import logging
def load_audio(audio_path: str, target_sample_rate: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]:
"""Load audio file and return tensor with length.
Args:
audio_path: Path to audio file.
target_sample_rate: Expected sample rate. Audio will be resampled if needed.
Returns:
Tuple of (audio_tensor, audio_lens) with shapes (1, T) and (1,).
Raises:
ValueError: If audio file cannot be loaded or has wrong sample rate.
"""
audio, sr = sf.read(audio_path, dtype='float32')
if sr != target_sample_rate:
try:
import librosa
audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
logging.info(f"Resampled audio from {sr}Hz to {target_sample_rate}Hz")
except ImportError:
raise ValueError(
f"Audio sample rate {sr} does not match target {target_sample_rate}. "
"Install librosa for automatic resampling: pip install librosa"
)
# Convert to tensor: (T,) -> (1, T)
audio_tensor = torch.tensor(audio).unsqueeze(0)
audio_lens = torch.tensor([audio_tensor.shape[1]])
return audio_tensor, audio_lens
def bake_model_context_embedding(
model: MagpieTTSModel,
context_audio: torch.Tensor,
context_audio_lens: torch.Tensor,
) -> None:
"""Compute and store the context embedding from reference audio into the model.
This function runs the context audio through the model's context_encoder and stores
the resulting embedding as a buffer. After baking, the context_encoder weights can
be removed from the checkpoint to reduce model size.
Only supported for decoder_ce model type.
Args:
model: MagpieTTSModel instance with context_encoder.
context_audio: Reference audio waveform. Shape: (1, T_samples).
context_audio_lens: Length of audio in samples. Shape: (1,).
Raises:
ValueError: If model type is not decoder_ce.
RuntimeError: If context_encoder is not available.
"""
if model.model_type != 'decoder_ce':
raise ValueError(
f"Baking context embedding is only supported for decoder_ce model type, got {model.model_type}"
)
if not hasattr(model, 'context_encoder'):
raise RuntimeError("context_encoder not found. Cannot bake embedding.")
with torch.no_grad():
# Convert audio to codec tokens
context_audio_codes, context_audio_codes_lens = model.audio_to_codes(
context_audio, context_audio_lens, audio_type='context'
)
context_audio_codes = model.pad_audio_codes(
context_audio_codes, model.frame_stacking_factor, pad_token=0
)
context_audio_embedded = model.embed_audio_tokens(context_audio_codes)
# Compute context length after frame stacking
context_input_lens = torch.ceil(
context_audio_codes_lens / model.frame_stacking_factor
).to(context_audio_codes_lens.dtype)
context_mask = get_mask_from_lengths(context_input_lens)
# Run through context encoder
context_embedding = model.context_encoder(
context_audio_embedded, context_mask, cond=None, cond_mask=None
)['output']
# Store as buffers (squeeze batch dim since we store single embedding)
model.baked_context_embedding = context_embedding.squeeze(0) # (T, E)
model.baked_context_embedding_len = context_input_lens.squeeze(0) # scalar
logging.info(
f"Baked context embedding with shape {model.baked_context_embedding.shape}, "
f"length {model.baked_context_embedding_len.item()}"
)
def bake_context_embedding(
input_checkpoint: str,
config_path: str,
output_checkpoint: str,
context_audio: str,
device: str = 'cuda',
) -> None:
"""Bake context embedding into checkpoint.
Args:
input_checkpoint: Path to original MagpieTTS checkpoint (.ckpt).
config_path: Path to model config file (.yaml).
output_checkpoint: Path to save the baked checkpoint.
context_audio: Path to reference audio file for baking.
device: Device to run inference on ('cuda' or 'cpu').
Raises:
ValueError: If model type is not decoder_ce.
FileNotFoundError: If input files don't exist.
"""
# Validate inputs
if not os.path.exists(input_checkpoint):
raise FileNotFoundError(f"Input checkpoint not found: {input_checkpoint}")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
if not os.path.exists(context_audio):
raise FileNotFoundError(f"Context audio not found: {context_audio}")
logging.info(f"Loading model from {input_checkpoint}")
logging.info(f"Using config from {config_path}")
# Load config
cfg = OmegaConf.load(config_path)
if "cfg" in cfg:
cfg = cfg.cfg
print(cfg)
with open_dict(cfg):
cfg, cfg_sample_rate = update_config_for_inference(
cfg,
"/nemo_codec_checkpoints/21fps_causal_codecmodel.nemo",
False,
False,
)
# Load model
model = MagpieTTSModel(cfg)
ckpt = torch.load(input_checkpoint, weights_only=False, map_location=device)
state_dict = ckpt.get('state_dict', ckpt)
model.load_state_dict(state_dict, strict=False)
model = model.to(device)
model.eval()
# Validate model type
if model.model_type != 'decoder_ce':
raise ValueError(
f"Baking context embedding is only supported for decoder_ce model type, "
f"got {model.model_type}"
)
# Check that context_encoder exists
if not hasattr(model, 'context_encoder'):
raise RuntimeError(
"Model does not have context_encoder. It may already have a baked embedding."
)
# Load reference audio
logging.info(f"Loading reference audio from {context_audio}")
sample_rate = model.sample_rate
audio_tensor, audio_lens = load_audio(context_audio, target_sample_rate=sample_rate)
audio_tensor = audio_tensor.to(device)
audio_lens = audio_lens.to(device)
logging.info(f"Reference audio duration: {audio_lens[0].item() / sample_rate:.2f}s")
# Bake the embedding
logging.info("Computing context embedding...")
bake_model_context_embedding(model, audio_tensor, audio_lens)
# Verify baking worked
if not model.has_baked_context_embedding:
raise RuntimeError("Failed to bake context embedding")
logging.info(
f"Baked embedding shape: {model.baked_context_embedding.shape}, "
f"length: {model.baked_context_embedding_len.item()}"
)
# Save the model - state_dict will automatically exclude context_encoder
logging.info(f"Saving baked checkpoint to {output_checkpoint}")
# Get state dict (will exclude context_encoder due to has_baked_context_embedding)
state_dict = model.state_dict()
# Explicitly remove any remaining context_encoder keys
context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k]
for key in context_encoder_keys:
del state_dict[key]
# Count excluded keys for reporting
original_ckpt = torch.load(input_checkpoint, weights_only=False, map_location='cpu')
original_state_dict = original_ckpt.get('state_dict', original_ckpt)
excluded_keys = [k for k in original_state_dict.keys() if 'context_encoder' in k]
logging.info(f"Removed {len(excluded_keys)} context_encoder parameters")
# Calculate size reduction
if excluded_keys:
original_size = sum(original_state_dict[k].numel() for k in excluded_keys)
logging.info(f"Approximate size reduction: {original_size * 4 / 1024 / 1024:.1f} MB (float32)")
# Create modified config without context_encoder
logging.info("Creating modified config without context_encoder...")
modified_cfg = deepcopy(cfg)
with open_dict(modified_cfg):
if 'model' in modified_cfg and 'context_encoder' in modified_cfg.model:
del modified_cfg.model.context_encoder
logging.info("Removed 'context_encoder' from config")
# Add flag to indicate this checkpoint has baked embedding
if 'model' in modified_cfg:
modified_cfg.model.has_baked_context_embedding = True
# Save checkpoint
output_dir = os.path.dirname(output_checkpoint)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
torch.save({'state_dict': state_dict}, output_checkpoint)
logging.info(f"Saved baked checkpoint to {output_checkpoint}")
# Save modified config
output_config_path = output_checkpoint.replace('.ckpt', '_config.yaml')
if output_config_path == output_checkpoint:
output_config_path = output_checkpoint + '_config.yaml'
OmegaConf.save(modified_cfg, output_config_path)
logging.info(f"Saved modified config to {output_config_path}")
# Verify the saved checkpoint
logging.info("Verifying saved checkpoint...")
loaded_state = torch.load(output_checkpoint, weights_only=False, map_location='cpu')['state_dict']
assert 'baked_context_embedding' in loaded_state, "baked_context_embedding not in saved checkpoint"
assert 'baked_context_embedding_len' in loaded_state, "baked_context_embedding_len not in saved checkpoint"
assert not any(
'context_encoder' in k for k in loaded_state.keys()
), "context_encoder keys should not be in saved checkpoint"
# Verify the saved config
logging.info("Verifying saved config...")
loaded_cfg = OmegaConf.load(output_config_path)
assert 'context_encoder' not in loaded_cfg.get('model', {}), "context_encoder should not be in saved config"
logging.info("Verification successful!")
def main():
parser = argparse.ArgumentParser(
description="Bake context embedding into MagpieTTS checkpoint",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
'--input_checkpoint',
type=str,
required=True,
help='Path to original MagpieTTS checkpoint (.ckpt)',
)
parser.add_argument(
'--config_path',
type=str,
required=True,
help='Path to model config file (.yaml)',
)
parser.add_argument(
'--output_checkpoint',
type=str,
required=True,
help='Path to save the baked checkpoint',
)
parser.add_argument(
'--context_audio',
type=str,
required=True,
help='Path to reference audio file for baking',
)
parser.add_argument(
'--device',
type=str,
default='cuda',
choices=['cuda', 'cpu'],
help='Device to run inference on (default: cuda)',
)
args = parser.parse_args()
bake_context_embedding(
input_checkpoint=args.input_checkpoint,
config_path=args.config_path,
output_checkpoint=args.output_checkpoint,
context_audio=args.context_audio,
device=args.device,
)
if __name__ == '__main__':
main()