import json import os from dataclasses import dataclass from functools import lru_cache from pathlib import Path os.environ.setdefault("AST_MODEL", "MIT/ast-finetuned-audioset-10-10-0.4593") os.environ.setdefault("SSLAM_MODEL", "ta012/SSLAM_pretrain") import gradio as gr import librosa import matplotlib import numpy as np import torch import torchaudio.transforms as T from huggingface_hub import hf_hub_download matplotlib.use("Agg") import matplotlib.pyplot as plt from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector, pair_summary_features SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000")) MODEL_REPO = os.environ.get("MODEL_REPO", "dayngerous/whoSampledAST") def _resolve_checkpoint() -> str: """Return local checkpoint path, downloading from HF Hub if needed.""" env_path = os.environ.get("MODEL_CHECKPOINT", "") for p in [env_path, "models/best.pt", "checkpoints/best.pt", "checkpoints2/best.pt"]: if p and Path(p).exists(): return p try: return hf_hub_download(repo_id=MODEL_REPO, filename="models/best.pt") except Exception as exc: raise FileNotFoundError( f"No local checkpoint found and download from {MODEL_REPO} failed: {exc}" ) def _resolve_meta() -> str: """Return local test_indices.json path, downloading from HF Hub if needed.""" for p in ["models/test_indices.json", "checkpoints2/test_indices.json", "checkpoints/test_indices.json"]: if Path(p).exists(): return p try: return hf_hub_download(repo_id=MODEL_REPO, filename="models/test_indices.json") except Exception: return "" DEFAULT_CHECKPOINT = _resolve_checkpoint() DEFAULT_META = os.environ.get("MODEL_META", "") or _resolve_meta() TARGET_FRAMES_PER_BEAT = 50 N_FFT = 1024 MEL_HOP = 512 N_MELS_VIZ = 128 @dataclass class AudioClip: waveform: torch.Tensor sample_rate: int offset_sec: float duration_sec: float @dataclass class BeatWindow: waveform: torch.Tensor start_sec: float end_sec: float beat_intervals: list[tuple[float, float]] def _format_time(seconds: float) -> str: seconds = max(0.0, float(seconds)) minutes = int(seconds // 60) rem = seconds - minutes * 60 return f"{minutes}:{rem:04.1f}" def _format_intervals(intervals: list[tuple[float, float]], limit: int = 4) -> str: if not intervals: return "none" shown = ", ".join(f"{_format_time(a)}-{_format_time(b)}" for a, b in intervals[:limit]) if len(intervals) > limit: shown += f", +{len(intervals) - limit} more" return shown def _merge_intervals(intervals: list[tuple[float, float]], gap: float = 0.05) -> list[tuple[float, float]]: if not intervals: return [] ordered = sorted((float(a), float(b)) for a, b in intervals if b > a) if not ordered: return [] merged = [ordered[0]] for start, end in ordered[1:]: prev_start, prev_end = merged[-1] if start <= prev_end + gap: merged[-1] = (prev_start, max(prev_end, end)) else: merged.append((start, end)) return merged def _load_args(checkpoint_path: Path) -> dict: meta_path = Path(DEFAULT_META) if DEFAULT_META else checkpoint_path.parent / "test_indices.json" args = {} if meta_path.exists(): with open(meta_path) as f: args = json.load(f).get("args", {}) args.setdefault("backbone", os.environ.get("MODEL_BACKBONE", "ast")) args.setdefault("ast_model", os.environ.get("AST_MODEL")) args.setdefault("bars", int(os.environ.get("MODEL_BARS", "4"))) args.setdefault("n_mels", int(os.environ.get("MODEL_N_MELS", "128"))) args.setdefault("sample_rate", SAMPLE_RATE) return args def _build_model(args: dict, device: torch.device): beats_per_window = int(args.get("bars", 4)) * 4 n_mels = int(args.get("n_mels", 128)) backbone = args.get("backbone", "ast") if backbone == "ast": model = SampleDetector( model_name=args.get("ast_model", os.environ["AST_MODEL"]), freeze_encoder=True, beats_per_window=beats_per_window, n_mels=n_mels, ) elif backbone == "sslam": model = SSLAMSampleDetector( freeze_encoder=True, beats_per_window=beats_per_window, n_mels=n_mels, ) else: model = CNNSampleDetector(beats_per_window=beats_per_window, n_mels=n_mels) return model.to(device) @lru_cache(maxsize=2) def _load_model(checkpoint_path: str): path = Path(checkpoint_path) if not path.exists(): raise FileNotFoundError( f"Checkpoint not found: {path}. Set MODEL_CHECKPOINT or place a checkpoint at models/best.pt." ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = _load_args(path) model = _build_model(args, device) ckpt = torch.load(path, map_location=device) state = ckpt.get("model_state", ckpt) pair_head_loaded = any(k.startswith("pair_mask_head.") for k in state) missing, unexpected = model.load_state_dict(state, strict=False) model.eval() return { "model": model, "args": args, "device": device, "epoch": ckpt.get("epoch", "?"), "pair_head_loaded": pair_head_loaded, "missing": missing, "unexpected": unexpected, } def _load_audio(path: str, offset_sec: float, max_seconds: float) -> AudioClip: if not path: raise gr.Error("Upload both audio files before running verification.") audio, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True) waveform = torch.from_numpy(audio).float() offset_sec = max(0.0, float(offset_sec or 0.0)) max_seconds = max(1.0, float(max_seconds or 1.0)) start = min(int(offset_sec * sr), max(waveform.numel() - 1, 0)) end = min(start + int(max_seconds * sr), waveform.numel()) waveform = waveform[start:end].float().contiguous() if waveform.numel() < sr // 4: raise gr.Error("Each upload must contain at least 0.25 seconds of audio after offset trimming.") peak = waveform.abs().max().clamp_min(1e-6) waveform = waveform / peak return AudioClip( waveform=waveform, sample_rate=sr, offset_sec=offset_sec, duration_sec=waveform.numel() / sr, ) def _estimate_beats(waveform: torch.Tensor, sample_rate: int) -> tuple[float, np.ndarray]: y = waveform.detach().cpu().numpy().astype(np.float32) tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sample_rate, hop_length=512) bpm = float(np.atleast_1d(tempo)[0]) if np.size(tempo) else 120.0 if not np.isfinite(bpm) or bpm <= 0: bpm = 120.0 bpm = float(np.clip(bpm, 60.0, 200.0)) beat_samples = librosa.frames_to_samples(beat_frames, hop_length=512) beat_samples = beat_samples[(beat_samples >= 0) & (beat_samples < waveform.numel())] if len(beat_samples) < 2: step = max(1, int(round(sample_rate * 60.0 / bpm))) beat_samples = np.arange(0, waveform.numel(), step, dtype=np.int64) elif beat_samples[0] > sample_rate * 60.0 / bpm: beat_samples = np.insert(beat_samples, 0, 0) return bpm, beat_samples.astype(np.int64) def _to_mel(waveform: torch.Tensor, bpm: float, args: dict) -> torch.Tensor: sample_rate = int(args.get("sample_rate", SAMPLE_RATE)) n_mels = int(args.get("n_mels", 128)) bars = int(args.get("bars", 4)) fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT hop = max(1, round(60 * sample_rate / (bpm * TARGET_FRAMES_PER_BEAT))) mel_transform = T.MelSpectrogram( sample_rate=sample_rate, n_fft=N_FFT, hop_length=hop, n_mels=n_mels, power=2.0, ) amp_to_db = T.AmplitudeToDB(stype="power", top_db=80) mel = amp_to_db(mel_transform(waveform)).T if mel.shape[0] > fixed_frames: mel = mel[:fixed_frames] elif mel.shape[0] < fixed_frames: mel = torch.cat([mel, torch.zeros(fixed_frames - mel.shape[0], mel.shape[1])], dim=0) mel = (mel - mel.mean()) / (mel.std() + 1e-6) return mel.unsqueeze(0) def _make_windows( clip: AudioClip, bpm: float, beat_samples: np.ndarray, args: dict, stride_beats: int, max_windows: int, ) -> list[BeatWindow]: bars = int(args.get("bars", 4)) beats_per_window = bars * 4 window_samples = max(1, int(round(beats_per_window * 60.0 / bpm * clip.sample_rate))) beat_seconds = 60.0 / bpm stride_beats = max(1, int(stride_beats)) max_windows = max(1, int(max_windows)) valid = [i for i in range(0, len(beat_samples), stride_beats) if beat_samples[i] < clip.waveform.numel()] if not valid: valid = [0] if len(valid) > max_windows: chosen_positions = np.linspace(0, len(valid) - 1, max_windows, dtype=np.int64) valid = [valid[i] for i in sorted(set(chosen_positions.tolist()))] windows = [] for beat_idx in valid: start_sample = int(beat_samples[beat_idx]) if len(beat_samples) else 0 chunk = clip.waveform[start_sample:start_sample + window_samples] if chunk.numel() < window_samples: chunk = torch.nn.functional.pad(chunk, (0, window_samples - chunk.numel())) start_sec = clip.offset_sec + start_sample / clip.sample_rate end_sec = start_sec + window_samples / clip.sample_rate beat_intervals = [ (start_sec + i * beat_seconds, start_sec + (i + 1) * beat_seconds) for i in range(beats_per_window) ] windows.append(BeatWindow(chunk, start_sec, end_sec, beat_intervals)) return windows def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor: embs = [] for start in range(0, mels.shape[0], batch_size): embs.append(model.encoder(mels[start:start + batch_size])) return torch.cat(embs, dim=0) def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor: """Score each (track, source) window pair using the classifier head (model.forward).""" track_emb = _encode(model, track_mels, batch_size) source_emb = _encode(model, source_mels, batch_size) n_track, n_source = track_emb.shape[0], source_emb.shape[0] scores = torch.zeros(n_track, n_source, device=track_emb.device) for i in range(n_track): for j in range(n_source): t = track_emb[i:i + 1] s = source_emb[j:j + 1] pair_feat = pair_summary_features( model.pair_mask_head(track_mels[i:i + 1], source_mels[j:j + 1]) ) combined = torch.cat([t, s, torch.abs(t - s), t * s, pair_feat], dim=-1) logits = model.head(combined) scores[i, j] = torch.softmax(logits, dim=-1)[0, 1] return scores def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]: intervals = [] for use, (start, end) in zip(mask.tolist(), window.beat_intervals): if use: intervals.append((start, min(end, max_end))) return _merge_intervals(intervals) def _find_contiguous_beats(pair_probs: np.ndarray, min_beats: int = 2) -> tuple[np.ndarray, np.ndarray]: """Find the best contiguous diagonal run in the beat similarity matrix. Searches every diagonal offset (track_beat - source_beat) and uses Kadane's algorithm to find the highest-scoring contiguous segment along each diagonal. Returns boolean masks over track and source beats. """ n_track, n_source = pair_probs.shape best_score = -np.inf best_track_mask = np.zeros(n_track, dtype=bool) best_source_mask = np.zeros(n_source, dtype=bool) for d in range(-(n_source - 1), n_track): # diagonal: track[i], source[i - d] for valid i i0 = max(0, d) j0 = max(0, -d) length = min(n_track - i0, n_source - j0) if length < min_beats: continue diag = pair_probs[i0:i0 + length, j0:j0 + length].diagonal() # Kadane's max-subarray on the diagonal values curr_sum = 0.0 curr_start = 0 best_sum = -np.inf seg_start = seg_end = 0 for k, val in enumerate(diag): curr_sum += val if curr_sum > best_sum: best_sum = curr_sum seg_start = curr_start seg_end = k if curr_sum < 0: curr_sum = 0.0 curr_start = k + 1 seg_len = seg_end - seg_start + 1 if seg_len < min_beats: continue avg_score = best_sum / seg_len if avg_score > best_score: best_score = avg_score track_mask = np.zeros(n_track, dtype=bool) source_mask = np.zeros(n_source, dtype=bool) track_mask[i0 + seg_start: i0 + seg_end + 1] = True source_mask[j0 + seg_start: j0 + seg_end + 1] = True best_track_mask = track_mask best_source_mask = source_mask return best_track_mask, best_source_mask def _localize_match( model, track_mel: torch.Tensor, source_mel: torch.Tensor, track_window: BeatWindow, source_window: BeatWindow, track_clip: AudioClip, source_clip: AudioClip, threshold: float, pair_head_loaded: bool, ) -> tuple[list[tuple[float, float]], list[tuple[float, float]], str]: if not pair_head_loaded: return ( [(track_window.start_sec, min(track_window.end_sec, track_clip.offset_sec + track_clip.duration_sec))], [(source_window.start_sec, min(source_window.end_sec, source_clip.offset_sec + source_clip.duration_sec))], "The checkpoint does not include a trained pairwise beat head, so the highlight covers the best matching window.", ) with torch.inference_mode(): pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy() track_mask, source_mask = _find_contiguous_beats(pair_probs, min_beats=2) # Fall back to top-k individual beats if no contiguous run was found if not track_mask.any(): top_k = min(6, pair_probs.size) flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:] selected = np.zeros_like(pair_probs, dtype=bool) selected.reshape(-1)[flat] = True track_mask = selected.any(axis=1) source_mask = selected.any(axis=0) track_regions = _intervals_from_mask( track_mask, track_window, track_clip.offset_sec + track_clip.duration_sec, ) source_regions = _intervals_from_mask( source_mask, source_window, source_clip.offset_sec + source_clip.duration_sec, ) return track_regions, source_regions, "" def _draw_waveform(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str): y = clip.waveform.detach().cpu().numpy() n = len(y) points = min(20000, n) idx = np.linspace(0, n - 1, points, dtype=np.int64) x = clip.offset_sec + idx / clip.sample_rate ax.plot(x, y[idx], color="#111827", linewidth=0.55) for start, end in regions: ax.axvspan(start, end, color=color, alpha=0.28) ax.set_title(title, loc="left", fontsize=10) ax.set_ylabel("Amplitude") ax.set_xlim(clip.offset_sec, clip.offset_sec + clip.duration_sec) ax.set_ylim(-1.05, 1.05) ax.grid(True, alpha=0.18) def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str, matched: bool): y = clip.waveform.detach().cpu().numpy().astype(np.float32) mel = librosa.feature.melspectrogram(y=y, sr=clip.sample_rate, n_mels=N_MELS_VIZ, hop_length=MEL_HOP) mel_db = librosa.power_to_db(mel, ref=np.max) t_start = clip.offset_sec t_end = clip.offset_sec + clip.duration_sec f_max = clip.sample_rate / 2 ax.imshow( mel_db, aspect="auto", origin="lower", extent=[t_start, t_end, 0, f_max], cmap="magma", interpolation="nearest", ) ax.set_title(title, loc="left", fontsize=10) ax.set_ylabel("Frequency (Hz)") ax.set_xlim(t_start, t_end) if regions: for start, end in regions: ax.axvspan(start, end, color=color, alpha=0.38 if matched else 0.22, linewidth=0) if not matched: ax.text( 0.5, 0.5, "No Match", transform=ax.transAxes, fontsize=18, color="white", ha="center", va="center", fontweight="bold", bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65), ) def _plot_waveforms( track_clip: AudioClip, source_clip: AudioClip, track_regions: list[tuple[float, float]], source_regions: list[tuple[float, float]], score: float | None, matched: bool, ) -> plt.Figure: fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=False) color = "#22c55e" if matched else "#f59e0b" title_score = "unavailable" if score is None else f"{score:.3f}" fig.suptitle(f"Best match score: {title_score}" if score is not None else "Waveform preview", fontsize=12) _draw_waveform(axes[0], track_clip, track_regions, color, "Track / song audio") _draw_waveform(axes[1], source_clip, source_regions, color, "Source sample audio") axes[1].set_xlabel("Time in uploaded file (seconds)") fig.tight_layout() return fig def _plot_mels( track_clip: AudioClip, source_clip: AudioClip, track_regions: list[tuple[float, float]], source_regions: list[tuple[float, float]], matched: bool, ) -> plt.Figure: fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=False) color = "#22c55e" if matched else "#f59e0b" _draw_mel(axes[0], track_clip, track_regions, color, "Track mel spectrogram", matched) _draw_mel(axes[1], source_clip, source_regions, color, "Source mel spectrogram", matched) axes[1].set_xlabel("Time in uploaded file (seconds)") fig.tight_layout() return fig def _image_to_mel_tensor(image_path: str, args: dict) -> torch.Tensor: """Load a BPM-normalized mel spectrogram PNG as the model's input tensor.""" from PIL import Image as PILImage n_mels = int(args.get("n_mels", 128)) bars = int(args.get("bars", 4)) fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT img = PILImage.open(image_path).convert("L") img = img.resize((fixed_frames, n_mels), PILImage.LANCZOS) arr = np.array(img, dtype=np.float32) / 255.0 # [n_mels, fixed_frames] # Image was saved with origin="lower": row 0 in pixels = highest freq bin arr = arr[::-1] # flip so row 0 = lowest mel bin mel = torch.from_numpy(arr.T.copy()).float() # [fixed_frames, n_mels] mel = (mel - mel.mean()) / (mel.std() + 1e-6) return mel.unsqueeze(0) # [1, fixed_frames, n_mels] def _plot_spectrograms_with_mask( track_img_path: str, source_img_path: str, track_beats: np.ndarray, source_beats: np.ndarray, score: float, matched: bool, ) -> plt.Figure: from PIL import Image as PILImage color = "#22c55e" if matched else "#f59e0b" fig, axes = plt.subplots(2, 1, figsize=(12, 5)) fig.suptitle(f"Score: {score:.3f}", fontsize=12) for ax, img_path, label, beats in [ (axes[0], track_img_path, "Track spectrogram", track_beats), (axes[1], source_img_path, "Source spectrogram", source_beats), ]: img = np.array(PILImage.open(img_path).convert("RGB")) W = img.shape[1] ax.imshow(img, aspect="auto") ax.set_title(label, loc="left", fontsize=10) ax.set_xlabel("Time frame (BPM-normalized)") ax.set_ylabel("Mel bin") ax.tick_params(labelsize=7) if beats is not None and beats.any(): n_beats = len(beats) beat_w = W / n_beats for i, active in enumerate(beats): if active: ax.axvspan(i * beat_w, (i + 1) * beat_w, color=color, alpha=0.38, linewidth=0) if not matched: ax.text(0.5, 0.5, "No Match", transform=ax.transAxes, fontsize=18, color="white", ha="center", va="center", fontweight="bold", bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65)) fig.tight_layout() return fig def _norm_file_list(files) -> list[str]: """Normalise whatever gr.File returns into a flat list of path strings.""" if not files: return [] if isinstance(files, (str, bytes)): return [str(files)] paths = [] for f in (files if isinstance(files, list) else [files]): if isinstance(f, str): paths.append(f) elif hasattr(f, "name"): paths.append(f.name) return paths def verify_spectrograms( track_specs, source_specs, checkpoint_path, match_threshold, localization_threshold, ): track_paths = _norm_file_list(track_specs) source_paths = _norm_file_list(source_specs) if not track_paths or not source_paths: raise gr.Error("Upload at least one spectrogram image for both track and source.") try: loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT) except Exception as exc: return f"Model could not be loaded: {exc}", None, None model = loaded["model"] args = loaded["args"] device = loaded["device"] batch_size = 8 if device.type == "cpu" else 32 track_mels = torch.stack([_image_to_mel_tensor(p, args) for p in track_paths]).to(device) source_mels = torch.stack([_image_to_mel_tensor(p, args) for p in source_paths]).to(device) with torch.inference_mode(): score_matrix = _score_pairs(model, track_mels, source_mels, batch_size) best_flat = int(torch.argmax(score_matrix).item()) best_track_idx = best_flat // score_matrix.shape[1] best_source_idx = best_flat % score_matrix.shape[1] best_score = float(score_matrix[best_track_idx, best_source_idx]) matched = best_score >= float(match_threshold) best_track_mel = track_mels[best_track_idx:best_track_idx + 1] best_source_mel = source_mels[best_source_idx:best_source_idx + 1] beats_per_window = int(args.get("bars", 4)) * 4 if loaded["pair_head_loaded"]: with torch.inference_mode(): pair_probs = torch.sigmoid(model.pair_mask_head(best_track_mel, best_source_mel))[0].cpu().numpy() track_beats, source_beats = _find_contiguous_beats(pair_probs, min_beats=2) if not track_beats.any(): track_beats = np.ones(beats_per_window, dtype=bool) source_beats = np.ones(beats_per_window, dtype=bool) else: track_beats = np.ones(beats_per_window, dtype=bool) source_beats = np.ones(beats_per_window, dtype=bool) spec_fig = _plot_spectrograms_with_mask( track_paths[best_track_idx], source_paths[best_source_idx], track_beats, source_beats, best_score, matched, ) verdict = "Likely match" if matched else "No match" details = [ f"**{verdict}**", f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).", f"Best window: track `w{best_track_idx:02d}` × source `w{best_source_idx:02d}` " f"({len(track_paths)} × {len(source_paths)} combinations tried).", f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.", ] if not loaded["pair_head_loaded"]: details.append("Checkpoint does not include a trained pairwise beat head.") return "\n\n".join(details), None, spec_fig def preview_waveforms(track_audio, source_audio): if not track_audio or not source_audio: return None, None try: track_clip = _load_audio(track_audio, 0.0, 120.0) source_clip = _load_audio(source_audio, 0.0, 120.0) wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False) mfig = _plot_mels(track_clip, source_clip, [], [], False) return wfig, mfig except Exception: return None, None def verify( track_audio, source_audio, checkpoint_path, match_threshold, localization_threshold, track_offset, source_offset, max_seconds, stride_beats, max_windows, ): try: track_clip = _load_audio(track_audio, track_offset, max_seconds) source_clip = _load_audio(source_audio, source_offset, max_seconds) except Exception as exc: raise gr.Error(str(exc)) try: loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT) except Exception as exc: wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False) mfig = _plot_mels(track_clip, source_clip, [], [], False) return f"Model could not be loaded: {exc}", wfig, mfig model = loaded["model"] args = loaded["args"] device = loaded["device"] batch_size = 8 if device.type == "cpu" else 32 track_bpm, track_beats = _estimate_beats(track_clip.waveform, track_clip.sample_rate) source_bpm, source_beats = _estimate_beats(source_clip.waveform, source_clip.sample_rate) track_windows = _make_windows(track_clip, track_bpm, track_beats, args, stride_beats, max_windows) source_windows = _make_windows(source_clip, source_bpm, source_beats, args, stride_beats, max_windows) track_mels = torch.stack([_to_mel(w.waveform, track_bpm, args) for w in track_windows]).to(device) source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device) with torch.inference_mode(): score_matrix = _score_pairs(model, track_mels, source_mels, batch_size) best_flat = int(torch.argmax(score_matrix).item()) best_track = best_flat // score_matrix.shape[1] best_source = best_flat % score_matrix.shape[1] best_score = float(score_matrix[best_track, best_source].detach().cpu()) matched = best_score >= float(match_threshold) track_regions, source_regions, note = _localize_match( model, track_mels[best_track:best_track + 1], source_mels[best_source:best_source + 1], track_windows[best_track], source_windows[best_source], track_clip, source_clip, localization_threshold, loaded["pair_head_loaded"], ) if not track_regions or not source_regions: matched = False track_regions = [] source_regions = [] if not note: note = "Localization was inconclusive, so the result is treated as no match." wfig = _plot_waveforms(track_clip, source_clip, track_regions, source_regions, best_score, matched) mfig = _plot_mels(track_clip, source_clip, track_regions, source_regions, matched) verdict = "Likely match" if matched else "No match" details = [ f"**{verdict}**", f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).", f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.", f"{'Matched' if matched else 'Proposed'} track section(s): {_format_intervals(track_regions)}.", f"{'Matched' if matched else 'Proposed'} source section(s): {_format_intervals(source_regions)}.", f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.", ] if note: details.append(note) if loaded["missing"]: details.append(f"Missing checkpoint keys initialized at load time: `{len(loaded['missing'])}`.") return "\n\n".join(details), wfig, mfig with gr.Blocks(title="Sample Match Verifier") as demo: gr.Markdown("# Sample Match Verifier") gr.Markdown( "Upload a track and a possible source sample. " "Click **Verify match** to run the model." ) with gr.Tabs(): with gr.Tab("Audio"): gr.Markdown("Waveforms appear immediately on upload.") with gr.Row(): track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"]) source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"]) audio_run = gr.Button("Verify match", variant="primary") with gr.Tab("Spectrogram"): gr.Markdown( "Upload the window images " "(`*_w00.png`, `*_w01.png`, …). Select **all windows** for each file — " "the app will score every combination and return the best match." ) with gr.Row(): track_spec = gr.File(label="Track spectrogram windows", file_count="multiple", file_types=[".png", ".jpg", ".jpeg"]) source_spec = gr.File(label="Source spectrogram windows", file_count="multiple", file_types=[".png", ".jpg", ".jpeg"]) spec_run = gr.Button("Verify match", variant="primary") with gr.Accordion("Settings", open=False): checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT) with gr.Row(): match_threshold = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Match threshold") localization_threshold = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Highlight threshold") with gr.Row(): track_offset = gr.Number(value=0.0, label="Track start offset, seconds") source_offset = gr.Number(value=0.0, label="Source start offset, seconds") with gr.Row(): max_seconds = gr.Slider(5, 180, value=60, step=5, label="Analyze duration per upload, seconds") stride_beats = gr.Slider(1, 16, value=16, step=1, label="Window stride, beats") max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload") result = gr.Markdown() waveform_plot = gr.Plot(label="Waveforms") mel_plot = gr.Plot(label="Mel Spectrograms") # Show waveforms as soon as both audio files are uploaded for audio_input in [track_audio, source_audio]: audio_input.change( preview_waveforms, inputs=[track_audio, source_audio], outputs=[waveform_plot, mel_plot], ) audio_run.click( verify, inputs=[ track_audio, source_audio, checkpoint_path, match_threshold, localization_threshold, track_offset, source_offset, max_seconds, stride_beats, max_windows, ], outputs=[result, waveform_plot, mel_plot], ) spec_run.click( verify_spectrograms, inputs=[track_spec, source_spec, checkpoint_path, match_threshold, localization_threshold], outputs=[result, waveform_plot, mel_plot], ) if __name__ == "__main__": demo.queue(max_size=8).launch()