SyncAI / src /beat_detector.py
ICGenAIShare04's picture
Upload 52 files
72f552e verified
"""Beat/kick detection using madmom's RNN beat tracker."""
import json
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
import numpy as np
from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor
# Bandpass filter: isolate kick drum frequency range (50-200 Hz)
HIGHPASS_CUTOFF = 50
LOWPASS_CUTOFF = 500
def _bandpass_filter(input_path: Path) -> Path:
"""Apply a 50-200 Hz bandpass filter to isolate kick drum transients.
Returns path to a temporary filtered WAV file.
"""
filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
filtered.close()
subprocess.run([
"ffmpeg", "-y",
"-i", str(input_path),
"-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}",
str(filtered.name),
], check=True, capture_output=True)
return Path(filtered.name)
def detect_beats(
drum_stem_path: str | Path,
min_bpm: float = 55.0,
max_bpm: float = 215.0,
transition_lambda: float = 100,
fps: int = 1000,
) -> np.ndarray:
"""Detect beat timestamps from a drum stem using madmom.
Uses an ensemble of bidirectional LSTMs to produce a beat activation
function, then a Dynamic Bayesian Network to decode beat positions.
Args:
drum_stem_path: Path to the isolated drum stem WAV file.
min_bpm: Minimum expected tempo. Narrow this if you know the song's
approximate BPM for better accuracy.
max_bpm: Maximum expected tempo.
transition_lambda: Tempo smoothness — higher values penalise tempo
changes more (100 = very steady, good for most pop/rock).
fps: Frames per second for the DBN decoder. The RNN outputs at 100fps;
higher values interpolate for finer timestamp resolution (1ms at 1000fps).
Returns:
1D numpy array of beat timestamps in seconds, sorted chronologically.
"""
drum_stem_path = Path(drum_stem_path)
# Step 0: Bandpass filter to isolate kick drum range (50-200 Hz)
filtered_path = _bandpass_filter(drum_stem_path)
# Step 1: RNN produces beat activation function (probability per frame at 100fps)
act_proc = RNNBeatProcessor()
activations = act_proc(str(filtered_path))
# Clean up temp file
filtered_path.unlink(missing_ok=True)
# Step 2: Interpolate to higher fps for finer timestamp resolution (1ms at 1000fps)
if fps != 100:
from scipy.interpolate import interp1d
n_frames = len(activations)
t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False)
n_new = int(n_frames * fps / 100)
t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False)
activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new)
activations = np.clip(activations, 0, None) # cubic spline can go negative
# Step 3: DBN decodes activations into beat timestamps
# correct=False lets the DBN place beats using its own high-res state space
# instead of snapping to the coarse 100fps activation peaks
beat_proc = DBNBeatTrackingProcessor(
min_bpm=min_bpm,
max_bpm=max_bpm,
transition_lambda=transition_lambda,
fps=fps,
correct=False,
)
beats = beat_proc(activations)
return beats
def detect_drop(
audio_path: str | Path,
beat_times: np.ndarray,
window_sec: float = 0.5,
) -> float:
"""Find the beat where the biggest energy jump occurs (the drop).
Computes RMS energy in a window around each beat and returns the beat
with the largest increase compared to the previous beat.
Args:
audio_path: Path to the full mix audio file.
beat_times: Array of beat timestamps in seconds.
window_sec: Duration of the analysis window around each beat.
Returns:
Timestamp (seconds) of the detected drop beat.
"""
import librosa
y, sr = librosa.load(str(audio_path), sr=None, mono=True)
half_win = int(window_sec / 2 * sr)
rms_values = []
for t in beat_times:
center = int(t * sr)
start = max(0, center - half_win)
end = min(len(y), center + half_win)
segment = y[start:end]
rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0
rms_values.append(rms)
rms_values = np.array(rms_values)
# Find largest positive jump between consecutive beats
diffs = np.diff(rms_values)
drop_idx = int(np.argmax(diffs)) + 1 # +1 because diff shifts by one
drop_time = float(beat_times[drop_idx])
print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s "
f"(energy jump: {diffs[drop_idx - 1]:.4f})")
return drop_time
def select_beats(
beats: np.ndarray,
max_duration: float = 15.0,
min_interval: float = 0.3,
) -> np.ndarray:
"""Select a subset of beats for video generation.
Filters beats to fit within a duration limit and enforces a minimum
interval between consecutive beats (to avoid generating too many frames).
Args:
beats: Array of beat timestamps in seconds.
max_duration: Maximum video duration in seconds.
min_interval: Minimum time between selected beats in seconds.
Beats closer together than this are skipped.
Returns:
Filtered array of beat timestamps.
"""
if len(beats) == 0:
return beats
# Trim to max duration
beats = beats[beats <= max_duration]
if len(beats) == 0:
return beats
# Enforce minimum interval between beats
selected = [beats[0]]
for beat in beats[1:]:
if beat - selected[-1] >= min_interval:
selected.append(beat)
return np.array(selected)
def save_beats(
beats: np.ndarray,
output_path: str | Path,
) -> Path:
"""Save beat timestamps to a JSON file.
Format matches the project convention (same style as lyrics.json):
a list of objects with beat index and timestamp.
Args:
beats: Array of beat timestamps in seconds.
output_path: Path to save the JSON file.
Returns:
Path to the saved JSON file.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
data = [
{"beat": i + 1, "time": round(float(t), 3)}
for i, t in enumerate(beats)
]
with open(output_path, "w") as f:
json.dump(data, f, indent=2)
return output_path
def run(
drum_stem_path: str | Path,
output_dir: Optional[str | Path] = None,
min_bpm: float = 55.0,
max_bpm: float = 215.0,
) -> dict:
"""Full beat detection pipeline: detect, select, and save.
Args:
drum_stem_path: Path to the isolated drum stem WAV file.
output_dir: Directory to save beats.json. Defaults to the
parent of the drum stem's parent (e.g. data/Gone/ if
stem is at data/Gone/stems/drums.wav).
min_bpm: Minimum expected tempo.
max_bpm: Maximum expected tempo.
Returns:
Dict with 'all_beats', 'selected_beats', and 'beats_path'.
"""
drum_stem_path = Path(drum_stem_path)
if output_dir is None:
# stems/drums.wav -> parent is stems/, parent.parent is data/Gone/
output_dir = drum_stem_path.parent.parent
output_dir = Path(output_dir)
all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm)
selected = select_beats(all_beats)
# Detect drop using the full mix audio (one level above stems/)
song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir
audio_path = None
for ext in [".wav", ".mp3", ".flac", ".m4a"]:
candidates = list(song_dir.glob(f"*{ext}"))
if candidates:
audio_path = candidates[0]
break
drop_time = None
if audio_path and len(all_beats) > 2:
drop_time = detect_drop(audio_path, all_beats)
beats_path = save_beats(all_beats, output_dir / "beats.json")
# Save drop time alongside beats
if drop_time is not None:
drop_path = output_dir / "drop.json"
with open(drop_path, "w") as f:
json.dump({"drop_time": round(drop_time, 3)}, f, indent=2)
return {
"all_beats": all_beats,
"selected_beats": selected,
"beats_path": beats_path,
"drop_time": drop_time,
}
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.beat_detector <drum_stem.wav>")
sys.exit(1)
result = run(sys.argv[1])
all_beats = result["all_beats"]
selected = result["selected_beats"]
print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})")
print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):")
for i, t in enumerate(selected):
print(f" Beat {i + 1}: {t:.3f}s")