MagpieTTS_Internal_Demo / scripts /magpietts /inspect_checkpoint.py
subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
#!/usr/bin/env python
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Inspect MagpieTTS Checkpoint
A diagnostic script to check the contents of a MagpieTTS checkpoint:
- Whether it has context_encoder weights
- Whether it has baked context embeddings
- Shape of baked embeddings if present
Usage:
python scripts/magpietts/inspect_checkpoint.py --checkpoint /path/to/checkpoint.ckpt
"""
from __future__ import annotations
import argparse
import os
import torch
def inspect_checkpoint(checkpoint_path: str) -> None:
"""Inspect a MagpieTTS checkpoint for context_encoder and baked embeddings.
Args:
checkpoint_path: Path to the checkpoint file (.ckpt).
"""
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, weights_only=False, map_location='cpu')
# Get state dict
if 'state_dict' in ckpt:
state_dict = ckpt['state_dict']
print("Found 'state_dict' key in checkpoint")
else:
state_dict = ckpt
print("Checkpoint is a raw state_dict (no 'state_dict' wrapper)")
print(f"\nTotal keys in state_dict: {len(state_dict)}")
# Check for context_encoder weights
context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k]
print(f"\n{'=' * 60}")
print("CONTEXT ENCODER WEIGHTS")
print('=' * 60)
if context_encoder_keys:
print(f"βœ“ Found {len(context_encoder_keys)} context_encoder parameters")
# Calculate size
total_params = sum(state_dict[k].numel() for k in context_encoder_keys)
size_mb = total_params * 4 / 1024 / 1024 # float32
print(f" Total parameters: {total_params:,}")
print(f" Approximate size: {size_mb:.2f} MB (float32)")
print("\n Sample keys:")
for key in context_encoder_keys[:5]:
print(f" - {key}: {state_dict[key].shape}")
if len(context_encoder_keys) > 5:
print(f" ... and {len(context_encoder_keys) - 5} more")
else:
print("βœ— No context_encoder weights found")
# Check for baked context embedding
print(f"\n{'=' * 60}")
print("BAKED CONTEXT EMBEDDING")
print('=' * 60)
has_baked_embedding = 'baked_context_embedding' in state_dict
has_baked_embedding_len = 'baked_context_embedding_len' in state_dict
if has_baked_embedding:
embedding = state_dict['baked_context_embedding']
if embedding is not None and embedding.numel() > 0:
print(f"βœ“ Found baked_context_embedding")
print(f" Shape: {embedding.shape}")
print(f" Dtype: {embedding.dtype}")
print(f" Parameters: {embedding.numel():,}")
size_mb = embedding.numel() * 4 / 1024 / 1024
print(f" Size: {size_mb:.4f} MB (float32)")
else:
print("βœ— baked_context_embedding key exists but is None or empty")
else:
print("βœ— No baked_context_embedding found")
if has_baked_embedding_len:
embedding_len = state_dict['baked_context_embedding_len']
if embedding_len is not None:
print(f"βœ“ Found baked_context_embedding_len: {embedding_len.item()}")
else:
print("βœ— baked_context_embedding_len key exists but is None")
else:
print("βœ— No baked_context_embedding_len found")
# Summary
print(f"\n{'=' * 60}")
print("SUMMARY")
print('=' * 60)
if context_encoder_keys and not has_baked_embedding:
print("β†’ This is a STANDARD checkpoint with context_encoder")
print(" Can be used for any voice cloning with dynamic context audio")
elif has_baked_embedding and embedding is not None and embedding.numel() > 0:
if context_encoder_keys:
print("β†’ This checkpoint has BOTH context_encoder AND baked embedding")
print(" This is unusual - consider removing context_encoder weights")
else:
print("β†’ This is a BAKED checkpoint")
print(" Will always use the baked voice, ignoring input context audio")
else:
print("β†’ This checkpoint has NEITHER context_encoder NOR baked embedding")
print(" This may indicate an issue or a different model type")
def main():
parser = argparse.ArgumentParser(
description="Inspect MagpieTTS checkpoint for context_encoder and baked embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=__doc__,
)
parser.add_argument(
'--checkpoint',
type=str,
required=True,
help='Path to the checkpoint file (.ckpt)',
)
args = parser.parse_args()
inspect_checkpoint(args.checkpoint)
if __name__ == '__main__':
main()