File size: 8,895 Bytes
72f552e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""Beat/kick detection using madmom's RNN beat tracker."""

import json
import subprocess
import tempfile
from pathlib import Path
from typing import Optional

import numpy as np
from madmom.features.beats import DBNBeatTrackingProcessor, RNNBeatProcessor

# Bandpass filter: isolate kick drum frequency range (50-200 Hz)
HIGHPASS_CUTOFF = 50
LOWPASS_CUTOFF = 500


def _bandpass_filter(input_path: Path) -> Path:
    """Apply a 50-200 Hz bandpass filter to isolate kick drum transients.

    Returns path to a temporary filtered WAV file.
    """
    filtered = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
    filtered.close()
    subprocess.run([
        "ffmpeg", "-y",
        "-i", str(input_path),
        "-af", f"highpass=f={HIGHPASS_CUTOFF},lowpass=f={LOWPASS_CUTOFF}",
        str(filtered.name),
    ], check=True, capture_output=True)
    return Path(filtered.name)


def detect_beats(
    drum_stem_path: str | Path,
    min_bpm: float = 55.0,
    max_bpm: float = 215.0,
    transition_lambda: float = 100,
    fps: int = 1000,
) -> np.ndarray:
    """Detect beat timestamps from a drum stem using madmom.

    Uses an ensemble of bidirectional LSTMs to produce a beat activation
    function, then a Dynamic Bayesian Network to decode beat positions.

    Args:
        drum_stem_path: Path to the isolated drum stem WAV file.
        min_bpm: Minimum expected tempo. Narrow this if you know the song's
            approximate BPM for better accuracy.
        max_bpm: Maximum expected tempo.
        transition_lambda: Tempo smoothness — higher values penalise tempo
            changes more (100 = very steady, good for most pop/rock).
        fps: Frames per second for the DBN decoder. The RNN outputs at 100fps;
            higher values interpolate for finer timestamp resolution (1ms at 1000fps).

    Returns:
        1D numpy array of beat timestamps in seconds, sorted chronologically.
    """
    drum_stem_path = Path(drum_stem_path)

    # Step 0: Bandpass filter to isolate kick drum range (50-200 Hz)
    filtered_path = _bandpass_filter(drum_stem_path)

    # Step 1: RNN produces beat activation function (probability per frame at 100fps)
    act_proc = RNNBeatProcessor()
    activations = act_proc(str(filtered_path))

    # Clean up temp file
    filtered_path.unlink(missing_ok=True)

    # Step 2: Interpolate to higher fps for finer timestamp resolution (1ms at 1000fps)
    if fps != 100:
        from scipy.interpolate import interp1d
        n_frames = len(activations)
        t_orig = np.linspace(0, n_frames / 100, n_frames, endpoint=False)
        n_new = int(n_frames * fps / 100)
        t_new = np.linspace(0, n_frames / 100, n_new, endpoint=False)
        activations = interp1d(t_orig, activations, kind="cubic", fill_value="extrapolate")(t_new)
        activations = np.clip(activations, 0, None)  # cubic spline can go negative

    # Step 3: DBN decodes activations into beat timestamps
    # correct=False lets the DBN place beats using its own high-res state space
    # instead of snapping to the coarse 100fps activation peaks
    beat_proc = DBNBeatTrackingProcessor(
        min_bpm=min_bpm,
        max_bpm=max_bpm,
        transition_lambda=transition_lambda,
        fps=fps,
        correct=False,
    )
    beats = beat_proc(activations)

    return beats


def detect_drop(
    audio_path: str | Path,
    beat_times: np.ndarray,
    window_sec: float = 0.5,
) -> float:
    """Find the beat where the biggest energy jump occurs (the drop).

    Computes RMS energy in a window around each beat and returns the beat
    with the largest increase compared to the previous beat.

    Args:
        audio_path: Path to the full mix audio file.
        beat_times: Array of beat timestamps in seconds.
        window_sec: Duration of the analysis window around each beat.

    Returns:
        Timestamp (seconds) of the detected drop beat.
    """
    import librosa

    y, sr = librosa.load(str(audio_path), sr=None, mono=True)
    half_win = int(window_sec / 2 * sr)

    rms_values = []
    for t in beat_times:
        center = int(t * sr)
        start = max(0, center - half_win)
        end = min(len(y), center + half_win)
        segment = y[start:end]
        rms = np.sqrt(np.mean(segment ** 2)) if len(segment) > 0 else 0.0
        rms_values.append(rms)

    rms_values = np.array(rms_values)

    # Find largest positive jump between consecutive beats
    diffs = np.diff(rms_values)
    drop_idx = int(np.argmax(diffs)) + 1  # +1 because diff shifts by one
    drop_time = float(beat_times[drop_idx])

    print(f"  Drop detected at beat {drop_idx + 1}: {drop_time:.3f}s "
          f"(energy jump: {diffs[drop_idx - 1]:.4f})")
    return drop_time


