import torch import os from safetensors.torch import load_file def count_layers(state_dict, exclude_prefixes=None): """ Counts unique layers in a state dict. Groups parameters by their module prefix (everything before the last dot). """ if exclude_prefixes is None: exclude_prefixes = [] total_layers = set() custom_layers = set() for key in state_dict.keys(): parts = key.split('.') if len(parts) > 1: module_name = '.'.join(parts[:-1]) else: module_name = key total_layers.add(module_name) # Check if this module is pretrained is_pretrained = any(key.startswith(p + '.') or key == p for p in exclude_prefixes) if not is_pretrained: custom_layers.add(module_name) return len(total_layers), len(custom_layers) def count_parameters(state_dict, exclude_prefixes=None): if exclude_prefixes is None: exclude_prefixes = [] total_params = 0 custom_params = 0 for name, param in state_dict.items(): p_count = param.numel() total_params += p_count is_pretrained = any(name.startswith(p + '.') or name == p for p in exclude_prefixes) if not is_pretrained: custom_params += p_count return total_params, custom_params # Define models and their pretrained prefixes models_config = { "Bass": { "file": "bass_sota.pth", "exclude": ["audio_encoder"] }, "Drums": { "file": "drums.safetensors", "exclude": ["wavlm"] }, "Vocals": { "file": "vocals.pt", "exclude": [] } } print(f"{'='*100}") print(f"{'MODEL':<10} | {'TOTAL LAYERS':<15} | {'CUSTOM LAYERS':<15} | {'CUSTOM PARAMS':<15} | {'FILE'}") print(f"{'='*100}") for model_name, cfg in models_config.items(): filename = cfg["file"] exclude = cfg["exclude"] if not os.path.exists(filename): print(f"{model_name:<10} | {'MISSING':<15} | {'N/A':<15} | {'N/A':<15} | {filename}") continue try: if filename.endswith(".safetensors"): data = load_file(filename, device='cpu') else: data = torch.load(filename, map_location='cpu', weights_only=False) # Handle cases where model is wrapped in a dict if isinstance(data, dict): if "model" in data: data = data["model"] elif "model_state_dict" in data: data = data["model_state_dict"] elif "state_dict" in data: data = data["state_dict"] total_l, custom_l = count_layers(data, exclude) total_p, custom_p = count_parameters(data, exclude) print(f"{model_name:<10} | {total_l:<15} | {custom_l:<15} | {custom_p:<15,} | {filename}") except Exception as e: print(f"{model_name:<10} | {'ERROR':<15} | {'N/A':<15} | {'N/A':<15} | {filename} - {str(e)[:30]}...") print(f"{'='*100}")