Spaces:
Sleeping
Sleeping
File size: 3,994 Bytes
656e7f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | #!/usr/bin/env python3
"""
Inspect ModelScope CAM++ model structure
This script downloads a model and displays its parameter structure
to help debug parameter mapping issues.
"""
import torch
from modelscope import snapshot_download
import tempfile
from pathlib import Path
from collections import defaultdict
def inspect_model(repo_id: str):
"""Inspect model structure"""
print(f"Downloading model: {repo_id}")
print("="*70)
with tempfile.TemporaryDirectory() as temp_dir:
# Download model
model_dir = snapshot_download(
model_id=repo_id,
local_dir=f"{temp_dir}/model"
)
print(f"\nModel downloaded to: {model_dir}")
# Find model file
model_files = list(Path(model_dir).rglob("*.bin"))
if not model_files:
model_files = list(Path(model_dir).rglob("*.pt"))
if not model_files:
print("ERROR: No model file found!")
return
model_path = model_files[0]
print(f"Loading model from: {model_path}")
# Load weights
weights = torch.load(model_path, map_location='cpu')
# Get state dict if needed
if not isinstance(weights, dict):
if hasattr(weights, 'state_dict'):
weights = weights.state_dict()
print(f"\n{'='*70}")
print(f"TOTAL PARAMETERS: {len(weights)}")
print(f"{'='*70}")
# Group by component
groups = defaultdict(list)
for name in weights.keys():
if 'xvector.' in name:
parts = name.split('.')
component = parts[1] if len(parts) > 1 else 'unknown'
elif 'head.' in name:
component = 'head'
else:
component = 'other'
groups[component].append(name)
# Display each group
for component in sorted(groups.keys()):
params = sorted(groups[component])
print(f"\n{component.upper()} ({len(params)} parameters)")
print("-"*70)
for param_name in params:
tensor = weights[param_name]
shape_str = str(tuple(tensor.shape)).ljust(20)
print(f" {shape_str} {param_name}")
# Show CAM layer details
print(f"\n{'='*70}")
print("CAM LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
cam_params = [n for n in weights.keys() if 'cam' in n.lower()]
if cam_params:
for name in sorted(cam_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No CAM parameters found!")
# Show dense/fc layer details
print(f"\n{'='*70}")
print("DENSE/FC LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
dense_params = [n for n in weights.keys() if 'dense' in n.lower() or 'fc' in n.lower()]
if dense_params:
for name in sorted(dense_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No dense/fc parameters found!")
# Show output/pooling layer details
print(f"\n{'='*70}")
print("OUTPUT/POOLING LAYER DETAILED STRUCTURE")
print(f"{'='*70}")
output_params = [n for n in weights.keys() if 'output' in n.lower() or 'pool' in n.lower() or 'attention' in n.lower()]
if output_params:
for name in sorted(output_params):
shape = tuple(weights[name].shape)
print(f" {str(shape).ljust(20)} {name}")
else:
print(" No output/pooling parameters found!")
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
repo = sys.argv[1]
else:
repo = "iic/speech_campplus_sv_zh-cn_16k-common"
print(f"\nInspecting CAM++ Model Structure")
print(f"{'='*70}\n")
inspect_model(repo)
|