dayngerous commited on
Commit
0a95bc3
·
1 Parent(s): 20c4cc2

Use classifier head for match verdict, show proposed masks on no-match

Browse files
Files changed (2) hide show
  1. app.py +32 -26
  2. model.py +61 -19
app.py CHANGED
@@ -18,7 +18,7 @@ from huggingface_hub import hf_hub_download
18
  matplotlib.use("Agg")
19
  import matplotlib.pyplot as plt
20
 
21
- from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector
22
 
23
 
24
  SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000"))
@@ -282,24 +282,33 @@ def _encode(model, mels: torch.Tensor, batch_size: int) -> torch.Tensor:
282
  return torch.cat(embs, dim=0)
283
 
284
 
285
- def _score_pairs(model, track_mels: torch.Tensor, source_mels: torch.Tensor, batch_size: int) -> torch.Tensor:
 
 
 
 
 
 
286
  track_emb = _encode(model, track_mels, batch_size)
287
  source_emb = _encode(model, source_mels, batch_size)
288
  n_track, n_source = track_emb.shape[0], source_emb.shape[0]
289
- scores = []
290
 
291
- pair_indices = [(i, j) for i in range(n_track) for j in range(n_source)]
292
- for start in range(0, len(pair_indices), batch_size):
293
- chunk = pair_indices[start:start + batch_size]
294
- ti = torch.tensor([p[0] for p in chunk], device=track_emb.device)
295
- sj = torch.tensor([p[1] for p in chunk], device=track_emb.device)
296
- t = track_emb.index_select(0, ti)
297
- s = source_emb.index_select(0, sj)
298
- combined = torch.cat([t, s, torch.abs(t - s), t * s], dim=-1)
299
- logits = model.head(combined)
300
- scores.append(torch.softmax(logits, dim=-1)[:, 1])
 
 
 
301
 
302
- return torch.cat(scores).reshape(n_track, n_source)
303
 
304
 
305
  def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]:
