Spaces:
Sleeping
Sleeping
Commit ·
df731c1
1
Parent(s): af29998
Select contiguous beat runs via diagonal Kadane search on pair similarity matrix
Browse files
app.py
CHANGED
|
@@ -311,6 +311,61 @@ def _intervals_from_mask(mask: np.ndarray, window: BeatWindow, max_end: float) -
|
|
| 311 |
return _merge_intervals(intervals)
|
| 312 |
|
| 313 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
def _localize_match(
|
| 315 |
model,
|
| 316 |
track_mel: torch.Tensor,
|
|
@@ -332,15 +387,17 @@ def _localize_match(
|
|
| 332 |
with torch.inference_mode():
|
| 333 |
pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
|
|
|
|
|
|
| 337 |
top_k = min(6, pair_probs.size)
|
| 338 |
flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
|
| 339 |
selected = np.zeros_like(pair_probs, dtype=bool)
|
| 340 |
selected.reshape(-1)[flat] = True
|
|
|
|
|
|
|
| 341 |
|
| 342 |
-
track_mask = selected.any(axis=1)
|
| 343 |
-
source_mask = selected.any(axis=0)
|
| 344 |
track_regions = _intervals_from_mask(
|
| 345 |
track_mask,
|
| 346 |
track_window,
|
|
|
|
| 311 |
return _merge_intervals(intervals)
|
| 312 |
|
| 313 |
|
| 314 |
+
def _find_contiguous_beats(pair_probs: np.ndarray, min_beats: int = 2) -> tuple[np.ndarray, np.ndarray]:
|
| 315 |
+
"""Find the best contiguous diagonal run in the beat similarity matrix.
|
| 316 |
+
|
| 317 |
+
Searches every diagonal offset (track_beat - source_beat) and uses
|
| 318 |
+
Kadane's algorithm to find the highest-scoring contiguous segment along
|
| 319 |
+
each diagonal. Returns boolean masks over track and source beats.
|
| 320 |
+
"""
|
| 321 |
+
n_track, n_source = pair_probs.shape
|
| 322 |
+
best_score = -np.inf
|
| 323 |
+
best_track_mask = np.zeros(n_track, dtype=bool)
|
| 324 |
+
best_source_mask = np.zeros(n_source, dtype=bool)
|
| 325 |
+
|
| 326 |
+
for d in range(-(n_source - 1), n_track):
|
| 327 |
+
# diagonal: track[i], source[i - d] for valid i
|
| 328 |
+
i0 = max(0, d)
|
| 329 |
+
j0 = max(0, -d)
|
| 330 |
+
length = min(n_track - i0, n_source - j0)
|
| 331 |
+
if length < min_beats:
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
diag = pair_probs[i0:i0 + length, j0:j0 + length].diagonal()
|
| 335 |
+
|
| 336 |
+
# Kadane's max-subarray on the diagonal values
|
| 337 |
+
curr_sum = 0.0
|
| 338 |
+
curr_start = 0
|
| 339 |
+
best_sum = -np.inf
|
| 340 |
+
seg_start = seg_end = 0
|
| 341 |
+
|
| 342 |
+
for k, val in enumerate(diag):
|
| 343 |
+
curr_sum += val
|
| 344 |
+
if curr_sum > best_sum:
|
| 345 |
+
best_sum = curr_sum
|
| 346 |
+
seg_start = curr_start
|
| 347 |
+
seg_end = k
|
| 348 |
+
if curr_sum < 0:
|
| 349 |
+
curr_sum = 0.0
|
| 350 |
+
curr_start = k + 1
|
| 351 |
+
|
| 352 |
+
seg_len = seg_end - seg_start + 1
|
| 353 |
+
if seg_len < min_beats:
|
| 354 |
+
continue
|
| 355 |
+
avg_score = best_sum / seg_len
|
| 356 |
+
|
| 357 |
+
if avg_score > best_score:
|
| 358 |
+
best_score = avg_score
|
| 359 |
+
track_mask = np.zeros(n_track, dtype=bool)
|
| 360 |
+
source_mask = np.zeros(n_source, dtype=bool)
|
| 361 |
+
track_mask[i0 + seg_start: i0 + seg_end + 1] = True
|
| 362 |
+
source_mask[j0 + seg_start: j0 + seg_end + 1] = True
|
| 363 |
+
best_track_mask = track_mask
|
| 364 |
+
best_source_mask = source_mask
|
| 365 |
+
|
| 366 |
+
return best_track_mask, best_source_mask
|
| 367 |
+
|
| 368 |
+
|
| 369 |
def _localize_match(
|
| 370 |
model,
|
| 371 |
track_mel: torch.Tensor,
|
|
|
|
| 387 |
with torch.inference_mode():
|
| 388 |
pair_probs = torch.sigmoid(model.pair_mask_head(track_mel, source_mel))[0].detach().cpu().numpy()
|
| 389 |
|
| 390 |
+
track_mask, source_mask = _find_contiguous_beats(pair_probs, min_beats=2)
|
| 391 |
+
|
| 392 |
+
# Fall back to top-k individual beats if no contiguous run was found
|
| 393 |
+
if not track_mask.any():
|
| 394 |
top_k = min(6, pair_probs.size)
|
| 395 |
flat = np.argpartition(pair_probs.reshape(-1), -top_k)[-top_k:]
|
| 396 |
selected = np.zeros_like(pair_probs, dtype=bool)
|
| 397 |
selected.reshape(-1)[flat] = True
|
| 398 |
+
track_mask = selected.any(axis=1)
|
| 399 |
+
source_mask = selected.any(axis=0)
|
| 400 |
|
|
|
|
|
|
|
| 401 |
track_regions = _intervals_from_mask(
|
| 402 |
track_mask,
|
| 403 |
track_window,
|