Spaces:
Sleeping
Sleeping
Commit ·
c2160cd
1
Parent(s): df731c1
Add spectrogram image input tab
Browse files
app.py
CHANGED
|
@@ -501,6 +501,130 @@ def _plot_mels(
|
|
| 501 |
return fig
|
| 502 |
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
def preview_waveforms(track_audio, source_audio):
|
| 505 |
if not track_audio or not source_audio:
|
| 506 |
return None, None
|
|
@@ -595,13 +719,28 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
|
|
| 595 |
gr.Markdown("# Sample Match Verifier")
|
| 596 |
gr.Markdown(
|
| 597 |
"Upload a track and a possible source sample. "
|
| 598 |
-
"
|
| 599 |
-
"
|
|
|
|
| 600 |
)
|
| 601 |
|
| 602 |
-
with gr.
|
| 603 |
-
|
| 604 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
with gr.Accordion("Settings", open=False):
|
| 607 |
checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
|
|
@@ -616,13 +755,11 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
|
|
| 616 |
stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
|
| 617 |
max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
|
| 618 |
|
| 619 |
-
run = gr.Button("Verify match", variant="primary")
|
| 620 |
result = gr.Markdown()
|
| 621 |
-
|
| 622 |
waveform_plot = gr.Plot(label="Waveforms")
|
| 623 |
mel_plot = gr.Plot(label="Mel Spectrograms")
|
| 624 |
|
| 625 |
-
# Show waveforms as soon as both files are uploaded
|
| 626 |
for audio_input in [track_audio, source_audio]:
|
| 627 |
audio_input.change(
|
| 628 |
preview_waveforms,
|
|
@@ -630,7 +767,7 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
|
|
| 630 |
outputs=[waveform_plot, mel_plot],
|
| 631 |
)
|
| 632 |
|
| 633 |
-
|
| 634 |
verify,
|
| 635 |
inputs=[
|
| 636 |
track_audio,
|
|
@@ -647,6 +784,12 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
|
|
| 647 |
outputs=[result, waveform_plot, mel_plot],
|
| 648 |
)
|
| 649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
|
| 651 |
if __name__ == "__main__":
|
| 652 |
demo.queue(max_size=8).launch()
|
|
|
|
| 501 |
return fig
|
| 502 |
|
| 503 |
|
| 504 |
+
def _image_to_mel_tensor(image_path: str, args: dict) -> torch.Tensor:
|
| 505 |
+
"""Reconstruct the model's input tensor from a saved BPM-normalized mel spectrogram PNG."""
|
| 506 |
+
from PIL import Image as PILImage
|
| 507 |
+
n_mels = int(args.get("n_mels", 128))
|
| 508 |
+
bars = int(args.get("bars", 4))
|
| 509 |
+
fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
|
| 510 |
+
|
| 511 |
+
img = PILImage.open(image_path).convert("RGB")
|
| 512 |
+
img = img.resize((fixed_frames, n_mels), PILImage.LANCZOS)
|
| 513 |
+
arr = np.array(img, dtype=np.float32) / 255.0 # [n_mels, fixed_frames, 3]
|
| 514 |
+
|
| 515 |
+
# Invert magma via luminance — monotone proxy for the original mel value
|
| 516 |
+
luminance = 0.2126 * arr[:, :, 0] + 0.7152 * arr[:, :, 1] + 0.0722 * arr[:, :, 2]
|
| 517 |
+
# PNG rows are top-to-bottom; origin="lower" means row 0 in data = bottom of image
|
| 518 |
+
luminance = luminance[::-1] # flip so row 0 = lowest mel bin
|
| 519 |
+
mel = torch.from_numpy(luminance.T.copy()).float() # [fixed_frames, n_mels]
|
| 520 |
+
mel = (mel - mel.mean()) / (mel.std() + 1e-6)
|
| 521 |
+
return mel.unsqueeze(0) # [1, fixed_frames, n_mels]
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def _plot_spectrograms_with_mask(
|
| 525 |
+
track_img_path: str,
|
| 526 |
+
source_img_path: str,
|
| 527 |
+
track_beats: np.ndarray,
|
| 528 |
+
source_beats: np.ndarray,
|
| 529 |
+
score: float,
|
| 530 |
+
matched: bool,
|
| 531 |
+
) -> plt.Figure:
|
| 532 |
+
from PIL import Image as PILImage
|
| 533 |
+
color = "#22c55e" if matched else "#f59e0b"
|
| 534 |
+
fig, axes = plt.subplots(2, 1, figsize=(12, 5))
|
| 535 |
+
fig.suptitle(f"Score: {score:.3f}", fontsize=12)
|
| 536 |
+
|
| 537 |
+
for ax, img_path, label, beats in [
|
| 538 |
+
(axes[0], track_img_path, "Track spectrogram", track_beats),
|
| 539 |
+
(axes[1], source_img_path, "Source spectrogram", source_beats),
|
| 540 |
+
]:
|
| 541 |
+
img = np.array(PILImage.open(img_path).convert("RGB"))
|
| 542 |
+
W = img.shape[1]
|
| 543 |
+
ax.imshow(img, aspect="auto")
|
| 544 |
+
ax.set_title(label, loc="left", fontsize=10)
|
| 545 |
+
ax.set_xlabel("Time frame (BPM-normalized)")
|
| 546 |
+
ax.set_ylabel("Mel bin")
|
| 547 |
+
ax.tick_params(labelsize=7)
|
| 548 |
+
|
| 549 |
+
if beats is not None and beats.any():
|
| 550 |
+
n_beats = len(beats)
|
| 551 |
+
beat_w = W / n_beats
|
| 552 |
+
for i, active in enumerate(beats):
|
| 553 |
+
if active:
|
| 554 |
+
ax.axvspan(i * beat_w, (i + 1) * beat_w, color=color, alpha=0.38, linewidth=0)
|
| 555 |
+
|
| 556 |
+
if not matched:
|
| 557 |
+
ax.text(0.5, 0.5, "No Match", transform=ax.transAxes,
|
| 558 |
+
fontsize=18, color="white", ha="center", va="center", fontweight="bold",
|
| 559 |
+
bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65))
|
| 560 |
+
|
| 561 |
+
fig.tight_layout()
|
| 562 |
+
return fig
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def verify_spectrograms(
|
| 566 |
+
track_spec_path,
|
| 567 |
+
source_spec_path,
|
| 568 |
+
checkpoint_path,
|
| 569 |
+
match_threshold,
|
| 570 |
+
localization_threshold,
|
| 571 |
+
):
|
| 572 |
+
if not track_spec_path or not source_spec_path:
|
| 573 |
+
raise gr.Error("Upload both spectrogram images before running verification.")
|
| 574 |
+
|
| 575 |
+
try:
|
| 576 |
+
loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
|
| 577 |
+
except Exception as exc:
|
| 578 |
+
return f"Model could not be loaded: {exc}", None, None
|
| 579 |
+
|
| 580 |
+
model = loaded["model"]
|
| 581 |
+
args = loaded["args"]
|
| 582 |
+
device = loaded["device"]
|
| 583 |
+
|
| 584 |
+
track_mel = _image_to_mel_tensor(track_spec_path, args).unsqueeze(0).to(device)
|
| 585 |
+
source_mel = _image_to_mel_tensor(source_spec_path, args).unsqueeze(0).to(device)
|
| 586 |
+
|
| 587 |
+
with torch.inference_mode():
|
| 588 |
+
track_emb = model.encoder(track_mel)
|
| 589 |
+
source_emb = model.encoder(source_mel)
|
| 590 |
+
pair_feat = pair_summary_features(model.pair_mask_head(track_mel, source_mel))
|
| 591 |
+
combined = torch.cat(
|
| 592 |
+
[track_emb, source_emb, torch.abs(track_emb - source_emb), track_emb * source_emb, pair_feat],
|
| 593 |
+
dim=-1,
|
| 594 |
+
)
|
| 595 |
+
score = torch.softmax(model.head(combined), dim=-1)[0, 1].item()
|
| 596 |
+
|
| 597 |
+
matched = score >= float(match_threshold)
|
| 598 |
+
beats_per_window = int(args.get("bars", 4)) * 4
|
| 599 |
+
|
| 600 |
+
if loaded["pair_head_loaded"]:
|
| 601 |
+
with torch.inference_mode():
|
| 602 |
+
pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].cpu().numpy()
|
| 603 |
+
track_beats, source_beats = _find_contiguous_beats(pair_probs, min_beats=2)
|
| 604 |
+
if not track_beats.any():
|
| 605 |
+
track_beats = np.ones(beats_per_window, dtype=bool)
|
| 606 |
+
source_beats = np.ones(beats_per_window, dtype=bool)
|
| 607 |
+
else:
|
| 608 |
+
track_beats = np.ones(beats_per_window, dtype=bool)
|
| 609 |
+
source_beats = np.ones(beats_per_window, dtype=bool)
|
| 610 |
+
|
| 611 |
+
spec_fig = _plot_spectrograms_with_mask(
|
| 612 |
+
track_spec_path, source_spec_path,
|
| 613 |
+
track_beats, source_beats,
|
| 614 |
+
score, matched,
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
verdict = "Likely match" if matched else "No match"
|
| 618 |
+
details = [
|
| 619 |
+
f"**{verdict}**",
|
| 620 |
+
f"Classifier score: `{score:.3f}` (threshold `{float(match_threshold):.2f}`).",
|
| 621 |
+
f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
|
| 622 |
+
]
|
| 623 |
+
if not loaded["pair_head_loaded"]:
|
| 624 |
+
details.append("Checkpoint does not include a trained pairwise beat head.")
|
| 625 |
+
return "\n\n".join(details), None, spec_fig
|
| 626 |
+
|
| 627 |
+
|
| 628 |
def preview_waveforms(track_audio, source_audio):
|
| 629 |
if not track_audio or not source_audio:
|
| 630 |
return None, None
|
|
|
|
| 719 |
gr.Markdown("# Sample Match Verifier")
|
| 720 |
gr.Markdown(
|
| 721 |
"Upload a track and a possible source sample. "
|
| 722 |
+
"Use the **Audio** tab for raw audio files, or the **Spectrogram** tab to upload "
|
| 723 |
+
"pre-computed BPM-normalized mel spectrogram images. "
|
| 724 |
+
"Click **Verify match** to run the model."
|
| 725 |
)
|
| 726 |
|
| 727 |
+
with gr.Tabs():
|
| 728 |
+
with gr.Tab("Audio"):
|
| 729 |
+
gr.Markdown("Waveforms appear immediately on upload.")
|
| 730 |
+
with gr.Row():
|
| 731 |
+
track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
|
| 732 |
+
source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
|
| 733 |
+
audio_run = gr.Button("Verify match", variant="primary")
|
| 734 |
+
|
| 735 |
+
with gr.Tab("Spectrogram"):
|
| 736 |
+
gr.Markdown(
|
| 737 |
+
"Upload BPM-normalized mel spectrogram images (e.g. from `make_test_spectrograms.py`). "
|
| 738 |
+
"Offset / duration / stride settings are ignored in this mode."
|
| 739 |
+
)
|
| 740 |
+
with gr.Row():
|
| 741 |
+
track_spec = gr.Image(label="Track spectrogram", type="filepath", sources=["upload"])
|
| 742 |
+
source_spec = gr.Image(label="Source spectrogram", type="filepath", sources=["upload"])
|
| 743 |
+
spec_run = gr.Button("Verify match", variant="primary")
|
| 744 |
|
| 745 |
with gr.Accordion("Settings", open=False):
|
| 746 |
checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
|
|
|
|
| 755 |
stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
|
| 756 |
max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
|
| 757 |
|
|
|
|
| 758 |
result = gr.Markdown()
|
|
|
|
| 759 |
waveform_plot = gr.Plot(label="Waveforms")
|
| 760 |
mel_plot = gr.Plot(label="Mel Spectrograms")
|
| 761 |
|
| 762 |
+
# Show waveforms as soon as both audio files are uploaded
|
| 763 |
for audio_input in [track_audio, source_audio]:
|
| 764 |
audio_input.change(
|
| 765 |
preview_waveforms,
|
|
|
|
| 767 |
outputs=[waveform_plot, mel_plot],
|
| 768 |
)
|
| 769 |
|
| 770 |
+
audio_run.click(
|
| 771 |
verify,
|
| 772 |
inputs=[
|
| 773 |
track_audio,
|
|
|
|
| 784 |
outputs=[result, waveform_plot, mel_plot],
|
| 785 |
)
|
| 786 |
|
| 787 |
+
spec_run.click(
|
| 788 |
+
verify_spectrograms,
|
| 789 |
+
inputs=[track_spec, source_spec, checkpoint_path, match_threshold, localization_threshold],
|
| 790 |
+
outputs=[result, waveform_plot, mel_plot],
|
| 791 |
+
)
|
| 792 |
+
|
| 793 |
|
| 794 |
if __name__ == "__main__":
|
| 795 |
demo.queue(max_size=8).launch()
|