dayngerous commited on
Commit
df731c1
·
1 Parent(s): af29998

Select contiguous beat runs via diagonal Kadane search on pair similarity matrix

Browse files
Files changed (1) hide show
  1. app.py +61 -4
app.py CHANGED
@@ -311,6 +311,61 @@ def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -
311
  return _merge_intervals(intervals)
312
 
313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  def _localize_match(
315
  model,
316
  track_mel: torch.Tensor,
@@ -332,15 +387,17 @@ def _localize_match(
332
  with torch.inference_mode():
333
  pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
334
 
335
- selected = pair_probs >= float(threshold)
336
- if not selected.any():
 
 
337
  top_k = min(6, pair_probs.size)
338
  flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
339
  selected = np.zeros_like(pair_probs, dtype=bool)
340
  selected.reshape(-1)[flat] = True
 
 
341
 
342
- track_mask = selected.any(axis=1)
343
- source_mask = selected.any(axis=0)
344
  track_regions = _intervals_from_mask(
345
  track_mask,
346
  track_window,
 
311
  return _merge_intervals(intervals)
312
 
313
 
314
+ def _find_contiguous_beats(pair_probs: np.ndarray, min_beats: int = 2) -> tuple[np.ndarray, np.ndarray]:
315
+ """Find the best contiguous diagonal run in the beat similarity matrix.
316
+
317
+ Searches every diagonal offset (track_beat - source_beat) and uses
318
+ Kadane's algorithm to find the highest-scoring contiguous segment along
319
+ each diagonal. Returns boolean masks over track and source beats.
320
+ """
321
+ n_track, n_source = pair_probs.shape
322
+ best_score = -np.inf
323
+ best_track_mask = np.zeros(n_track, dtype=bool)
324
+ best_source_mask = np.zeros(n_source, dtype=bool)
325
+
326
+ for d in range(-(n_source - 1), n_track):
327
+ # diagonal: track[i], source[i - d] for valid i
328
+ i0 = max(0, d)
329
+ j0 = max(0, -d)
330
+ length = min(n_track - i0, n_source - j0)
331
+ if length < min_beats:
332
+ continue
333
+
334
+ diag = pair_probs[i0:i0 + length, j0:j0 + length].diagonal()
335
+
336
+ # Kadane's max-subarray on the diagonal values
337
+ curr_sum = 0.0
338
+ curr_start = 0
339
+ best_sum = -np.inf
340
+ seg_start = seg_end = 0
341
+
342
+ for k, val in enumerate(diag):
343
+ curr_sum += val
344
+ if curr_sum > best_sum:
345
+ best_sum = curr_sum
346
+ seg_start = curr_start
347
+ seg_end = k
348
+ if curr_sum < 0:
349
+ curr_sum = 0.0
350
+ curr_start = k + 1
351
+
352
+ seg_len = seg_end - seg_start + 1
353
+ if seg_len < min_beats:
354
+ continue
355
+ avg_score = best_sum / seg_len
356
+
357
+ if avg_score > best_score:
358
+ best_score = avg_score
359
+ track_mask = np.zeros(n_track, dtype=bool)
360
+ source_mask = np.zeros(n_source, dtype=bool)
361
+ track_mask[i0 + seg_start: i0 + seg_end + 1] = True
362
+ source_mask[j0 + seg_start: j0 + seg_end + 1] = True
363
+ best_track_mask = track_mask
364
+ best_source_mask = source_mask
365
+
366
+ return best_track_mask, best_source_mask
367
+
368
+
369
  def _localize_match(
370
  model,
371
  track_mel: torch.Tensor,
 
387
  with torch.inference_mode():
388
  pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
389
 
390
+ track_mask, source_mask = _find_contiguous_beats(pair_probs, min_beats=2)
391
+
392
+ # Fall back to top-k individual beats if no contiguous run was found
393
+ if not track_mask.any():
394
  top_k = min(6, pair_probs.size)
395
  flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
396
  selected = np.zeros_like(pair_probs, dtype=bool)
397
  selected.reshape(-1)[flat] = True
398
+ track_mask = selected.any(axis=1)
399
+ source_mask = selected.any(axis=0)
400
 
 
 
401
  track_regions = _intervals_from_mask(
402
  track_mask,
403
  track_window,