sampled / model.py
dayngerous's picture
Use classifier head for match verdict, show proposed masks on no-match
0a95bc3
import glob
import importlib
import os
import torch
import torch.nn as nn
from dotenv import load_dotenv
from transformers import ASTModel, ASTConfig
load_dotenv()
AST_TIME_DIM = 1024
AST_FREQ_DIM = 128
SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
SSLAM_TIME_DIM = 1024
SSLAM_FREQ_DIM = 128
PAIR_SUMMARY_DIM = 8
class ASTEncoder(nn.Module):
"""Wraps ASTModel and returns the [CLS] token embedding."""
def __init__(self, model_name: str, freeze: bool = True):
super().__init__()
self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
# print(f"AST hidden size: {self.ast.config.hidden_size}")
if freeze:
for p in self.ast.parameters():
p.requires_grad = False
def unfreeze_last_n(self, n: int = 2):
for block in self.ast.encoder.layer[-n:]:
for p in block.parameters():
p.requires_grad = True
for p in self.ast.layernorm.parameters():
p.requires_grad = True
# trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad)
# print(f"unfroze {n} blocks, trainable params: {trainable:,}")
@staticmethod
def _prep(mel: torch.Tensor) -> torch.Tensor:
"""mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]"""
x = mel.squeeze(1)
T = x.shape[1]
# print(f"input T={T}, target={AST_TIME_DIM}")
if T < AST_TIME_DIM:
pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=1)
elif T > AST_TIME_DIM:
x = x[:, :AST_TIME_DIM, :]
return x
def forward(self, mel: torch.Tensor) -> torch.Tensor:
x = self._prep(mel)
out = self.ast(input_values=x)
# print(f"AST output shape: {out.last_hidden_state.shape}")
return out.last_hidden_state[:, 0, :]
class PairMaskHead(nn.Module):
"""Beat-by-beat pair matching head over two mel spectrograms."""
def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64, frames_per_beat: int = 8):
super().__init__()
self.beats_per_window = beats_per_window
self.frames_per_beat = frames_per_beat
self.pool = nn.AdaptiveAvgPool2d((beats_per_window * frames_per_beat, n_mels))
self.patch_encoder = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=(3, 5), padding=(1, 2), bias=False),
nn.GroupNorm(4, 16),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2), bias=False),
nn.GroupNorm(8, 32),
nn.GELU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(32, beat_dim),
nn.GELU(),
nn.Linear(beat_dim, beat_dim),
)
self.logit_scale = nn.Parameter(torch.tensor(1.0))
self.bias = nn.Parameter(torch.zeros(()))
def _beats(self, mel: torch.Tensor) -> torch.Tensor:
# mel: [B, 1, T, F] -> [B * beats, 1, frames_per_beat, F]
bsz = mel.shape[0]
x = self.pool(mel)
x = x.view(bsz, 1, self.beats_per_window, self.frames_per_beat, x.shape[-1])
x = x.permute(0, 2, 1, 3, 4).contiguous()
x = x.view(bsz * self.beats_per_window, 1, self.frames_per_beat, x.shape[-1])
x = self.patch_encoder(x).view(bsz, self.beats_per_window, -1)
return torch.nn.functional.normalize(x, dim=-1)
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
t = self._beats(track_mel)
o = self._beats(orig_mel)
return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias
def pair_summary_features(pair_logits: torch.Tensor) -> torch.Tensor:
probs = torch.sigmoid(pair_logits)
flat = probs.flatten(1)
row_max = probs.max(dim=2).values
col_max = probs.max(dim=1).values
diag = torch.diagonal(probs, dim1=1, dim2=2)
top_k = min(8, flat.shape[1])
topk_mean = flat.topk(top_k, dim=1).values.mean(dim=1)
return torch.stack(
[
flat.mean(dim=1),
flat.max(dim=1).values,
flat.std(dim=1, unbiased=False),
topk_mean,
row_max.mean(dim=1),
row_max.max(dim=1).values,
col_max.mean(dim=1),
diag.mean(dim=1),
],
dim=-1,
)
class SampleDetector(nn.Module):
"""Siamese AST encoder + interaction head for binary sample detection."""
def __init__(
self,
model_name: str = os.environ["AST_MODEL"],
freeze_encoder: bool = True,
dropout: float = 0.3,
beats_per_window: int = 16,
n_mels: int = 128,
):
super().__init__()
self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
H = self.encoder.ast.config.hidden_size
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
self.head = nn.Sequential(
nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, 2),
)
def unfreeze_encoder(self, n_blocks: int = 2):
self.encoder.unfreeze_last_n(n_blocks)
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
t = self.encoder(track_mel)
o = self.encoder(orig_mel)
pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
# print(f"embeddings: t={t.shape}, o={o.shape}")
combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
# print(f"combined shape: {combined.shape}")
return self.head(combined)
class ConvBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
nn.BatchNorm2d(out_ch),
nn.GELU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.block(x)
class CNNEncoder(nn.Module):
def __init__(self, embed_dim: int = 256):
super().__init__()
self.net = nn.Sequential(
ConvBlock(1, 32),
ConvBlock(32, 64),
ConvBlock(64, 128),
ConvBlock(128, 256),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(256, embed_dim),
)
def forward(self, mel: torch.Tensor) -> torch.Tensor:
return self.net(mel)
class CNNSampleDetector(nn.Module):
"""Drop-in CNN alternative to SampleDetector."""
def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
super().__init__()
self.encoder = CNNEncoder(embed_dim)
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
self.head = nn.Sequential(
nn.LayerNorm(4 * embed_dim + PAIR_SUMMARY_DIM),
nn.Linear(4 * embed_dim + PAIR_SUMMARY_DIM, 256),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(256, 64),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(64, 2),
)
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
t = self.encoder(track_mel)
o = self.encoder(orig_mel)
pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
return self.head(combined)
class SSLAMEncoder(nn.Module):
"""Wraps the EAT (SSLAM) model and returns the CLS-like token embedding.
Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility
with EATModel's missing all_tied_weights_keys attribute.
"""
def __init__(self, freeze: bool = True):
super().__init__()
from transformers import AutoConfig
import safetensors.torch
from huggingface_hub import hf_hub_download
cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True)
self.hidden_size = cfg.embed_dim
sha = cfg._commit_hash or self._find_sha()
eat_mod = importlib.import_module(
f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model"
)
self.eat = eat_mod.EAT(cfg)
weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors")
raw = safetensors.torch.load_file(weights_path)
state = {k.removeprefix("model."): v for k, v in raw.items()}
self.eat.load_state_dict(state, strict=True)
if freeze:
for p in self.eat.parameters():
p.requires_grad = False
@staticmethod
def _find_sha() -> str:
dirs = glob.glob(
os.path.expanduser(
f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*"
)
)
dirs = [d for d in dirs if os.path.isdir(d)]
if not dirs:
raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first")
return os.path.basename(sorted(dirs)[-1])
def unfreeze_last_n(self, n: int):
for block in self.eat.blocks[-n:]:
for p in block.parameters():
p.requires_grad = True
for p in self.eat.pre_norm.parameters():
p.requires_grad = True
@staticmethod
def _prep(mel: torch.Tensor) -> torch.Tensor:
"""mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]"""
x = mel.float()
T = x.shape[2]
if T < SSLAM_TIME_DIM:
pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3],
device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=2)
elif T > SSLAM_TIME_DIM:
x = x[:, :, :SSLAM_TIME_DIM, :]
return x
def forward(self, mel: torch.Tensor) -> torch.Tensor:
x = self._prep(mel)
feats = self.eat.extract_features(x)
# print(f"SSLAM features: {feats.shape}") # should be [B, 1+patches, 768]
return feats[:, 0, :]
class SSLAMSampleDetector(nn.Module):
"""SampleDetector using SSLAM/EAT encoder instead of AST."""
def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
super().__init__()
self.encoder = SSLAMEncoder(freeze=freeze_encoder)
H = self.encoder.hidden_size
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
self.head = nn.Sequential(
nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(128, 2),
)
def unfreeze_encoder(self, n_blocks: int):
self.encoder.unfreeze_last_n(n_blocks)
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
t = self.encoder(track_mel)
o = self.encoder(orig_mel)
pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
return self.head(combined)
class ContrastiveSampleDetector(nn.Module):
"""Siamese AST encoder + projection head trained with CosineEmbeddingLoss."""
def __init__(
self,
model_name: str = os.environ["AST_MODEL"],
freeze_encoder: bool = True,
proj_dim: int = 256,
dropout: float = 0.2,
):
super().__init__()
self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
H = self.encoder.ast.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(H, 512),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(512, proj_dim),
)
def embed(self, mel: torch.Tensor) -> torch.Tensor:
h = self.encoder(mel)
# print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}")
z = self.proj(h)
return torch.nn.functional.normalize(z, dim=-1)
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple:
return self.embed(track_mel), self.embed(orig_mel)
def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
t, o = self.embed(track_mel), self.embed(orig_mel)
return (t * o).sum(dim=-1)
def unfreeze_encoder(self, n_blocks: int = 2):
self.encoder.unfreeze_last_n(n_blocks)