"""Beat/kick detection using madmom's RNN beat tracker.""" import json import subprocess import tempfile from pathlib import Path from typing import Optional import numpy as np from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor # Bandpass filter: isolate kick drum frequency range (50-200 Hz) HIGHPASS_CUTOFF = 50 LOWPASS_CUTOFF = 500 def _bandpass_filter(input_path: Path) -> Path: """Apply a 50-200 Hz bandpass filter to isolate kick drum transients. Returns path to a temporary filtered WAV file. """ filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) filtered.close() subprocess.run([ "ffmpeg", "-y", "-i", str(input_path), "-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}", str(filtered.name), ], check=True, capture_output=True) return Path(filtered.name) def detect_beats( drum_stem_path: str | Path, min_bpm: float = 55.0, max_bpm: float = 215.0, transition_lambda: float = 100, fps: int = 1000, ) -> np.ndarray: """Detect beat timestamps from a drum stem using madmom. Uses an ensemble of bidirectional LSTMs to produce a beat activation function, then a Dynamic Bayesian Network to decode beat positions. Args: drum_stem_path: Path to the isolated drum stem WAV file. min_bpm: Minimum expected tempo. Narrow this if you know the song's approximate BPM for better accuracy. max_bpm: Maximum expected tempo. transition_lambda: Tempo smoothness — higher values penalise tempo changes more (100 = very steady, good for most pop/rock). fps: Frames per second for the DBN decoder. The RNN outputs at 100fps; higher values interpolate for finer timestamp resolution (1ms at 1000fps). Returns: 1D numpy array of beat timestamps in seconds, sorted chronologically. """ drum_stem_path = Path(drum_stem_path) # Step 0: Bandpass filter to isolate kick drum range (50-200 Hz) filtered_path = _bandpass_filter(drum_stem_path) # Step 1: RNN produces beat activation function (probability per frame at 100fps) act_proc = RNNBeatProcessor() activations = act_proc(str(filtered_path)) # Clean up temp file filtered_path.unlink(missing_ok=True) # Step 2: Interpolate to higher fps for finer timestamp resolution (1ms at 1000fps) if fps != 100: from scipy.interpolate import interp1d n_frames = len(activations) t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False) n_new = int(n_frames * fps / 100) t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False) activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new) activations = np.clip(activations, 0, None) # cubic spline can go negative # Step 3: DBN decodes activations into beat timestamps # correct=False lets the DBN place beats using its own high-res state space # instead of snapping to the coarse 100fps activation peaks beat_proc = DBNBeatTrackingProcessor( min_bpm=min_bpm, max_bpm=max_bpm, transition_lambda=transition_lambda, fps=fps, correct=False, ) beats = beat_proc(activations) return beats def detect_drop( audio_path: str | Path, beat_times: np.ndarray, window_sec: float = 0.5, ) -> float: """Find the beat where the biggest energy jump occurs (the drop). Computes RMS energy in a window around each beat and returns the beat with the largest increase compared to the previous beat. Args: audio_path: Path to the full mix audio file. beat_times: Array of beat timestamps in seconds. window_sec: Duration of the analysis window around each beat. Returns: Timestamp (seconds) of the detected drop beat. """ import librosa y, sr = librosa.load(str(audio_path), sr=None, mono=True) half_win = int(window_sec / 2 * sr) rms_values = [] for t in beat_times: center = int(t * sr) start = max(0, center - half_win) end = min(len(y), center + half_win) segment = y[start:end] rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0 rms_values.append(rms) rms_values = np.array(rms_values) # Find largest positive jump between consecutive beats diffs = np.diff(rms_values) drop_idx = int(np.argmax(diffs)) + 1 # +1 because diff shifts by one drop_time = float(beat_times[drop_idx]) print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s " f"(energy jump: {diffs[drop_idx - 1]:.4f})") return drop_time def select_beats( beats: np.ndarray, max_duration: float = 15.0, min_interval: float = 0.3, ) -> np.ndarray: """Select a subset of beats for video generation. Filters beats to fit within a duration limit and enforces a minimum interval between consecutive beats (to avoid generating too many frames). Args: beats: Array of beat timestamps in seconds. max_duration: Maximum video duration in seconds. min_interval: Minimum time between selected beats in seconds. Beats closer together than this are skipped. Returns: Filtered array of beat timestamps. """ if len(beats) == 0: return beats # Trim to max duration beats = beats[beats <= max_duration] if len(beats) == 0: return beats # Enforce minimum interval between beats selected = [beats[0]] for beat in beats[1:]: if beat - selected[-1] >= min_interval: selected.append(beat) return np.array(selected) def save_beats( beats: np.ndarray, output_path: str | Path, ) -> Path: """Save beat timestamps to a JSON file. Format matches the project convention (same style as lyrics.json): a list of objects with beat index and timestamp. Args: beats: Array of beat timestamps in seconds. output_path: Path to save the JSON file. Returns: Path to the saved JSON file. """ output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) data = [ {"beat": i + 1, "time": round(float(t), 3)} for i, t in enumerate(beats) ] with open(output_path, "w") as f: json.dump(data, f, indent=2) return output_path def run( drum_stem_path: str | Path, output_dir: Optional[str | Path] = None, min_bpm: float = 55.0, max_bpm: float = 215.0, ) -> dict: """Full beat detection pipeline: detect, select, and save. Args: drum_stem_path: Path to the isolated drum stem WAV file. output_dir: Directory to save beats.json. Defaults to the parent of the drum stem's parent (e.g. data/Gone/ if stem is at data/Gone/stems/drums.wav). min_bpm: Minimum expected tempo. max_bpm: Maximum expected tempo. Returns: Dict with 'all_beats', 'selected_beats', and 'beats_path'. """ drum_stem_path = Path(drum_stem_path) if output_dir is None: # stems/drums.wav -> parent is stems/, parent.parent is data/Gone/ output_dir = drum_stem_path.parent.parent output_dir = Path(output_dir) all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm) selected = select_beats(all_beats) # Detect drop using the full mix audio (one level above stems/) song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir audio_path = None for ext in [".wav", ".mp3", ".flac", ".m4a"]: candidates = list(song_dir.glob(f"*{ext}")) if candidates: audio_path = candidates[0] break drop_time = None if audio_path and len(all_beats) > 2: drop_time = detect_drop(audio_path, all_beats) beats_path = save_beats(all_beats, output_dir / "beats.json") # Save drop time alongside beats if drop_time is not None: drop_path = output_dir / "drop.json" with open(drop_path, "w") as f: json.dump({"drop_time": round(drop_time, 3)}, f, indent=2) return { "all_beats": all_beats, "selected_beats": selected, "beats_path": beats_path, "drop_time": drop_time, } if __name__ == "__main__": import sys if len(sys.argv) < 2: print("Usage: python -m src.beat_detector ") sys.exit(1) result = run(sys.argv[1]) all_beats = result["all_beats"] selected = result["selected_beats"] print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})") print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):") for i, t in enumerate(selected): print(f" Beat {i + 1}: {t:.3f}s")