dayngerous commited on
Commit
c2160cd
·
1 Parent(s): df731c1

Add spectrogram image input tab

Browse files
Files changed (1) hide show
  1. app.py +152 -9
app.py CHANGED
@@ -501,6 +501,130 @@ def _plot_mels(
501
  return fig
502
 
503
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  def preview_waveforms(track_audio, source_audio):
505
  if not track_audio or not source_audio:
506
  return None, None
@@ -595,13 +719,28 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
595
  gr.Markdown("# Sample Match Verifier")
596
  gr.Markdown(
597
  "Upload a track and a possible source sample. "
598
- "Waveforms appear immediately on upload. "
599
- "Click **Verify match** to run the model and highlight sampled sections."
 
600
  )
601
 
602
- with gr.Row():
603
- track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
604
- source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
 
606
  with gr.Accordion("Settings", open=False):
607
  checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
@@ -616,13 +755,11 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
616
  stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
617
  max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
618
 
619
- run = gr.Button("Verify match", variant="primary")
620
  result = gr.Markdown()
621
-
622
  waveform_plot = gr.Plot(label="Waveforms")
623
  mel_plot = gr.Plot(label="Mel Spectrograms")
624
 
625
- # Show waveforms as soon as both files are uploaded
626
  for audio_input in [track_audio, source_audio]:
627
  audio_input.change(
628
  preview_waveforms,
@@ -630,7 +767,7 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
630
  outputs=[waveform_plot, mel_plot],
631
  )
632
 
633
- run.click(
634
  verify,
635
  inputs=[
636
  track_audio,
@@ -647,6 +784,12 @@ with gr.Blocks(title="Sample Match Verifier") as demo:
647
  outputs=[result, waveform_plot, mel_plot],
648
  )
649
 
 
 
 
 
 
 
650
 
651
  if __name__ == "__main__":
652
  demo.queue(max_size=8).launch()
 
501
  return fig
502
 
503
 
504
+ def _image_to_mel_tensor(image_path: str, args: dict) -> torch.Tensor:
505
+ """Reconstruct the model's input tensor from a saved BPM-normalized mel spectrogram PNG."""
506
+ from PIL import Image as PILImage
507
+ n_mels = int(args.get("n_mels", 128))
508
+ bars = int(args.get("bars", 4))
509
+ fixed_frames = bars * 4 * TARGET_FRAMES_PER_BEAT
510
+
511
+ img = PILImage.open(image_path).convert("RGB")
512
+ img = img.resize((fixed_frames, n_mels), PILImage.LANCZOS)
513
+ arr = np.array(img, dtype=np.float32) / 255.0 # [n_mels, fixed_frames, 3]
514
+
515
+ # Invert magma via luminance — monotone proxy for the original mel value
516
+ luminance = 0.2126 * arr[:, :, 0] + 0.7152 * arr[:, :, 1] + 0.0722 * arr[:, :, 2]
517
+ # PNG rows are top-to-bottom; origin="lower" means row 0 in data = bottom of image
518
+ luminance = luminance[::-1] # flip so row 0 = lowest mel bin
519
+ mel = torch.from_numpy(luminance.T.copy()).float() # [fixed_frames, n_mels]
520
+ mel = (mel - mel.mean()) / (mel.std() + 1e-6)
521
+ return mel.unsqueeze(0) # [1, fixed_frames, n_mels]
522
+
523
+
524
+ def _plot_spectrograms_with_mask(
525
+ track_img_path: str,
526
+ source_img_path: str,
527
+ track_beats: np.ndarray,
528
+ source_beats: np.ndarray,
529
+ score: float,
530
+ matched: bool,
531
+ ) -> plt.Figure:
532
+ from PIL import Image as PILImage
533
+ color = "#22c55e" if matched else "#f59e0b"
534
+ fig, axes = plt.subplots(2, 1, figsize=(12, 5))
535
+ fig.suptitle(f"Score: {score:.3f}", fontsize=12)
536
+
537
+ for ax, img_path, label, beats in [
538
+ (axes[0], track_img_path, "Track spectrogram", track_beats),
539
+ (axes[1], source_img_path, "Source spectrogram", source_beats),
540
+ ]:
541
+ img = np.array(PILImage.open(img_path).convert("RGB"))
542
+ W = img.shape[1]
543
+ ax.imshow(img, aspect="auto")
544
+ ax.set_title(label, loc="left", fontsize=10)
545
+ ax.set_xlabel("Time frame (BPM-normalized)")
546
+ ax.set_ylabel("Mel bin")
547
+ ax.tick_params(labelsize=7)
548
+
549
+ if beats is not None and beats.any():
550
+ n_beats = len(beats)
551
+ beat_w = W / n_beats
552
+ for i, active in enumerate(beats):
553
+ if active:
554
+ ax.axvspan(i * beat_w, (i + 1) * beat_w, color=color, alpha=0.38, linewidth=0)
555
+
556
+ if not matched:
557
+ ax.text(0.5, 0.5, "No Match", transform=ax.transAxes,
558
+ fontsize=18, color="white", ha="center", va="center", fontweight="bold",
559
+ bbox=dict(boxstyle="round,pad=0.4", facecolor="#111827", alpha=0.65))
560
+
561
+ fig.tight_layout()
562
+ return fig
563
+
564
+
565
+ def verify_spectrograms(
566
+ track_spec_path,
567
+ source_spec_path,
568
+ checkpoint_path,
569
+ match_threshold,
570
+ localization_threshold,
571
+ ):
572
+ if not track_spec_path or not source_spec_path:
573
+ raise gr.Error("Upload both spectrogram images before running verification.")
574
+
575
+ try:
576
+ loaded = _load_model(checkpoint_path or DEFAULT_CHECKPOINT)
577
+ except Exception as exc:
578
+ return f"Model could not be loaded: {exc}", None, None
579
+
580
+ model = loaded["model"]
581
+ args = loaded["args"]
582
+ device = loaded["device"]
583
+
584
+ track_mel = _image_to_mel_tensor(track_spec_path, args).unsqueeze(0).to(device)
585
+ source_mel = _image_to_mel_tensor(source_spec_path, args).unsqueeze(0).to(device)
586
+
587
+ with torch.inference_mode():
588
+ track_emb = model.encoder(track_mel)
589
+ source_emb = model.encoder(source_mel)
590
+ pair_feat = pair_summary_features(model.pair_mask_head(track_mel, source_mel))
591
+ combined = torch.cat(
592
+ [track_emb, source_emb, torch.abs(track_emb - source_emb), track_emb * source_emb, pair_feat],
593
+ dim=-1,
594
+ )
595
+ score = torch.softmax(model.head(combined), dim=-1)[0, 1].item()
596
+
597
+ matched = score >= float(match_threshold)
598
+ beats_per_window = int(args.get("bars", 4)) * 4
599
+
600
+ if loaded["pair_head_loaded"]:
601
+ with torch.inference_mode():
602
+ pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].cpu().numpy()
603
+ track_beats, source_beats = _find_contiguous_beats(pair_probs, min_beats=2)
604
+ if not track_beats.any():
605
+ track_beats = np.ones(beats_per_window, dtype=bool)
606
+ source_beats = np.ones(beats_per_window, dtype=bool)
607
+ else:
608
+ track_beats = np.ones(beats_per_window, dtype=bool)
609
+ source_beats = np.ones(beats_per_window, dtype=bool)
610
+
611
+ spec_fig = _plot_spectrograms_with_mask(
612
+ track_spec_path, source_spec_path,
613
+ track_beats, source_beats,
614
+ score, matched,
615
+ )
616
+
617
+ verdict = "Likely match" if matched else "No match"
618
+ details = [
619
+ f"**{verdict}**",
620
+ f"Classifier score: `{score:.3f}` (threshold `{float(match_threshold):.2f}`).",
621
+ f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
622
+ ]
623
+ if not loaded["pair_head_loaded"]:
624
+ details.append("Checkpoint does not include a trained pairwise beat head.")
625
+ return "\n\n".join(details), None, spec_fig
626
+
627
+
628
  def preview_waveforms(track_audio, source_audio):
629
  if not track_audio or not source_audio:
630
  return None, None
 
719
  gr.Markdown("# Sample Match Verifier")
720
  gr.Markdown(
721
  "Upload a track and a possible source sample. "
722
+ "Use the **Audio** tab for raw audio files, or the **Spectrogram** tab to upload "
723
+ "pre-computed BPM-normalized mel spectrogram images. "
724
+ "Click **Verify match** to run the model."
725
  )
726
 
727
+ with gr.Tabs():
728
+ with gr.Tab("Audio"):
729
+ gr.Markdown("Waveforms appear immediately on upload.")
730
+ with gr.Row():
731
+ track_audio = gr.Audio(label="Track / song audio", type="filepath", sources=["upload"])
732
+ source_audio = gr.Audio(label="Source sample audio", type="filepath", sources=["upload"])
733
+ audio_run = gr.Button("Verify match", variant="primary")
734
+
735
+ with gr.Tab("Spectrogram"):
736
+ gr.Markdown(
737
+ "Upload BPM-normalized mel spectrogram images (e.g. from `make_test_spectrograms.py`). "
738
+ "Offset / duration / stride settings are ignored in this mode."
739
+ )
740
+ with gr.Row():
741
+ track_spec = gr.Image(label="Track spectrogram", type="filepath", sources=["upload"])
742
+ source_spec = gr.Image(label="Source spectrogram", type="filepath", sources=["upload"])
743
+ spec_run = gr.Button("Verify match", variant="primary")
744
 
745
  with gr.Accordion("Settings", open=False):
746
  checkpoint_path = gr.Textbox(label="Checkpoint path", value=DEFAULT_CHECKPOINT)
 
755
  stride_beats = gr.Slider(1, 16, value=4, step=1, label="Window stride, beats")
756
  max_windows = gr.Slider(4, 64, value=24, step=1, label="Max windows per upload")
757
 
 
758
  result = gr.Markdown()
 
759
  waveform_plot = gr.Plot(label="Waveforms")
760
  mel_plot = gr.Plot(label="Mel Spectrograms")
761
 
762
+ # Show waveforms as soon as both audio files are uploaded
763
  for audio_input in [track_audio, source_audio]:
764
  audio_input.change(
765
  preview_waveforms,
 
767
  outputs=[waveform_plot, mel_plot],
768
  )
769
 
770
+ audio_run.click(
771
  verify,
772
  inputs=[
773
  track_audio,
 
784
  outputs=[result, waveform_plot, mel_plot],
785
  )
786
 
787
+ spec_run.click(
788
+ verify_spectrograms,
789
+ inputs=[track_spec, source_spec, checkpoint_path, match_threshold, localization_threshold],
790
+ outputs=[result, waveform_plot, mel_plot],
791
+ )
792
+
793
 
794
  if __name__ == "__main__":
795
  demo.queue(max_size=8).launch()