vineetshukla.work@gmail.com
final commit
c5c9261
"""
Backbone Loader — Provides a unified interface for loading SSL audio models.
Supports: Wav2Vec2, HuBERT, WavLM with configurable layer freezing.
"""
import torch
import torch.nn as nn
from transformers import (
AutoModel,
AutoFeatureExtractor,
Wav2Vec2Model,
HubertModel,
WavLMModel,
)
import logging
logger = logging.getLogger(__name__)
# Registry of supported models and their HF classes
MODEL_REGISTRY = {
"wav2vec2": {
"default": "facebook/wav2vec2-large-xlsr-53",
"class": Wav2Vec2Model,
},
"hubert": {
"default": "facebook/hubert-large-ls960",
"class": HubertModel,
},
"wavlm": {
"default": "microsoft/wavlm-large",
"class": WavLMModel,
},
}
class BackboneLoader:
"""Loads and configures SSL audio backbones with layer freezing."""
@staticmethod
def load(model_type: str, model_name: str = None, freeze_layers: int = 0,
device: str = "cpu") -> tuple:
"""
Load a backbone model and its feature extractor.
Args:
model_type: One of "wav2vec2", "hubert", "wavlm"
model_name: HuggingFace model ID (uses default if None)
freeze_layers: Number of initial transformer layers to freeze
device: Target device
Returns:
(model, feature_extractor, hidden_size)
"""
if model_type not in MODEL_REGISTRY:
raise ValueError(f"Unknown model type: {model_type}. "
f"Choose from {list(MODEL_REGISTRY.keys())}")
reg = MODEL_REGISTRY[model_type]
name = model_name or reg["default"]
logger.info(f"Loading backbone: {name} (type={model_type})")
# Load model and feature extractor
model = AutoModel.from_pretrained(name)
feature_extractor = AutoFeatureExtractor.from_pretrained(name)
# Get hidden size from config
hidden_size = model.config.hidden_size
logger.info(f"Hidden size: {hidden_size}")
# Freeze layers for efficient fine-tuning
if freeze_layers > 0:
BackboneLoader._freeze_layers(model, freeze_layers, model_type)
model = model.to(device)
return model, feature_extractor, hidden_size
@staticmethod
def _freeze_layers(model, num_layers: int, model_type: str):
"""Freeze the feature extractor and first N transformer layers."""
# Always freeze the CNN feature extractor (low-level features)
if hasattr(model, "feature_extractor"):
for param in model.feature_extractor.parameters():
param.requires_grad = False
logger.info(" Froze CNN feature extractor")
if hasattr(model, "feature_projection"):
for param in model.feature_projection.parameters():
param.requires_grad = False
logger.info(" Froze feature projection")
# Freeze the first N encoder layers
if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
total_layers = len(model.encoder.layers)
freeze_count = min(num_layers, total_layers)
for i in range(freeze_count):
for param in model.encoder.layers[i].parameters():
param.requires_grad = False
logger.info(f" Froze {freeze_count}/{total_layers} transformer layers")
# Count trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
logger.info(f" Trainable params: {trainable:,} / {total:,} "
f"({trainable/total*100:.1f}%)")