sampled / app.py
dayngerous's picture
Fix no result
f1274ac
import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
os.environ.setdefault("AST_MODEL", "MIT/ast-finetuned-audioset-10-10-0.4593")
os.environ.setdefault("SSLAM_MODEL", "ta012/SSLAM_pretrain")
import gradio as gr
import librosa
import matplotlib
import numpy as np
import torch
import torchaudio.transforms as T
from huggingface_hub import hf_hub_download
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector, pair_summary_features
SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000"))
MODEL_REPO = os.environ.get("MODEL_REPO", "dayngerous/whoSampledAST")
def _resolve_checkpoint() -> str:
"""Return local checkpoint path, downloading from HF Hub if needed."""
env_path = os.environ.get("MODEL_CHECKPOINT", "")
for p in [env_path, "models/best.pt", "checkpoints/best.pt", "checkpoints2/best.pt"]:
if p and Path(p).exists():
return p
try:
return hf_hub_download(repo_id=MODEL_REPO, filename="models/best.pt")
except Exception as exc:
raise FileNotFoundError(
f"No local checkpoint found and download from {MODEL_REPO} failed: {exc}"
)
def _resolve_meta() -> str:
"""Return local test_indices.json path, downloading from HF Hub if needed."""
for p in ["models/test_indices.json", "checkpoints2/test_indices.json", "checkpoints/test_indices.json"]:
if Path(p).exists():
return p
try:
return hf_hub_download(repo_id=MODEL_REPO, filename="models/test_indices.json")
except Exception:
return ""
DEFAULT_CHECKPOINT = _resolve_checkpoint()
DEFAULT_META = os.environ.get("MODEL_META", "") or _resolve_meta()
TARGET_FRAMES_PER_BEAT = 50
N_FFT = 1024
MEL_HOP = 512
N_MELS_VIZ = 128
@dataclass
class AudioClip:
waveform: torch.Tensor
sample_rate: int
offset_sec: float
duration_sec: float
@dataclass
class BeatWindow:
waveform: torch.Tensor
start_sec: float
end_sec: float
beat_intervals: list[tuple[float, float]]
def _format_time(seconds: float) -> str:
seconds = max(0.0, float(seconds))
minutes = int(seconds // 60)
rem = seconds - minutes * 60
return f"{minutes}:{rem:04.1f}"
def _format_intervals(intervals: list[tuple[float, float]], limit: int = 4) -> str:
if not intervals:
return "none"
shown = ", ".join(f"{_format_time(a)}-{_format_time(b)}" for a, b in intervals[:limit])
if len(intervals) > limit:
shown += f", +{len(intervals) - limit} more"
return shown
def _merge_intervals(intervals: list[tuple[float, float]], gap: float = 0.05) -> list[tuple[float, float]]:
if not intervals:
return []
ordered = sorted((float(a), float(b)) for a, b in intervals if b > a)
if not ordered:
return []
merged = [ordered[0]]
for start, end in ordered[1:]:
prev_start, prev_end = merged[-1]
if start <= prev_end + gap:
merged[-1] = (prev_start, max(prev_end, end))
else:
merged.append((start, end))
return merged
def _load_args(checkpoint_path: Path) -> dict:
meta_path = Path(DEFAULT_META) if DEFAULT_META else checkpoint_path.parent / "test_indices.json"
args = {}
if meta_path.exists():
with open(meta_path) as f:
args = json.load(f).get("args", {})
args.setdefault("backbone", os.environ.get("MODEL_BACKBONE", "ast"))
args.setdefault("ast_model", os.environ.get("AST_MODEL"))
args.setdefault("bars", int(os.environ.get("MODEL_BARS", "4")))
args.setdefault("n_mels", int(os.environ.get("MODEL_N_MELS", "128")))
args.setdefault("sample_rate", SAMPLE_RATE)
return args
def _build_model(args: dict, device: torch.device):
beats_per_window = int(args.get("bars", 4)) * 4
n_mels = int(args.get("n_mels", 128))
backbone = args.get("backbone", "ast")
if backbone == "ast":
model = SampleDetector(
model_name=args.get("ast_model", os.environ["AST_MODEL"]),
freeze_encoder=True,
beats_per_window=beats_per_window,
n_mels=n_mels,
)
elif backbone == "sslam":
model = SSLAMSampleDetector(
freeze_encoder=True,
beats_per_window=beats_per_window,
n_mels=n_mels,
)
else:
model = CNNSampleDetector(beats_per_window=beats_per_window, n_mels=n_mels)
return model.to(device)
@lru_cache(maxsize=2)
def _load_model(checkpoint_path: str):
path = Path(checkpoint_path)
if not path.exists():
raise FileNotFoundError(
f"Checkpoint not found: {path}. Set MODEL_CHECKPOINT or place a checkpoint at models/best.pt."
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = _load_args(path)
model = _build_model(args, device)
ckpt = torch.load(path, map_location=device)
state = ckpt.get("model_state", ckpt)
pair_head_loaded = any(k.startswith("pair_mask_head.") for k in state)
missing, unexpected = model.load_state_dict(state, strict=False)
model.eval()
return {
"model": model,
"args": args,
"device": device,
"epoch": ckpt.get("epoch", "?"),
"pair_head_loaded": pair_head_loaded,
"missing": missing,
"unexpected": unexpected,
}
def _load_audio(path: str, offset_sec: float, max_seconds: float) -> AudioClip:
if not path:
raise gr.Error("Upload both audio files before running verification.")
audio, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True)
waveform = torch.from_numpy(audio).float()
offset_sec = max(0.0, float(offset_sec or 0.0))
max_seconds = max(1.0, float(max_seconds or 1.0))
start = min(int(offset_sec * sr), max(waveform.numel() - 1, 0))
end = min(start + int(max_seconds * sr), waveform.numel())
waveform = waveform[start:end].float().contiguous()
if waveform.numel() < sr // 4:
raise gr.Error("Each upload must contain at least 0.25 seconds of audio after offset trimming.")
peak = waveform.abs().max().clamp_min(1e-6)
waveform = waveform / peak
return AudioClip(
waveform=waveform,
sample_rate=sr,
offset_sec=offset_sec,
duration_sec=waveform.numel() / sr,
)
def _estimate_beats(waveform: torch.Tensor, sample_rate: int) -> tuple[float, np.ndarray]:
y = waveform.detach().cpu().numpy().astype(np.float32)
tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sample_rate, hop_length=512)
bpm = float(np.atleast_1d(tempo)[0]) if np.size(tempo) else 120.0
if not np.isfinite(bpm) or bpm <= 0:
bpm = 120.0
bpm = float(np.clip(bpm, 60.0, 200.0))
beat_samples = librosa.frames_to_samples(beat_frames, hop_length=512)
beat_samples = beat_samples[(beat_samples >= 0) & (beat_samples < waveform.numel())]
if len(beat_samples) < 2:
step = max(1, int(round(sample_rate * 60.0 / bpm)))
beat_samples = np.arange(0, waveform.numel(), step, dtype=np.int64)
elif beat_samples[0] > sample_rate * 60.0 / bpm:
beat_samples = np.insert(beat_samples, 0, 0)
return bpm, beat_samples.astype(np.int64)
def _to_mel(waveform: torch.Tensor, bpm: float, args: dict) -> torch.Tensor:
sample_rate = int(args.get("sample_rate", SAMPLE_RATE))
n_mels = int(args.get("n_mels", 128))
bars = int(args.get("bars", 4))
fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
hop = max(1, round(60 * sample_rate / (bpm * TARGET_FRAMES_PER_BEAT)))
mel_transform = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=N_FFT,
hop_length=hop,
n_mels=n_mels,
power=2.0,
)
amp_to_db = T.AmplitudeToDB(stype="power", top_db=80)
mel = amp_to_db(mel_transform(waveform)).T
if mel.shape[0] > fixed_frames:
mel = mel[:fixed_frames]
elif mel.shape[0] < fixed_frames:
mel = torch.cat([mel, torch.zeros(fixed_frames - mel.shape[0], mel.shape[1])], dim=0)
mel = (mel - mel.mean()) / (mel.std() + 1e-6)
return mel.unsqueeze(0)
def _make_windows(
clip: AudioClip,
bpm: float,
beat_samples: np.ndarray,
args: dict,
stride_beats: int,
max_windows: int,
) -> list[BeatWindow]:
bars = int(args.get("bars", 4))
beats_per_window = bars * 4
window_samples = max(1, int(round(beats_per_window * 60.0 / bpm * clip.sample_rate)))
beat_seconds = 60.0 / bpm
stride_beats = max(1, int(stride_beats))
max_windows = max(1, int(max_windows))
valid = [i for i in range(0, len(beat_samples), stride_beats) if beat_samples[i] < clip.waveform.numel()]
if not valid:
valid = [0]
if len(valid) > max_windows:
chosen_positions = np.linspace(0, len(valid) - 1, max_windows, dtype=np.int64)
valid = [valid[i] for i in sorted(set(chosen_positions.tolist()))]
windows = []
for beat_idx in valid:
start_sample = int(beat_samples[beat_idx]) if len(beat_samples) else 0
chunk = clip.waveform[start_sample:start_sample + window_samples]
if chunk.numel() < window_samples:
chunk = torch.nn.functional.pad(chunk, (0, window_samples - chunk.numel()))
start_sec = clip.offset_sec + start_sample / clip.sample_rate
end_sec = start_sec + window_samples / clip.sample_rate
beat_intervals = [
(start_sec + i * beat_seconds, start_sec + (i + 1) * beat_seconds)
for i in range(beats_per_window)
]
windows.append(BeatWindow(chunk, start_sec, end_sec, beat_intervals))
return windows
def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor:
embs = []
for start in range(0, mels.shape[0], batch_size):
embs.append(model.encoder(mels[start:start + batch_size]))
return torch.cat(embs, dim=0)
def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor:
"""Score each (track, source) window pair using the classifier head (model.forward)."""
track_emb = _encode(model, track_mels, batch_size)
source_emb = _encode(model, source_mels, batch_size)
n_track, n_source = track_emb.shape[0], source_emb.shape[0]
scores = torch.zeros(n_track, n_source, device=track_emb.device)
for i in range(n_track):
for j in range(n_source):
t = track_emb[i:i + 1]
s = source_emb[j:j + 1]
pair_feat = pair_summary_features(
model.pair_mask_head(track_mels[i:i + 1], source_mels[j:j + 1])
)
combined = torch.cat([t, s, torch.abs(t - s), t * s, pair_feat], dim=-1)
logits = model.head(combined)
scores[i, j] = torch.softmax(logits, dim=-1)[0, 1]
return scores
def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]:
intervals = []
for use, (start, end) in zip(mask.tolist(), window.beat_intervals):
if use:
intervals.append((start, min(end, max_end)))
return _merge_intervals(intervals)
def _find_contiguous_beats(pair_probs: np.ndarray, min_beats: int = 2) -> tuple[np.ndarray, np.ndarray]:
"""Find the best contiguous diagonal run in the beat similarity matrix.
Searches every diagonal offset (track_beat - source_beat) and uses
Kadane's algorithm to find the highest-scoring contiguous segment along
each diagonal. Returns boolean masks over track and source beats.
"""
n_track, n_source = pair_probs.shape
best_score = -np.inf
best_track_mask = np.zeros(n_track, dtype=bool)
best_source_mask = np.zeros(n_source, dtype=bool)
for d in range(-(n_source - 1), n_track):
# diagonal: track[i], source[i - d] for valid i
i0 = max(0, d)
j0 = max(0, -d)
length = min(n_track - i0, n_source - j0)
if length < min_beats:
continue
diag = pair_probs[i0:i0 + length, j0:j0 + length].diagonal()
# Kadane's max-subarray on the diagonal values
curr_sum = 0.0
curr_start = 0
best_sum = -np.inf
seg_start = seg_end = 0
for k, val in enumerate(diag):
curr_sum += val
if curr_sum > best_sum:
best_sum = curr_sum
seg_start = curr_start
seg_end = k
if curr_sum < 0:
curr_sum = 0.0
curr_start = k + 1
seg_len = seg_end - seg_start + 1
if seg_len < min_beats:
continue
avg_score = best_sum / seg_len
if avg_score > best_score:
best_score = avg_score
track_mask = np.zeros(n_track, dtype=bool)
source_mask = np.zeros(n_source, dtype=bool)
track_mask[i0 + seg_start: i0 + seg_end + 1] = True
source_mask[j0 + seg_start: j0 + seg_end + 1] = True
best_track_mask = track_mask
best_source_mask = source_mask
return best_track_mask, best_source_mask
def _localize_match(
model,
track_mel: torch.Tensor,
source_mel: torch.Tensor,
track_window: BeatWindow,
source_window: BeatWindow,
track_clip: AudioClip,
source_clip: AudioClip,
threshold: float,
pair_head_loaded: bool,
) -> tuple[list[tuple[float, float]], list[tuple[float, float]], str]:
if not pair_head_loaded:
return (
[(track_window.start_sec, min(track_window.end_sec, track_clip.offset_sec + track_clip.duration_sec))],
[(source_window.start_sec, min(source_window.end_sec, source_clip.offset_sec + source_clip.duration_sec))],
"The checkpoint does not include a trained pairwise beat head, so the highlight covers the best matching window.",
)
with torch.inference_mode():
pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
track_mask, source_mask = _find_contiguous_beats(pair_probs, min_beats=2)
# Fall back to top-k individual beats if no contiguous run was found
if not track_mask.any():
top_k = min(6, pair_probs.size)
flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
selected = np.zeros_like(pair_probs, dtype=bool)
selected.reshape(-1)[flat] = True
track_mask = selected.any(axis=1)
source_mask = selected.any(axis=0)
track_regions = _intervals_from_mask(
track_mask,
track_window,
track_clip.offset_sec + track_clip.duration_sec,
)
source_regions = _intervals_from_mask(
source_mask,
source_window,
source_clip.offset_sec + source_clip.duration_sec,
)
return track_regions, source_regions, ""
def _draw_waveform(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str):
y = clip.waveform.detach().cpu().numpy()
n = len(y)
points = min(20000, n)
idx = np.linspace(0, n - 1, points, dtype=np.int64)
x = clip.offset_sec + idx / clip.sample_rate
ax.plot(x, y[idx], color="#111827", linewidth=0.55)
for start, end in regions:
ax.axvspan(start, end, color=color, alpha=0.28)
ax.set_title(title, loc="left", fontsize=10)
ax.set_ylabel("Amplitude")
ax.set_xlim(clip.offset_sec, clip.offset_sec + clip.duration_sec)
ax.set_ylim(-1.05, 1.05)
ax.grid(True, alpha=0.18)
def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str, matched: bool):
y = clip.waveform.detach().cpu().numpy().astype(np.float32)
mel = librosa.feature.melspectrogram(y=y, sr=clip.sample_rate, n_mels=N_MELS_VIZ, hop_length=MEL_HOP)
mel_db = librosa.power_to_db(mel, ref=np.max)
t_start = clip.offset_sec
t_end = clip.offset_sec + clip.duration_sec
f_max = clip.sample_rate / 2
ax.imshow(
mel_db,
aspect="auto",
origin="lower",
extent=[t_start, t_end, 0, f_max],
cmap="magma",
interpolation="nearest",
)
ax.set_title(title, loc="left", fontsize=10)
ax.set_ylabel("Frequency (Hz)")
ax.set_xlim(t_start, t_end)
if regions:
for start, end in regions:
ax.axvspan(start, end, color=color, alpha=0.38 if matched else 0.22, linewidth=0)
if not matched:
ax.text(
0.5, 0.5, "No Match",
transform=ax.transAxes,
fontsize=18,
color="white",
ha="center",
va="center",
fontweight="bold",
bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65),
)
def _plot_waveforms(
track_clip: AudioClip,
source_clip: AudioClip,
track_regions: list[tuple[float, float]],
source_regions: list[tuple[float, float]],
score: float | None,
matched: bool,
) -> plt.Figure:
fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=False)
color = "#22c55e" if matched else "#f59e0b"
title_score = "unavailable" if score is None else f"{score:.3f}"
fig.suptitle(f"Best match score: {title_score}" if score is not None else "Waveform preview", fontsize=12)
_draw_waveform(axes[0], track_clip, track_regions, color, "Track / song audio")
_draw_waveform(axes[1], source_clip, source_regions, color, "Source sample audio")
axes[1].set_xlabel("Time in uploaded file (seconds)")
fig.tight_layout()
return fig
def _plot_mels(
track_clip: AudioClip,
source_clip: AudioClip,
track_regions: list[tuple[float, float]],
source_regions: list[tuple[float, float]],
matched: bool,
) -> plt.Figure:
fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=False)
color = "#22c55e" if matched else "#f59e0b"
_draw_mel(axes[0], track_clip, track_regions, color, "Track mel spectrogram", matched)
_draw_mel(axes[1], source_clip, source_regions, color, "Source mel spectrogram", matched)
axes[1].set_xlabel("Time in uploaded file (seconds)")
fig.tight_layout()
return fig
def _image_to_mel_tensor(image_path: str, args: dict) -> torch.Tensor:
"""Load a BPM-normalized mel spectrogram PNG as the model's input tensor."""
from PIL import Image as PILImage
n_mels = int(args.get("n_mels", 128))
bars = int(args.get("bars", 4))
fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
img = PILImage.open(image_path).convert("L")
img = img.resize((fixed_frames, n_mels), PILImage.LANCZOS)
arr = np.array(img, dtype=np.float32) / 255.0 # [n_mels, fixed_frames]
# Image was saved with origin="lower": row 0 in pixels = highest freq bin
arr = arr[::-1] # flip so row 0 = lowest mel bin
mel = torch.from_numpy(arr.T.copy()).float() # [fixed_frames, n_mels]
mel = (mel - mel.mean()) / (mel.std() + 1e-6)
return mel.unsqueeze(0) # [1, fixed_frames, n_mels]
def _plot_spectrograms_with_mask(
track_img_path: str,
source_img_path: str,
track_beats: np.ndarray,
source_beats: np.ndarray,
score: float,
matched: bool,
) -> plt.Figure:
from PIL import Image as PILImage
color = "#22c55e" if matched else "#f59e0b"
fig, axes = plt.subplots(2, 1, figsize=(12, 5))
fig.suptitle(f"Score: {score:.3f}", fontsize=12)
for ax, img_path, label, beats in [
(axes[0], track_img_path, "Track spectrogram", track_beats),
(axes[1], source_img_path, "Source spectrogram", source_beats),
]:
img = np.array(PILImage.open(img_path).convert("RGB"))
W = img.shape[1]
ax.imshow(img, aspect="auto")
ax.set_title(label, loc="left", fontsize=10)
ax.set_xlabel("Time frame (BPM-normalized)")
ax.set_ylabel("Mel bin")
ax.tick_params(labelsize=7)
if beats is not None and beats.any():
n_beats = len(beats)
beat_w = W / n_beats
for i, active in enumerate(beats):
if active:
ax.axvspan(i * beat_w, (i + 1) * beat_w, color=color, alpha=0.38, linewidth=0)
if not matched:
ax.text(0.5, 0.5, "No Match", transform=ax.transAxes,
fontsize=18, color="white", ha="center", va="center", fontweight="bold",
bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65))
fig.tight_layout()
return fig
def _norm_file_list(files) -> list[str]:
"""Normalise whatever gr.File returns into a flat list of path strings."""
if not files:
return []
if isinstance(files, (str, bytes)):
return [str(files)]
paths = []
for f in (files if isinstance(files, list) else [files]):
if isinstance(f, str):
paths.append(f)
elif hasattr(f, "name"):
paths.append(f.name)
return paths
def verify_spectrograms(
track_specs,
source_specs,
checkpoint_path,
match_threshold,
localization_threshold,
):
track_paths = _norm_file_list(track_specs)
source_paths = _norm_file_list(source_specs)
if not track_paths or not source_paths:
raise gr.Error("Upload at least one spectrogram image for both track and source.")
try:
loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
except Exception as exc:
return f"Model could not be loaded: {exc}", None, None
model = loaded["model"]
args = loaded["args"]
device = loaded["device"]
batch_size = 8 if device.type == "cpu" else 32
track_mels = torch.stack([_image_to_mel_tensor(p, args) for p in track_paths]).to(device)
source_mels = torch.stack([_image_to_mel_tensor(p, args) for p in source_paths]).to(device)
with torch.inference_mode():
score_matrix = _score_pairs(model, track_mels, source_mels, batch_size)
best_flat = int(torch.argmax(score_matrix).item())
best_track_idx = best_flat // score_matrix.shape[1]
best_source_idx = best_flat % score_matrix.shape[1]
best_score = float(score_matrix[best_track_idx, best_source_idx])
matched = best_score >= float(match_threshold)
best_track_mel = track_mels[best_track_idx:best_track_idx + 1]
best_source_mel = source_mels[best_source_idx:best_source_idx + 1]
beats_per_window = int(args.get("bars", 4)) * 4
if loaded["pair_head_loaded"]:
with torch.inference_mode():
pair_probs = torch.sigmoid(model.pair_mask_head(best_track_mel, best_source_mel))[0].cpu().numpy()
track_beats, source_beats = _find_contiguous_beats(pair_probs, min_beats=2)
if not track_beats.any():
track_beats = np.ones(beats_per_window, dtype=bool)
source_beats = np.ones(beats_per_window, dtype=bool)
else:
track_beats = np.ones(beats_per_window, dtype=bool)
source_beats = np.ones(beats_per_window, dtype=bool)
spec_fig = _plot_spectrograms_with_mask(
track_paths[best_track_idx], source_paths[best_source_idx],
track_beats, source_beats, best_score, matched,
)
verdict = "Likely match" if matched else "No match"
details = [
f"**{verdict}**",
f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).",
f"Best window: track `w{best_track_idx:02d}` × source `w{best_source_idx:02d}` "
f"({len(track_paths)} × {len(source_paths)} combinations tried).",
f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
]
if not loaded["pair_head_loaded"]:
details.append("Checkpoint does not include a trained pairwise beat head.")
return "\n\n".join(details), None, spec_fig
def preview_waveforms(track_audio, source_audio):
if not track_audio or not source_audio:
return None, None
try:
track_clip = _load_audio(track_audio, 0.0, 120.0)
source_clip = _load_audio(source_audio, 0.0, 120.0)
wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
mfig = _plot_mels(track_clip, source_clip, [], [], False)
return wfig, mfig
except Exception:
return None, None
def verify(
track_audio,
source_audio,
checkpoint_path,
match_threshold,
localization_threshold,
track_offset,
source_offset,
max_seconds,
stride_beats,
max_windows,
):
try:
track_clip = _load_audio(track_audio, track_offset, max_seconds)
source_clip = _load_audio(source_audio, source_offset, max_seconds)
except Exception as exc:
raise gr.Error(str(exc))
try:
loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
except Exception as exc:
wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False)
mfig = _plot_mels(track_clip, source_clip, [], [], False)
return f"Model could not be loaded: {exc}", wfig, mfig
model = loaded["model"]
args = loaded["args"]
device = loaded["device"]
batch_size = 8 if device.type == "cpu" else 32
track_bpm, track_beats = _estimate_beats(track_clip.waveform, track_clip.sample_rate)
source_bpm, source_beats = _estimate_beats(source_clip.waveform, source_clip.sample_rate)
track_windows = _make_windows(track_clip, track_bpm, track_beats, args, stride_beats, max_windows)
source_windows = _make_windows(source_clip, source_bpm, source_beats, args, stride_beats, max_windows)
track_mels = torch.stack([_to_mel(w.waveform, track_bpm, args) for w in track_windows]).to(device)
source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device)
with torch.inference_mode():
score_matrix = _score_pairs(model, track_mels, source_mels, batch_size)
best_flat = int(torch.argmax(score_matrix).item())
best_track = best_flat // score_matrix.shape[1]
best_source = best_flat % score_matrix.shape[1]
best_score = float(score_matrix[best_track, best_source].detach().cpu())
matched = best_score >= float(match_threshold)
track_regions, source_regions, note = _localize_match(
model,
track_mels[best_track:best_track + 1],
source_mels[best_source:best_source + 1],
track_windows[best_track],
source_windows[best_source],
track_clip,
source_clip,
localization_threshold,
loaded["pair_head_loaded"],
)
if not track_regions or not source_regions:
matched = False
track_regions = []
source_regions = []
if not note:
note = "Localization was inconclusive, so the result is treated as no match."
wfig = _plot_waveforms(track_clip, source_clip, track_regions, source_regions, best_score, matched)
mfig = _plot_mels(track_clip, source_clip, track_regions, source_regions, matched)
verdict = "Likely match" if matched else "No match"
details = [
f"**{verdict}**",
f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).",
f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.",
f"{'Matched' if matched else 'Proposed'} track section(s): {_format_intervals(track_regions)}.",
f"{'Matched' if matched else 'Proposed'} source section(s): {_format_intervals(source_regions)}.",
f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
]
if note:
details.append(note)
if loaded["missing"]:
details.append(f"Missing checkpoint keys initialized at load time: `{len(loaded['missing'])}`.")
return "\n\n".join(details), wfig, mfig
with gr.Blocks(title="Sample Match Verifier") as demo:
gr.Markdown("# Sample Match Verifier")
gr.Markdown(
"Upload a track and a possible source sample. "
"Click **Verify match** to run the model."
)
with gr.Tabs():
with gr.Tab("Audio"):
gr.Markdown("Waveforms appear immediately on upload.")
with gr.Row():
track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
audio_run = gr.Button("Verify match", variant="primary")
with gr.Tab("Spectrogram"):
gr.Markdown(
"Upload the window images "
"(`*_w00.png`, `*_w01.png`, …). Select **all windows** for each file — "
"the app will score every combination and return the best match."
)
with gr.Row():
track_spec = gr.File(label="Track spectrogram windows", file_count="multiple",
file_types=[".png", ".jpg", ".jpeg"])
source_spec = gr.File(label="Source spectrogram windows", file_count="multiple",
file_types=[".png", ".jpg", ".jpeg"])
spec_run = gr.Button("Verify match", variant="primary")
with gr.Accordion("Settings", open=False):
checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
with gr.Row():
match_threshold = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Match threshold")
localization_threshold = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Highlight threshold")
with gr.Row():
track_offset = gr.Number(value=0.0, label="Track start offset, seconds")
source_offset = gr.Number(value=0.0, label="Source start offset, seconds")
with gr.Row():
max_seconds = gr.Slider(5, 180, value=60, step=5, label="Analyze duration per upload, seconds")
stride_beats = gr.Slider(1, 16, value=16, step=1, label="Window stride, beats")
max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
result = gr.Markdown()
waveform_plot = gr.Plot(label="Waveforms")
mel_plot = gr.Plot(label="Mel Spectrograms")
# Show waveforms as soon as both audio files are uploaded
for audio_input in [track_audio, source_audio]:
audio_input.change(
preview_waveforms,
inputs=[track_audio, source_audio],
outputs=[waveform_plot, mel_plot],
)
audio_run.click(
verify,
inputs=[
track_audio,
source_audio,
checkpoint_path,
match_threshold,
localization_threshold,
track_offset,
source_offset,
max_seconds,
stride_beats,
max_windows,
],
outputs=[result, waveform_plot, mel_plot],
)
spec_run.click(
verify_spectrograms,
inputs=[track_spec, source_spec, checkpoint_path, match_threshold, localization_threshold],
outputs=[result, waveform_plot, mel_plot],
)
if __name__ == "__main__":
demo.queue(max_size=8).launch()