campp-mlx-converter / inspect_model.py
BMP's picture
feat: Add batch conversion scripts for CAM++ models
656e7f6
#!/usr/bin/env python3
"""
Inspect ModelScope CAM++ model structure
This script downloads a model and displays its parameter structure
to help debug parameter mapping issues.
"""
import torch
from modelscope import snapshot_download
import tempfile
from pathlib import Path
from collections import defaultdict
def inspect_model(repo_id: str):
"""Inspect model structure"""
print(f"Downloading model: {repo_id}")
print("="*70)
with tempfile.TemporaryDirectory() as temp_dir:
# Download model
model_dir = snapshot_download(
model_id=repo_id,
local_dir=f"{temp_dir}/model"
)
print(f"\nModel downloaded to: {model_dir}")
# Find model file
model_files = list(Path(model_dir).rglob("*.bin"))
if not model_files:
model_files = list(Path(model_dir).rglob("*.pt"))
if not model_files:
print("ERROR: No model file found!")
return
model_path = model_files[0]
print(f"Loading model from: {model_path}")
# Load weights
weights = torch.load(model_path, map_location='cpu')
# Get state dict if needed
if not isinstance(weights, dict):
if hasattr(weights, 'state_dict'):
weights = weights.state_dict()
print(f"\n{'='*70}")
print(f"TOTAL PARAMETERS: {len(weights)}")
print(f"{'='*70}")
# Group by component
groups = defaultdict(list)
for name in weights.keys():
if 'xvector.' in name:
parts = name.split('.')
component = parts[1] if len(parts) > 1 else 'unknown'
elif 'head.' in name:
component = 'head'
else:
component = 'other'
groups[component].append(name)
# Display each group
for component in sorted(groups.keys()):
params = sorted(groups[component])
print(f"\n{component.upper()} ({len(params)} parameters)")
print("-"*70)
for param_name in params:
tensor = weights[param_name]
shape_str = str(tuple(tensor.shape)).ljust(20)
print(f" {shape_str} {param_name}")
# Show CAM layer details
print(f"\n{'='*70}")
print("CAM LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
cam_params = [n for n in weights.keys() if 'cam' in n.lower()]
if cam_params:
for name in sorted(cam_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No CAM parameters found!")
# Show dense/fc layer details
print(f"\n{'='*70}")
print("DENSE/FC LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
dense_params = [n for n in weights.keys() if 'dense' in n.lower() or 'fc' in n.lower()]
if dense_params:
for name in sorted(dense_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No dense/fc parameters found!")
# Show output/pooling layer details
print(f"\n{'='*70}")
print("OUTPUT/POOLING LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
output_params = [n for n in weights.keys() if 'output' in n.lower() or 'pool' in n.lower() or 'attention' in n.lower()]
if output_params:
for name in sorted(output_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No output/pooling parameters found!")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
repo = sys.argv[1]
else:
repo = "iic/speech_campplus_sv_zh-cn_16k-common"
print(f"\nInspecting CAM++ Model Structure")
print(f"{'='*70}\n")
inspect_model(repo)