muse-archive / scripts /analyze_layers.py
lamooon's picture
Upload scripts/analyze_layers.py with huggingface_hub
ff74bc4 verified
import torch
import os
from safetensors.torch import load_file
def count_layers(state_dict, exclude_prefixes=None):
"""
Counts unique layers in a state dict.
Groups parameters by their module prefix (everything before the last dot).
"""
if exclude_prefixes is None:
exclude_prefixes = []
total_layers = set()
custom_layers = set()
for key in state_dict.keys():
parts = key.split('.')
if len(parts) > 1:
module_name = '.'.join(parts[:-1])
else:
module_name = key
total_layers.add(module_name)
# Check if this module is pretrained
is_pretrained = any(key.startswith(p + '.') or key == p for p in exclude_prefixes)
if not is_pretrained:
custom_layers.add(module_name)
return len(total_layers), len(custom_layers)
def count_parameters(state_dict, exclude_prefixes=None):
if exclude_prefixes is None:
exclude_prefixes = []
total_params = 0
custom_params = 0
for name, param in state_dict.items():
p_count = param.numel()
total_params += p_count
is_pretrained = any(name.startswith(p + '.') or name == p for p in exclude_prefixes)
if not is_pretrained:
custom_params += p_count
return total_params, custom_params
# Define models and their pretrained prefixes
models_config = {
"Bass": {
"file": "bass_sota.pth",
"exclude": ["audio_encoder"]
},
"Drums": {
"file": "drums.safetensors",
"exclude": ["wavlm"]
},
"Vocals": {
"file": "vocals.pt",
"exclude": []
}
}
print(f"{'='*100}")
print(f"{'MODEL':<10} | {'TOTAL LAYERS':<15} | {'CUSTOM LAYERS':<15} | {'CUSTOM PARAMS':<15} | {'FILE'}")
print(f"{'='*100}")
for model_name, cfg in models_config.items():
filename = cfg["file"]
exclude = cfg["exclude"]
if not os.path.exists(filename):
print(f"{model_name:<10} | {'MISSING':<15} | {'N/A':<15} | {'N/A':<15} | {filename}")
continue
try:
if filename.endswith(".safetensors"):
data = load_file(filename, device='cpu')
else:
data = torch.load(filename, map_location='cpu', weights_only=False)
# Handle cases where model is wrapped in a dict
if isinstance(data, dict):
if "model" in data:
data = data["model"]
elif "model_state_dict" in data:
data = data["model_state_dict"]
elif "state_dict" in data:
data = data["state_dict"]
total_l, custom_l = count_layers(data, exclude)
total_p, custom_p = count_parameters(data, exclude)
print(f"{model_name:<10} | {total_l:<15} | {custom_l:<15} | {custom_p:<15,} | {filename}")
except Exception as e:
print(f"{model_name:<10} | {'ERROR':<15} | {'N/A':<15} | {'N/A':<15} | {filename} - {str(e)[:30]}...")
print(f"{'='*100}")