@@ -390,10 +399,10 @@ def _draw_mel(ax, clip: AudioClip, regions: list[tuple[float, float]], color: st
390
  ax.set_ylabel("Frequency (Hz)")
391
  ax.set_xlim(t_start, t_end)
392
 
393
- if matched and regions:
394
  for start, end in regions:
395
- ax.axvspan(start, end, color=color, alpha=0.38, linewidth=0)
396
- elif not matched:
397
  ax.text(
398
  0.5, 0.5, "No Match",
399
  transform=ax.transAxes,
@@ -495,7 +504,7 @@ def verify(
495
  source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device)
496
 
497
  with torch.inference_mode():
498
- score_matrix = _score_pairs(model, track_mels, source_mels, batch_size)
499
  best_flat = int(torch.argmax(score_matrix).item())
500
  best_track = best_flat // score_matrix.shape[1]
501
  best_source = best_flat % score_matrix.shape[1]
@@ -514,19 +523,16 @@ def verify(
514
  loaded["pair_head_loaded"],
515
  )
516
 
517
- highlight_track = track_regions if matched else []
518
- highlight_source = source_regions if matched else []
519
-
520
- wfig = _plot_waveforms(track_clip, source_clip, highlight_track, highlight_source, best_score, matched)
521
- mfig = _plot_mels(track_clip, source_clip, highlight_track, highlight_source, matched)
522
 
523
- verdict = "Likely match" if matched else "No confident match"
524
  details = [
525
  f"**{verdict}**",
526
  f"Score: `{best_score:.3f}` with threshold `{float(match_threshold):.2f}`.",
527
  f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.",
528
- f"Highlighted track section(s): {_format_intervals(highlight_track)}.",
529
- f"Highlighted source section(s): {_format_intervals(highlight_source)}.",
530
  f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
531
  ]
532
  if note:
 
18
  matplotlib.use("Agg")
19
  import matplotlib.pyplot as plt
20
 
21
+ from model import CNNSampleDetector, SSLAMSampleDetector, SampleDetector, pair_summary_features
22
 
23
 
24
  SAMPLE_RATE = int(os.environ.get("APP_SAMPLE_RATE", "16000"))
 
282
  return torch.cat(embs, dim=0)
283
 
284
 
285
+ def _score_pairs(
286
+ model,
287
+ track_mels: torch.Tensor,
288
+ source_mels: torch.Tensor,
289
+ batch_size: int,
290
+ pair_head_loaded: bool,
291
+ ) -> torch.Tensor:
292
  track_emb = _encode(model, track_mels, batch_size)
293
  source_emb = _encode(model, source_mels, batch_size)
294
  n_track, n_source = track_emb.shape[0], source_emb.shape[0]
295
+ scores = torch.zeros(n_track, n_source, device=track_emb.device)
296
 
297
+ for i in range(n_track):
298
+ for j in range(n_source):
299
+ t = track_emb[i:i + 1]
300
+ s = source_emb[j:j + 1]
301
+ if pair_head_loaded:
302
+ pair_feat = pair_summary_features(
303
+ model.pair_mask_head(track_mels[i:i + 1], source_mels[j:j + 1])
304
+ )
305
+ combined = torch.cat([t, s, torch.abs(t - s), t * s, pair_feat], dim=-1)
306
+ else:
307
+ combined = torch.cat([t, s, torch.abs(t - s), t * s], dim=-1)
308
+ logits = model.head(combined)
309
+ scores[i, j] = torch.softmax(logits, dim=-1)[0, 1]
310
 
311
+ return scores
312
 
313
 
314
  def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -> list[tuple[float, float]]:
 
399
  ax.set_ylabel("Frequency (Hz)")
400
  ax.set_xlim(t_start, t_end)
401
 
402
+ if regions:
403
  for start, end in regions:
404
+ ax.axvspan(start, end, color=color, alpha=0.38 if matched else 0.22, linewidth=0)
405
+ if not matched:
406
  ax.text(
407
  0.5, 0.5, "No Match",
408
  transform=ax.transAxes,
 
504
  source_mels = torch.stack([_to_mel(w.waveform, source_bpm, args) for w in source_windows]).to(device)
505
 
506
  with torch.inference_mode():
507
+ score_matrix = _score_pairs(model, track_mels, source_mels, batch_size, loaded["pair_head_loaded"])
508
  best_flat = int(torch.argmax(score_matrix).item())
509
  best_track = best_flat // score_matrix.shape[1]
510
  best_source = best_flat % score_matrix.shape[1]
 
523
  loaded["pair_head_loaded"],
524
  )
525
 
526
+ wfig = _plot_waveforms(track_clip, source_clip, track_regions, source_regions, best_score, matched)
527
+ mfig = _plot_mels(track_clip, source_clip, track_regions, source_regions, matched)
 
 
 
528
 
529
+ verdict = "Likely match" if matched else "No match"
530
  details = [
531
  f"**{verdict}**",
532
  f"Score: `{best_score:.3f}` with threshold `{float(match_threshold):.2f}`.",
533
  f"Estimated BPM: track `{track_bpm:.1f}`, source `{source_bpm:.1f}`.",
534
+ f"{'Matched' if matched else 'Proposed'} track section(s): {_format_intervals(track_regions)}.",
535
+ f"{'Matched' if matched else 'Proposed'} source section(s): {_format_intervals(source_regions)}.",
536
  f"Model: `{args.get('backbone', 'ast')}` checkpoint epoch `{loaded['epoch']}` on `{device}`.",
537
  ]
538
  if note:
model.py CHANGED
@@ -14,6 +14,7 @@ AST_FREQ_DIM = 128
14
  SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
15
  SSLAM_TIME_DIM = 1024
16
  SSLAM_FREQ_DIM = 128
 
17
 
18
 
19
  class ASTEncoder(nn.Module):
@@ -61,11 +62,21 @@ class ASTEncoder(nn.Module):
61
  class PairMaskHead(nn.Module):
62
  """Beat-by-beat pair matching head over two mel spectrograms."""
63
 
64
- def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64):
65
  super().__init__()
66
- self.pool = nn.AdaptiveAvgPool2d((beats_per_window, n_mels))
67
- self.beat_proj = nn.Sequential(
68
- nn.Linear(n_mels, beat_dim),
 
 
 
 
 
 
 
 
 
 
69
  nn.GELU(),
70
  nn.Linear(beat_dim, beat_dim),
71
  )
@@ -73,9 +84,14 @@ class PairMaskHead(nn.Module):
73
  self.bias = nn.Parameter(torch.zeros(()))
74
 
75
  def _beats(self, mel: torch.Tensor) -> torch.Tensor:
76
- # mel: [B, 1, T, F] -> [B, beats, F] -> [B, beats, beat_dim]
77
- x = self.pool(mel).squeeze(1)
78
- return torch.nn.functional.normalize(self.beat_proj(x), dim=-1)
 
 
 
 
 
79
 
80
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
81
  t = self._beats(track_mel)
@@ -83,6 +99,29 @@ class PairMaskHead(nn.Module):
83
  return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  class SampleDetector(nn.Module):
87
  """Siamese AST encoder + interaction head for binary sample detection."""
88
 
@@ -97,9 +136,10 @@ class SampleDetector(nn.Module):
97
  super().__init__()
98
  self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
99
  H = self.encoder.ast.config.hidden_size
 
100
  self.head = nn.Sequential(
101
- nn.LayerNorm(4 * H),
102
- nn.Linear(4 * H, 512),
103
  nn.GELU(),
104
  nn.Dropout(dropout),
105
  nn.Linear(512, 128),
@@ -107,7 +147,6 @@ class SampleDetector(nn.Module):
107
  nn.Dropout(dropout),
108
  nn.Linear(128, 2),
109
  )
110
- self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
111
 
112
  def unfreeze_encoder(self, n_blocks: int = 2):
113
  self.encoder.unfreeze_last_n(n_blocks)
@@ -115,8 +154,9 @@ class SampleDetector(nn.Module):
115
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
116
  t = self.encoder(track_mel)
117
  o = self.encoder(orig_mel)
 
118
  # print(f"embeddings: t={t.shape}, o={o.shape}")
119
- combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
120
  # print(f"combined shape: {combined.shape}")
121
  return self.head(combined)
122
 
@@ -160,9 +200,10 @@ class CNNSampleDetector(nn.Module):
160
  def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
161
  super().__init__()
162
  self.encoder = CNNEncoder(embed_dim)
 
163
  self.head = nn.Sequential(
164
- nn.LayerNorm(4 * embed_dim),
165
- nn.Linear(4 * embed_dim, 256),
166
  nn.GELU(),
167
  nn.Dropout(dropout),
168
  nn.Linear(256, 64),
@@ -170,12 +211,12 @@ class CNNSampleDetector(nn.Module):
170
  nn.Dropout(dropout),
171
  nn.Linear(64, 2),
172
  )
173
- self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
174
 
175
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
176
  t = self.encoder(track_mel)
177
  o = self.encoder(orig_mel)
178
- combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
 
179
  return self.head(combined)
180
 
181
 
@@ -257,9 +298,10 @@ class SSLAMSampleDetector(nn.Module):
257
  super().__init__()
258
  self.encoder = SSLAMEncoder(freeze=freeze_encoder)
259
  H = self.encoder.hidden_size
 
260
  self.head = nn.Sequential(
261
- nn.LayerNorm(4 * H),
262
- nn.Linear(4 * H, 512),
263
  nn.GELU(),
264
  nn.Dropout(dropout),
265
  nn.Linear(512, 128),
@@ -267,7 +309,6 @@ class SSLAMSampleDetector(nn.Module):
267
  nn.Dropout(dropout),
268
  nn.Linear(128, 2),
269
  )
270
- self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
271
 
272
  def unfreeze_encoder(self, n_blocks: int):
273
  self.encoder.unfreeze_last_n(n_blocks)
@@ -275,7 +316,8 @@ class SSLAMSampleDetector(nn.Module):
275
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
276
  t = self.encoder(track_mel)
277
  o = self.encoder(orig_mel)
278
- combined = torch.cat([t, o, torch.abs(t - o), t * o], dim=-1)
 
279
  return self.head(combined)
280
 
281
 
 
14
  SSLAM_HF_REPO = os.environ["SSLAM_MODEL"]
15
  SSLAM_TIME_DIM = 1024
16
  SSLAM_FREQ_DIM = 128
17
+ PAIR_SUMMARY_DIM = 8
18
 
19
 
20
  class ASTEncoder(nn.Module):
 
62
  class PairMaskHead(nn.Module):
63
  """Beat-by-beat pair matching head over two mel spectrograms."""
64
 
65
+ def __init__(self, beats_per_window: int, n_mels: int, beat_dim: int = 64, frames_per_beat: int = 8):
66
  super().__init__()
67
+ self.beats_per_window = beats_per_window
68
+ self.frames_per_beat = frames_per_beat
69
+ self.pool = nn.AdaptiveAvgPool2d((beats_per_window * frames_per_beat, n_mels))
70
+ self.patch_encoder = nn.Sequential(
71
+ nn.Conv2d(1, 16, kernel_size=(3, 5), padding=(1, 2), bias=False),
72
+ nn.GroupNorm(4, 16),
73
+ nn.GELU(),
74
+ nn.Conv2d(16, 32, kernel_size=(3, 5), stride=(2, 2), padding=(1, 2), bias=False),
75
+ nn.GroupNorm(8, 32),
76
+ nn.GELU(),
77
+ nn.AdaptiveAvgPool2d(1),
78
+ nn.Flatten(),
79
+ nn.Linear(32, beat_dim),
80
  nn.GELU(),
81
  nn.Linear(beat_dim, beat_dim),
82
  )
 
84
  self.bias = nn.Parameter(torch.zeros(()))
85
 
86
  def _beats(self, mel: torch.Tensor) -> torch.Tensor:
87
+ # mel: [B, 1, T, F] -> [B * beats, 1, frames_per_beat, F]
88
+ bsz = mel.shape[0]
89
+ x = self.pool(mel)
90
+ x = x.view(bsz, 1, self.beats_per_window, self.frames_per_beat, x.shape[-1])
91
+ x = x.permute(0, 2, 1, 3, 4).contiguous()
92
+ x = x.view(bsz * self.beats_per_window, 1, self.frames_per_beat, x.shape[-1])
93
+ x = self.patch_encoder(x).view(bsz, self.beats_per_window, -1)
94
+ return torch.nn.functional.normalize(x, dim=-1)
95
 
96
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
97
  t = self._beats(track_mel)
 
99
  return torch.einsum("bih,bjh->bij", t, o) * self.logit_scale.exp() + self.bias
100
 
101
 
102
+ def pair_summary_features(pair_logits: torch.Tensor) -> torch.Tensor:
103
+ probs = torch.sigmoid(pair_logits)
104
+ flat = probs.flatten(1)
105
+ row_max = probs.max(dim=2).values
106
+ col_max = probs.max(dim=1).values
107
+ diag = torch.diagonal(probs, dim1=1, dim2=2)
108
+ top_k = min(8, flat.shape[1])
109
+ topk_mean = flat.topk(top_k, dim=1).values.mean(dim=1)
110
+ return torch.stack(
111
+ [
112
+ flat.mean(dim=1),
113
+ flat.max(dim=1).values,
114
+ flat.std(dim=1, unbiased=False),
115
+ topk_mean,
116
+ row_max.mean(dim=1),
117
+ row_max.max(dim=1).values,
118
+ col_max.mean(dim=1),
119
+ diag.mean(dim=1),
120
+ ],
121
+ dim=-1,
122
+ )
123
+
124
+
125
  class SampleDetector(nn.Module):
126
  """Siamese AST encoder + interaction head for binary sample detection."""
127
 
 
136
  super().__init__()
137
  self.encoder = ASTEncoder(model_name, freeze=freeze_encoder)
138
  H = self.encoder.ast.config.hidden_size
139
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
140
  self.head = nn.Sequential(
141
+ nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
142
+ nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
143
  nn.GELU(),
144
  nn.Dropout(dropout),
145
  nn.Linear(512, 128),
 
147
  nn.Dropout(dropout),
148
  nn.Linear(128, 2),
149
  )
 
150
 
151
  def unfreeze_encoder(self, n_blocks: int = 2):
152
  self.encoder.unfreeze_last_n(n_blocks)
 
154
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
155
  t = self.encoder(track_mel)
156
  o = self.encoder(orig_mel)
157
+ pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
158
  # print(f"embeddings: t={t.shape}, o={o.shape}")
159
+ combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
160
  # print(f"combined shape: {combined.shape}")
161
  return self.head(combined)
162
 
 
200
  def __init__(self, embed_dim: int = 256, dropout: float = 0.3, beats_per_window: int = 16, n_mels: int = 128):
201
  super().__init__()
202
  self.encoder = CNNEncoder(embed_dim)
203
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
204
  self.head = nn.Sequential(
205
+ nn.LayerNorm(4 * embed_dim + PAIR_SUMMARY_DIM),
206
+ nn.Linear(4 * embed_dim + PAIR_SUMMARY_DIM, 256),
207
  nn.GELU(),
208
  nn.Dropout(dropout),
209
  nn.Linear(256, 64),
 
211
  nn.Dropout(dropout),
212
  nn.Linear(64, 2),
213
  )
 
214
 
215
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
216
  t = self.encoder(track_mel)
217
  o = self.encoder(orig_mel)
218
+ pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
219
+ combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
220
  return self.head(combined)
221
 
222
 
 
298
  super().__init__()
299
  self.encoder = SSLAMEncoder(freeze=freeze_encoder)
300
  H = self.encoder.hidden_size
301
+ self.pair_mask_head = PairMaskHead(beats_per_window, n_mels)
302
  self.head = nn.Sequential(
303
+ nn.LayerNorm(4 * H + PAIR_SUMMARY_DIM),
304
+ nn.Linear(4 * H + PAIR_SUMMARY_DIM, 512),
305
  nn.GELU(),
306
  nn.Dropout(dropout),
307
  nn.Linear(512, 128),
 
309
  nn.Dropout(dropout),
310
  nn.Linear(128, 2),
311
  )
 
312
 
313
  def unfreeze_encoder(self, n_blocks: int):
314
  self.encoder.unfreeze_last_n(n_blocks)
 
316
  def forward(self, track_mel: torch.Tensor, orig_mel: torch.Tensor) -> torch.Tensor:
317
  t = self.encoder(track_mel)
318
  o = self.encoder(orig_mel)
319
+ pair_features = pair_summary_features(self.pair_mask_head(track_mel, orig_mel))
320
+ combined = torch.cat([t, o, torch.abs(t - o), t * o, pair_features], dim=-1)
321
  return self.head(combined)
322
 
323