File size: 8,895 Bytes
72f552e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | """Beat/kick detection using madmom's RNN beat tracker."""
import json
import subprocess
import tempfile
from pathlib import Path
from typing import Optional
import numpy as np
from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor
# Bandpass filter: isolate kick drum frequency range (50-200 Hz)
HIGHPASS_CUTOFF = 50
LOWPASS_CUTOFF = 500
def _bandpass_filter(input_path: Path) -> Path:
"""Apply a 50-200 Hz bandpass filter to isolate kick drum transients.
Returns path to a temporary filtered WAV file.
"""
filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
filtered.close()
subprocess.run([
"ffmpeg", "-y",
"-i", str(input_path),
"-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}",
str(filtered.name),
], check=True, capture_output=True)
return Path(filtered.name)
def detect_beats(
drum_stem_path: str | Path,
min_bpm: float = 55.0,
max_bpm: float = 215.0,
transition_lambda: float = 100,
fps: int = 1000,
) -> np.ndarray:
"""Detect beat timestamps from a drum stem using madmom.
Uses an ensemble of bidirectional LSTMs to produce a beat activation
function, then a Dynamic Bayesian Network to decode beat positions.
Args:
drum_stem_path: Path to the isolated drum stem WAV file.
min_bpm: Minimum expected tempo. Narrow this if you know the song's
approximate BPM for better accuracy.
max_bpm: Maximum expected tempo.
transition_lambda: Tempo smoothness — higher values penalise tempo
changes more (100 = very steady, good for most pop/rock).
fps: Frames per second for the DBN decoder. The RNN outputs at 100fps;
higher values interpolate for finer timestamp resolution (1ms at 1000fps).
Returns:
1D numpy array of beat timestamps in seconds, sorted chronologically.
"""
drum_stem_path = Path(drum_stem_path)
# Step 0: Bandpass filter to isolate kick drum range (50-200 Hz)
filtered_path = _bandpass_filter(drum_stem_path)
# Step 1: RNN produces beat activation function (probability per frame at 100fps)
act_proc = RNNBeatProcessor()
activations = act_proc(str(filtered_path))
# Clean up temp file
filtered_path.unlink(missing_ok=True)
# Step 2: Interpolate to higher fps for finer timestamp resolution (1ms at 1000fps)
if fps != 100:
from scipy.interpolate import interp1d
n_frames = len(activations)
t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False)
n_new = int(n_frames * fps / 100)
t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False)
activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new)
activations = np.clip(activations, 0, None) # cubic spline can go negative
# Step 3: DBN decodes activations into beat timestamps
# correct=False lets the DBN place beats using its own high-res state space
# instead of snapping to the coarse 100fps activation peaks
beat_proc = DBNBeatTrackingProcessor(
min_bpm=min_bpm,
max_bpm=max_bpm,
transition_lambda=transition_lambda,
fps=fps,
correct=False,
)
beats = beat_proc(activations)
return beats
def detect_drop(
audio_path: str | Path,
beat_times: np.ndarray,
window_sec: float = 0.5,
) -> float:
"""Find the beat where the biggest energy jump occurs (the drop).
Computes RMS energy in a window around each beat and returns the beat
with the largest increase compared to the previous beat.
Args:
audio_path: Path to the full mix audio file.
beat_times: Array of beat timestamps in seconds.
window_sec: Duration of the analysis window around each beat.
Returns:
Timestamp (seconds) of the detected drop beat.
"""
import librosa
y, sr = librosa.load(str(audio_path), sr=None, mono=True)
half_win = int(window_sec / 2 * sr)
rms_values = []
for t in beat_times:
center = int(t * sr)
start = max(0, center - half_win)
end = min(len(y), center + half_win)
segment = y[start:end]
rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0
rms_values.append(rms)
rms_values = np.array(rms_values)
# Find largest positive jump between consecutive beats
diffs = np.diff(rms_values)
drop_idx = int(np.argmax(diffs)) + 1 # +1 because diff shifts by one
drop_time = float(beat_times[drop_idx])
print(f" Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s "
f"(energy jump: {diffs[drop_idx - 1]:.4f})")
return drop_time
def select_beats(
beats: np.ndarray,
max_duration: float = 15.0,
min_interval: float = 0.3,
) -> np.ndarray:
"""Select a subset of beats for video generation.
Filters beats to fit within a duration limit and enforces a minimum
interval between consecutive beats (to avoid generating too many frames).
Args:
beats: Array of beat timestamps in seconds.
max_duration: Maximum video duration in seconds.
min_interval: Minimum time between selected beats in seconds.
Beats closer together than this are skipped.
Returns:
Filtered array of beat timestamps.
"""
if len(beats) == 0:
return beats
# Trim to max duration
beats = beats[beats <= max_duration]
if len(beats) == 0:
return beats
# Enforce minimum interval between beats
selected = [beats[0]]
for beat in beats[1:]:
if beat - selected[-1] >= min_interval:
selected.append(beat)
return np.array(selected)
def save_beats(
beats: np.ndarray,
output_path: str | Path,
) -> Path:
"""Save beat timestamps to a JSON file.
Format matches the project convention (same style as lyrics.json):
a list of objects with beat index and timestamp.
Args:
beats: Array of beat timestamps in seconds.
output_path: Path to save the JSON file.
Returns:
Path to the saved JSON file.
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
data = [
{"beat": i + 1, "time": round(float(t), 3)}
for i, t in enumerate(beats)
]
with open(output_path, "w") as f:
json.dump(data, f, indent=2)
return output_path
def run(
drum_stem_path: str | Path,
output_dir: Optional[str | Path] = None,
min_bpm: float = 55.0,
max_bpm: float = 215.0,
) -> dict:
"""Full beat detection pipeline: detect, select, and save.
Args:
drum_stem_path: Path to the isolated drum stem WAV file.
output_dir: Directory to save beats.json. Defaults to the
parent of the drum stem's parent (e.g. data/Gone/ if
stem is at data/Gone/stems/drums.wav).
min_bpm: Minimum expected tempo.
max_bpm: Maximum expected tempo.
Returns:
Dict with 'all_beats', 'selected_beats', and 'beats_path'.
"""
drum_stem_path = Path(drum_stem_path)
if output_dir is None:
# stems/drums.wav -> parent is stems/, parent.parent is data/Gone/
output_dir = drum_stem_path.parent.parent
output_dir = Path(output_dir)
all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm)
selected = select_beats(all_beats)
# Detect drop using the full mix audio (one level above stems/)
song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir
audio_path = None
for ext in [".wav", ".mp3", ".flac", ".m4a"]:
candidates = list(song_dir.glob(f"*{ext}"))
if candidates:
audio_path = candidates[0]
break
drop_time = None
if audio_path and len(all_beats) > 2:
drop_time = detect_drop(audio_path, all_beats)
beats_path = save_beats(all_beats, output_dir / "beats.json")
# Save drop time alongside beats
if drop_time is not None:
drop_path = output_dir / "drop.json"
with open(drop_path, "w") as f:
json.dump({"drop_time": round(drop_time, 3)}, f, indent=2)
return {
"all_beats": all_beats,
"selected_beats": selected,
"beats_path": beats_path,
"drop_time": drop_time,
}
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python -m src.beat_detector <drum_stem.wav>")
sys.exit(1)
result = run(sys.argv[1])
all_beats = result["all_beats"]
selected = result["selected_beats"]
print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})")
print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):")
for i, t in enumerate(selected):
print(f" Beat {i + 1}: {t:.3f}s")
|