File size: 3,994 Bytes
656e7f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
"""
Inspect ModelScope CAM++ model structure

This script downloads a model and displays its parameter structure
to help debug parameter mapping issues.
"""

import torch
from modelscope import snapshot_download
import tempfile
from pathlib import Path
from collections import defaultdict

def inspect_model(repo_id: str):
    """Inspect model structure"""

    print(f"Downloading model: {repo_id}")
    print("="*70)

    with tempfile.TemporaryDirectory() as temp_dir:
        # Download model
        model_dir = snapshot_download(
            model_id=repo_id,
            local_dir=f"{temp_dir}/model"
        )

        print(f"\nModel downloaded to: {model_dir}")

        # Find model file
        model_files = list(Path(model_dir).rglob("*.bin"))
        if not model_files:
            model_files = list(Path(model_dir).rglob("*.pt"))

        if not model_files:
            print("ERROR: No model file found!")
            return

        model_path = model_files[0]
        print(f"Loading model from: {model_path}")

        # Load weights
        weights = torch.load(model_path, map_location='cpu')

        # Get state dict if needed
        if not isinstance(weights, dict):
            if hasattr(weights, 'state_dict'):
                weights = weights.state_dict()

        print(f"\n{'='*70}")
        print(f"TOTAL PARAMETERS: {len(weights)}")
        print(f"{'='*70}")

        # Group by component
        groups = defaultdict(list)
        for name in weights.keys():
            if 'xvector.' in name:
                parts = name.split('.')
                component = parts[1] if len(parts) > 1 else 'unknown'
            elif 'head.' in name:
                component = 'head'
            else:
                component = 'other'
            groups[component].append(name)

        # Display each group
        for component in sorted(groups.keys()):
            params = sorted(groups[component])
            print(f"\n{component.upper()} ({len(params)} parameters)")
            print("-"*70)

            for param_name in params:
                tensor = weights[param_name]
                shape_str = str(tuple(tensor.shape)).ljust(20)
                print(f"  {shape_str} {param_name}")

        # Show CAM layer details
        print(f"\n{'='*70}")
        print("CAM LAYER DETAILED STRUCTURE")
        print(f"{'='*70}")
        cam_params = [n for n in weights.keys() if 'cam' in n.lower()]
        if cam_params:
            for name in sorted(cam_params):
                shape = tuple(weights[name].shape)
                print(f"  {str(shape).ljust(20)} {name}")
        else:
            print("  No CAM parameters found!")

        # Show dense/fc layer details
        print(f"\n{'='*70}")
        print("DENSE/FC LAYER DETAILED STRUCTURE")
        print(f"{'='*70}")
        dense_params = [n for n in weights.keys() if 'dense' in n.lower() or 'fc' in n.lower()]
        if dense_params:
            for name in sorted(dense_params):
                shape = tuple(weights[name].shape)
                print(f"  {str(shape).ljust(20)} {name}")
        else:
            print("  No dense/fc parameters found!")

        # Show output/pooling layer details
        print(f"\n{'='*70}")
        print("OUTPUT/POOLING LAYER DETAILED STRUCTURE")
        print(f"{'='*70}")
        output_params = [n for n in weights.keys() if 'output' in n.lower() or 'pool' in n.lower() or 'attention' in n.lower()]
        if output_params:
            for name in sorted(output_params):
                shape = tuple(weights[name].shape)
                print(f"  {str(shape).ljust(20)} {name}")
        else:
            print("  No output/pooling parameters found!")

if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        repo = sys.argv[1]
    else:
        repo = "iic/speech_campplus_sv_zh-cn_16k-common"

    print(f"\nInspecting CAM++ Model Structure")
    print(f"{'='*70}\n")

    inspect_model(repo)