Spaces:
Sleeping
Sleeping
Commit ·
4242909
1
Parent(s): dde4389
Add Gradio app, model code, and deps — checkpoint downloads from dayngerous/whoSampledAST
Browse files- README.md +24 -7
- app.py +597 -0
- model.py +316 -0
- requirements.txt +17 -0
README.md
CHANGED
|
@@ -1,13 +1,30 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
-
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Sample Match Verifier
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: "5.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# Sample Match Verifier
|
| 14 |
+
|
| 15 |
+
Upload a track and a possible source sample. Waveforms appear immediately on upload. Click **Verify match** to run the model — it scans beat-aligned windows, scores the best match, and highlights the predicted sampled sections on both the waveform and mel spectrogram. If no confident match is found, the mel spectrogram shows a **No Match** overlay.
|
| 16 |
+
|
| 17 |
+
## Model checkpoint
|
| 18 |
+
|
| 19 |
+
Place your checkpoint at `models/best.pt` (committed via Git LFS) or set the `MODEL_CHECKPOINT` environment variable to its path. The app falls back to `checkpoints/best.pt` if `models/best.pt` is not found.
|
| 20 |
+
|
| 21 |
+
## Environment variables
|
| 22 |
+
|
| 23 |
+
| Variable | Default | Description |
|
| 24 |
+
|---|---|---|
|
| 25 |
+
| `MODEL_CHECKPOINT` | `models/best.pt` | Path to the `.pt` checkpoint |
|
| 26 |
+
| `MODEL_BACKBONE` | `ast` | Backbone: `ast`, `sslam`, or `cnn` |
|
| 27 |
+
| `AST_MODEL` | `MIT/ast-finetuned-audioset-10-10-0.4593` | HuggingFace AST model ID |
|
| 28 |
+
| `MODEL_BARS` | `4` | Bars per analysis window |
|
| 29 |
+
| `MODEL_N_MELS` | `128` | Mel frequency bins |
|
| 30 |
+
| `APP_SAMPLE_RATE` | `16000` | Audio sample rate |
|
app.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
os.environ.setdefault("AST_MODEL", "MIT/ast-finetuned-audioset-10-10-0.4593")
|
| 8 |
+
os.environ.setdefault("SSLAM_MODEL", "ta012/SSLAM_pretrain")
|
| 9 |
+
|
| 10 |
+
import gradio as gr
|
| 11 |
+
import librosa
|
| 12 |
+
import matplotlib
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchaudio.transforms as T
|
| 16 |
+
from huggingface_hub import hf_hub_download
|
| 17 |
+
|
| 18 |
+
matplotlib.use("Agg")
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
|
| 21 |
+
from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000"))
|
| 25 |
+
MODEL_REPO = os.environ.get("MODEL_REPO", "dayngerous/whoSampledAST")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _resolve_checkpoint() -> str:
|
| 29 |
+
"""Return local checkpoint path, downloading from HF Hub if needed."""
|
| 30 |
+
env_path = os.environ.get("MODEL_CHECKPOINT", "")
|
| 31 |
+
for p in [env_path, "models/best.pt", "checkpoints/best.pt", "checkpoints2/best.pt"]:
|
| 32 |
+
if p and Path(p).exists():
|
| 33 |
+
return p
|
| 34 |
+
try:
|
| 35 |
+
return hf_hub_download(repo_id=MODEL_REPO, filename="models/best.pt")
|
| 36 |
+
except Exception as exc:
|
| 37 |
+
raise FileNotFoundError(
|
| 38 |
+
f"No local checkpoint found and download from {MODEL_REPO} failed: {exc}"
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _resolve_meta() -> str:
|
| 43 |
+
"""Return local test_indices.json path, downloading from HF Hub if needed."""
|
| 44 |
+
for p in ["models/test_indices.json", "checkpoints2/test_indices.json", "checkpoints/test_indices.json"]:
|
| 45 |
+
if Path(p).exists():
|
| 46 |
+
return p
|
| 47 |
+
try:
|
| 48 |
+
return hf_hub_download(repo_id=MODEL_REPO, filename="models/test_indices.json")
|
| 49 |
+
except Exception:
|
| 50 |
+
return ""
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
DEFAULT_CHECKPOINT = _resolve_checkpoint()
|
| 54 |
+
DEFAULT_META = DEFAULT_META or _resolve_meta()
|
| 55 |
+
TARGET_FRAMES_PER_BEAT = 50
|
| 56 |
+
N_FFT = 1024
|
| 57 |
+
MEL_HOP = 512
|
| 58 |
+
N_MELS_VIZ = 128
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class AudioClip:
|
| 63 |
+
waveform: torch.Tensor
|
| 64 |
+
sample_rate: int
|
| 65 |
+
offset_sec: float
|
| 66 |
+
duration_sec: float
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class BeatWindow:
|
| 71 |
+
waveform: torch.Tensor
|
| 72 |
+
start_sec: float
|
| 73 |
+
end_sec: float
|
| 74 |
+
beat_intervals: list[tuple[float, float]]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _format_time(seconds: float) -> str:
|
| 78 |
+
seconds = max(0.0, float(seconds))
|
| 79 |
+
minutes = int(seconds // 60)
|
| 80 |
+
rem = seconds - minutes * 60
|
| 81 |
+
return f"{minutes}:{rem:04.1f}"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _format_intervals(intervals: list[tuple[float, float]], limit: int = 4) -> str:
|
| 85 |
+
if not intervals:
|
| 86 |
+
return "none"
|
| 87 |
+
shown = ", ".join(f"{_format_time(a)}-{_format_time(b)}" for a, b in intervals[:limit])
|
| 88 |
+
if len(intervals) > limit:
|
| 89 |
+
shown += f", +{len(intervals) - limit} more"
|
| 90 |
+
return shown
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _merge_intervals(intervals: list[tuple[float, float]], gap: float = 0.05) -> list[tuple[float, float]]:
|
| 94 |
+
if not intervals:
|
| 95 |
+
return []
|
| 96 |
+
ordered = sorted((float(a), float(b)) for a, b in intervals if b > a)
|
| 97 |
+
merged = [ordered[0]]
|
| 98 |
+
for start, end in ordered[1:]:
|
| 99 |
+
prev_start, prev_end = merged[-1]
|
| 100 |
+
if start <= prev_end + gap:
|
| 101 |
+
merged[-1] = (prev_start, max(prev_end, end))
|
| 102 |
+
else:
|
| 103 |
+
merged.append((start, end))
|
| 104 |
+
return merged
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _load_args(checkpoint_path: Path) -> dict:
|
| 108 |
+
meta_path = Path(DEFAULT_META) if DEFAULT_META else checkpoint_path.parent / "test_indices.json"
|
| 109 |
+
args = {}
|
| 110 |
+
if meta_path.exists():
|
| 111 |
+
with open(meta_path) as f:
|
| 112 |
+
args = json.load(f).get("args", {})
|
| 113 |
+
|
| 114 |
+
args.setdefault("backbone", os.environ.get("MODEL_BACKBONE", "ast"))
|
| 115 |
+
args.setdefault("ast_model", os.environ.get("AST_MODEL"))
|
| 116 |
+
args.setdefault("bars", int(os.environ.get("MODEL_BARS", "4")))
|
| 117 |
+
args.setdefault("n_mels", int(os.environ.get("MODEL_N_MELS", "128")))
|
| 118 |
+
args.setdefault("sample_rate", SAMPLE_RATE)
|
| 119 |
+
return args
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _build_model(args: dict, device: torch.device):
|
| 123 |
+
beats_per_window = int(args.get("bars", 4)) * 4
|
| 124 |
+
n_mels = int(args.get("n_mels", 128))
|
| 125 |
+
backbone = args.get("backbone", "ast")
|
| 126 |
+
if backbone == "ast":
|
| 127 |
+
model = SampleDetector(
|
| 128 |
+
model_name=args.get("ast_model", os.environ["AST_MODEL"]),
|
| 129 |
+
freeze_encoder=True,
|
| 130 |
+
beats_per_window=beats_per_window,
|
| 131 |
+
n_mels=n_mels,
|
| 132 |
+
)
|
| 133 |
+
elif backbone == "sslam":
|
| 134 |
+
model = SSLAMSampleDetector(
|
| 135 |
+
freeze_encoder=True,
|
| 136 |
+
beats_per_window=beats_per_window,
|
| 137 |
+
n_mels=n_mels,
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
model = CNNSampleDetector(beats_per_window=beats_per_window, n_mels=n_mels)
|
| 141 |
+
return model.to(device)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@lru_cache(maxsize=2)
|
| 145 |
+
def _load_model(checkpoint_path: str):
|
| 146 |
+
path = Path(checkpoint_path)
|
| 147 |
+
if not path.exists():
|
| 148 |
+
raise FileNotFoundError(
|
| 149 |
+
f"Checkpoint not found: {path}. Set MODEL_CHECKPOINT or place a checkpoint at models/best.pt."
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 153 |
+
args = _load_args(path)
|
| 154 |
+
model = _build_model(args, device)
|
| 155 |
+
ckpt = torch.load(path, map_location=device)
|
| 156 |
+
state = ckpt.get("model_state", ckpt)
|
| 157 |
+
pair_head_loaded = any(k.startswith("pair_mask_head.") for k in state)
|
| 158 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 159 |
+
model.eval()
|
| 160 |
+
return {
|
| 161 |
+
"model": model,
|
| 162 |
+
"args": args,
|
| 163 |
+
"device": device,
|
| 164 |
+
"epoch": ckpt.get("epoch", "?"),
|
| 165 |
+
"pair_head_loaded": pair_head_loaded,
|
| 166 |
+
"missing": missing,
|
| 167 |
+
"unexpected": unexpected,
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _load_audio(path: str, offset_sec: float, max_seconds: float) -> AudioClip:
|
| 172 |
+
if not path:
|
| 173 |
+
raise gr.Error("Upload both audio files before running verification.")
|
| 174 |
+
|
| 175 |
+
audio, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True)
|
| 176 |
+
waveform = torch.from_numpy(audio).float()
|
| 177 |
+
|
| 178 |
+
offset_sec = max(0.0, float(offset_sec or 0.0))
|
| 179 |
+
max_seconds = max(1.0, float(max_seconds or 1.0))
|
| 180 |
+
start = min(int(offset_sec * sr), max(waveform.numel() - 1, 0))
|
| 181 |
+
end = min(start + int(max_seconds * sr), waveform.numel())
|
| 182 |
+
waveform = waveform[start:end].float().contiguous()
|
| 183 |
+
if waveform.numel() < sr // 4:
|
| 184 |
+
raise gr.Error("Each upload must contain at least 0.25 seconds of audio after offset trimming.")
|
| 185 |
+
|
| 186 |
+
peak = waveform.abs().max().clamp_min(1e-6)
|
| 187 |
+
waveform = waveform / peak
|
| 188 |
+
return AudioClip(
|
| 189 |
+
waveform=waveform,
|
| 190 |
+
sample_rate=sr,
|
| 191 |
+
offset_sec=offset_sec,
|
| 192 |
+
duration_sec=waveform.numel() / sr,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _estimate_beats(waveform: torch.Tensor, sample_rate: int) -> tuple[float, np.ndarray]:
|
| 197 |
+
y = waveform.detach().cpu().numpy().astype(np.float32)
|
| 198 |
+
tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sample_rate, hop_length=512)
|
| 199 |
+
bpm = float(np.atleast_1d(tempo)[0]) if np.size(tempo) else 120.0
|
| 200 |
+
if not np.isfinite(bpm) or bpm <= 0:
|
| 201 |
+
bpm = 120.0
|
| 202 |
+
bpm = float(np.clip(bpm, 60.0, 200.0))
|
| 203 |
+
|
| 204 |
+
beat_samples = librosa.frames_to_samples(beat_frames, hop_length=512)
|
| 205 |
+
beat_samples = beat_samples[(beat_samples >= 0) & (beat_samples < waveform.numel())]
|
| 206 |
+
if len(beat_samples) < 2:
|
| 207 |
+
step = max(1, int(round(sample_rate * 60.0 / bpm)))
|
| 208 |
+
beat_samples = np.arange(0, waveform.numel(), step, dtype=np.int64)
|
| 209 |
+
elif beat_samples[0] > sample_rate * 60.0 / bpm:
|
| 210 |
+
beat_samples = np.insert(beat_samples, 0, 0)
|
| 211 |
+
return bpm, beat_samples.astype(np.int64)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _to_mel(waveform: torch.Tensor, bpm: float, args: dict) -> torch.Tensor:
|
| 215 |
+
sample_rate = int(args.get("sample_rate", SAMPLE_RATE))
|
| 216 |
+
n_mels = int(args.get("n_mels", 128))
|
| 217 |
+
bars = int(args.get("bars", 4))
|
| 218 |
+
fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
|
| 219 |
+
hop = max(1, round(60 * sample_rate / (bpm * TARGET_FRAMES_PER_BEAT)))
|
| 220 |
+
|
| 221 |
+
mel_transform = T.MelSpectrogram(
|
| 222 |
+
sample_rate=sample_rate,
|
| 223 |
+
n_fft=N_FFT,
|
| 224 |
+
hop_length=hop,
|
| 225 |
+
n_mels=n_mels,
|
| 226 |
+
power=2.0,
|
| 227 |
+
)
|
| 228 |
+
amp_to_db = T.AmplitudeToDB(stype="power", top_db=80)
|
| 229 |
+
mel = amp_to_db(mel_transform(waveform)).T
|
| 230 |
+
if mel.shape[0] > fixed_frames:
|
| 231 |
+
mel = mel[:fixed_frames]
|
| 232 |
+
elif mel.shape[0] < fixed_frames:
|
| 233 |
+
mel = torch.cat([mel, torch.zeros(fixed_frames - mel.shape[0], mel.shape[1])], dim=0)
|
| 234 |
+
mel = (mel - mel.mean()) / (mel.std() + 1e-6)
|
| 235 |
+
return mel.unsqueeze(0)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _make_windows(
|
| 239 |
+
clip: AudioClip,
|
| 240 |
+
bpm: float,
|
| 241 |
+
beat_samples: np.ndarray,
|
| 242 |
+
args: dict,
|
| 243 |
+
stride_beats: int,
|
| 244 |
+
max_windows: int,
|
| 245 |
+
) -> list[BeatWindow]:
|
| 246 |
+
bars = int(args.get("bars", 4))
|
| 247 |
+
beats_per_window = bars * 4
|
| 248 |
+
window_samples = max(1, int(round(beats_per_window * 60.0 / bpm * clip.sample_rate)))
|
| 249 |
+
beat_seconds = 60.0 / bpm
|
| 250 |
+
stride_beats = max(1, int(stride_beats))
|
| 251 |
+
max_windows = max(1, int(max_windows))
|
| 252 |
+
|
| 253 |
+
valid = [i for i in range(0, len(beat_samples), stride_beats) if beat_samples[i] < clip.waveform.numel()]
|
| 254 |
+
if not valid:
|
| 255 |
+
valid = [0]
|
| 256 |
+
|
| 257 |
+
if len(valid) > max_windows:
|
| 258 |
+
chosen_positions = np.linspace(0, len(valid) - 1, max_windows, dtype=np.int64)
|
| 259 |
+
valid = [valid[i] for i in sorted(set(chosen_positions.tolist()))]
|
| 260 |
+
|
| 261 |
+
windows = []
|
| 262 |
+
for beat_idx in valid:
|
| 263 |
+
start_sample = int(beat_samples[beat_idx]) if len(beat_samples) else 0
|
| 264 |
+
chunk = clip.waveform[start_sample:start_sample + window_samples]
|
| 265 |
+
if chunk.numel() < window_samples:
|
| 266 |
+
chunk = torch.nn.functional.pad(chunk, (0, window_samples - chunk.numel()))
|
| 267 |
+
|
| 268 |
+
start_sec = clip.offset_sec + start_sample / clip.sample_rate
|
| 269 |
+
end_sec = start_sec + window_samples / clip.sample_rate
|
| 270 |
+
beat_intervals = [
|
| 271 |
+
(start_sec + i * beat_seconds, start_sec + (i + 1) * beat_seconds)
|
| 272 |
+
for i in range(beats_per_window)
|
| 273 |
+
]
|
| 274 |
+
windows.append(BeatWindow(chunk, start_sec, end_sec, beat_intervals))
|
| 275 |
+
return windows
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 279 |
+
embs = []
|
| 280 |
+
for start in range(0, mels.shape[0], batch_size):
|
| 281 |
+
embs.append(model.encoder(mels[start:start + batch_size]))
|
| 282 |
+
return torch.cat(embs, dim=0)
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor:
|
| 286 |
+
track_emb = _encode(model, track_mels, batch_size)
|
| 287 |
+
source_emb = _encode(model, source_mels, batch_size)
|
| 288 |
+
n_track, n_source = track_emb.shape[0], source_emb.shape[0]
|
| 289 |
+
scores = []
|
| 290 |
+
|
| 291 |
+
pair_indices = [(i, j) for i in range(n_track) for j in range(n_source)]
|
| 292 |
+
for start in range(0, len(pair_indices), batch_size):
|
| 293 |
+
chunk = pair_indices[start:start + batch_size]
|
| 294 |
+
ti = torch.tensor([p[0] for p in chunk], device=track_emb.device)
|
| 295 |
+
sj = torch.tensor([p[1] for p in chunk], device=track_emb.device)
|
| 296 |
+
t = track_emb.index_select(0, ti)
|
| 297 |
+
s = source_emb.index_select(0, sj)
|
| 298 |
+
combined = torch.cat([t, s, torch.abs(t - s), t * s], dim=-1)
|
| 299 |
+
logits = model.head(combined)
|
| 300 |
+
scores.append(torch.softmax(logits, dim=-1)[:, 1])
|
| 301 |
+
|
| 302 |
+
return torch.cat(scores).reshape(n_track, n_source)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]:
|
| 306 |
+
intervals = []
|
| 307 |
+
for use, (start, end) in zip(mask.tolist(), window.beat_intervals):
|
| 308 |
+
if use:
|
| 309 |
+
intervals.append((start, min(end, max_end)))
|
| 310 |
+
return _merge_intervals(intervals)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _localize_match(
|
| 314 |
+
model,
|
| 315 |
+
track_mel: torch.Tensor,
|
| 316 |
+
source_mel: torch.Tensor,
|
| 317 |
+
track_window: BeatWindow,
|
| 318 |
+
source_window: BeatWindow,
|
| 319 |
+
track_clip: AudioClip,
|
| 320 |
+
source_clip: AudioClip,
|
| 321 |
+
threshold: float,
|
| 322 |
+
pair_head_loaded: bool,
|
| 323 |
+
) -> tuple[list[tuple[float, float]], list[tuple[float, float]], str]:
|
| 324 |
+
if not pair_head_loaded:
|
| 325 |
+
return (
|
| 326 |
+
[(track_window.start_sec, min(track_window.end_sec, track_clip.offset_sec + track_clip.duration_sec))],
|
| 327 |
+
[(source_window.start_sec, min(source_window.end_sec, source_clip.offset_sec + source_clip.duration_sec))],
|
| 328 |
+
"The checkpoint does not include a trained pairwise beat head, so the highlight covers the best matching window.",
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
with torch.inference_mode():
|
| 332 |
+
pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
|
| 333 |
+
|
| 334 |
+
selected = pair_probs >= float(threshold)
|
| 335 |
+
if not selected.any():
|
| 336 |
+
top_k = min(6, pair_probs.size)
|
| 337 |
+
flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
|
| 338 |
+
selected = np.zeros_like(pair_probs, dtype=bool)
|
| 339 |
+
selected.reshape(-1)[flat] = True
|
| 340 |
+
|
| 341 |
+
track_mask = selected.any(axis=1)
|
| 342 |
+
source_mask = selected.any(axis=0)
|
| 343 |
+
track_regions = _intervals_from_mask(
|
| 344 |
+
track_mask,
|
| 345 |
+
track_window,
|
| 346 |
+
track_clip.offset_sec + track_clip.duration_sec,
|
| 347 |
+
)
|
| 348 |
+
source_regions = _intervals_from_mask(
|
| 349 |
+
source_mask,
|
| 350 |
+
source_window,
|
| 351 |
+
source_clip.offset_sec + source_clip.duration_sec,
|
| 352 |
+
)
|
| 353 |
+
return track_regions, source_regions, ""
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def _draw_waveform(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str):
|
| 357 |
+
y = clip.waveform.detach().cpu().numpy()
|
| 358 |
+
n = len(y)
|
| 359 |
+
points = min(20000, n)
|
| 360 |
+
idx = np.linspace(0, n - 1, points, dtype=np.int64)
|
| 361 |
+
x = clip.offset_sec + idx / clip.sample_rate
|
| 362 |
+
ax.plot(x, y[idx], color="#111827", linewidth=0.55)
|
| 363 |
+
for start, end in regions:
|
| 364 |
+
ax.axvspan(start, end, color=color, alpha=0.28)
|
| 365 |
+
ax.set_title(title, loc="left", fontsize=10)
|
| 366 |
+
ax.set_ylabel("Amplitude")
|
| 367 |
+
ax.set_xlim(clip.offset_sec, clip.offset_sec + clip.duration_sec)
|
| 368 |
+
ax.set_ylim(-1.05, 1.05)
|
| 369 |
+
ax.grid(True, alpha=0.18)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str, matched: bool):
|
| 373 |
+
y = clip.waveform.detach().cpu().numpy().astype(np.float32)
|
| 374 |
+
mel = librosa.feature.melspectrogram(y=y, sr=clip.sample_rate, n_mels=N_MELS_VIZ, hop_length=MEL_HOP)
|
| 375 |
+
mel_db = librosa.power_to_db(mel, ref=np.max)
|
| 376 |
+
|
| 377 |
+
t_start = clip.offset_sec
|
| 378 |
+
t_end = clip.offset_sec + clip.duration_sec
|
| 379 |
+
f_max = clip.sample_rate / 2
|
| 380 |
+
|
| 381 |
+
ax.imshow(
|
| 382 |
+
mel_db,
|
| 383 |
+
aspect="auto",
|
| 384 |
+
origin="lower",
|
| 385 |
+
extent=[t_start, t_end, 0, f_max],
|
| 386 |
+
cmap="magma",
|
| 387 |
+
interpolation="nearest",
|
| 388 |
+
)
|
| 389 |
+
ax.set_title(title, loc="left", fontsize=10)
|
| 390 |
+
ax.set_ylabel("Frequency (Hz)")
|
| 391 |
+
ax.set_xlim(t_start, t_end)
|
| 392 |
+
|
| 393 |
+
if matched and regions:
|
| 394 |
+
for start, end in regions:
|
| 395 |
+
ax.axvspan(start, end, color=color, alpha=0.38, linewidth=0)
|
| 396 |
+
elif not matched:
|
| 397 |
+
ax.text(
|
| 398 |
+
0.5, 0.5, "No Match",
|
| 399 |
+
transform=ax.transAxes,
|
| 400 |
+
fontsize=18,
|
| 401 |
+
color="white",
|
| 402 |
+
ha="center",
|
| 403 |
+
va="center",
|
| 404 |
+
fontweight="bold",
|
| 405 |
+
bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def _plot_waveforms(
|
| 410 |
+
track_clip: AudioClip,
|
| 411 |
+
source_clip: AudioClip,
|
| 412 |
+
track_regions: list[tuple[float, float]],
|
| 413 |
+
source_regions: list[tuple[float, float]],
|
| 414 |
+
score: float | None,
|
| 415 |
+
matched: bool,
|
| 416 |
+
) -> plt.Figure:
|
| 417 |
+
fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=False)
|
| 418 |
+
color = "#22c55e" if matched else "#f59e0b"
|
| 419 |
+
title_score = "unavailable" if score is None else f"{score:.3f}"
|
| 420 |
+
fig.suptitle(f"Best match score: {title_score}" if score is not None else "Waveform preview", fontsize=12)
|
| 421 |
+
|
| 422 |
+
_draw_waveform(axes[0], track_clip, track_regions, color, "Track / song audio")
|
| 423 |
+
_draw_waveform(axes[1], source_clip, source_regions, color, "Source sample audio")
|
| 424 |
+
axes[1].set_xlabel("Time in uploaded file (seconds)")
|
| 425 |
+
fig.tight_layout()
|
| 426 |
+
return fig
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _plot_mels(
|
| 430 |
+
track_clip: AudioClip,
|
| 431 |
+
source_clip: AudioClip,
|
| 432 |
+
track_regions: list[tuple[float, float]],
|
| 433 |
+
source_regions: list[tuple[float, float]],
|
| 434 |
+
matched: bool,
|
| 435 |
+
) -> plt.Figure:
|
| 436 |
+
fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=False)
|
| 437 |
+
color = "#22c55e" if matched else "#f59e0b"
|
| 438 |
+
|
| 439 |
+
_draw_mel(axes[0], track_clip, track_regions, color, "Track mel spectrogram", matched)
|
| 440 |
+
_draw_mel(axes[1], source_clip, source_regions, color, "Source mel spectrogram", matched)
|
| 441 |
+
axes[1].set_xlabel("Time in uploaded file (seconds)")
|
| 442 |
+
fig.tight_layout()
|
| 443 |
+
return fig
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
def preview_waveforms(track_audio, source_audio):
|
| 447 |
+
if not track_audio or not source_audio:
|
| 448 |
+
return None, None
|
| 449 |
+
try:
|
| 450 |
+
track_clip = _load_audio(track_audio, 0.0, 120.0)
|
| 451 |
+
source_clip = _load_audio(source_audio, 0.0, 120.0)
|
| 452 |
+
wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
|
| 453 |
+
mfig = _plot_mels(track_clip, source_clip, [], [], False)
|
| 454 |
+
return wfig, mfig
|
| 455 |
+
except Exception:
|
| 456 |
+
return None, None
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def verify(
|
| 460 |
+
track_audio,
|
| 461 |
+
source_audio,
|
| 462 |
+
checkpoint_path,
|
| 463 |
+
match_threshold,
|
| 464 |
+
localization_threshold,
|
| 465 |
+
track_offset,
|
| 466 |
+
source_offset,
|
| 467 |
+
max_seconds,
|
| 468 |
+
stride_beats,
|
| 469 |
+
max_windows,
|
| 470 |
+
):
|
| 471 |
+
try:
|
| 472 |
+
track_clip = _load_audio(track_audio, track_offset, max_seconds)
|
| 473 |
+
source_clip = _load_audio(source_audio, source_offset, max_seconds)
|
| 474 |
+
except Exception as exc:
|
| 475 |
+
raise gr.Error(str(exc))
|
| 476 |
+
|
| 477 |
+
try:
|
| 478 |
+
loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
|
| 479 |
+
except Exception as exc:
|
| 480 |
+
wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
|
| 481 |
+
mfig = _plot_mels(track_clip, source_clip, [], [], False)
|
| 482 |
+
return f"Model could not be loaded: {exc}", wfig, mfig
|
| 483 |
+
|
| 484 |
+
model = loaded["model"]
|
| 485 |
+
args = loaded["args"]
|
| 486 |
+
device = loaded["device"]
|
| 487 |
+
batch_size = 8 if device.type == "cpu" else 32
|
| 488 |
+
|
| 489 |
+
track_bpm, track_beats = _estimate_beats(track_clip.waveform, track_clip.sample_rate)
|
| 490 |
+
source_bpm, source_beats = _estimate_beats(source_clip.waveform, source_clip.sample_rate)
|
| 491 |
+
track_windows = _make_windows(track_clip, track_bpm, track_beats, args, stride_beats, max_windows)
|
| 492 |
+
source_windows = _make_windows(source_clip, source_bpm, source_beats, args, stride_beats, max_windows)
|
| 493 |
+
|
| 494 |
+
track_mels = torch.stack([_to_mel(w.waveform, track_bpm, args) for w in track_windows]).to(device)
|
| 495 |
+
source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device)
|
| 496 |
+
|
| 497 |
+
with torch.inference_mode():
|
| 498 |
+
score_matrix = _score_pairs(model, track_mels, source_mels, batch_size)
|
| 499 |
+
best_flat = int(torch.argmax(score_matrix).item())
|
| 500 |
+
best_track = best_flat // score_matrix.shape[1]
|
| 501 |
+
best_source = best_flat % score_matrix.shape[1]
|
| 502 |
+
best_score = float(score_matrix[best_track, best_source].detach().cpu())
|
| 503 |
+
matched = best_score >= float(match_threshold)
|
| 504 |
+
|
| 505 |
+
track_regions, source_regions, note = _localize_match(
|
| 506 |
+
model,
|
| 507 |
+
track_mels[best_track:best_track + 1],
|
| 508 |
+
source_mels[best_source:best_source + 1],
|
| 509 |
+
track_windows[best_track],
|
| 510 |
+
source_windows[best_source],
|
| 511 |
+
track_clip,
|
| 512 |
+
source_clip,
|
| 513 |
+
localization_threshold,
|
| 514 |
+
loaded["pair_head_loaded"],
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
highlight_track = track_regions if matched else []
|
| 518 |
+
highlight_source = source_regions if matched else []
|
| 519 |
+
|
| 520 |
+
wfig = _plot_waveforms(track_clip, source_clip, highlight_track, highlight_source, best_score, matched)
|
| 521 |
+
mfig = _plot_mels(track_clip, source_clip, highlight_track, highlight_source, matched)
|
| 522 |
+
|
| 523 |
+
verdict = "Likely match" if matched else "No confident match"
|
| 524 |
+
details = [
|
| 525 |
+
f"**{verdict}**",
|
| 526 |
+
f"Score: `{best_score:.3f}` with threshold `{float(match_threshold):.2f}`.",
|
| 527 |
+
f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.",
|
| 528 |
+
f"Highlighted track section(s): {_format_intervals(highlight_track)}.",
|
| 529 |
+
f"Highlighted source section(s): {_format_intervals(highlight_source)}.",
|
| 530 |
+
f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
|
| 531 |
+
]
|
| 532 |
+
if note:
|
| 533 |
+
details.append(note)
|
| 534 |
+
if loaded["missing"]:
|
| 535 |
+
details.append(f"Missing checkpoint keys initialized at load time: `{len(loaded['missing'])}`.")
|
| 536 |
+
return "\n\n".join(details), wfig, mfig
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
with gr.Blocks(title="Sample Match Verifier") as demo:
|
| 540 |
+
gr.Markdown("# Sample Match Verifier")
|
| 541 |
+
gr.Markdown(
|
| 542 |
+
"Upload a track and a possible source sample. "
|
| 543 |
+
"Waveforms appear immediately on upload. "
|
| 544 |
+
"Click **Verify match** to run the model and highlight sampled sections."
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
with gr.Row():
|
| 548 |
+
track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
|
| 549 |
+
source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
|
| 550 |
+
|
| 551 |
+
with gr.Accordion("Settings", open=False):
|
| 552 |
+
checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
|
| 553 |
+
with gr.Row():
|
| 554 |
+
match_threshold = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Match threshold")
|
| 555 |
+
localization_threshold = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Highlight threshold")
|
| 556 |
+
with gr.Row():
|
| 557 |
+
track_offset = gr.Number(value=0.0, label="Track start offset, seconds")
|
| 558 |
+
source_offset = gr.Number(value=0.0, label="Source start offset, seconds")
|
| 559 |
+
with gr.Row():
|
| 560 |
+
max_seconds = gr.Slider(5, 180, value=60, step=5, label="Analyze duration per upload, seconds")
|
| 561 |
+
stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
|
| 562 |
+
max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
|
| 563 |
+
|
| 564 |
+
run = gr.Button("Verify match", variant="primary")
|
| 565 |
+
result = gr.Markdown()
|
| 566 |
+
|
| 567 |
+
waveform_plot = gr.Plot(label="Waveforms")
|
| 568 |
+
mel_plot = gr.Plot(label="Mel Spectrograms")
|
| 569 |
+
|
| 570 |
+
# Show waveforms as soon as both files are uploaded
|
| 571 |
+
for audio_input in [track_audio, source_audio]:
|
| 572 |
+
audio_input.change(
|
| 573 |
+
preview_waveforms,
|
| 574 |
+
inputs=[track_audio, source_audio],
|
| 575 |
+
outputs=[waveform_plot, mel_plot],
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
run.click(
|
| 579 |
+
verify,
|
| 580 |
+
inputs=[
|
| 581 |
+
track_audio,
|
| 582 |
+
source_audio,
|
| 583 |
+
checkpoint_path,
|
| 584 |
+
match_threshold,
|
| 585 |
+
localization_threshold,
|
| 586 |
+
track_offset,
|
| 587 |
+
source_offset,
|
| 588 |
+
max_seconds,
|
| 589 |
+
stride_beats,
|
| 590 |
+
max_windows,
|
| 591 |
+
],
|
| 592 |
+
outputs=[result, waveform_plot, mel_plot],
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
if __name__ == "__main__":
|
| 597 |
+
demo.queue(max_size=8).launch()
|
model.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import importlib
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from transformers import ASTModel, ASTConfig
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
AST_TIME_DIM = 1024
|
| 13 |
+
AST_FREQ_DIM = 128
|
| 14 |
+
SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
|
| 15 |
+
SSLAM_TIME_DIM = 1024
|
| 16 |
+
SSLAM_FREQ_DIM = 128
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ASTEncoder(nn.Module):
|
| 20 |
+
"""Wraps ASTModel and returns the [CLS] token embedding."""
|
| 21 |
+
|
| 22 |
+
def __init__(self, model_name: str, freeze: bool = True):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
|
| 25 |
+
# print(f"AST hidden size: {self.ast.config.hidden_size}")
|
| 26 |
+
if freeze:
|
| 27 |
+
for p in self.ast.parameters():
|
| 28 |
+
p.requires_grad = False
|
| 29 |
+
|
| 30 |
+
def unfreeze_last_n(self, n: int = 2):
|
| 31 |
+
for block in self.ast.encoder.layer[-n:]:
|
| 32 |
+
for p in block.parameters():
|
| 33 |
+
p.requires_grad = True
|
| 34 |
+
for p in self.ast.layernorm.parameters():
|
| 35 |
+
p.requires_grad = True
|
| 36 |
+
# trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad)
|
| 37 |
+
# print(f"unfroze {n} blocks, trainable params: {trainable:,}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def _prep(mel: torch.Tensor) -> torch.Tensor:
|
| 43 |
+
"""mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]"""
|
| 44 |
+
x = mel.squeeze(1)
|
| 45 |
+
T = x.shape[1]
|
| 46 |
+
# print(f"input T={T}, target={AST_TIME_DIM}")
|
| 47 |
+
if T < AST_TIME_DIM:
|
| 48 |
+
pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype)
|
| 49 |
+
x = torch.cat([x, pad], dim=1)
|
| 50 |
+
elif T > AST_TIME_DIM:
|
| 51 |
+
x = x[:, :AST_TIME_DIM, :]
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
| 55 |
+
x = self._prep(mel)
|
| 56 |
+
out = self.ast(input_values=x)
|
| 57 |
+
# print(f"AST output shape: {out.last_hidden_state.shape}")
|
| 58 |
+
return out.last_hidden_state[:, 0, :]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class PairMaskHead(nn.Module):
|
| 62 |
+
"""Beat-by-beat pair matching head over two mel spectrograms."""
|
| 63 |
+
|
| 64 |
+
def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.pool = nn.AdaptiveAvgPool2d((beats_per_window, n_mels))
|
| 67 |
+
self.beat_proj = nn.Sequential(
|
| 68 |
+
nn.Linear(n_mels, beat_dim),
|
| 69 |
+
nn.GELU(),
|
| 70 |
+
nn.Linear(beat_dim, beat_dim),
|
| 71 |
+
)
|
| 72 |
+
self.logit_scale = nn.Parameter(torch.tensor(1.0))
|
| 73 |
+
self.bias = nn.Parameter(torch.zeros(()))
|
| 74 |
+
|
| 75 |
+
def _beats(self, mel: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
# mel: [B, 1, T, F] -> [B, beats, F] -> [B, beats, beat_dim]
|
| 77 |
+
x = self.pool(mel).squeeze(1)
|
| 78 |
+
return torch.nn.functional.normalize(self.beat_proj(x), dim=-1)
|
| 79 |
+
|
| 80 |
+
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
t = self._beats(track_mel)
|
| 82 |
+
o = self._beats(orig_mel)
|
| 83 |
+
return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class SampleDetector(nn.Module):
|
| 87 |
+
"""Siamese AST encoder + interaction head for binary sample detection."""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
model_name: str = os.environ["AST_MODEL"],
|
| 92 |
+
freeze_encoder: bool = True,
|
| 93 |
+
dropout: float = 0.3,
|
| 94 |
+
beats_per_window: int = 16,
|
| 95 |
+
n_mels: int = 128,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
|
| 99 |
+
H = self.encoder.ast.config.hidden_size
|
| 100 |
+
self.head = nn.Sequential(
|
| 101 |
+
nn.LayerNorm(4 * H),
|
| 102 |
+
nn.Linear(4 * H, 512),
|
| 103 |
+
nn.GELU(),
|
| 104 |
+
nn.Dropout(dropout),
|
| 105 |
+
nn.Linear(512, 128),
|
| 106 |
+
nn.GELU(),
|
| 107 |
+
nn.Dropout(dropout),
|
| 108 |
+
nn.Linear(128, 2),
|
| 109 |
+
)
|
| 110 |
+
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
|
| 111 |
+
|
| 112 |
+
def unfreeze_encoder(self, n_blocks: int = 2):
|
| 113 |
+
self.encoder.unfreeze_last_n(n_blocks)
|
| 114 |
+
|
| 115 |
+
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
t = self.encoder(track_mel)
|
| 117 |
+
o = self.encoder(orig_mel)
|
| 118 |
+
# print(f"embeddings: t={t.shape}, o={o.shape}")
|
| 119 |
+
combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
|
| 120 |
+
# print(f"combined shape: {combined.shape}")
|
| 121 |
+
return self.head(combined)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class ConvBlock(nn.Module):
|
| 125 |
+
def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.block = nn.Sequential(
|
| 128 |
+
nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
|
| 129 |
+
nn.BatchNorm2d(out_ch),
|
| 130 |
+
nn.GELU(),
|
| 131 |
+
nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
|
| 132 |
+
nn.BatchNorm2d(out_ch),
|
| 133 |
+
nn.GELU(),
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
return self.block(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class CNNEncoder(nn.Module):
|
| 141 |
+
def __init__(self, embed_dim: int = 256):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.net = nn.Sequential(
|
| 144 |
+
ConvBlock(1, 32),
|
| 145 |
+
ConvBlock(32, 64),
|
| 146 |
+
ConvBlock(64, 128),
|
| 147 |
+
ConvBlock(128, 256),
|
| 148 |
+
nn.AdaptiveAvgPool2d(1),
|
| 149 |
+
nn.Flatten(),
|
| 150 |
+
nn.Linear(256, embed_dim),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
| 154 |
+
return self.net(mel)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class CNNSampleDetector(nn.Module):
|
| 158 |
+
"""Drop-in CNN alternative to SampleDetector."""
|
| 159 |
+
|
| 160 |
+
def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.encoder = CNNEncoder(embed_dim)
|
| 163 |
+
self.head = nn.Sequential(
|
| 164 |
+
nn.LayerNorm(4 * embed_dim),
|
| 165 |
+
nn.Linear(4 * embed_dim, 256),
|
| 166 |
+
nn.GELU(),
|
| 167 |
+
nn.Dropout(dropout),
|
| 168 |
+
nn.Linear(256, 64),
|
| 169 |
+
nn.GELU(),
|
| 170 |
+
nn.Dropout(dropout),
|
| 171 |
+
nn.Linear(64, 2),
|
| 172 |
+
)
|
| 173 |
+
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
|
| 174 |
+
|
| 175 |
+
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
|
| 176 |
+
t = self.encoder(track_mel)
|
| 177 |
+
o = self.encoder(orig_mel)
|
| 178 |
+
combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
|
| 179 |
+
return self.head(combined)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class SSLAMEncoder(nn.Module):
|
| 184 |
+
"""Wraps the EAT (SSLAM) model and returns the CLS-like token embedding.
|
| 185 |
+
|
| 186 |
+
Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility
|
| 187 |
+
with EATModel's missing all_tied_weights_keys attribute.
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(self, freeze: bool = True):
|
| 191 |
+
super().__init__()
|
| 192 |
+
from transformers import AutoConfig
|
| 193 |
+
import safetensors.torch
|
| 194 |
+
from huggingface_hub import hf_hub_download
|
| 195 |
+
|
| 196 |
+
cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True)
|
| 197 |
+
self.hidden_size = cfg.embed_dim
|
| 198 |
+
sha = cfg._commit_hash or self._find_sha()
|
| 199 |
+
eat_mod = importlib.import_module(
|
| 200 |
+
f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model"
|
| 201 |
+
)
|
| 202 |
+
self.eat = eat_mod.EAT(cfg)
|
| 203 |
+
|
| 204 |
+
weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors")
|
| 205 |
+
raw = safetensors.torch.load_file(weights_path)
|
| 206 |
+
state = {k.removeprefix("model."): v for k, v in raw.items()}
|
| 207 |
+
self.eat.load_state_dict(state, strict=True)
|
| 208 |
+
if freeze:
|
| 209 |
+
for p in self.eat.parameters():
|
| 210 |
+
p.requires_grad = False
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def _find_sha() -> str:
|
| 214 |
+
dirs = glob.glob(
|
| 215 |
+
os.path.expanduser(
|
| 216 |
+
f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*"
|
| 217 |
+
)
|
| 218 |
+
)
|
| 219 |
+
dirs = [d for d in dirs if os.path.isdir(d)]
|
| 220 |
+
if not dirs:
|
| 221 |
+
raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first")
|
| 222 |
+
return os.path.basename(sorted(dirs)[-1])
|
| 223 |
+
|
| 224 |
+
def unfreeze_last_n(self, n: int):
|
| 225 |
+
for block in self.eat.blocks[-n:]:
|
| 226 |
+
for p in block.parameters():
|
| 227 |
+
p.requires_grad = True
|
| 228 |
+
|
| 229 |
+
for p in self.eat.pre_norm.parameters():
|
| 230 |
+
p.requires_grad = True
|
| 231 |
+
|
| 232 |
+
@staticmethod
|
| 233 |
+
def _prep(mel: torch.Tensor) -> torch.Tensor:
|
| 234 |
+
"""mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]"""
|
| 235 |
+
x = mel.float()
|
| 236 |
+
T = x.shape[2]
|
| 237 |
+
if T < SSLAM_TIME_DIM:
|
| 238 |
+
pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3],
|
| 239 |
+
device=x.device, dtype=x.dtype)
|
| 240 |
+
x = torch.cat([x, pad], dim=2)
|
| 241 |
+
elif T > SSLAM_TIME_DIM:
|
| 242 |
+
x = x[:, :, :SSLAM_TIME_DIM, :]
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def forward(self, mel: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
x = self._prep(mel)
|
| 247 |
+
feats = self.eat.extract_features(x)
|
| 248 |
+
# print(f"SSLAM features: {feats.shape}") # should be [B, 1+patches, 768]
|
| 249 |
+
return feats[:, 0, :]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class SSLAMSampleDetector(nn.Module):
|
| 254 |
+
"""SampleDetector using SSLAM/EAT encoder instead of AST."""
|
| 255 |
+
|
| 256 |
+
def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
|
| 257 |
+
super().__init__()
|
| 258 |
+
self.encoder = SSLAMEncoder(freeze=freeze_encoder)
|
| 259 |
+
H = self.encoder.hidden_size
|
| 260 |
+
self.head = nn.Sequential(
|
| 261 |
+
nn.LayerNorm(4 * H),
|
| 262 |
+
nn.Linear(4 * H, 512),
|
| 263 |
+
nn.GELU(),
|
| 264 |
+
nn.Dropout(dropout),
|
| 265 |
+
nn.Linear(512, 128),
|
| 266 |
+
nn.GELU(),
|
| 267 |
+
nn.Dropout(dropout),
|
| 268 |
+
nn.Linear(128, 2),
|
| 269 |
+
)
|
| 270 |
+
self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
|
| 271 |
+
|
| 272 |
+
def unfreeze_encoder(self, n_blocks: int):
|
| 273 |
+
self.encoder.unfreeze_last_n(n_blocks)
|
| 274 |
+
|
| 275 |
+
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
|
| 276 |
+
t = self.encoder(track_mel)
|
| 277 |
+
o = self.encoder(orig_mel)
|
| 278 |
+
combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
|
| 279 |
+
return self.head(combined)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class ContrastiveSampleDetector(nn.Module):
|
| 283 |
+
"""Siamese AST encoder + projection head trained with CosineEmbeddingLoss."""
|
| 284 |
+
|
| 285 |
+
def __init__(
|
| 286 |
+
self,
|
| 287 |
+
model_name: str = os.environ["AST_MODEL"],
|
| 288 |
+
freeze_encoder: bool = True,
|
| 289 |
+
proj_dim: int = 256,
|
| 290 |
+
dropout: float = 0.2,
|
| 291 |
+
):
|
| 292 |
+
super().__init__()
|
| 293 |
+
self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
|
| 294 |
+
H = self.encoder.ast.config.hidden_size
|
| 295 |
+
self.proj = nn.Sequential(
|
| 296 |
+
nn.Linear(H, 512),
|
| 297 |
+
nn.GELU(),
|
| 298 |
+
nn.Dropout(dropout),
|
| 299 |
+
nn.Linear(512, proj_dim),
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
def embed(self, mel: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
h = self.encoder(mel)
|
| 304 |
+
# print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}")
|
| 305 |
+
z = self.proj(h)
|
| 306 |
+
return torch.nn.functional.normalize(z, dim=-1)
|
| 307 |
+
|
| 308 |
+
def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple:
|
| 309 |
+
return self.embed(track_mel), self.embed(orig_mel)
|
| 310 |
+
|
| 311 |
+
def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
|
| 312 |
+
t, o = self.embed(track_mel), self.embed(orig_mel)
|
| 313 |
+
return (t * o).sum(dim=-1)
|
| 314 |
+
|
| 315 |
+
def unfreeze_encoder(self, n_blocks: int = 2):
|
| 316 |
+
self.encoder.unfreeze_last_n(n_blocks)
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0
|
| 2 |
+
matplotlib>=3.8
|
| 3 |
+
torch>=2.5
|
| 4 |
+
torchaudio>=2.5
|
| 5 |
+
accelerate==1.13.0
|
| 6 |
+
python-dotenv==1.2.2
|
| 7 |
+
safetensors==0.7.0
|
| 8 |
+
audiomentations==0.43.1
|
| 9 |
+
av==17.0.0
|
| 10 |
+
huggingface-hub==1.10.1
|
| 11 |
+
librosa==0.11.0
|
| 12 |
+
numpy==2.4.4
|
| 13 |
+
scikit-learn==1.8.0
|
| 14 |
+
scipy==1.17.1
|
| 15 |
+
soundfile==0.13.1
|
| 16 |
+
transformers==5.5.4
|
| 17 |
+
yt-dlp==2026.3.17
|