Spaces:
Sleeping
Sleeping
| import glob | |
| import importlib | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from dotenv import load_dotenv | |
| from transformers import ASTModel, ASTConfig | |
| load_dotenv() | |
| AST_TIME_DIM = 1024 | |
| AST_FREQ_DIM = 128 | |
| SSLAM_HF_REPO = os.environ["SSLAM_MODEL"] | |
| SSLAM_TIME_DIM = 1024 | |
| SSLAM_FREQ_DIM = 128 | |
| PAIR_SUMMARY_DIM = 8 | |
| class ASTEncoder(nn.Module): | |
| """Wraps ASTModel and returns the [CLS] token embedding.""" | |
| def __init__(self, model_name: str, freeze: bool = True): | |
| super().__init__() | |
| self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True) | |
| # print(f"AST hidden size: {self.ast.config.hidden_size}") | |
| if freeze: | |
| for p in self.ast.parameters(): | |
| p.requires_grad = False | |
| def unfreeze_last_n(self, n: int = 2): | |
| for block in self.ast.encoder.layer[-n:]: | |
| for p in block.parameters(): | |
| p.requires_grad = True | |
| for p in self.ast.layernorm.parameters(): | |
| p.requires_grad = True | |
| # trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad) | |
| # print(f"unfroze {n} blocks, trainable params: {trainable:,}") | |
| def _prep(mel: torch.Tensor) -> torch.Tensor: | |
| """mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]""" | |
| x = mel.squeeze(1) | |
| T = x.shape[1] | |
| # print(f"input T={T}, target={AST_TIME_DIM}") | |
| if T < AST_TIME_DIM: | |
| pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype) | |
| x = torch.cat([x, pad], dim=1) | |
| elif T > AST_TIME_DIM: | |
| x = x[:, :AST_TIME_DIM, :] | |
| return x | |
| def forward(self, mel: torch.Tensor) -> torch.Tensor: | |
| x = self._prep(mel) | |
| out = self.ast(input_values=x) | |
| # print(f"AST output shape: {out.last_hidden_state.shape}") | |
| return out.last_hidden_state[:, 0, :] | |
| class PairMaskHead(nn.Module): | |
| """Beat-by-beat pair matching head over two mel spectrograms.""" | |
| def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64, frames_per_beat: int = 8): | |
| super().__init__() | |
| self.beats_per_window = beats_per_window | |
| self.frames_per_beat = frames_per_beat | |
| self.pool = nn.AdaptiveAvgPool2d((beats_per_window * frames_per_beat, n_mels)) | |
| self.patch_encoder = nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=(3, 5), padding=(1, 2), bias=False), | |
| nn.GroupNorm(4, 16), | |
| nn.GELU(), | |
| nn.Conv2d(16, 32, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2), bias=False), | |
| nn.GroupNorm(8, 32), | |
| nn.GELU(), | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(32, beat_dim), | |
| nn.GELU(), | |
| nn.Linear(beat_dim, beat_dim), | |
| ) | |
| self.logit_scale = nn.Parameter(torch.tensor(1.0)) | |
| self.bias = nn.Parameter(torch.zeros(())) | |
| def _beats(self, mel: torch.Tensor) -> torch.Tensor: | |
| # mel: [B, 1, T, F] -> [B * beats, 1, frames_per_beat, F] | |
| bsz = mel.shape[0] | |
| x = self.pool(mel) | |
| x = x.view(bsz, 1, self.beats_per_window, self.frames_per_beat, x.shape[-1]) | |
| x = x.permute(0, 2, 1, 3, 4).contiguous() | |
| x = x.view(bsz * self.beats_per_window, 1, self.frames_per_beat, x.shape[-1]) | |
| x = self.patch_encoder(x).view(bsz, self.beats_per_window, -1) | |
| return torch.nn.functional.normalize(x, dim=-1) | |
| def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: | |
| t = self._beats(track_mel) | |
| o = self._beats(orig_mel) | |
| return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias | |
| def pair_summary_features(pair_logits: torch.Tensor) -> torch.Tensor: | |
| probs = torch.sigmoid(pair_logits) | |
| flat = probs.flatten(1) | |
| row_max = probs.max(dim=2).values | |
| col_max = probs.max(dim=1).values | |
| diag = torch.diagonal(probs, dim1=1, dim2=2) | |
| top_k = min(8, flat.shape[1]) | |
| topk_mean = flat.topk(top_k, dim=1).values.mean(dim=1) | |
| return torch.stack( | |
| [ | |
| flat.mean(dim=1), | |
| flat.max(dim=1).values, | |
| flat.std(dim=1, unbiased=False), | |
| topk_mean, | |
| row_max.mean(dim=1), | |
| row_max.max(dim=1).values, | |
| col_max.mean(dim=1), | |
| diag.mean(dim=1), | |
| ], | |
| dim=-1, | |
| ) | |
| class SampleDetector(nn.Module): | |
| """Siamese AST encoder + interaction head for binary sample detection.""" | |
| def __init__( | |
| self, | |
| model_name: str = os.environ["AST_MODEL"], | |
| freeze_encoder: bool = True, | |
| dropout: float = 0.3, | |
| beats_per_window: int = 16, | |
| n_mels: int = 128, | |
| ): | |
| super().__init__() | |
| self.encoder = ASTEncoder(model_name, freeze=freeze_encoder) | |
| H = self.encoder.ast.config.hidden_size | |
| self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM), | |
| nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, 128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, 2), | |
| ) | |
| def unfreeze_encoder(self, n_blocks: int = 2): | |
| self.encoder.unfreeze_last_n(n_blocks) | |
| def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: | |
| t = self.encoder(track_mel) | |
| o = self.encoder(orig_mel) | |
| pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) | |
| # print(f"embeddings: t={t.shape}, o={o.shape}") | |
| combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) | |
| # print(f"combined shape: {combined.shape}") | |
| return self.head(combined) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_ch: int, out_ch: int, stride: int = 2): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False), | |
| nn.BatchNorm2d(out_ch), | |
| nn.GELU(), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), | |
| nn.BatchNorm2d(out_ch), | |
| nn.GELU(), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.block(x) | |
| class CNNEncoder(nn.Module): | |
| def __init__(self, embed_dim: int = 256): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| ConvBlock(1, 32), | |
| ConvBlock(32, 64), | |
| ConvBlock(64, 128), | |
| ConvBlock(128, 256), | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Flatten(), | |
| nn.Linear(256, embed_dim), | |
| ) | |
| def forward(self, mel: torch.Tensor) -> torch.Tensor: | |
| return self.net(mel) | |
| class CNNSampleDetector(nn.Module): | |
| """Drop-in CNN alternative to SampleDetector.""" | |
| def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128): | |
| super().__init__() | |
| self.encoder = CNNEncoder(embed_dim) | |
| self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(4 * embed_dim + PAIR_SUMMARY_DIM), | |
| nn.Linear(4 * embed_dim + PAIR_SUMMARY_DIM, 256), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 64), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 2), | |
| ) | |
| def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: | |
| t = self.encoder(track_mel) | |
| o = self.encoder(orig_mel) | |
| pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) | |
| combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) | |
| return self.head(combined) | |
| class SSLAMEncoder(nn.Module): | |
| """Wraps the EAT (SSLAM) model and returns the CLS-like token embedding. | |
| Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility | |
| with EATModel's missing all_tied_weights_keys attribute. | |
| """ | |
| def __init__(self, freeze: bool = True): | |
| super().__init__() | |
| from transformers import AutoConfig | |
| import safetensors.torch | |
| from huggingface_hub import hf_hub_download | |
| cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True) | |
| self.hidden_size = cfg.embed_dim | |
| sha = cfg._commit_hash or self._find_sha() | |
| eat_mod = importlib.import_module( | |
| f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model" | |
| ) | |
| self.eat = eat_mod.EAT(cfg) | |
| weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors") | |
| raw = safetensors.torch.load_file(weights_path) | |
| state = {k.removeprefix("model."): v for k, v in raw.items()} | |
| self.eat.load_state_dict(state, strict=True) | |
| if freeze: | |
| for p in self.eat.parameters(): | |
| p.requires_grad = False | |
| def _find_sha() -> str: | |
| dirs = glob.glob( | |
| os.path.expanduser( | |
| f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*" | |
| ) | |
| ) | |
| dirs = [d for d in dirs if os.path.isdir(d)] | |
| if not dirs: | |
| raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first") | |
| return os.path.basename(sorted(dirs)[-1]) | |
| def unfreeze_last_n(self, n: int): | |
| for block in self.eat.blocks[-n:]: | |
| for p in block.parameters(): | |
| p.requires_grad = True | |
| for p in self.eat.pre_norm.parameters(): | |
| p.requires_grad = True | |
| def _prep(mel: torch.Tensor) -> torch.Tensor: | |
| """mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]""" | |
| x = mel.float() | |
| T = x.shape[2] | |
| if T < SSLAM_TIME_DIM: | |
| pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3], | |
| device=x.device, dtype=x.dtype) | |
| x = torch.cat([x, pad], dim=2) | |
| elif T > SSLAM_TIME_DIM: | |
| x = x[:, :, :SSLAM_TIME_DIM, :] | |
| return x | |
| def forward(self, mel: torch.Tensor) -> torch.Tensor: | |
| x = self._prep(mel) | |
| feats = self.eat.extract_features(x) | |
| # print(f"SSLAM features: {feats.shape}") # should be [B, 1+patches, 768] | |
| return feats[:, 0, :] | |
| class SSLAMSampleDetector(nn.Module): | |
| """SampleDetector using SSLAM/EAT encoder instead of AST.""" | |
| def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128): | |
| super().__init__() | |
| self.encoder = SSLAMEncoder(freeze=freeze_encoder) | |
| H = self.encoder.hidden_size | |
| self.pair_mask_head = PairMaskHead(beats_per_window, n_mels) | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM), | |
| nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, 128), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, 2), | |
| ) | |
| def unfreeze_encoder(self, n_blocks: int): | |
| self.encoder.unfreeze_last_n(n_blocks) | |
| def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: | |
| t = self.encoder(track_mel) | |
| o = self.encoder(orig_mel) | |
| pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel)) | |
| combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1) | |
| return self.head(combined) | |
| class ContrastiveSampleDetector(nn.Module): | |
| """Siamese AST encoder + projection head trained with CosineEmbeddingLoss.""" | |
| def __init__( | |
| self, | |
| model_name: str = os.environ["AST_MODEL"], | |
| freeze_encoder: bool = True, | |
| proj_dim: int = 256, | |
| dropout: float = 0.2, | |
| ): | |
| super().__init__() | |
| self.encoder = ASTEncoder(model_name, freeze=freeze_encoder) | |
| H = self.encoder.ast.config.hidden_size | |
| self.proj = nn.Sequential( | |
| nn.Linear(H, 512), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(512, proj_dim), | |
| ) | |
| def embed(self, mel: torch.Tensor) -> torch.Tensor: | |
| h = self.encoder(mel) | |
| # print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}") | |
| z = self.proj(h) | |
| return torch.nn.functional.normalize(z, dim=-1) | |
| def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple: | |
| return self.embed(track_mel), self.embed(orig_mel) | |
| def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor: | |
| t, o = self.embed(track_mel), self.embed(orig_mel) | |
| return (t * o).sum(dim=-1) | |
| def unfreeze_encoder(self, n_blocks: int = 2): | |
| self.encoder.unfreeze_last_n(n_blocks) | |