Spaces:
Sleeping
Sleeping
| """ | |
| Backbone Loader — Provides a unified interface for loading SSL audio models. | |
| Supports: Wav2Vec2, HuBERT, WavLM with configurable layer freezing. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| AutoModel, | |
| AutoFeatureExtractor, | |
| Wav2Vec2Model, | |
| HubertModel, | |
| WavLMModel, | |
| ) | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Registry of supported models and their HF classes | |
| MODEL_REGISTRY = { | |
| "wav2vec2": { | |
| "default": "facebook/wav2vec2-large-xlsr-53", | |
| "class": Wav2Vec2Model, | |
| }, | |
| "hubert": { | |
| "default": "facebook/hubert-large-ls960", | |
| "class": HubertModel, | |
| }, | |
| "wavlm": { | |
| "default": "microsoft/wavlm-large", | |
| "class": WavLMModel, | |
| }, | |
| } | |
| class BackboneLoader: | |
| """Loads and configures SSL audio backbones with layer freezing.""" | |
| def load(model_type: str, model_name: str = None, freeze_layers: int = 0, | |
| device: str = "cpu") -> tuple: | |
| """ | |
| Load a backbone model and its feature extractor. | |
| Args: | |
| model_type: One of "wav2vec2", "hubert", "wavlm" | |
| model_name: HuggingFace model ID (uses default if None) | |
| freeze_layers: Number of initial transformer layers to freeze | |
| device: Target device | |
| Returns: | |
| (model, feature_extractor, hidden_size) | |
| """ | |
| if model_type not in MODEL_REGISTRY: | |
| raise ValueError(f"Unknown model type: {model_type}. " | |
| f"Choose from {list(MODEL_REGISTRY.keys())}") | |
| reg = MODEL_REGISTRY[model_type] | |
| name = model_name or reg["default"] | |
| logger.info(f"Loading backbone: {name} (type={model_type})") | |
| # Load model and feature extractor | |
| model = AutoModel.from_pretrained(name) | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(name) | |
| # Get hidden size from config | |
| hidden_size = model.config.hidden_size | |
| logger.info(f"Hidden size: {hidden_size}") | |
| # Freeze layers for efficient fine-tuning | |
| if freeze_layers > 0: | |
| BackboneLoader._freeze_layers(model, freeze_layers, model_type) | |
| model = model.to(device) | |
| return model, feature_extractor, hidden_size | |
| def _freeze_layers(model, num_layers: int, model_type: str): | |
| """Freeze the feature extractor and first N transformer layers.""" | |
| # Always freeze the CNN feature extractor (low-level features) | |
| if hasattr(model, "feature_extractor"): | |
| for param in model.feature_extractor.parameters(): | |
| param.requires_grad = False | |
| logger.info(" Froze CNN feature extractor") | |
| if hasattr(model, "feature_projection"): | |
| for param in model.feature_projection.parameters(): | |
| param.requires_grad = False | |
| logger.info(" Froze feature projection") | |
| # Freeze the first N encoder layers | |
| if hasattr(model, "encoder") and hasattr(model.encoder, "layers"): | |
| total_layers = len(model.encoder.layers) | |
| freeze_count = min(num_layers, total_layers) | |
| for i in range(freeze_count): | |
| for param in model.encoder.layers[i].parameters(): | |
| param.requires_grad = False | |
| logger.info(f" Froze {freeze_count}/{total_layers} transformer layers") | |
| # Count trainable params | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in model.parameters()) | |
| logger.info(f" Trainable params: {trainable:,} / {total:,} " | |
| f"({trainable/total*100:.1f}%)") | |