#!/usr/bin/env python3 """ Inspect ModelScope CAM++ model structure This script downloads a model and displays its parameter structure to help debug parameter mapping issues. """ import torch from modelscope import snapshot_download import tempfile from pathlib import Path from collections import defaultdict def inspect_model(repo_id: str): """Inspect model structure""" print(f"Downloading model: {repo_id}") print("="*70) with tempfile.TemporaryDirectory() as temp_dir: # Download model model_dir = snapshot_download( model_id=repo_id, local_dir=f"{temp_dir}/model" ) print(f"\nModel downloaded to: {model_dir}") # Find model file model_files = list(Path(model_dir).rglob("*.bin")) if not model_files: model_files = list(Path(model_dir).rglob("*.pt")) if not model_files: print("ERROR: No model file found!") return model_path = model_files[0] print(f"Loading model from: {model_path}") # Load weights weights = torch.load(model_path, map_location='cpu') # Get state dict if needed if not isinstance(weights, dict): if hasattr(weights, 'state_dict'): weights = weights.state_dict() print(f"\n{'='*70}") print(f"TOTAL PARAMETERS: {len(weights)}") print(f"{'='*70}") # Group by component groups = defaultdict(list) for name in weights.keys(): if 'xvector.' in name: parts = name.split('.') component = parts[1] if len(parts) > 1 else 'unknown' elif 'head.' in name: component = 'head' else: component = 'other' groups[component].append(name) # Display each group for component in sorted(groups.keys()): params = sorted(groups[component]) print(f"\n{component.upper()} ({len(params)} parameters)") print("-"*70) for param_name in params: tensor = weights[param_name] shape_str = str(tuple(tensor.shape)).ljust(20) print(f" {shape_str} {param_name}") # Show CAM layer details print(f"\n{'='*70}") print("CAM LAYER DETAILED STRUCTURE") print(f"{'='*70}") cam_params = [n for n in weights.keys() if 'cam' in n.lower()] if cam_params: for name in sorted(cam_params): shape = tuple(weights[name].shape) print(f" {str(shape).ljust(20)} {name}") else: print(" No CAM parameters found!") # Show dense/fc layer details print(f"\n{'='*70}") print("DENSE/FC LAYER DETAILED STRUCTURE") print(f"{'='*70}") dense_params = [n for n in weights.keys() if 'dense' in n.lower() or 'fc' in n.lower()] if dense_params: for name in sorted(dense_params): shape = tuple(weights[name].shape) print(f" {str(shape).ljust(20)} {name}") else: print(" No dense/fc parameters found!") # Show output/pooling layer details print(f"\n{'='*70}") print("OUTPUT/POOLING LAYER DETAILED STRUCTURE") print(f"{'='*70}") output_params = [n for n in weights.keys() if 'output' in n.lower() or 'pool' in n.lower() or 'attention' in n.lower()] if output_params: for name in sorted(output_params): shape = tuple(weights[name].shape) print(f" {str(shape).ljust(20)} {name}") else: print(" No output/pooling parameters found!") if __name__ == "__main__": import sys if len(sys.argv) > 1: repo = sys.argv[1] else: repo = "iic/speech_campplus_sv_zh-cn_16k-common" print(f"\nInspecting CAM++ Model Structure") print(f"{'='*70}\n") inspect_model(repo)