import glob import importlib import os import torch import torch.nn as nn from dotenv import load_dotenv from transformers import ASTModel, ASTConfig load_dotenv() AST_TIME_DIM = 1024 AST_FREQ_DIM = 128 SSLAM_HF_REPO = os.environ["SSLAM_MODEL"] SSLAM_TIME_DIM = 1024 SSLAM_FREQ_DIM = 128 PAIR_SUMMARY_DIM = 8 class ASTEncoder(nn.Module): """Wraps ASTModel and returns the [CLS] token embedding.""" def __init__(self, model_name: str, freeze: bool = True): super().__init__() self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True) # print(f"AST hidden size: {self.ast.config.hidden_size}") if freeze: for p in self.ast.parameters(): p.requires_grad = False def unfreeze_last_n(self, n: int = 2): for block in self.ast.encoder.layer[-n:]: for p in block.parameters(): p.requires_grad = True for p in self.ast.layernorm.parameters(): p.requires_grad = True # trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad) # print(f"unfroze {n} blocks, trainable params: {trainable:,}") @staticmethod def _prep(mel: torch.Tensor) -> torch.Tensor: """mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]""" x = mel.squeeze(1) T = x.shape[1] # print(f"input T={T}, target={AST_TIME_DIM}") if T < AST_TIME_DIM: pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype) x = torch.cat([x, pad], dim=1) elif T > AST_TIME_DIM: x = x[:, :AST_TIME_DIM, :] return x def forward(self, mel: torch.Tensor) -> torch.Tensor: x = self._prep(mel) out = self.ast(input_values=x) # print(f"AST output shape: {out.last_hidden_state.shape}") return out.last_hidden_state[:, 0, :] class PairMaskHead(nn.Module): """Beat-by-beat pair matching head over two mel spectrograms.""" def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64, frames_per_beat: int = 8): super().__init__() self.beats_per_window = beats_per_window self.frames_per_beat = frames_per_beat self.pool = nn.AdaptiveAvgPool2d((beats_per_window * frames_per_beat, n_mels)) self.patch_encoder = nn.Sequential( nn.Conv2d(1, 16, kernel_size=(3, 5), padding=(1, 2), bias=False), nn.GroupNorm(4, 16), nn.GELU(), nn.Conv2d(16, 32, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2), bias=False), nn.GroupNorm(8, 32), nn.GELU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(32, beat_dim), nn.GELU(), nn.Linear(beat_dim, beat_dim), ) self.logit_scale = nn.Parameter(torch.tensor(1.0)) self.bias = nn.Parameter(torch.zeros(())) def _beats(self, mel: torch.Tensor) -> torch.Tensor: # mel: [B, 1, T, F] -> [B * beats, 1, frames_per_beat, F] bsz = mel.shape[0] x = self.pool(mel) x = x.view(bsz, 1, self.beats_per_window, self.frames_per_beat, x.shape[-1]) x = x.permute(0, 2, 1, 3, 4).contiguous() x = x.view(bsz * self.beats_per_window, 1, self.frames_per_beat, x.shape[-1]) x = self.patch_encoder(x).view(bsz, self.beats_per_window, -1) return torch.nn.functional.normalize(x, dim=-1) def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: t = self._beats(track_mel) o = self._beats(orig_mel) return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias def pair_summary_features(pair_logits: torch.Tensor) -> torch.Tensor: probs = torch.sigmoid(pair_logits) flat = probs.flatten(1) row_max = probs.max(dim=2).values col_max = probs.max(dim=1).values diag = torch.diagonal(probs, dim1=1, dim2=2) top_k = min(8, flat.shape[1]) topk_mean = flat.topk(top_k, dim=1).values.mean(dim=1) return torch.stack( [ flat.mean(dim=1), flat.max(dim=1).values, flat.std(dim=1, unbiased=False), topk_mean, row_max.mean(dim=1), row_max.max(dim=1).values, col_max.mean(dim=1), diag.mean(dim=1), ], dim=-1, ) class SampleDetector(nn.Module): """Siamese AST encoder + interaction head for binary sample detection.""" def __init__( self, model_name: str = os.environ["AST_MODEL"], freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128, ): super().__init__() self.encoder = ASTEncoder(model_name, freeze=freeze_encoder) H = self.encoder.ast.config.hidden_size self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) self.head = nn.Sequential( nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM), nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512), nn.GELU(), nn.Dropout(dropout), nn.Linear(512, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, 2), ) def unfreeze_encoder(self, n_blocks: int = 2): self.encoder.unfreeze_last_n(n_blocks) def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: t = self.encoder(track_mel) o = self.encoder(orig_mel) pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) # print(f"embeddings: t={t.shape}, o={o.shape}") combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) # print(f"combined shape: {combined.shape}") return self.head(combined) class ConvBlock(nn.Module): def __init__(self, in_ch: int, out_ch: int, stride: int = 2): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.GELU(), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.GELU(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) class CNNEncoder(nn.Module): def __init__(self, embed_dim: int = 256): super().__init__() self.net = nn.Sequential( ConvBlock(1, 32), ConvBlock(32, 64), ConvBlock(64, 128), ConvBlock(128, 256), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(256, embed_dim), ) def forward(self, mel: torch.Tensor) -> torch.Tensor: return self.net(mel) class CNNSampleDetector(nn.Module): """Drop-in CNN alternative to SampleDetector.""" def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128): super().__init__() self.encoder = CNNEncoder(embed_dim) self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) self.head = nn.Sequential( nn.LayerNorm(4 * embed_dim + PAIR_SUMMARY_DIM), nn.Linear(4 * embed_dim + PAIR_SUMMARY_DIM, 256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, 64), nn.GELU(), nn.Dropout(dropout), nn.Linear(64, 2), ) def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: t = self.encoder(track_mel) o = self.encoder(orig_mel) pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) return self.head(combined) class SSLAMEncoder(nn.Module): """Wraps the EAT (SSLAM) model and returns the CLS-like token embedding. Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility with EATModel's missing all_tied_weights_keys attribute. """ def __init__(self, freeze: bool = True): super().__init__() from transformers import AutoConfig import safetensors.torch from huggingface_hub import hf_hub_download cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True) self.hidden_size = cfg.embed_dim sha = cfg._commit_hash or self._find_sha() eat_mod = importlib.import_module( f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model" ) self.eat = eat_mod.EAT(cfg) weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors") raw = safetensors.torch.load_file(weights_path) state = {k.removeprefix("model."): v for k, v in raw.items()} self.eat.load_state_dict(state, strict=True) if freeze: for p in self.eat.parameters(): p.requires_grad = False @staticmethod def _find_sha() -> str: dirs = glob.glob( os.path.expanduser( f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*" ) ) dirs = [d for d in dirs if os.path.isdir(d)] if not dirs: raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first") return os.path.basename(sorted(dirs)[-1]) def unfreeze_last_n(self, n: int): for block in self.eat.blocks[-n:]: for p in block.parameters(): p.requires_grad = True for p in self.eat.pre_norm.parameters(): p.requires_grad = True @staticmethod def _prep(mel: torch.Tensor) -> torch.Tensor: """mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]""" x = mel.float() T = x.shape[2] if T < SSLAM_TIME_DIM: pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3], device=x.device, dtype=x.dtype) x = torch.cat([x, pad], dim=2) elif T > SSLAM_TIME_DIM: x = x[:, :, :SSLAM_TIME_DIM, :] return x def forward(self, mel: torch.Tensor) -> torch.Tensor: x = self._prep(mel) feats = self.eat.extract_features(x) # print(f"SSLAM features: {feats.shape}") # should be [B, 1+patches, 768] return feats[:, 0, :] class SSLAMSampleDetector(nn.Module): """SampleDetector using SSLAM/EAT encoder instead of AST.""" def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128): super().__init__() self.encoder = SSLAMEncoder(freeze=freeze_encoder) H = self.encoder.hidden_size self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) self.head = nn.Sequential( nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM), nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512), nn.GELU(), nn.Dropout(dropout), nn.Linear(512, 128), nn.GELU(), nn.Dropout(dropout), nn.Linear(128, 2), ) def unfreeze_encoder(self, n_blocks: int): self.encoder.unfreeze_last_n(n_blocks) def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: t = self.encoder(track_mel) o = self.encoder(orig_mel) pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) return self.head(combined) class ContrastiveSampleDetector(nn.Module): """Siamese AST encoder + projection head trained with CosineEmbeddingLoss.""" def __init__( self, model_name: str = os.environ["AST_MODEL"], freeze_encoder: bool = True, proj_dim: int = 256, dropout: float = 0.2, ): super().__init__() self.encoder = ASTEncoder(model_name, freeze=freeze_encoder) H = self.encoder.ast.config.hidden_size self.proj = nn.Sequential( nn.Linear(H, 512), nn.GELU(), nn.Dropout(dropout), nn.Linear(512, proj_dim), ) def embed(self, mel: torch.Tensor) -> torch.Tensor: h = self.encoder(mel) # print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}") z = self.proj(h) return torch.nn.functional.normalize(z, dim=-1) def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple: return self.embed(track_mel), self.embed(orig_mel) def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: t, o = self.embed(track_mel), self.embed(orig_mel) return (t * o).sum(dim=-1) def unfreeze_encoder(self, n_blocks: int = 2): self.encoder.unfreeze_last_n(n_blocks)