File size: 3,096 Bytes
ff74bc4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | import torch
import os
from safetensors.torch import load_file
def count_layers(state_dict, exclude_prefixes=None):
"""
Counts unique layers in a state dict.
Groups parameters by their module prefix (everything before the last dot).
"""
if exclude_prefixes is None:
exclude_prefixes = []
total_layers = set()
custom_layers = set()
for key in state_dict.keys():
parts = key.split('.')
if len(parts) > 1:
module_name = '.'.join(parts[:-1])
else:
module_name = key
total_layers.add(module_name)
# Check if this module is pretrained
is_pretrained = any(key.startswith(p + '.') or key == p for p in exclude_prefixes)
if not is_pretrained:
custom_layers.add(module_name)
return len(total_layers), len(custom_layers)
def count_parameters(state_dict, exclude_prefixes=None):
if exclude_prefixes is None:
exclude_prefixes = []
total_params = 0
custom_params = 0
for name, param in state_dict.items():
p_count = param.numel()
total_params += p_count
is_pretrained = any(name.startswith(p + '.') or name == p for p in exclude_prefixes)
if not is_pretrained:
custom_params += p_count
return total_params, custom_params
# Define models and their pretrained prefixes
models_config = {
"Bass": {
"file": "bass_sota.pth",
"exclude": ["audio_encoder"]
},
"Drums": {
"file": "drums.safetensors",
"exclude": ["wavlm"]
},
"Vocals": {
"file": "vocals.pt",
"exclude": []
}
}
print(f"{'='*100}")
print(f"{'MODEL':<10} | {'TOTAL LAYERS':<15} | {'CUSTOM LAYERS':<15} | {'CUSTOM PARAMS':<15} | {'FILE'}")
print(f"{'='*100}")
for model_name, cfg in models_config.items():
filename = cfg["file"]
exclude = cfg["exclude"]
if not os.path.exists(filename):
print(f"{model_name:<10} | {'MISSING':<15} | {'N/A':<15} | {'N/A':<15} | {filename}")
continue
try:
if filename.endswith(".safetensors"):
data = load_file(filename, device='cpu')
else:
data = torch.load(filename, map_location='cpu', weights_only=False)
# Handle cases where model is wrapped in a dict
if isinstance(data, dict):
if "model" in data:
data = data["model"]
elif "model_state_dict" in data:
data = data["model_state_dict"]
elif "state_dict" in data:
data = data["state_dict"]
total_l, custom_l = count_layers(data, exclude)
total_p, custom_p = count_parameters(data, exclude)
print(f"{model_name:<10} | {total_l:<15} | {custom_l:<15} | {custom_p:<15,} | {filename}")
except Exception as e:
print(f"{model_name:<10} | {'ERROR':<15} | {'N/A':<15} | {'N/A':<15} | {filename} - {str(e)[:30]}...")
print(f"{'='*100}")
|