"""Beat and downbeat tracking via beat_this (CPJKU).""" from dataclasses import dataclass import numpy as np @dataclass class BeatData: beats: np.ndarray downbeats: np.ndarray beat_numbers: np.ndarray def track_beats(audio_path: str, device: str = "cuda") -> BeatData: """Run beat and downbeat tracking on an audio file.""" from beat_this.inference import File2Beats processor = File2Beats(checkpoint_path="final0", device=device) beats, downbeats = processor(audio_path) beat_numbers = _assign_beat_numbers(beats, downbeats) return BeatData( beats=np.asarray(beats), downbeats=np.asarray(downbeats), beat_numbers=beat_numbers, ) def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray: beats = np.asarray(beats) downbeats_set = set(np.round(downbeats, 6)) numbers = np.zeros(len(beats), dtype=int) beat_num = 1 for i, t in enumerate(beats): if round(float(t), 6) in downbeats_set: beat_num = 1 numbers[i] = beat_num beat_num += 1 return numbers