dayngerous commited on
Commit
4242909
·
1 Parent(s): dde4389

Add Gradio app, model code, and deps — checkpoint downloads from dayngerous/whoSampledAST

Browse files
Files changed (4) hide show
  1. README.md +24 -7
  2. app.py +597 -0
  3. model.py +316 -0
  4. requirements.txt +17 -0
README.md CHANGED
@@ -1,13 +1,30 @@
1
  ---
2
- title: Sampled
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
  app_file: app.py
9
  pinned: false
10
- short_description: Detect if a sample is in another song
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Sample Match Verifier
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: "5.0"
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
+ # Sample Match Verifier
14
+
15
+ Upload a track and a possible source sample. Waveforms appear immediately on upload. Click **Verify match** to run the model — it scans beat-aligned windows, scores the best match, and highlights the predicted sampled sections on both the waveform and mel spectrogram. If no confident match is found, the mel spectrogram shows a **No Match** overlay.
16
+
17
+ ## Model checkpoint
18
+
19
+ Place your checkpoint at `models/best.pt` (committed via Git LFS) or set the `MODEL_CHECKPOINT` environment variable to its path. The app falls back to `checkpoints/best.pt` if `models/best.pt` is not found.
20
+
21
+ ## Environment variables
22
+
23
+ | Variable | Default | Description |
24
+ |---|---|---|
25
+ | `MODEL_CHECKPOINT` | `models/best.pt` | Path to the `.pt` checkpoint |
26
+ | `MODEL_BACKBONE` | `ast` | Backbone: `ast`, `sslam`, or `cnn` |
27
+ | `AST_MODEL` | `MIT/ast-finetuned-audioset-10-10-0.4593` | HuggingFace AST model ID |
28
+ | `MODEL_BARS` | `4` | Bars per analysis window |
29
+ | `MODEL_N_MELS` | `128` | Mel frequency bins |
30
+ | `APP_SAMPLE_RATE` | `16000` | Audio sample rate |
app.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+
7
+ os.environ.setdefault("AST_MODEL", "MIT/ast-finetuned-audioset-10-10-0.4593")
8
+ os.environ.setdefault("SSLAM_MODEL", "ta012/SSLAM_pretrain")
9
+
10
+ import gradio as gr
11
+ import librosa
12
+ import matplotlib
13
+ import numpy as np
14
+ import torch
15
+ import torchaudio.transforms as T
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ matplotlib.use("Agg")
19
+ import matplotlib.pyplot as plt
20
+
21
+ from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector
22
+
23
+
24
+ SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000"))
25
+ MODEL_REPO = os.environ.get("MODEL_REPO", "dayngerous/whoSampledAST")
26
+
27
+
28
+ def _resolve_checkpoint() -> str:
29
+ """Return local checkpoint path, downloading from HF Hub if needed."""
30
+ env_path = os.environ.get("MODEL_CHECKPOINT", "")
31
+ for p in [env_path, "models/best.pt", "checkpoints/best.pt", "checkpoints2/best.pt"]:
32
+ if p and Path(p).exists():
33
+ return p
34
+ try:
35
+ return hf_hub_download(repo_id=MODEL_REPO, filename="models/best.pt")
36
+ except Exception as exc:
37
+ raise FileNotFoundError(
38
+ f"No local checkpoint found and download from {MODEL_REPO} failed: {exc}"
39
+ )
40
+
41
+
42
+ def _resolve_meta() -> str:
43
+ """Return local test_indices.json path, downloading from HF Hub if needed."""
44
+ for p in ["models/test_indices.json", "checkpoints2/test_indices.json", "checkpoints/test_indices.json"]:
45
+ if Path(p).exists():
46
+ return p
47
+ try:
48
+ return hf_hub_download(repo_id=MODEL_REPO, filename="models/test_indices.json")
49
+ except Exception:
50
+ return ""
51
+
52
+
53
+ DEFAULT_CHECKPOINT = _resolve_checkpoint()
54
+ DEFAULT_META = DEFAULT_META or _resolve_meta()
55
+ TARGET_FRAMES_PER_BEAT = 50
56
+ N_FFT = 1024
57
+ MEL_HOP = 512
58
+ N_MELS_VIZ = 128
59
+
60
+
61
+ @dataclass
62
+ class AudioClip:
63
+ waveform: torch.Tensor
64
+ sample_rate: int
65
+ offset_sec: float
66
+ duration_sec: float
67
+
68
+
69
+ @dataclass
70
+ class BeatWindow:
71
+ waveform: torch.Tensor
72
+ start_sec: float
73
+ end_sec: float
74
+ beat_intervals: list[tuple[float, float]]
75
+
76
+
77
+ def _format_time(seconds: float) -> str:
78
+ seconds = max(0.0, float(seconds))
79
+ minutes = int(seconds // 60)
80
+ rem = seconds - minutes * 60
81
+ return f"{minutes}:{rem:04.1f}"
82
+
83
+
84
+ def _format_intervals(intervals: list[tuple[float, float]], limit: int = 4) -> str:
85
+ if not intervals:
86
+ return "none"
87
+ shown = ", ".join(f"{_format_time(a)}-{_format_time(b)}" for a, b in intervals[:limit])
88
+ if len(intervals) > limit:
89
+ shown += f", +{len(intervals) - limit} more"
90
+ return shown
91
+
92
+
93
+ def _merge_intervals(intervals: list[tuple[float, float]], gap: float = 0.05) -> list[tuple[float, float]]:
94
+ if not intervals:
95
+ return []
96
+ ordered = sorted((float(a), float(b)) for a, b in intervals if b > a)
97
+ merged = [ordered[0]]
98
+ for start, end in ordered[1:]:
99
+ prev_start, prev_end = merged[-1]
100
+ if start <= prev_end + gap:
101
+ merged[-1] = (prev_start, max(prev_end, end))
102
+ else:
103
+ merged.append((start, end))
104
+ return merged
105
+
106
+
107
+ def _load_args(checkpoint_path: Path) -> dict:
108
+ meta_path = Path(DEFAULT_META) if DEFAULT_META else checkpoint_path.parent / "test_indices.json"
109
+ args = {}
110
+ if meta_path.exists():
111
+ with open(meta_path) as f:
112
+ args = json.load(f).get("args", {})
113
+
114
+ args.setdefault("backbone", os.environ.get("MODEL_BACKBONE", "ast"))
115
+ args.setdefault("ast_model", os.environ.get("AST_MODEL"))
116
+ args.setdefault("bars", int(os.environ.get("MODEL_BARS", "4")))
117
+ args.setdefault("n_mels", int(os.environ.get("MODEL_N_MELS", "128")))
118
+ args.setdefault("sample_rate", SAMPLE_RATE)
119
+ return args
120
+
121
+
122
+ def _build_model(args: dict, device: torch.device):
123
+ beats_per_window = int(args.get("bars", 4)) * 4
124
+ n_mels = int(args.get("n_mels", 128))
125
+ backbone = args.get("backbone", "ast")
126
+ if backbone == "ast":
127
+ model = SampleDetector(
128
+ model_name=args.get("ast_model", os.environ["AST_MODEL"]),
129
+ freeze_encoder=True,
130
+ beats_per_window=beats_per_window,
131
+ n_mels=n_mels,
132
+ )
133
+ elif backbone == "sslam":
134
+ model = SSLAMSampleDetector(
135
+ freeze_encoder=True,
136
+ beats_per_window=beats_per_window,
137
+ n_mels=n_mels,
138
+ )
139
+ else:
140
+ model = CNNSampleDetector(beats_per_window=beats_per_window, n_mels=n_mels)
141
+ return model.to(device)
142
+
143
+
144
+ @lru_cache(maxsize=2)
145
+ def _load_model(checkpoint_path: str):
146
+ path = Path(checkpoint_path)
147
+ if not path.exists():
148
+ raise FileNotFoundError(
149
+ f"Checkpoint not found: {path}. Set MODEL_CHECKPOINT or place a checkpoint at models/best.pt."
150
+ )
151
+
152
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153
+ args = _load_args(path)
154
+ model = _build_model(args, device)
155
+ ckpt = torch.load(path, map_location=device)
156
+ state = ckpt.get("model_state", ckpt)
157
+ pair_head_loaded = any(k.startswith("pair_mask_head.") for k in state)
158
+ missing, unexpected = model.load_state_dict(state, strict=False)
159
+ model.eval()
160
+ return {
161
+ "model": model,
162
+ "args": args,
163
+ "device": device,
164
+ "epoch": ckpt.get("epoch", "?"),
165
+ "pair_head_loaded": pair_head_loaded,
166
+ "missing": missing,
167
+ "unexpected": unexpected,
168
+ }
169
+
170
+
171
+ def _load_audio(path: str, offset_sec: float, max_seconds: float) -> AudioClip:
172
+ if not path:
173
+ raise gr.Error("Upload both audio files before running verification.")
174
+
175
+ audio, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True)
176
+ waveform = torch.from_numpy(audio).float()
177
+
178
+ offset_sec = max(0.0, float(offset_sec or 0.0))
179
+ max_seconds = max(1.0, float(max_seconds or 1.0))
180
+ start = min(int(offset_sec * sr), max(waveform.numel() - 1, 0))
181
+ end = min(start + int(max_seconds * sr), waveform.numel())
182
+ waveform = waveform[start:end].float().contiguous()
183
+ if waveform.numel() < sr // 4:
184
+ raise gr.Error("Each upload must contain at least 0.25 seconds of audio after offset trimming.")
185
+
186
+ peak = waveform.abs().max().clamp_min(1e-6)
187
+ waveform = waveform / peak
188
+ return AudioClip(
189
+ waveform=waveform,
190
+ sample_rate=sr,
191
+ offset_sec=offset_sec,
192
+ duration_sec=waveform.numel() / sr,
193
+ )
194
+
195
+
196
+ def _estimate_beats(waveform: torch.Tensor, sample_rate: int) -> tuple[float, np.ndarray]:
197
+ y = waveform.detach().cpu().numpy().astype(np.float32)
198
+ tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sample_rate, hop_length=512)
199
+ bpm = float(np.atleast_1d(tempo)[0]) if np.size(tempo) else 120.0
200
+ if not np.isfinite(bpm) or bpm <= 0:
201
+ bpm = 120.0
202
+ bpm = float(np.clip(bpm, 60.0, 200.0))
203
+
204
+ beat_samples = librosa.frames_to_samples(beat_frames, hop_length=512)
205
+ beat_samples = beat_samples[(beat_samples >= 0) & (beat_samples < waveform.numel())]
206
+ if len(beat_samples) < 2:
207
+ step = max(1, int(round(sample_rate * 60.0 / bpm)))
208
+ beat_samples = np.arange(0, waveform.numel(), step, dtype=np.int64)
209
+ elif beat_samples[0] > sample_rate * 60.0 / bpm:
210
+ beat_samples = np.insert(beat_samples, 0, 0)
211
+ return bpm, beat_samples.astype(np.int64)
212
+
213
+
214
+ def _to_mel(waveform: torch.Tensor, bpm: float, args: dict) -> torch.Tensor:
215
+ sample_rate = int(args.get("sample_rate", SAMPLE_RATE))
216
+ n_mels = int(args.get("n_mels", 128))
217
+ bars = int(args.get("bars", 4))
218
+ fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
219
+ hop = max(1, round(60 * sample_rate / (bpm * TARGET_FRAMES_PER_BEAT)))
220
+
221
+ mel_transform = T.MelSpectrogram(
222
+ sample_rate=sample_rate,
223
+ n_fft=N_FFT,
224
+ hop_length=hop,
225
+ n_mels=n_mels,
226
+ power=2.0,
227
+ )
228
+ amp_to_db = T.AmplitudeToDB(stype="power", top_db=80)
229
+ mel = amp_to_db(mel_transform(waveform)).T
230
+ if mel.shape[0] > fixed_frames:
231
+ mel = mel[:fixed_frames]
232
+ elif mel.shape[0] < fixed_frames:
233
+ mel = torch.cat([mel, torch.zeros(fixed_frames - mel.shape[0], mel.shape[1])], dim=0)
234
+ mel = (mel - mel.mean()) / (mel.std() + 1e-6)
235
+ return mel.unsqueeze(0)
236
+
237
+
238
+ def _make_windows(
239
+ clip: AudioClip,
240
+ bpm: float,
241
+ beat_samples: np.ndarray,
242
+ args: dict,
243
+ stride_beats: int,
244
+ max_windows: int,
245
+ ) -> list[BeatWindow]:
246
+ bars = int(args.get("bars", 4))
247
+ beats_per_window = bars * 4
248
+ window_samples = max(1, int(round(beats_per_window * 60.0 / bpm * clip.sample_rate)))
249
+ beat_seconds = 60.0 / bpm
250
+ stride_beats = max(1, int(stride_beats))
251
+ max_windows = max(1, int(max_windows))
252
+
253
+ valid = [i for i in range(0, len(beat_samples), stride_beats) if beat_samples[i] < clip.waveform.numel()]
254
+ if not valid:
255
+ valid = [0]
256
+
257
+ if len(valid) > max_windows:
258
+ chosen_positions = np.linspace(0, len(valid) - 1, max_windows, dtype=np.int64)
259
+ valid = [valid[i] for i in sorted(set(chosen_positions.tolist()))]
260
+
261
+ windows = []
262
+ for beat_idx in valid:
263
+ start_sample = int(beat_samples[beat_idx]) if len(beat_samples) else 0
264
+ chunk = clip.waveform[start_sample:start_sample + window_samples]
265
+ if chunk.numel() < window_samples:
266
+ chunk = torch.nn.functional.pad(chunk, (0, window_samples - chunk.numel()))
267
+
268
+ start_sec = clip.offset_sec + start_sample / clip.sample_rate
269
+ end_sec = start_sec + window_samples / clip.sample_rate
270
+ beat_intervals = [
271
+ (start_sec + i * beat_seconds, start_sec + (i + 1) * beat_seconds)
272
+ for i in range(beats_per_window)
273
+ ]
274
+ windows.append(BeatWindow(chunk, start_sec, end_sec, beat_intervals))
275
+ return windows
276
+
277
+
278
+ def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor:
279
+ embs = []
280
+ for start in range(0, mels.shape[0], batch_size):
281
+ embs.append(model.encoder(mels[start:start + batch_size]))
282
+ return torch.cat(embs, dim=0)
283
+
284
+
285
+ def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor:
286
+ track_emb = _encode(model, track_mels, batch_size)
287
+ source_emb = _encode(model, source_mels, batch_size)
288
+ n_track, n_source = track_emb.shape[0], source_emb.shape[0]
289
+ scores = []
290
+
291
+ pair_indices = [(i, j) for i in range(n_track) for j in range(n_source)]
292
+ for start in range(0, len(pair_indices), batch_size):
293
+ chunk = pair_indices[start:start + batch_size]
294
+ ti = torch.tensor([p[0] for p in chunk], device=track_emb.device)
295
+ sj = torch.tensor([p[1] for p in chunk], device=track_emb.device)
296
+ t = track_emb.index_select(0, ti)
297
+ s = source_emb.index_select(0, sj)
298
+ combined = torch.cat([t, s, torch.abs(t - s), t * s], dim=-1)
299
+ logits = model.head(combined)
300
+ scores.append(torch.softmax(logits, dim=-1)[:, 1])
301
+
302
+ return torch.cat(scores).reshape(n_track, n_source)
303
+
304
+
305
+ def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]:
306
+ intervals = []
307
+ for use, (start, end) in zip(mask.tolist(), window.beat_intervals):
308
+ if use:
309
+ intervals.append((start, min(end, max_end)))
310
+ return _merge_intervals(intervals)
311
+
312
+
313
+ def _localize_match(
314
+ model,
315
+ track_mel: torch.Tensor,
316
+ source_mel: torch.Tensor,
317
+ track_window: BeatWindow,
318
+ source_window: BeatWindow,
319
+ track_clip: AudioClip,
320
+ source_clip: AudioClip,
321
+ threshold: float,
322
+ pair_head_loaded: bool,
323
+ ) -> tuple[list[tuple[float, float]], list[tuple[float, float]], str]:
324
+ if not pair_head_loaded:
325
+ return (
326
+ [(track_window.start_sec, min(track_window.end_sec, track_clip.offset_sec + track_clip.duration_sec))],
327
+ [(source_window.start_sec, min(source_window.end_sec, source_clip.offset_sec + source_clip.duration_sec))],
328
+ "The checkpoint does not include a trained pairwise beat head, so the highlight covers the best matching window.",
329
+ )
330
+
331
+ with torch.inference_mode():
332
+ pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
333
+
334
+ selected = pair_probs >= float(threshold)
335
+ if not selected.any():
336
+ top_k = min(6, pair_probs.size)
337
+ flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
338
+ selected = np.zeros_like(pair_probs, dtype=bool)
339
+ selected.reshape(-1)[flat] = True
340
+
341
+ track_mask = selected.any(axis=1)
342
+ source_mask = selected.any(axis=0)
343
+ track_regions = _intervals_from_mask(
344
+ track_mask,
345
+ track_window,
346
+ track_clip.offset_sec + track_clip.duration_sec,
347
+ )
348
+ source_regions = _intervals_from_mask(
349
+ source_mask,
350
+ source_window,
351
+ source_clip.offset_sec + source_clip.duration_sec,
352
+ )
353
+ return track_regions, source_regions, ""
354
+
355
+
356
+ def _draw_waveform(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str):
357
+ y = clip.waveform.detach().cpu().numpy()
358
+ n = len(y)
359
+ points = min(20000, n)
360
+ idx = np.linspace(0, n - 1, points, dtype=np.int64)
361
+ x = clip.offset_sec + idx / clip.sample_rate
362
+ ax.plot(x, y[idx], color="#111827", linewidth=0.55)
363
+ for start, end in regions:
364
+ ax.axvspan(start, end, color=color, alpha=0.28)
365
+ ax.set_title(title, loc="left", fontsize=10)
366
+ ax.set_ylabel("Amplitude")
367
+ ax.set_xlim(clip.offset_sec, clip.offset_sec + clip.duration_sec)
368
+ ax.set_ylim(-1.05, 1.05)
369
+ ax.grid(True, alpha=0.18)
370
+
371
+
372
+ def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str, matched: bool):
373
+ y = clip.waveform.detach().cpu().numpy().astype(np.float32)
374
+ mel = librosa.feature.melspectrogram(y=y, sr=clip.sample_rate, n_mels=N_MELS_VIZ, hop_length=MEL_HOP)
375
+ mel_db = librosa.power_to_db(mel, ref=np.max)
376
+
377
+ t_start = clip.offset_sec
378
+ t_end = clip.offset_sec + clip.duration_sec
379
+ f_max = clip.sample_rate / 2
380
+
381
+ ax.imshow(
382
+ mel_db,
383
+ aspect="auto",
384
+ origin="lower",
385
+ extent=[t_start, t_end, 0, f_max],
386
+ cmap="magma",
387
+ interpolation="nearest",
388
+ )
389
+ ax.set_title(title, loc="left", fontsize=10)
390
+ ax.set_ylabel("Frequency (Hz)")
391
+ ax.set_xlim(t_start, t_end)
392
+
393
+ if matched and regions:
394
+ for start, end in regions:
395
+ ax.axvspan(start, end, color=color, alpha=0.38, linewidth=0)
396
+ elif not matched:
397
+ ax.text(
398
+ 0.5, 0.5, "No Match",
399
+ transform=ax.transAxes,
400
+ fontsize=18,
401
+ color="white",
402
+ ha="center",
403
+ va="center",
404
+ fontweight="bold",
405
+ bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65),
406
+ )
407
+
408
+
409
+ def _plot_waveforms(
410
+ track_clip: AudioClip,
411
+ source_clip: AudioClip,
412
+ track_regions: list[tuple[float, float]],
413
+ source_regions: list[tuple[float, float]],
414
+ score: float | None,
415
+ matched: bool,
416
+ ) -> plt.Figure:
417
+ fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=False)
418
+ color = "#22c55e" if matched else "#f59e0b"
419
+ title_score = "unavailable" if score is None else f"{score:.3f}"
420
+ fig.suptitle(f"Best match score: {title_score}" if score is not None else "Waveform preview", fontsize=12)
421
+
422
+ _draw_waveform(axes[0], track_clip, track_regions, color, "Track / song audio")
423
+ _draw_waveform(axes[1], source_clip, source_regions, color, "Source sample audio")
424
+ axes[1].set_xlabel("Time in uploaded file (seconds)")
425
+ fig.tight_layout()
426
+ return fig
427
+
428
+
429
+ def _plot_mels(
430
+ track_clip: AudioClip,
431
+ source_clip: AudioClip,
432
+ track_regions: list[tuple[float, float]],
433
+ source_regions: list[tuple[float, float]],
434
+ matched: bool,
435
+ ) -> plt.Figure:
436
+ fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=False)
437
+ color = "#22c55e" if matched else "#f59e0b"
438
+
439
+ _draw_mel(axes[0], track_clip, track_regions, color, "Track mel spectrogram", matched)
440
+ _draw_mel(axes[1], source_clip, source_regions, color, "Source mel spectrogram", matched)
441
+ axes[1].set_xlabel("Time in uploaded file (seconds)")
442
+ fig.tight_layout()
443
+ return fig
444
+
445
+
446
+ def preview_waveforms(track_audio, source_audio):
447
+ if not track_audio or not source_audio:
448
+ return None, None
449
+ try:
450
+ track_clip = _load_audio(track_audio, 0.0, 120.0)
451
+ source_clip = _load_audio(source_audio, 0.0, 120.0)
452
+ wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
453
+ mfig = _plot_mels(track_clip, source_clip, [], [], False)
454
+ return wfig, mfig
455
+ except Exception:
456
+ return None, None
457
+
458
+
459
+ def verify(
460
+ track_audio,
461
+ source_audio,
462
+ checkpoint_path,
463
+ match_threshold,
464
+ localization_threshold,
465
+ track_offset,
466
+ source_offset,
467
+ max_seconds,
468
+ stride_beats,
469
+ max_windows,
470
+ ):
471
+ try:
472
+ track_clip = _load_audio(track_audio, track_offset, max_seconds)
473
+ source_clip = _load_audio(source_audio, source_offset, max_seconds)
474
+ except Exception as exc:
475
+ raise gr.Error(str(exc))
476
+
477
+ try:
478
+ loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
479
+ except Exception as exc:
480
+ wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
481
+ mfig = _plot_mels(track_clip, source_clip, [], [], False)
482
+ return f"Model could not be loaded: {exc}", wfig, mfig
483
+
484
+ model = loaded["model"]
485
+ args = loaded["args"]
486
+ device = loaded["device"]
487
+ batch_size = 8 if device.type == "cpu" else 32
488
+
489
+ track_bpm, track_beats = _estimate_beats(track_clip.waveform, track_clip.sample_rate)
490
+ source_bpm, source_beats = _estimate_beats(source_clip.waveform, source_clip.sample_rate)
491
+ track_windows = _make_windows(track_clip, track_bpm, track_beats, args, stride_beats, max_windows)
492
+ source_windows = _make_windows(source_clip, source_bpm, source_beats, args, stride_beats, max_windows)
493
+
494
+ track_mels = torch.stack([_to_mel(w.waveform, track_bpm, args) for w in track_windows]).to(device)
495
+ source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device)
496
+
497
+ with torch.inference_mode():
498
+ score_matrix = _score_pairs(model, track_mels, source_mels, batch_size)
499
+ best_flat = int(torch.argmax(score_matrix).item())
500
+ best_track = best_flat // score_matrix.shape[1]
501
+ best_source = best_flat % score_matrix.shape[1]
502
+ best_score = float(score_matrix[best_track, best_source].detach().cpu())
503
+ matched = best_score >= float(match_threshold)
504
+
505
+ track_regions, source_regions, note = _localize_match(
506
+ model,
507
+ track_mels[best_track:best_track + 1],
508
+ source_mels[best_source:best_source + 1],
509
+ track_windows[best_track],
510
+ source_windows[best_source],
511
+ track_clip,
512
+ source_clip,
513
+ localization_threshold,
514
+ loaded["pair_head_loaded"],
515
+ )
516
+
517
+ highlight_track = track_regions if matched else []
518
+ highlight_source = source_regions if matched else []
519
+
520
+ wfig = _plot_waveforms(track_clip, source_clip, highlight_track, highlight_source, best_score, matched)
521
+ mfig = _plot_mels(track_clip, source_clip, highlight_track, highlight_source, matched)
522
+
523
+ verdict = "Likely match" if matched else "No confident match"
524
+ details = [
525
+ f"**{verdict}**",
526
+ f"Score: `{best_score:.3f}` with threshold `{float(match_threshold):.2f}`.",
527
+ f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.",
528
+ f"Highlighted track section(s): {_format_intervals(highlight_track)}.",
529
+ f"Highlighted source section(s): {_format_intervals(highlight_source)}.",
530
+ f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
531
+ ]
532
+ if note:
533
+ details.append(note)
534
+ if loaded["missing"]:
535
+ details.append(f"Missing checkpoint keys initialized at load time: `{len(loaded['missing'])}`.")
536
+ return "\n\n".join(details), wfig, mfig
537
+
538
+
539
+ with gr.Blocks(title="Sample Match Verifier") as demo:
540
+ gr.Markdown("# Sample Match Verifier")
541
+ gr.Markdown(
542
+ "Upload a track and a possible source sample. "
543
+ "Waveforms appear immediately on upload. "
544
+ "Click **Verify match** to run the model and highlight sampled sections."
545
+ )
546
+
547
+ with gr.Row():
548
+ track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
549
+ source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
550
+
551
+ with gr.Accordion("Settings", open=False):
552
+ checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
553
+ with gr.Row():
554
+ match_threshold = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Match threshold")
555
+ localization_threshold = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Highlight threshold")
556
+ with gr.Row():
557
+ track_offset = gr.Number(value=0.0, label="Track start offset, seconds")
558
+ source_offset = gr.Number(value=0.0, label="Source start offset, seconds")
559
+ with gr.Row():
560
+ max_seconds = gr.Slider(5, 180, value=60, step=5, label="Analyze duration per upload, seconds")
561
+ stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
562
+ max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
563
+
564
+ run = gr.Button("Verify match", variant="primary")
565
+ result = gr.Markdown()
566
+
567
+ waveform_plot = gr.Plot(label="Waveforms")
568
+ mel_plot = gr.Plot(label="Mel Spectrograms")
569
+
570
+ # Show waveforms as soon as both files are uploaded
571
+ for audio_input in [track_audio, source_audio]:
572
+ audio_input.change(
573
+ preview_waveforms,
574
+ inputs=[track_audio, source_audio],
575
+ outputs=[waveform_plot, mel_plot],
576
+ )
577
+
578
+ run.click(
579
+ verify,
580
+ inputs=[
581
+ track_audio,
582
+ source_audio,
583
+ checkpoint_path,
584
+ match_threshold,
585
+ localization_threshold,
586
+ track_offset,
587
+ source_offset,
588
+ max_seconds,
589
+ stride_beats,
590
+ max_windows,
591
+ ],
592
+ outputs=[result, waveform_plot, mel_plot],
593
+ )
594
+
595
+
596
+ if __name__ == "__main__":
597
+ demo.queue(max_size=8).launch()
model.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import importlib
3
+ import os
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from dotenv import load_dotenv
8
+ from transformers import ASTModel, ASTConfig
9
+
10
+ load_dotenv()
11
+
12
+ AST_TIME_DIM = 1024
13
+ AST_FREQ_DIM = 128
14
+ SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
15
+ SSLAM_TIME_DIM = 1024
16
+ SSLAM_FREQ_DIM = 128
17
+
18
+
19
+ class ASTEncoder(nn.Module):
20
+ """Wraps ASTModel and returns the [CLS] token embedding."""
21
+
22
+ def __init__(self, model_name: str, freeze: bool = True):
23
+ super().__init__()
24
+ self.ast = ASTModel.from_pretrained(model_name, ignore_mismatched_sizes=True)
25
+ # print(f"AST hidden size: {self.ast.config.hidden_size}")
26
+ if freeze:
27
+ for p in self.ast.parameters():
28
+ p.requires_grad = False
29
+
30
+ def unfreeze_last_n(self, n: int = 2):
31
+ for block in self.ast.encoder.layer[-n:]:
32
+ for p in block.parameters():
33
+ p.requires_grad = True
34
+ for p in self.ast.layernorm.parameters():
35
+ p.requires_grad = True
36
+ # trainable = sum(p.numel() for p in self.ast.parameters() if p.requires_grad)
37
+ # print(f"unfroze {n} blocks, trainable params: {trainable:,}")
38
+
39
+
40
+
41
+ @staticmethod
42
+ def _prep(mel: torch.Tensor) -> torch.Tensor:
43
+ """mel: [B, 1, T, F] => [B, AST_TIME_DIM, AST_FREQ_DIM]"""
44
+ x = mel.squeeze(1)
45
+ T = x.shape[1]
46
+ # print(f"input T={T}, target={AST_TIME_DIM}")
47
+ if T < AST_TIME_DIM:
48
+ pad = torch.zeros(x.shape[0], AST_TIME_DIM - T, x.shape[2], device=x.device, dtype=x.dtype)
49
+ x = torch.cat([x, pad], dim=1)
50
+ elif T > AST_TIME_DIM:
51
+ x = x[:, :AST_TIME_DIM, :]
52
+ return x
53
+
54
+ def forward(self, mel: torch.Tensor) -> torch.Tensor:
55
+ x = self._prep(mel)
56
+ out = self.ast(input_values=x)
57
+ # print(f"AST output shape: {out.last_hidden_state.shape}")
58
+ return out.last_hidden_state[:, 0, :]
59
+
60
+
61
+ class PairMaskHead(nn.Module):
62
+ """Beat-by-beat pair matching head over two mel spectrograms."""
63
+
64
+ def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64):
65
+ super().__init__()
66
+ self.pool = nn.AdaptiveAvgPool2d((beats_per_window, n_mels))
67
+ self.beat_proj = nn.Sequential(
68
+ nn.Linear(n_mels, beat_dim),
69
+ nn.GELU(),
70
+ nn.Linear(beat_dim, beat_dim),
71
+ )
72
+ self.logit_scale = nn.Parameter(torch.tensor(1.0))
73
+ self.bias = nn.Parameter(torch.zeros(()))
74
+
75
+ def _beats(self, mel: torch.Tensor) -> torch.Tensor:
76
+ # mel: [B, 1, T, F] -> [B, beats, F] -> [B, beats, beat_dim]
77
+ x = self.pool(mel).squeeze(1)
78
+ return torch.nn.functional.normalize(self.beat_proj(x), dim=-1)
79
+
80
+ def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
81
+ t = self._beats(track_mel)
82
+ o = self._beats(orig_mel)
83
+ return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias
84
+
85
+
86
+ class SampleDetector(nn.Module):
87
+ """Siamese AST encoder + interaction head for binary sample detection."""
88
+
89
+ def __init__(
90
+ self,
91
+ model_name: str = os.environ["AST_MODEL"],
92
+ freeze_encoder: bool = True,
93
+ dropout: float = 0.3,
94
+ beats_per_window: int = 16,
95
+ n_mels: int = 128,
96
+ ):
97
+ super().__init__()
98
+ self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
99
+ H = self.encoder.ast.config.hidden_size
100
+ self.head = nn.Sequential(
101
+ nn.LayerNorm(4 * H),
102
+ nn.Linear(4 * H, 512),
103
+ nn.GELU(),
104
+ nn.Dropout(dropout),
105
+ nn.Linear(512, 128),
106
+ nn.GELU(),
107
+ nn.Dropout(dropout),
108
+ nn.Linear(128, 2),
109
+ )
110
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
111
+
112
+ def unfreeze_encoder(self, n_blocks: int = 2):
113
+ self.encoder.unfreeze_last_n(n_blocks)
114
+
115
+ def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
116
+ t = self.encoder(track_mel)
117
+ o = self.encoder(orig_mel)
118
+ # print(f"embeddings: t={t.shape}, o={o.shape}")
119
+ combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
120
+ # print(f"combined shape: {combined.shape}")
121
+ return self.head(combined)
122
+
123
+
124
+ class ConvBlock(nn.Module):
125
+ def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
126
+ super().__init__()
127
+ self.block = nn.Sequential(
128
+ nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
129
+ nn.BatchNorm2d(out_ch),
130
+ nn.GELU(),
131
+ nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
132
+ nn.BatchNorm2d(out_ch),
133
+ nn.GELU(),
134
+ )
135
+
136
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
137
+ return self.block(x)
138
+
139
+
140
+ class CNNEncoder(nn.Module):
141
+ def __init__(self, embed_dim: int = 256):
142
+ super().__init__()
143
+ self.net = nn.Sequential(
144
+ ConvBlock(1, 32),
145
+ ConvBlock(32, 64),
146
+ ConvBlock(64, 128),
147
+ ConvBlock(128, 256),
148
+ nn.AdaptiveAvgPool2d(1),
149
+ nn.Flatten(),
150
+ nn.Linear(256, embed_dim),
151
+ )
152
+
153
+ def forward(self, mel: torch.Tensor) -> torch.Tensor:
154
+ return self.net(mel)
155
+
156
+
157
+ class CNNSampleDetector(nn.Module):
158
+ """Drop-in CNN alternative to SampleDetector."""
159
+
160
+ def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
161
+ super().__init__()
162
+ self.encoder = CNNEncoder(embed_dim)
163
+ self.head = nn.Sequential(
164
+ nn.LayerNorm(4 * embed_dim),
165
+ nn.Linear(4 * embed_dim, 256),
166
+ nn.GELU(),
167
+ nn.Dropout(dropout),
168
+ nn.Linear(256, 64),
169
+ nn.GELU(),
170
+ nn.Dropout(dropout),
171
+ nn.Linear(64, 2),
172
+ )
173
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
174
+
175
+ def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
176
+ t = self.encoder(track_mel)
177
+ o = self.encoder(orig_mel)
178
+ combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
179
+ return self.head(combined)
180
+
181
+
182
+
183
+ class SSLAMEncoder(nn.Module):
184
+ """Wraps the EAT (SSLAM) model and returns the CLS-like token embedding.
185
+
186
+ Bypasses AutoModel.from_pretrained due to a transformers >= 5.5 incompatibility
187
+ with EATModel's missing all_tied_weights_keys attribute.
188
+ """
189
+
190
+ def __init__(self, freeze: bool = True):
191
+ super().__init__()
192
+ from transformers import AutoConfig
193
+ import safetensors.torch
194
+ from huggingface_hub import hf_hub_download
195
+
196
+ cfg = AutoConfig.from_pretrained(SSLAM_HF_REPO, trust_remote_code=True)
197
+ self.hidden_size = cfg.embed_dim
198
+ sha = cfg._commit_hash or self._find_sha()
199
+ eat_mod = importlib.import_module(
200
+ f"transformers_modules.ta012.SSLAM_pretrain.{sha}.eat_model"
201
+ )
202
+ self.eat = eat_mod.EAT(cfg)
203
+
204
+ weights_path = hf_hub_download(SSLAM_HF_REPO, "model.safetensors")
205
+ raw = safetensors.torch.load_file(weights_path)
206
+ state = {k.removeprefix("model."): v for k, v in raw.items()}
207
+ self.eat.load_state_dict(state, strict=True)
208
+ if freeze:
209
+ for p in self.eat.parameters():
210
+ p.requires_grad = False
211
+
212
+ @staticmethod
213
+ def _find_sha() -> str:
214
+ dirs = glob.glob(
215
+ os.path.expanduser(
216
+ f"~/.cache/huggingface/modules/transformers_modules/{SSLAM_HF_REPO}/*"
217
+ )
218
+ )
219
+ dirs = [d for d in dirs if os.path.isdir(d)]
220
+ if not dirs:
221
+ raise RuntimeError("SSLAM modules not found in HF cache — run AutoConfig.from_pretrained first")
222
+ return os.path.basename(sorted(dirs)[-1])
223
+
224
+ def unfreeze_last_n(self, n: int):
225
+ for block in self.eat.blocks[-n:]:
226
+ for p in block.parameters():
227
+ p.requires_grad = True
228
+
229
+ for p in self.eat.pre_norm.parameters():
230
+ p.requires_grad = True
231
+
232
+ @staticmethod
233
+ def _prep(mel: torch.Tensor) -> torch.Tensor:
234
+ """mel: [B, 1, T, F] => [B, 1, SSLAM_TIME_DIM, SSLAM_FREQ_DIM]"""
235
+ x = mel.float()
236
+ T = x.shape[2]
237
+ if T < SSLAM_TIME_DIM:
238
+ pad = torch.zeros(x.shape[0], 1, SSLAM_TIME_DIM - T, x.shape[3],
239
+ device=x.device, dtype=x.dtype)
240
+ x = torch.cat([x, pad], dim=2)
241
+ elif T > SSLAM_TIME_DIM:
242
+ x = x[:, :, :SSLAM_TIME_DIM, :]
243
+ return x
244
+
245
+ def forward(self, mel: torch.Tensor) -> torch.Tensor:
246
+ x = self._prep(mel)
247
+ feats = self.eat.extract_features(x)
248
+ # print(f"SSLAM features: {feats.shape}") # should be [B, 1+patches, 768]
249
+ return feats[:, 0, :]
250
+
251
+
252
+
253
+ class SSLAMSampleDetector(nn.Module):
254
+ """SampleDetector using SSLAM/EAT encoder instead of AST."""
255
+
256
+ def __init__(self, freeze_encoder: bool = True, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
257
+ super().__init__()
258
+ self.encoder = SSLAMEncoder(freeze=freeze_encoder)
259
+ H = self.encoder.hidden_size
260
+ self.head = nn.Sequential(
261
+ nn.LayerNorm(4 * H),
262
+ nn.Linear(4 * H, 512),
263
+ nn.GELU(),
264
+ nn.Dropout(dropout),
265
+ nn.Linear(512, 128),
266
+ nn.GELU(),
267
+ nn.Dropout(dropout),
268
+ nn.Linear(128, 2),
269
+ )
270
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
271
+
272
+ def unfreeze_encoder(self, n_blocks: int):
273
+ self.encoder.unfreeze_last_n(n_blocks)
274
+
275
+ def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
276
+ t = self.encoder(track_mel)
277
+ o = self.encoder(orig_mel)
278
+ combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
279
+ return self.head(combined)
280
+
281
+
282
+ class ContrastiveSampleDetector(nn.Module):
283
+ """Siamese AST encoder + projection head trained with CosineEmbeddingLoss."""
284
+
285
+ def __init__(
286
+ self,
287
+ model_name: str = os.environ["AST_MODEL"],
288
+ freeze_encoder: bool = True,
289
+ proj_dim: int = 256,
290
+ dropout: float = 0.2,
291
+ ):
292
+ super().__init__()
293
+ self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
294
+ H = self.encoder.ast.config.hidden_size
295
+ self.proj = nn.Sequential(
296
+ nn.Linear(H, 512),
297
+ nn.GELU(),
298
+ nn.Dropout(dropout),
299
+ nn.Linear(512, proj_dim),
300
+ )
301
+
302
+ def embed(self, mel: torch.Tensor) -> torch.Tensor:
303
+ h = self.encoder(mel)
304
+ # print(f"encoder output: {h.shape}, norm={h.norm(dim=-1).mean():.3f}")
305
+ z = self.proj(h)
306
+ return torch.nn.functional.normalize(z, dim=-1)
307
+
308
+ def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> tuple:
309
+ return self.embed(track_mel), self.embed(orig_mel)
310
+
311
+ def similarity(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
312
+ t, o = self.embed(track_mel), self.embed(orig_mel)
313
+ return (t * o).sum(dim=-1)
314
+
315
+ def unfreeze_encoder(self, n_blocks: int = 2):
316
+ self.encoder.unfreeze_last_n(n_blocks)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0
2
+ matplotlib>=3.8
3
+ torch>=2.5
4
+ torchaudio>=2.5
5
+ accelerate==1.13.0
6
+ python-dotenv==1.2.2
7
+ safetensors==0.7.0
8
+ audiomentations==0.43.1
9
+ av==17.0.0
10
+ huggingface-hub==1.10.1
11
+ librosa==0.11.0
12
+ numpy==2.4.4
13
+ scikit-learn==1.8.0
14
+ scipy==1.17.1
15
+ soundfile==0.13.1
16
+ transformers==5.5.4
17
+ yt-dlp==2026.3.17