percu-backend / backend /notation /beat_grid.py
GitHub Actions
sync from github @ 55d21041cef6937693a261edde4a3fa17e5a12dc
d5f474a
import logging
import numpy as np
log = logging.getLogger(__name__)
_file2beats = None
def _get_predictor():
global _file2beats
if _file2beats is None:
from beat_this.inference import File2Beats
_file2beats = File2Beats(checkpoint_path="small0", device="cpu")
return _file2beats
def get_beat_grid(audio_path: str) -> tuple[np.ndarray, np.ndarray, int, float]:
"""
Detect beats and downbeats using beat-this (transformer model).
Returns
-------
beat_times : all beat positions in seconds
downbeat_times: positions of bar starts in seconds
beats_per_bar : most common number of beats per bar (for time signature display)
bpm : estimated tempo in BPM
"""
predictor = _get_predictor()
beats, downbeats = predictor(audio_path)
beats = np.array(beats)
downbeats = np.array(downbeats)
if len(downbeats) == 0 or len(beats) == 0:
raise ValueError("beat-this could not detect beats in audio.")
# beats_per_bar from time ratio between bar and beat duration
if len(downbeats) >= 2 and len(beats) > 1:
avg_bar = float(np.median(np.diff(downbeats)))
avg_beat = float(np.median(np.diff(beats)))
raw = int(np.round(avg_bar / avg_beat))
beats_per_bar = 4 if raw <= 2 else max(3, raw)
else:
beats_per_bar = 4
bpm = float(60.0 / np.median(np.diff(beats))) if len(beats) > 1 else 120.0
log.info("beat-this: BPM=%.1f meter=%d/4 beats=%d downbeats=%d", bpm, beats_per_bar, len(beats), len(downbeats))
return beats, downbeats, beats_per_bar, bpm