File size: 13,033 Bytes
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
0a95bc3
 
 
 
 
 
 
 
 
 
 
 
 
4242909
 
 
 
 
 
 
0a95bc3
 
 
 
 
 
 
 
4242909
 
 
 
 
 
 
0a95bc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
0a95bc3
 
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
0a95bc3
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
0a95bc3
 
4242909
 
 
 
 
 
 
 
 
 
 
0a95bc3
 
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
4242909
0a95bc3
 
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
0a95bc3
 
4242909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import glob
import importlib
import os

import torch
import torch.nn as nn
from dotenv import load_dotenv
from transformers import ASTModel, ASTConfig

load_dotenv()

AST_TIME_DIM = 1024
AST_FREQ_DIM = 128
SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
SSLAM_TIME_DIM = 1024
SSLAM_FREQ_DIM = 128
PAIR_SUMMARY_DIM = 8


class ASTEncoder(nn.Module):
    """Wraps ASTModel and returns the [CLS] token embedding."""

    def __init__(self, model_name: str, freeze: bool = True):
        super().__init__()
        self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
        # print(f"AST hidden size: {self.ast.config.hidden_size}")
        if freeze:
            for p in self.ast.parameters():
                p.requires_grad = False

    def unfreeze_last_n(self, n: int = 2):
        for block in self.ast.encoder.layer[-n:]:
            for p in block.parameters():
                p.requires_grad = True
        for p in self.ast.layernorm.parameters():
            p.requires_grad = True
        # trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad)
        # print(f"unfroze {n} blocks, trainable params: {trainable:,}")



    @staticmethod
    def _prep(mel: torch.Tensor) -> torch.Tensor:
        """mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]"""
        x = mel.squeeze(1)
        T = x.shape[1]
        # print(f"input T={T}, target={AST_TIME_DIM}")
        if T < AST_TIME_DIM:
            pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        elif T > AST_TIME_DIM:
            x = x[:, :AST_TIME_DIM, :]
        return x

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        x = self._prep(mel)
        out = self.ast(input_values=x)
        # print(f"AST output shape: {out.last_hidden_state.shape}")
        return out.last_hidden_state[:, 0, :]


class PairMaskHead(nn.Module):
    """Beat-by-beat pair matching head over two mel spectrograms."""

    def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64, frames_per_beat: int = 8):
        super().__init__()
        self.beats_per_window = beats_per_window
        self.frames_per_beat = frames_per_beat
        self.pool = nn.AdaptiveAvgPool2d((beats_per_window * frames_per_beat, n_mels))
        self.patch_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(3, 5), padding=(1, 2), bias=False),
            nn.GroupNorm(4, 16),
            nn.GELU(),
            nn.Conv2d(16, 32, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2), bias=False),
            nn.GroupNorm(8, 32),
            nn.GELU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32, beat_dim),
            nn.GELU(),
            nn.Linear(beat_dim, beat_dim),
        )
        self.logit_scale = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.zeros(()))

    def _beats(self, mel: torch.Tensor) -> torch.Tensor:
        # mel: [B, 1, T, F] -> [B * beats, 1, frames_per_beat, F]
        bsz = mel.shape[0]
        x = self.pool(mel)
        x = x.view(bsz, 1, self.beats_per_window, self.frames_per_beat, x.shape[-1])
        x = x.permute(0, 2, 1, 3, 4).contiguous()
        x = x.view(bsz * self.beats_per_window, 1, self.frames_per_beat, x.shape[-1])
        x = self.patch_encoder(x).view(bsz, self.beats_per_window, -1)
        return torch.nn.functional.normalize(x, dim=-1)

    def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
        t = self._beats(track_mel)
        o = self._beats(orig_mel)
        return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias


def pair_summary_features(pair_logits: torch.Tensor) -> torch.Tensor:
    probs = torch.sigmoid(pair_logits)
    flat = probs.flatten(1)
    row_max = probs.max(dim=2).values
    col_max = probs.max(dim=1).values
    diag = torch.diagonal(probs, dim1=1, dim2=2)
    top_k = min(8, flat.shape[1])
    topk_mean = flat.topk(top_k, dim=1).values.mean(dim=1)
    return torch.stack(
        [
            flat.mean(dim=1),
            flat.max(dim=1).values,
            flat.std(dim=1, unbiased=False),
            topk_mean,
            row_max.mean(dim=1),
            row_max.max(dim=1).values,
            col_max.mean(dim=1),
            diag.mean(dim=1),
        ],
        dim=-1,
    )


class SampleDetector(nn.Module):
    """Siamese AST encoder + interaction head for binary sample detection."""

    def __init__(
        self,
        model_name: str = os.environ["AST_MODEL"],
        freeze_encoder: bool = True,
        dropout: float = 0.3,
        beats_per_window: int = 16,
        n_mels: int = 128,
    ):
        super().__init__()
        self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
        H = self.encoder.ast.config.hidden_size
        self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
        self.head = nn.Sequential(
            nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
            nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2),
        )

    def unfreeze_encoder(self, n_blocks: int = 2):
        self.encoder.unfreeze_last_n(n_blocks)

    def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
        t = self.encoder(track_mel)
        o = self.encoder(orig_mel)
        pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
        # print(f"embeddings: t={t.shape}, o={o.shape}")
        combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
        # print(f"combined shape: {combined.shape}")
        return self.head(combined)


class ConvBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.block(x)


class CNNEncoder(nn.Module):
    def __init__(self, embed_dim: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            ConvBlock(1, 32),
            ConvBlock(32, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256, embed_dim),
        )

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        return self.net(mel)


class CNNSampleDetector(nn.Module):
    """Drop-in CNN alternative to SampleDetector."""

    def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
        super().__init__()
        self.encoder = CNNEncoder(embed_dim)
        self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
        self.head = nn.Sequential(
            nn.LayerNorm(4 * embed_dim + PAIR_SUMMARY_DIM),
            nn.Linear(4 * embed_dim + PAIR_SUMMARY_DIM, 256),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(256, 64),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2),
        )

    def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
        t = self.encoder(track_mel)
        o = self.encoder(orig_mel)
        pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
        combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
        return self.head(combined)



class SSLAMEncoder(nn.Module):
    """Wraps the EAT (SSLAM) model and returns the CLS-like token embedding.

    Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility
    with EATModel's missing all_tied_weights_keys attribute.
    """

    def __init__(self, freeze: bool = True):
        super().__init__()
        from transformers import AutoConfig
        import safetensors.torch
        from huggingface_hub import hf_hub_download

        cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True)
        self.hidden_size = cfg.embed_dim
        sha = cfg._commit_hash or self._find_sha()
        eat_mod = importlib.import_module(
            f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model"
        )
        self.eat = eat_mod.EAT(cfg)

        weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors")
        raw = safetensors.torch.load_file(weights_path)
        state = {k.removeprefix("model."): v for k, v in raw.items()}
        self.eat.load_state_dict(state, strict=True)
        if freeze:
            for p in self.eat.parameters():
                p.requires_grad = False

    @staticmethod
    def _find_sha() -> str:
        dirs = glob.glob(
            os.path.expanduser(
                f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*"
            )
        )
        dirs = [d for d in dirs if os.path.isdir(d)]
        if not dirs:
            raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first")
        return os.path.basename(sorted(dirs)[-1])

    def unfreeze_last_n(self, n: int):
        for block in self.eat.blocks[-n:]:
            for p in block.parameters():
                p.requires_grad = True
                
        for p in self.eat.pre_norm.parameters():
            p.requires_grad = True

    @staticmethod
    def _prep(mel: torch.Tensor) -> torch.Tensor:
        """mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]"""
        x = mel.float()
        T = x.shape[2]
        if T < SSLAM_TIME_DIM:
            pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3],
                              device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=2)
        elif T > SSLAM_TIME_DIM:
            x = x[:, :, :SSLAM_TIME_DIM, :]
        return x

    def forward(self, mel: torch.Tensor) -> torch.Tensor:
        x = self._prep(mel)
        feats = self.eat.extract_features(x)
        # print(f"SSLAM features: {feats.shape}")  # should be [B, 1+patches, 768]
        return feats[:, 0, :]



class SSLAMSampleDetector(nn.Module):
    """SampleDetector using SSLAM/EAT encoder instead of AST."""

    def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
        super().__init__()
        self.encoder = SSLAMEncoder(freeze=freeze_encoder)
        H = self.encoder.hidden_size
        self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
        self.head = nn.Sequential(
            nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
            nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(128, 2),
        )

    def unfreeze_encoder(self, n_blocks: int):
        self.encoder.unfreeze_last_n(n_blocks)

    def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
        t = self.encoder(track_mel)
        o = self.encoder(orig_mel)
        pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
        combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
        return self.head(combined)


class ContrastiveSampleDetector(nn.Module):
    """Siamese AST encoder + projection head trained with CosineEmbeddingLoss."""

    def __init__(
        self,
        model_name: str = os.environ["AST_MODEL"],
        freeze_encoder: bool = True,
        proj_dim: int = 256,
        dropout: float = 0.2,
    ):
        super().__init__()
        self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
        H = self.encoder.ast.config.hidden_size
        self.proj = nn.Sequential(
            nn.Linear(H, 512),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(512, proj_dim),
        )

    def embed(self, mel: torch.Tensor) -> torch.Tensor:
        h = self.encoder(mel)
        # print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}")
        z = self.proj(h)
        return torch.nn.functional.normalize(z, dim=-1)

    def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple:
        return self.embed(track_mel), self.embed(orig_mel)

    def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
        t, o = self.embed(track_mel), self.embed(orig_mel)
        return (t * o).sum(dim=-1)

    def unfreeze_encoder(self, n_blocks: int = 2):
        self.encoder.unfreeze_last_n(n_blocks)