Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Inspect MagpieTTS Checkpoint | |
| A diagnostic script to check the contents of a MagpieTTS checkpoint: | |
| - Whether it has context_encoder weights | |
| - Whether it has baked context embeddings | |
| - Shape of baked embeddings if present | |
| Usage: | |
| python scripts/magpietts/inspect_checkpoint.py --checkpoint /path/to/checkpoint.ckpt | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import torch | |
| def inspect_checkpoint(checkpoint_path: str) -> None: | |
| """Inspect a MagpieTTS checkpoint for context_encoder and baked embeddings. | |
| Args: | |
| checkpoint_path: Path to the checkpoint file (.ckpt). | |
| """ | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") | |
| print(f"Loading checkpoint: {checkpoint_path}") | |
| ckpt = torch.load(checkpoint_path, weights_only=False, map_location='cpu') | |
| # Get state dict | |
| if 'state_dict' in ckpt: | |
| state_dict = ckpt['state_dict'] | |
| print("Found 'state_dict' key in checkpoint") | |
| else: | |
| state_dict = ckpt | |
| print("Checkpoint is a raw state_dict (no 'state_dict' wrapper)") | |
| print(f"\nTotal keys in state_dict: {len(state_dict)}") | |
| # Check for context_encoder weights | |
| context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k] | |
| print(f"\n{'=' * 60}") | |
| print("CONTEXT ENCODER WEIGHTS") | |
| print('=' * 60) | |
| if context_encoder_keys: | |
| print(f"β Found {len(context_encoder_keys)} context_encoder parameters") | |
| # Calculate size | |
| total_params = sum(state_dict[k].numel() for k in context_encoder_keys) | |
| size_mb = total_params * 4 / 1024 / 1024 # float32 | |
| print(f" Total parameters: {total_params:,}") | |
| print(f" Approximate size: {size_mb:.2f} MB (float32)") | |
| print("\n Sample keys:") | |
| for key in context_encoder_keys[:5]: | |
| print(f" - {key}: {state_dict[key].shape}") | |
| if len(context_encoder_keys) > 5: | |
| print(f" ... and {len(context_encoder_keys) - 5} more") | |
| else: | |
| print("β No context_encoder weights found") | |
| # Check for baked context embedding | |
| print(f"\n{'=' * 60}") | |
| print("BAKED CONTEXT EMBEDDING") | |
| print('=' * 60) | |
| has_baked_embedding = 'baked_context_embedding' in state_dict | |
| has_baked_embedding_len = 'baked_context_embedding_len' in state_dict | |
| if has_baked_embedding: | |
| embedding = state_dict['baked_context_embedding'] | |
| if embedding is not None and embedding.numel() > 0: | |
| print(f"β Found baked_context_embedding") | |
| print(f" Shape: {embedding.shape}") | |
| print(f" Dtype: {embedding.dtype}") | |
| print(f" Parameters: {embedding.numel():,}") | |
| size_mb = embedding.numel() * 4 / 1024 / 1024 | |
| print(f" Size: {size_mb:.4f} MB (float32)") | |
| else: | |
| print("β baked_context_embedding key exists but is None or empty") | |
| else: | |
| print("β No baked_context_embedding found") | |
| if has_baked_embedding_len: | |
| embedding_len = state_dict['baked_context_embedding_len'] | |
| if embedding_len is not None: | |
| print(f"β Found baked_context_embedding_len: {embedding_len.item()}") | |
| else: | |
| print("β baked_context_embedding_len key exists but is None") | |
| else: | |
| print("β No baked_context_embedding_len found") | |
| # Summary | |
| print(f"\n{'=' * 60}") | |
| print("SUMMARY") | |
| print('=' * 60) | |
| if context_encoder_keys and not has_baked_embedding: | |
| print("β This is a STANDARD checkpoint with context_encoder") | |
| print(" Can be used for any voice cloning with dynamic context audio") | |
| elif has_baked_embedding and embedding is not None and embedding.numel() > 0: | |
| if context_encoder_keys: | |
| print("β This checkpoint has BOTH context_encoder AND baked embedding") | |
| print(" This is unusual - consider removing context_encoder weights") | |
| else: | |
| print("β This is a BAKED checkpoint") | |
| print(" Will always use the baked voice, ignoring input context audio") | |
| else: | |
| print("β This checkpoint has NEITHER context_encoder NOR baked embedding") | |
| print(" This may indicate an issue or a different model type") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Inspect MagpieTTS checkpoint for context_encoder and baked embeddings", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=__doc__, | |
| ) | |
| parser.add_argument( | |
| '--checkpoint', | |
| type=str, | |
| required=True, | |
| help='Path to the checkpoint file (.ckpt)', | |
| ) | |
| args = parser.parse_args() | |
| inspect_checkpoint(args.checkpoint) | |
| if __name__ == '__main__': | |
| main() | |