def select_beats(
    beats: np.ndarray,
    max_duration: float = 15.0,
    min_interval: float = 0.3,
) -> np.ndarray:
    """Select a subset of beats for video generation.

    Filters beats to fit within a duration limit and enforces a minimum
    interval between consecutive beats (to avoid generating too many frames).

    Args:
        beats: Array of beat timestamps in seconds.
        max_duration: Maximum video duration in seconds.
        min_interval: Minimum time between selected beats in seconds.
            Beats closer together than this are skipped.

    Returns:
        Filtered array of beat timestamps.
    """
    if len(beats) == 0:
        return beats

    # Trim to max duration
    beats = beats[beats <= max_duration]

    if len(beats) == 0:
        return beats

    # Enforce minimum interval between beats
    selected = [beats[0]]
    for beat in beats[1:]:
        if beat - selected[-1] >= min_interval:
            selected.append(beat)

    return np.array(selected)


def save_beats(
    beats: np.ndarray,
    output_path: str | Path,
) -> Path:
    """Save beat timestamps to a JSON file.

    Format matches the project convention (same style as lyrics.json):
    a list of objects with beat index and timestamp.

    Args:
        beats: Array of beat timestamps in seconds.
        output_path: Path to save the JSON file.

    Returns:
        Path to the saved JSON file.
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    data = [
        {"beat": i + 1, "time": round(float(t), 3)}
        for i, t in enumerate(beats)
    ]

    with open(output_path, "w") as f:
        json.dump(data, f, indent=2)

    return output_path


def run(
    drum_stem_path: str | Path,
    output_dir: Optional[str | Path] = None,
    min_bpm: float = 55.0,
    max_bpm: float = 215.0,
) -> dict:
    """Full beat detection pipeline: detect, select, and save.

    Args:
        drum_stem_path: Path to the isolated drum stem WAV file.
        output_dir: Directory to save beats.json. Defaults to the
            parent of the drum stem's parent (e.g. data/Gone/ if
            stem is at data/Gone/stems/drums.wav).
        min_bpm: Minimum expected tempo.
        max_bpm: Maximum expected tempo.

    Returns:
        Dict with 'all_beats', 'selected_beats', and 'beats_path'.
    """
    drum_stem_path = Path(drum_stem_path)

    if output_dir is None:
        # stems/drums.wav -> parent is stems/, parent.parent is data/Gone/
        output_dir = drum_stem_path.parent.parent
    output_dir = Path(output_dir)

    all_beats = detect_beats(drum_stem_path, min_bpm=min_bpm, max_bpm=max_bpm)
    selected = select_beats(all_beats)

    # Detect drop using the full mix audio (one level above stems/)
    song_dir = output_dir.parent if output_dir.name.startswith("run_") else output_dir
    audio_path = None
    for ext in [".wav", ".mp3", ".flac", ".m4a"]:
        candidates = list(song_dir.glob(f"*{ext}"))
        if candidates:
            audio_path = candidates[0]
            break

    drop_time = None
    if audio_path and len(all_beats) > 2:
        drop_time = detect_drop(audio_path, all_beats)

    beats_path = save_beats(all_beats, output_dir / "beats.json")

    # Save drop time alongside beats
    if drop_time is not None:
        drop_path = output_dir / "drop.json"
        with open(drop_path, "w") as f:
            json.dump({"drop_time": round(drop_time, 3)}, f, indent=2)

    return {
        "all_beats": all_beats,
        "selected_beats": selected,
        "beats_path": beats_path,
        "drop_time": drop_time,
    }


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python -m src.beat_detector <drum_stem.wav>")
        sys.exit(1)

    result = run(sys.argv[1])
    all_beats = result["all_beats"]
    selected = result["selected_beats"]

    print(f"Detected {len(all_beats)} beats (saved to {result['beats_path']})")
    print(f"Selected {len(selected)} beats (max 15s, min 0.3s apart):")
    for i, t in enumerate(selected):
        print(f"  Beat {i + 1}: {t:.3f}s")