File size: 8,107 Bytes
75d43d2
fc7b4a9
75d43d2
fc7b4a9
 
75d43d2
fc7b4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75d43d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import torch.nn as nn
from pathlib import Path
from .transformer import Transformer
from .tokenizer import STTokenizer
from src.spectttra.feature import FeatureExtractor


class SpecTTTra(nn.Module):
    """
    SpecTTTra: A Spectro-Temporal Transformer model for audio representation learning.

    This model first tokenizes the input spectrogram into temporal and spectral tokens,
    then processes them with a Transformer encoder to capture spectro-temporal dependencies.
    """

    def __init__(
        self,
        input_spec_dim,
        input_temp_dim,
        embed_dim,
        t_clip,
        f_clip,
        num_heads,
        num_layers,
        pre_norm=False,
        pe_learnable=False,
        pos_drop_rate=0.0,
        attn_drop_rate=0.0,
        proj_drop_rate=0.0,
        mlp_ratio=4.0,
    ):
        """
        Initialize the SpecTTTra model.

        Args:
            input_spec_dim (int): Input spectrogram frequency dimension (F).
            input_temp_dim (int): Input spectrogram temporal dimension (T).
            embed_dim (int): Embedding dimension for tokens.
            t_clip (int): Temporal clip size for tokenization.
            f_clip (int): Spectral clip size for tokenization.
            num_heads (int): Number of attention heads in the transformer.
            num_layers (int): Number of transformer layers.
            pre_norm (bool, optional): Whether to apply pre-normalization. Defaults to False.
            pe_learnable (bool, optional): If True, use learnable positional embeddings. Defaults to False.
            pos_drop_rate (float, optional): Dropout rate for positional embeddings. Defaults to 0.0.
            attn_drop_rate (float, optional): Dropout rate for attention. Defaults to 0.0.
            proj_drop_rate (float, optional): Dropout rate for projection layers. Defaults to 0.0.
            mlp_ratio (float, optional): Expansion ratio for MLP hidden dimension. Defaults to 4.0.
        """
        super(SpecTTTra, self).__init__()
        self.input_spec_dim = input_spec_dim
        self.input_temp_dim = input_temp_dim
        self.embed_dim = embed_dim
        self.t_clip = t_clip
        self.f_clip = f_clip
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.pre_norm = (
            pre_norm    # Applied after tokenization before transformer (used in CLIP)
        )
        self.pe_learnable = pe_learnable    # Learned positional encoding
        self.pos_drop_rate = pos_drop_rate
        self.attn_drop_rate = attn_drop_rate
        self.proj_drop_rate = proj_drop_rate
        self.mlp_ratio = mlp_ratio

        # Tokenizer for spectro-temporal features
        self.st_tokenizer = STTokenizer(
            input_spec_dim,
            input_temp_dim,
            t_clip,
            f_clip,
            embed_dim,
            pre_norm=pre_norm,
            pe_learnable=pe_learnable,
        )

        # Dropout applied after tokenization
        self.pos_drop = nn.Dropout(p=pos_drop_rate)

        # Transformer encoder
        self.transformer = Transformer(
            embed_dim,
            num_heads,
            num_layers,
            attn_drop=self.attn_drop_rate,
            proj_drop=self.proj_drop_rate,
            mlp_ratio=self.mlp_ratio,
        )

    def forward(self, x):
        """
        Forward pass of SpecTTTra.

        Args:
            x (torch.Tensor): Input spectrogram of shape
                - (B, 1, F, T) if channel dimension exists
                - (B, F, T) otherwise

        Returns:
            torch.Tensor: Transformer-encoded spectro-temporal tokens of shape
                (B, T/t + F/f, embed_dim)
        """
        # Squeeze the channel dimension if it exists
        if x.dim() == 4:
            x = x.squeeze(1)

        # Spectro-temporal tokenization
        spectro_temporal_tokens = self.st_tokenizer(x)

        # Positional dropout
        spectro_temporal_tokens = self.pos_drop(spectro_temporal_tokens)

        # Transformer
        output = self.transformer(spectro_temporal_tokens)  # shape: (B, T/t + F/f, dim)

        return output


