Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from functools import lru_cache | |
| from pathlib import Path | |
| os.environ.setdefault("AST_MODEL", "MIT/ast-finetuned-audioset-10-10-0.4593") | |
| os.environ.setdefault("SSLAM_MODEL", "ta012/SSLAM_pretrain") | |
| import gradio as gr | |
| import librosa | |
| import matplotlib | |
| import numpy as np | |
| import torch | |
| import torchaudio.transforms as T | |
| from huggingface_hub import hf_hub_download | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector, pair_summary_features | |
| SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000")) | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "dayngerous/whoSampledAST") | |
| def _resolve_checkpoint() -> str: | |
| """Return local checkpoint path, downloading from HF Hub if needed.""" | |
| env_path = os.environ.get("MODEL_CHECKPOINT", "") | |
| for p in [env_path, "models/best.pt", "checkpoints/best.pt", "checkpoints2/best.pt"]: | |
| if p and Path(p).exists(): | |
| return p | |
| try: | |
| return hf_hub_download(repo_id=MODEL_REPO, filename="models/best.pt") | |
| except Exception as exc: | |
| raise FileNotFoundError( | |
| f"No local checkpoint found and download from {MODEL_REPO} failed: {exc}" | |
| ) | |
| def _resolve_meta() -> str: | |
| """Return local test_indices.json path, downloading from HF Hub if needed.""" | |
| for p in ["models/test_indices.json", "checkpoints2/test_indices.json", "checkpoints/test_indices.json"]: | |
| if Path(p).exists(): | |
| return p | |
| try: | |
| return hf_hub_download(repo_id=MODEL_REPO, filename="models/test_indices.json") | |
| except Exception: | |
| return "" | |
| DEFAULT_CHECKPOINT = _resolve_checkpoint() | |
| DEFAULT_META = os.environ.get("MODEL_META", "") or _resolve_meta() | |
| TARGET_FRAMES_PER_BEAT = 50 | |
| N_FFT = 1024 | |
| MEL_HOP = 512 | |
| N_MELS_VIZ = 128 | |
| class AudioClip: | |
| waveform: torch.Tensor | |
| sample_rate: int | |
| offset_sec: float | |
| duration_sec: float | |
| class BeatWindow: | |
| waveform: torch.Tensor | |
| start_sec: float | |
| end_sec: float | |
| beat_intervals: list[tuple[float, float]] | |
| def _format_time(seconds: float) -> str: | |
| seconds = max(0.0, float(seconds)) | |
| minutes = int(seconds // 60) | |
| rem = seconds - minutes * 60 | |
| return f"{minutes}:{rem:04.1f}" | |
| def _format_intervals(intervals: list[tuple[float, float]], limit: int = 4) -> str: | |
| if not intervals: | |
| return "none" | |
| shown = ", ".join(f"{_format_time(a)}-{_format_time(b)}" for a, b in intervals[:limit]) | |
| if len(intervals) > limit: | |
| shown += f", +{len(intervals) - limit} more" | |
| return shown | |
| def _merge_intervals(intervals: list[tuple[float, float]], gap: float = 0.05) -> list[tuple[float, float]]: | |
| if not intervals: | |
| return [] | |
| ordered = sorted((float(a), float(b)) for a, b in intervals if b > a) | |
| if not ordered: | |
| return [] | |
| merged = [ordered[0]] | |
| for start, end in ordered[1:]: | |
| prev_start, prev_end = merged[-1] | |
| if start <= prev_end + gap: | |
| merged[-1] = (prev_start, max(prev_end, end)) | |
| else: | |
| merged.append((start, end)) | |
| return merged | |
| def _load_args(checkpoint_path: Path) -> dict: | |
| meta_path = Path(DEFAULT_META) if DEFAULT_META else checkpoint_path.parent / "test_indices.json" | |
| args = {} | |
| if meta_path.exists(): | |
| with open(meta_path) as f: | |
| args = json.load(f).get("args", {}) | |
| args.setdefault("backbone", os.environ.get("MODEL_BACKBONE", "ast")) | |
| args.setdefault("ast_model", os.environ.get("AST_MODEL")) | |
| args.setdefault("bars", int(os.environ.get("MODEL_BARS", "4"))) | |
| args.setdefault("n_mels", int(os.environ.get("MODEL_N_MELS", "128"))) | |
| args.setdefault("sample_rate", SAMPLE_RATE) | |
| return args | |
| def _build_model(args: dict, device: torch.device): | |
| beats_per_window = int(args.get("bars", 4)) * 4 | |
| n_mels = int(args.get("n_mels", 128)) | |
| backbone = args.get("backbone", "ast") | |
| if backbone == "ast": | |
| model = SampleDetector( | |
| model_name=args.get("ast_model", os.environ["AST_MODEL"]), | |
| freeze_encoder=True, | |
| beats_per_window=beats_per_window, | |
| n_mels=n_mels, | |
| ) | |
| elif backbone == "sslam": | |
| model = SSLAMSampleDetector( | |
| freeze_encoder=True, | |
| beats_per_window=beats_per_window, | |
| n_mels=n_mels, | |
| ) | |
| else: | |
| model = CNNSampleDetector(beats_per_window=beats_per_window, n_mels=n_mels) | |
| return model.to(device) | |
| def _load_model(checkpoint_path: str): | |
| path = Path(checkpoint_path) | |
| if not path.exists(): | |
| raise FileNotFoundError( | |
| f"Checkpoint not found: {path}. Set MODEL_CHECKPOINT or place a checkpoint at models/best.pt." | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| args = _load_args(path) | |
| model = _build_model(args, device) | |
| ckpt = torch.load(path, map_location=device) | |
| state = ckpt.get("model_state", ckpt) | |
| pair_head_loaded = any(k.startswith("pair_mask_head.") for k in state) | |
| missing, unexpected = model.load_state_dict(state, strict=False) | |
| model.eval() | |
| return { | |
| "model": model, | |
| "args": args, | |
| "device": device, | |
| "epoch": ckpt.get("epoch", "?"), | |
| "pair_head_loaded": pair_head_loaded, | |
| "missing": missing, | |
| "unexpected": unexpected, | |
| } | |
| def _load_audio(path: str, offset_sec: float, max_seconds: float) -> AudioClip: | |
| if not path: | |
| raise gr.Error("Upload both audio files before running verification.") | |
| audio, sr = librosa.load(path, sr=SAMPLE_RATE, mono=True) | |
| waveform = torch.from_numpy(audio).float() | |
| offset_sec = max(0.0, float(offset_sec or 0.0)) | |
| max_seconds = max(1.0, float(max_seconds or 1.0)) | |
| start = min(int(offset_sec * sr), max(waveform.numel() - 1, 0)) | |
| end = min(start + int(max_seconds * sr), waveform.numel()) | |
| waveform = waveform[start:end].float().contiguous() | |
| if waveform.numel() < sr // 4: | |
| raise gr.Error("Each upload must contain at least 0.25 seconds of audio after offset trimming.") | |
| peak = waveform.abs().max().clamp_min(1e-6) | |
| waveform = waveform / peak | |
| return AudioClip( | |
| waveform=waveform, | |
| sample_rate=sr, | |
| offset_sec=offset_sec, | |
| duration_sec=waveform.numel() / sr, | |
| ) | |
| def _estimate_beats(waveform: torch.Tensor, sample_rate: int) -> tuple[float, np.ndarray]: | |
| y = waveform.detach().cpu().numpy().astype(np.float32) | |
| tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sample_rate, hop_length=512) | |
| bpm = float(np.atleast_1d(tempo)[0]) if np.size(tempo) else 120.0 | |
| if not np.isfinite(bpm) or bpm <= 0: | |
| bpm = 120.0 | |
| bpm = float(np.clip(bpm, 60.0, 200.0)) | |
| beat_samples = librosa.frames_to_samples(beat_frames, hop_length=512) | |
| beat_samples = beat_samples[(beat_samples >= 0) & (beat_samples < waveform.numel())] | |
| if len(beat_samples) < 2: | |
| step = max(1, int(round(sample_rate * 60.0 / bpm))) | |
| beat_samples = np.arange(0, waveform.numel(), step, dtype=np.int64) | |
| elif beat_samples[0] > sample_rate * 60.0 / bpm: | |
| beat_samples = np.insert(beat_samples, 0, 0) | |
| return bpm, beat_samples.astype(np.int64) | |
| def _to_mel(waveform: torch.Tensor, bpm: float, args: dict) -> torch.Tensor: | |
| sample_rate = int(args.get("sample_rate", SAMPLE_RATE)) | |
| n_mels = int(args.get("n_mels", 128)) | |
| bars = int(args.get("bars", 4)) | |
| fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT | |
| hop = max(1, round(60 * sample_rate / (bpm * TARGET_FRAMES_PER_BEAT))) | |
| mel_transform = T.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=N_FFT, | |
| hop_length=hop, | |
| n_mels=n_mels, | |
| power=2.0, | |
| ) | |
| amp_to_db = T.AmplitudeToDB(stype="power", top_db=80) | |
| mel = amp_to_db(mel_transform(waveform)).T | |
| if mel.shape[0] > fixed_frames: | |
| mel = mel[:fixed_frames] | |
| elif mel.shape[0] < fixed_frames: | |
| mel = torch.cat([mel, torch.zeros(fixed_frames - mel.shape[0], mel.shape[1])], dim=0) | |
| mel = (mel - mel.mean()) / (mel.std() + 1e-6) | |
| return mel.unsqueeze(0) | |
| def _make_windows( | |
| clip: AudioClip, | |
| bpm: float, | |
| beat_samples: np.ndarray, | |
| args: dict, | |
| stride_beats: int, | |
| max_windows: int, | |
| ) -> list[BeatWindow]: | |
| bars = int(args.get("bars", 4)) | |
| beats_per_window = bars * 4 | |
| window_samples = max(1, int(round(beats_per_window * 60.0 / bpm * clip.sample_rate))) | |
| beat_seconds = 60.0 / bpm | |
| stride_beats = max(1, int(stride_beats)) | |
| max_windows = max(1, int(max_windows)) | |
| valid = [i for i in range(0, len(beat_samples), stride_beats) if beat_samples[i] < clip.waveform.numel()] | |
| if not valid: | |
| valid = [0] | |
| if len(valid) > max_windows: | |
| chosen_positions = np.linspace(0, len(valid) - 1, max_windows, dtype=np.int64) | |
| valid = [valid[i] for i in sorted(set(chosen_positions.tolist()))] | |
| windows = [] | |
| for beat_idx in valid: | |
| start_sample = int(beat_samples[beat_idx]) if len(beat_samples) else 0 | |
| chunk = clip.waveform[start_sample:start_sample + window_samples] | |
| if chunk.numel() < window_samples: | |
| chunk = torch.nn.functional.pad(chunk, (0, window_samples - chunk.numel())) | |
| start_sec = clip.offset_sec + start_sample / clip.sample_rate | |
| end_sec = start_sec + window_samples / clip.sample_rate | |
| beat_intervals = [ | |
| (start_sec + i * beat_seconds, start_sec + (i + 1) * beat_seconds) | |
| for i in range(beats_per_window) | |
| ] | |
| windows.append(BeatWindow(chunk, start_sec, end_sec, beat_intervals)) | |
| return windows | |
| def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor: | |
| embs = [] | |
| for start in range(0, mels.shape[0], batch_size): | |
| embs.append(model.encoder(mels[start:start + batch_size])) | |
| return torch.cat(embs, dim=0) | |
| def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor: | |
| """Score each (track, source) window pair using the classifier head (model.forward).""" | |
| track_emb = _encode(model, track_mels, batch_size) | |
| source_emb = _encode(model, source_mels, batch_size) | |
| n_track, n_source = track_emb.shape[0], source_emb.shape[0] | |
| scores = torch.zeros(n_track, n_source, device=track_emb.device) | |
| for i in range(n_track): | |
| for j in range(n_source): | |
| t = track_emb[i:i + 1] | |
| s = source_emb[j:j + 1] | |
| pair_feat = pair_summary_features( | |
| model.pair_mask_head(track_mels[i:i + 1], source_mels[j:j + 1]) | |
| ) | |
| combined = torch.cat([t, s, torch.abs(t - s), t * s, pair_feat], dim=-1) | |
| logits = model.head(combined) | |
| scores[i, j] = torch.softmax(logits, dim=-1)[0, 1] | |
| return scores | |
| def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]: | |
| intervals = [] | |
| for use, (start, end) in zip(mask.tolist(), window.beat_intervals): | |
| if use: | |
| intervals.append((start, min(end, max_end))) | |
| return _merge_intervals(intervals) | |
| def _find_contiguous_beats(pair_probs: np.ndarray, min_beats: int = 2) -> tuple[np.ndarray, np.ndarray]: | |
| """Find the best contiguous diagonal run in the beat similarity matrix. | |
| Searches every diagonal offset (track_beat - source_beat) and uses | |
| Kadane's algorithm to find the highest-scoring contiguous segment along | |
| each diagonal. Returns boolean masks over track and source beats. | |
| """ | |
| n_track, n_source = pair_probs.shape | |
| best_score = -np.inf | |
| best_track_mask = np.zeros(n_track, dtype=bool) | |
| best_source_mask = np.zeros(n_source, dtype=bool) | |
| for d in range(-(n_source - 1), n_track): | |
| # diagonal: track[i], source[i - d] for valid i | |
| i0 = max(0, d) | |
| j0 = max(0, -d) | |
| length = min(n_track - i0, n_source - j0) | |
| if length < min_beats: | |
| continue | |
| diag = pair_probs[i0:i0 + length, j0:j0 + length].diagonal() | |
| # Kadane's max-subarray on the diagonal values | |
| curr_sum = 0.0 | |
| curr_start = 0 | |
| best_sum = -np.inf | |
| seg_start = seg_end = 0 | |
| for k, val in enumerate(diag): | |
| curr_sum += val | |
| if curr_sum > best_sum: | |
| best_sum = curr_sum | |
| seg_start = curr_start | |
| seg_end = k | |
| if curr_sum < 0: | |
| curr_sum = 0.0 | |
| curr_start = k + 1 | |
| seg_len = seg_end - seg_start + 1 | |
| if seg_len < min_beats: | |
| continue | |
| avg_score = best_sum / seg_len | |
| if avg_score > best_score: | |
| best_score = avg_score | |
| track_mask = np.zeros(n_track, dtype=bool) | |
| source_mask = np.zeros(n_source, dtype=bool) | |
| track_mask[i0 + seg_start: i0 + seg_end + 1] = True | |
| source_mask[j0 + seg_start: j0 + seg_end + 1] = True | |
| best_track_mask = track_mask | |
| best_source_mask = source_mask | |
| return best_track_mask, best_source_mask | |
| def _localize_match( | |
| model, | |
| track_mel: torch.Tensor, | |
| source_mel: torch.Tensor, | |
| track_window: BeatWindow, | |
| source_window: BeatWindow, | |
| track_clip: AudioClip, | |
| source_clip: AudioClip, | |
| threshold: float, | |
| pair_head_loaded: bool, | |
| ) -> tuple[list[tuple[float, float]], list[tuple[float, float]], str]: | |
| if not pair_head_loaded: | |
| return ( | |
| [(track_window.start_sec, min(track_window.end_sec, track_clip.offset_sec + track_clip.duration_sec))], | |
| [(source_window.start_sec, min(source_window.end_sec, source_clip.offset_sec + source_clip.duration_sec))], | |
| "The checkpoint does not include a trained pairwise beat head, so the highlight covers the best matching window.", | |
| ) | |
| with torch.inference_mode(): | |
| pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy() | |
| track_mask, source_mask = _find_contiguous_beats(pair_probs, min_beats=2) | |
| # Fall back to top-k individual beats if no contiguous run was found | |
| if not track_mask.any(): | |
| top_k = min(6, pair_probs.size) | |
| flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:] | |
| selected = np.zeros_like(pair_probs, dtype=bool) | |
| selected.reshape(-1)[flat] = True | |
| track_mask = selected.any(axis=1) | |
| source_mask = selected.any(axis=0) | |
| track_regions = _intervals_from_mask( | |
| track_mask, | |
| track_window, | |
| track_clip.offset_sec + track_clip.duration_sec, | |
| ) | |
| source_regions = _intervals_from_mask( | |
| source_mask, | |
| source_window, | |
| source_clip.offset_sec + source_clip.duration_sec, | |
| ) | |
| return track_regions, source_regions, "" | |
| def _draw_waveform(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str): | |
| y = clip.waveform.detach().cpu().numpy() | |
| n = len(y) | |
| points = min(20000, n) | |
| idx = np.linspace(0, n - 1, points, dtype=np.int64) | |
| x = clip.offset_sec + idx / clip.sample_rate | |
| ax.plot(x, y[idx], color="#111827", linewidth=0.55) | |
| for start, end in regions: | |
| ax.axvspan(start, end, color=color, alpha=0.28) | |
| ax.set_title(title, loc="left", fontsize=10) | |
| ax.set_ylabel("Amplitude") | |
| ax.set_xlim(clip.offset_sec, clip.offset_sec + clip.duration_sec) | |
| ax.set_ylim(-1.05, 1.05) | |
| ax.grid(True, alpha=0.18) | |
| def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: str, title: str, matched: bool): | |
| y = clip.waveform.detach().cpu().numpy().astype(np.float32) | |
| mel = librosa.feature.melspectrogram(y=y, sr=clip.sample_rate, n_mels=N_MELS_VIZ, hop_length=MEL_HOP) | |
| mel_db = librosa.power_to_db(mel, ref=np.max) | |
| t_start = clip.offset_sec | |
| t_end = clip.offset_sec + clip.duration_sec | |
| f_max = clip.sample_rate / 2 | |
| ax.imshow( | |
| mel_db, | |
| aspect="auto", | |
| origin="lower", | |
| extent=[t_start, t_end, 0, f_max], | |
| cmap="magma", | |
| interpolation="nearest", | |
| ) | |
| ax.set_title(title, loc="left", fontsize=10) | |
| ax.set_ylabel("Frequency (Hz)") | |
| ax.set_xlim(t_start, t_end) | |
| if regions: | |
| for start, end in regions: | |
| ax.axvspan(start, end, color=color, alpha=0.38 if matched else 0.22, linewidth=0) | |
| if not matched: | |
| ax.text( | |
| 0.5, 0.5, "No Match", | |
| transform=ax.transAxes, | |
| fontsize=18, | |
| color="white", | |
| ha="center", | |
| va="center", | |
| fontweight="bold", | |
| bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65), | |
| ) | |
| def _plot_waveforms( | |
| track_clip: AudioClip, | |
| source_clip: AudioClip, | |
| track_regions: list[tuple[float, float]], | |
| source_regions: list[tuple[float, float]], | |
| score: float | None, | |
| matched: bool, | |
| ) -> plt.Figure: | |
| fig, axes = plt.subplots(2, 1, figsize=(12, 5), sharex=False) | |
| color = "#22c55e" if matched else "#f59e0b" | |
| title_score = "unavailable" if score is None else f"{score:.3f}" | |
| fig.suptitle(f"Best match score: {title_score}" if score is not None else "Waveform preview", fontsize=12) | |
| _draw_waveform(axes[0], track_clip, track_regions, color, "Track / song audio") | |
| _draw_waveform(axes[1], source_clip, source_regions, color, "Source sample audio") | |
| axes[1].set_xlabel("Time in uploaded file (seconds)") | |
| fig.tight_layout() | |
| return fig | |
| def _plot_mels( | |
| track_clip: AudioClip, | |
| source_clip: AudioClip, | |
| track_regions: list[tuple[float, float]], | |
| source_regions: list[tuple[float, float]], | |
| matched: bool, | |
| ) -> plt.Figure: | |
| fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=False) | |
| color = "#22c55e" if matched else "#f59e0b" | |
| _draw_mel(axes[0], track_clip, track_regions, color, "Track mel spectrogram", matched) | |
| _draw_mel(axes[1], source_clip, source_regions, color, "Source mel spectrogram", matched) | |
| axes[1].set_xlabel("Time in uploaded file (seconds)") | |
| fig.tight_layout() | |
| return fig | |
| def _image_to_mel_tensor(image_path: str, args: dict) -> torch.Tensor: | |
| """Load a BPM-normalized mel spectrogram PNG as the model's input tensor.""" | |
| from PIL import Image as PILImage | |
| n_mels = int(args.get("n_mels", 128)) | |
| bars = int(args.get("bars", 4)) | |
| fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT | |
| img = PILImage.open(image_path).convert("L") | |
| img = img.resize((fixed_frames, n_mels), PILImage.LANCZOS) | |
| arr = np.array(img, dtype=np.float32) / 255.0 # [n_mels, fixed_frames] | |
| # Image was saved with origin="lower": row 0 in pixels = highest freq bin | |
| arr = arr[::-1] # flip so row 0 = lowest mel bin | |
| mel = torch.from_numpy(arr.T.copy()).float() # [fixed_frames, n_mels] | |
| mel = (mel - mel.mean()) / (mel.std() + 1e-6) | |
| return mel.unsqueeze(0) # [1, fixed_frames, n_mels] | |
| def _plot_spectrograms_with_mask( | |
| track_img_path: str, | |
| source_img_path: str, | |
| track_beats: np.ndarray, | |
| source_beats: np.ndarray, | |
| score: float, | |
| matched: bool, | |
| ) -> plt.Figure: | |
| from PIL import Image as PILImage | |
| color = "#22c55e" if matched else "#f59e0b" | |
| fig, axes = plt.subplots(2, 1, figsize=(12, 5)) | |
| fig.suptitle(f"Score: {score:.3f}", fontsize=12) | |
| for ax, img_path, label, beats in [ | |
| (axes[0], track_img_path, "Track spectrogram", track_beats), | |
| (axes[1], source_img_path, "Source spectrogram", source_beats), | |
| ]: | |
| img = np.array(PILImage.open(img_path).convert("RGB")) | |
| W = img.shape[1] | |
| ax.imshow(img, aspect="auto") | |
| ax.set_title(label, loc="left", fontsize=10) | |
| ax.set_xlabel("Time frame (BPM-normalized)") | |
| ax.set_ylabel("Mel bin") | |
| ax.tick_params(labelsize=7) | |
| if beats is not None and beats.any(): | |
| n_beats = len(beats) | |
| beat_w = W / n_beats | |
| for i, active in enumerate(beats): | |
| if active: | |
| ax.axvspan(i * beat_w, (i + 1) * beat_w, color=color, alpha=0.38, linewidth=0) | |
| if not matched: | |
| ax.text(0.5, 0.5, "No Match", transform=ax.transAxes, | |
| fontsize=18, color="white", ha="center", va="center", fontweight="bold", | |
| bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65)) | |
| fig.tight_layout() | |
| return fig | |
| def _norm_file_list(files) -> list[str]: | |
| """Normalise whatever gr.File returns into a flat list of path strings.""" | |
| if not files: | |
| return [] | |
| if isinstance(files, (str, bytes)): | |
| return [str(files)] | |
| paths = [] | |
| for f in (files if isinstance(files, list) else [files]): | |
| if isinstance(f, str): | |
| paths.append(f) | |
| elif hasattr(f, "name"): | |
| paths.append(f.name) | |
| return paths | |
| def verify_spectrograms( | |
| track_specs, | |
| source_specs, | |
| checkpoint_path, | |
| match_threshold, | |
| localization_threshold, | |
| ): | |
| track_paths = _norm_file_list(track_specs) | |
| source_paths = _norm_file_list(source_specs) | |
| if not track_paths or not source_paths: | |
| raise gr.Error("Upload at least one spectrogram image for both track and source.") | |
| try: | |
| loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT) | |
| except Exception as exc: | |
| return f"Model could not be loaded: {exc}", None, None | |
| model = loaded["model"] | |
| args = loaded["args"] | |
| device = loaded["device"] | |
| batch_size = 8 if device.type == "cpu" else 32 | |
| track_mels = torch.stack([_image_to_mel_tensor(p, args) for p in track_paths]).to(device) | |
| source_mels = torch.stack([_image_to_mel_tensor(p, args) for p in source_paths]).to(device) | |
| with torch.inference_mode(): | |
| score_matrix = _score_pairs(model, track_mels, source_mels, batch_size) | |
| best_flat = int(torch.argmax(score_matrix).item()) | |
| best_track_idx = best_flat // score_matrix.shape[1] | |
| best_source_idx = best_flat % score_matrix.shape[1] | |
| best_score = float(score_matrix[best_track_idx, best_source_idx]) | |
| matched = best_score >= float(match_threshold) | |
| best_track_mel = track_mels[best_track_idx:best_track_idx + 1] | |
| best_source_mel = source_mels[best_source_idx:best_source_idx + 1] | |
| beats_per_window = int(args.get("bars", 4)) * 4 | |
| if loaded["pair_head_loaded"]: | |
| with torch.inference_mode(): | |
| pair_probs = torch.sigmoid(model.pair_mask_head(best_track_mel, best_source_mel))[0].cpu().numpy() | |
| track_beats, source_beats = _find_contiguous_beats(pair_probs, min_beats=2) | |
| if not track_beats.any(): | |
| track_beats = np.ones(beats_per_window, dtype=bool) | |
| source_beats = np.ones(beats_per_window, dtype=bool) | |
| else: | |
| track_beats = np.ones(beats_per_window, dtype=bool) | |
| source_beats = np.ones(beats_per_window, dtype=bool) | |
| spec_fig = _plot_spectrograms_with_mask( | |
| track_paths[best_track_idx], source_paths[best_source_idx], | |
| track_beats, source_beats, best_score, matched, | |
| ) | |
| verdict = "Likely match" if matched else "No match" | |
| details = [ | |
| f"**{verdict}**", | |
| f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).", | |
| f"Best window: track `w{best_track_idx:02d}` × source `w{best_source_idx:02d}` " | |
| f"({len(track_paths)} × {len(source_paths)} combinations tried).", | |
| f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.", | |
| ] | |
| if not loaded["pair_head_loaded"]: | |
| details.append("Checkpoint does not include a trained pairwise beat head.") | |
| return "\n\n".join(details), None, spec_fig | |
| def preview_waveforms(track_audio, source_audio): | |
| if not track_audio or not source_audio: | |
| return None, None | |
| try: | |
| track_clip = _load_audio(track_audio, 0.0, 120.0) | |
| source_clip = _load_audio(source_audio, 0.0, 120.0) | |
| wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False) | |
| mfig = _plot_mels(track_clip, source_clip, [], [], False) | |
| return wfig, mfig | |
| except Exception: | |
| return None, None | |
| def verify( | |
| track_audio, | |
| source_audio, | |
| checkpoint_path, | |
| match_threshold, | |
| localization_threshold, | |
| track_offset, | |
| source_offset, | |
| max_seconds, | |
| stride_beats, | |
| max_windows, | |
| ): | |
| try: | |
| track_clip = _load_audio(track_audio, track_offset, max_seconds) | |
| source_clip = _load_audio(source_audio, source_offset, max_seconds) | |
| except Exception as exc: | |
| raise gr.Error(str(exc)) | |
| try: | |
| loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT) | |
| except Exception as exc: | |
| wfig = _plot_waveforms(track_clip, source_clip, [], [], None, False) | |
| mfig = _plot_mels(track_clip, source_clip, [], [], False) | |
| return f"Model could not be loaded: {exc}", wfig, mfig | |
| model = loaded["model"] | |
| args = loaded["args"] | |
| device = loaded["device"] | |
| batch_size = 8 if device.type == "cpu" else 32 | |
| track_bpm, track_beats = _estimate_beats(track_clip.waveform, track_clip.sample_rate) | |
| source_bpm, source_beats = _estimate_beats(source_clip.waveform, source_clip.sample_rate) | |
| track_windows = _make_windows(track_clip, track_bpm, track_beats, args, stride_beats, max_windows) | |
| source_windows = _make_windows(source_clip, source_bpm, source_beats, args, stride_beats, max_windows) | |
| track_mels = torch.stack([_to_mel(w.waveform, track_bpm, args) for w in track_windows]).to(device) | |
| source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device) | |
| with torch.inference_mode(): | |
| score_matrix = _score_pairs(model, track_mels, source_mels, batch_size) | |
| best_flat = int(torch.argmax(score_matrix).item()) | |
| best_track = best_flat // score_matrix.shape[1] | |
| best_source = best_flat % score_matrix.shape[1] | |
| best_score = float(score_matrix[best_track, best_source].detach().cpu()) | |
| matched = best_score >= float(match_threshold) | |
| track_regions, source_regions, note = _localize_match( | |
| model, | |
| track_mels[best_track:best_track + 1], | |
| source_mels[best_source:best_source + 1], | |
| track_windows[best_track], | |
| source_windows[best_source], | |
| track_clip, | |
| source_clip, | |
| localization_threshold, | |
| loaded["pair_head_loaded"], | |
| ) | |
| if not track_regions or not source_regions: | |
| matched = False | |
| track_regions = [] | |
| source_regions = [] | |
| if not note: | |
| note = "Localization was inconclusive, so the result is treated as no match." | |
| wfig = _plot_waveforms(track_clip, source_clip, track_regions, source_regions, best_score, matched) | |
| mfig = _plot_mels(track_clip, source_clip, track_regions, source_regions, matched) | |
| verdict = "Likely match" if matched else "No match" | |
| details = [ | |
| f"**{verdict}**", | |
| f"Classifier score: `{best_score:.3f}` (threshold `{float(match_threshold):.2f}`).", | |
| f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.", | |
| f"{'Matched' if matched else 'Proposed'} track section(s): {_format_intervals(track_regions)}.", | |
| f"{'Matched' if matched else 'Proposed'} source section(s): {_format_intervals(source_regions)}.", | |
| f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.", | |
| ] | |
| if note: | |
| details.append(note) | |
| if loaded["missing"]: | |
| details.append(f"Missing checkpoint keys initialized at load time: `{len(loaded['missing'])}`.") | |
| return "\n\n".join(details), wfig, mfig | |
| with gr.Blocks(title="Sample Match Verifier") as demo: | |
| gr.Markdown("# Sample Match Verifier") | |
| gr.Markdown( | |
| "Upload a track and a possible source sample. " | |
| "Click **Verify match** to run the model." | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Audio"): | |
| gr.Markdown("Waveforms appear immediately on upload.") | |
| with gr.Row(): | |
| track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"]) | |
| source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"]) | |
| audio_run = gr.Button("Verify match", variant="primary") | |
| with gr.Tab("Spectrogram"): | |
| gr.Markdown( | |
| "Upload the window images " | |
| "(`*_w00.png`, `*_w01.png`, …). Select **all windows** for each file — " | |
| "the app will score every combination and return the best match." | |
| ) | |
| with gr.Row(): | |
| track_spec = gr.File(label="Track spectrogram windows", file_count="multiple", | |
| file_types=[".png", ".jpg", ".jpeg"]) | |
| source_spec = gr.File(label="Source spectrogram windows", file_count="multiple", | |
| file_types=[".png", ".jpg", ".jpeg"]) | |
| spec_run = gr.Button("Verify match", variant="primary") | |
| with gr.Accordion("Settings", open=False): | |
| checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT) | |
| with gr.Row(): | |
| match_threshold = gr.Slider(0.0, 1.0, value=0.50, step=0.01, label="Match threshold") | |
| localization_threshold = gr.Slider(0.0, 1.0, value=0.55, step=0.01, label="Highlight threshold") | |
| with gr.Row(): | |
| track_offset = gr.Number(value=0.0, label="Track start offset, seconds") | |
| source_offset = gr.Number(value=0.0, label="Source start offset, seconds") | |
| with gr.Row(): | |
| max_seconds = gr.Slider(5, 180, value=60, step=5, label="Analyze duration per upload, seconds") | |
| stride_beats = gr.Slider(1, 16, value=16, step=1, label="Window stride, beats") | |
| max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload") | |
| result = gr.Markdown() | |
| waveform_plot = gr.Plot(label="Waveforms") | |
| mel_plot = gr.Plot(label="Mel Spectrograms") | |
| # Show waveforms as soon as both audio files are uploaded | |
| for audio_input in [track_audio, source_audio]: | |
| audio_input.change( | |
| preview_waveforms, | |
| inputs=[track_audio, source_audio], | |
| outputs=[waveform_plot, mel_plot], | |
| ) | |
| audio_run.click( | |
| verify, | |
| inputs=[ | |
| track_audio, | |
| source_audio, | |
| checkpoint_path, | |
| match_threshold, | |
| localization_threshold, | |
| track_offset, | |
| source_offset, | |
| max_seconds, | |
| stride_beats, | |
| max_windows, | |
| ], | |
| outputs=[result, waveform_plot, mel_plot], | |
| ) | |
| spec_run.click( | |
| verify_spectrograms, | |
| inputs=[track_spec, source_spec, checkpoint_path, match_threshold, localization_threshold], | |
| outputs=[result, waveform_plot, mel_plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=8).launch() | |