| """Beat and downbeat tracking via beat_this (CPJKU).""" | |
| from dataclasses import dataclass | |
| import numpy as np | |
| class BeatData: | |
| beats: np.ndarray | |
| downbeats: np.ndarray | |
| beat_numbers: np.ndarray | |
| def track_beats(audio_path: str, device: str = "cuda") -> BeatData: | |
| """Run beat and downbeat tracking on an audio file.""" | |
| from beat_this.inference import File2Beats | |
| processor = File2Beats(checkpoint_path="final0", device=device) | |
| beats, downbeats = processor(audio_path) | |
| beat_numbers = _assign_beat_numbers(beats, downbeats) | |
| return BeatData( | |
| beats=np.asarray(beats), | |
| downbeats=np.asarray(downbeats), | |
| beat_numbers=beat_numbers, | |
| ) | |
| def _assign_beat_numbers(beats: np.ndarray, downbeats: np.ndarray) -> np.ndarray: | |
| beats = np.asarray(beats) | |
| downbeats_set = set(np.round(downbeats, 6)) | |
| numbers = np.zeros(len(beats), dtype=int) | |
| beat_num = 1 | |
| for i, t in enumerate(beats): | |
| if round(float(t), 6) in downbeats_set: | |
| beat_num = 1 | |
| numbers[i] = beat_num | |
| beat_num += 1 | |
| return numbers | |