def build_spectttra_from_cfg(cfg, device):
    """
    Constructs the SpecTTTra model and its associated FeatureExtractor from a given configuration.

    Args:
        cfg (SimpleNamespace): Configuration object containing model and feature extraction parameters. Expected attributes include:
                - cfg.melspec.n_mels: Number of mel frequency bins.
                - cfg.model: Model-specific parameters (e.g., embed_dim, t_clip, f_clip, etc.).
        device (torch.device): The device on which the model and feature extractor will be allocated (e.g., 'cpu' or 'cuda').

    Returns:
        tuple:
            FeatureExtractor: Initialized feature extraction module moved to the specified device.
            SpecTTTra: Constructed SpecTTTra model moved to the specified device.
    """

    feat_ext = FeatureExtractor(cfg).to(device)

    # The pre-trained model expects specific, fixed input dimensions.
    # Hardcoded to ensure the model architecture matches the checkpoint weights exactly.
    # The expected number of frames (n_frames) is taken directly from the RuntimeError message.
    n_mels = cfg.melspec.n_mels     # n_mels should be 128
    n_frames = 3744                 # n_frames match the checkpoint's expectation

    print(f"[INFO] Initializing SpecTTTra with fixed dimensions: n_mels={n_mels}, n_frames={n_frames}")

    model_cfg = cfg.model
    model = SpecTTTra(
        input_spec_dim=n_mels,
        input_temp_dim=n_frames,
        embed_dim=model_cfg.embed_dim,
        t_clip=model_cfg.t_clip,
        f_clip=model_cfg.f_clip,
        num_heads=model_cfg.num_heads,
        num_layers=model_cfg.num_layers,
        pre_norm=model_cfg.pre_norm,
        pe_learnable=model_cfg.pe_learnable,
        pos_drop_rate=model_cfg.pos_drop_rate,
        attn_drop_rate=model_cfg.attn_drop_rate,
        proj_drop_rate=model_cfg.proj_drop_rate,
        mlp_ratio=model_cfg.mlp_ratio,
    ).to(device)

    return feat_ext, model


def load_frozen_spectttra(model, ckpt_path, device):
    """
    Loads pretrained SpecTTTra weights from a frozen checkpoint file.

    Args:
        model (torch.nn.Module): An initialized SpecTTTra model instance to load weights into.
        ckpt_path (str or Path): Path to the pretrained model checkpoint file (e.g., 'spectttra_frozen.pth').
        device (torch.device): The device to map the loaded weights to (e.g., 'cpu' or 'cuda').

    Returns:
        model (torch.nn.Module): The SpecTTTra model with loaded pretrained weights, set to evaluation mode.

    Raises:
        FileNotFoundError: If the specified checkpoint file does not exist at `ckpt_path`.
    """
    ckpt_path = Path(ckpt_path)
    if not ckpt_path.exists():
        raise FileNotFoundError(
            f"Pre-trained model not found at {ckpt_path}. "
            "Please download 'pytorch_model.bin', rename to 'spectttra_frozen.pth', "
            "and place it in the correct directory."
        )

    print(f"[INFO] Found SpecTTTra checkpoint at {ckpt_path}. Loading weights...")
    state = torch.load(ckpt_path, map_location=device)

    new_state_dict = {}
    for k, v in state.items():
        if k.startswith("encoder."):
            new_key = k[len("encoder."):]
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v

    # Now that the shapes match, this should load without a size mismatch error.
    missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
    if missing_keys:
        # Might see a few missing keys if your SpecTTTra class is slightly different, but the core should load.
        print(f"[WARNING] Missing keys in model: {missing_keys}")
    if unexpected_keys:
        # Seeing 'classifier' or 'ft_extractor' keys here is NORMAL and SAFE.
        print(f"[INFO] Unused keys in checkpoint: {unexpected_keys}")

    print("[INFO] Successfully loaded pre-trained SpecTTTra weights.")
    
    model.eval()
    return model