File size: 3,699 Bytes
c5c9261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Backbone Loader — Provides a unified interface for loading SSL audio models.
Supports: Wav2Vec2, HuBERT, WavLM with configurable layer freezing.
"""
import torch
import torch.nn as nn
from transformers import (
    AutoModel,
    AutoFeatureExtractor,
    Wav2Vec2Model,
    HubertModel,
    WavLMModel,
)
import logging

logger = logging.getLogger(__name__)

# Registry of supported models and their HF classes
MODEL_REGISTRY = {
    "wav2vec2": {
        "default": "facebook/wav2vec2-large-xlsr-53",
        "class": Wav2Vec2Model,
    },
    "hubert": {
        "default": "facebook/hubert-large-ls960",
        "class": HubertModel,
    },
    "wavlm": {
        "default": "microsoft/wavlm-large",
        "class": WavLMModel,
    },
}


class BackboneLoader:
    """Loads and configures SSL audio backbones with layer freezing."""

    @staticmethod
    def load(model_type: str, model_name: str = None, freeze_layers: int = 0,
             device: str = "cpu") -> tuple:
        """
        Load a backbone model and its feature extractor.

        Args:
            model_type: One of "wav2vec2", "hubert", "wavlm"
            model_name: HuggingFace model ID (uses default if None)
            freeze_layers: Number of initial transformer layers to freeze
            device: Target device

        Returns:
            (model, feature_extractor, hidden_size)
        """
        if model_type not in MODEL_REGISTRY:
            raise ValueError(f"Unknown model type: {model_type}. "
                             f"Choose from {list(MODEL_REGISTRY.keys())}")

        reg = MODEL_REGISTRY[model_type]
        name = model_name or reg["default"]

        logger.info(f"Loading backbone: {name} (type={model_type})")

        # Load model and feature extractor
        model = AutoModel.from_pretrained(name)
        feature_extractor = AutoFeatureExtractor.from_pretrained(name)

        # Get hidden size from config
        hidden_size = model.config.hidden_size
        logger.info(f"Hidden size: {hidden_size}")

        # Freeze layers for efficient fine-tuning
        if freeze_layers > 0:
            BackboneLoader._freeze_layers(model, freeze_layers, model_type)

        model = model.to(device)
        return model, feature_extractor, hidden_size

    @staticmethod
    def _freeze_layers(model, num_layers: int, model_type: str):
        """Freeze the feature extractor and first N transformer layers."""
        # Always freeze the CNN feature extractor (low-level features)
        if hasattr(model, "feature_extractor"):
            for param in model.feature_extractor.parameters():
                param.requires_grad = False
            logger.info("  Froze CNN feature extractor")

        if hasattr(model, "feature_projection"):
            for param in model.feature_projection.parameters():
                param.requires_grad = False
            logger.info("  Froze feature projection")

        # Freeze the first N encoder layers
        if hasattr(model, "encoder") and hasattr(model.encoder, "layers"):
            total_layers = len(model.encoder.layers)
            freeze_count = min(num_layers, total_layers)
            for i in range(freeze_count):
                for param in model.encoder.layers[i].parameters():
                    param.requires_grad = False
            logger.info(f"  Froze {freeze_count}/{total_layers} transformer layers")

        # Count trainable params
        trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total = sum(p.numel() for p in model.parameters())
        logger.info(f"  Trainable params: {trainable:,} / {total:,} "
                    f"({trainable/total*100:.1f}%)")