JacobLinCool commited on
Commit
975d5cb
·
verified ·
1 Parent(s): 28666cc

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ outputs/*
13
+ !outputs/baseline/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
.ruff_cache/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Automatically created by ruff.
2
+ *
.ruff_cache/0.12.5/14293067367466839361 ADDED
Binary file (495 Bytes). View file
 
.ruff_cache/CACHEDIR.TAG ADDED
@@ -0,0 +1 @@
 
 
1
+ Signature: 8a477f597d28d172789f06886806bc55
README.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Beat Tracking Challenge
2
+
3
+ A challenge for detecting beats and downbeats in music audio, with a focus on handling dynamic tempo changes common in rhythm game charts.
4
+
5
+ ## Goal
6
+
7
+ The goal is to **detect and identify beats and downbeats** in audio to assist composers by providing a flexible timing grid when working with samples that have dynamic BPM changes.
8
+
9
+ - **Beat**: A regular pulse in music (e.g., quarter notes in 4/4 time)
10
+ - **Downbeat**: The first beat of each measure (the "1" in counting "1-2-3-4")
11
+
12
+ This is particularly useful for:
13
+ - Music production with samples of varying tempos
14
+ - Rhythm game chart creation and verification
15
+ - Audio analysis and music information retrieval (MIR)
16
+
17
+ ---
18
+
19
+ ## Dataset
20
+
21
+ The dataset is derived from Taiko no Tatsujin rhythm game charts, providing high-quality human-annotated beat and downbeat ground truth.
22
+
23
+ **Source**: [`JacobLinCool/taiko-1000-parsed`](https://huggingface.co/datasets/JacobLinCool/taiko-1000-parsed)
24
+
25
+ | Split | Tracks | Duration | Description |
26
+ |-------|--------|----------|-------------|
27
+ | `train` | ~900 | 1-3 min each | Training data with beat/downbeat annotations |
28
+ | `test` | ~100 | 1-3 min each | Held-out test set for final evaluation |
29
+
30
+ ### Data Features
31
+
32
+ Each example contains:
33
+
34
+ | Field | Type | Description |
35
+ |-------|------|-------------|
36
+ | `audio` | `Audio` | Audio waveform at 16kHz sample rate |
37
+ | `title` | `str` | Track title |
38
+ | `beats` | `list[float]` | Beat timestamps in seconds |
39
+ | `downbeats` | `list[float]` | Downbeat timestamps in seconds |
40
+
41
+ ### Dataset Characteristics
42
+
43
+ - **Dynamic BPM**: Many tracks feature tempo changes mid-song
44
+ - **Variable Time Signatures**: Common patterns include 4/4, 3/4, 6/8, and more exotic meters
45
+ - **Diverse Genres**: Japanese pop, anime themes, classical arrangements, electronic music
46
+ - **High-Quality Annotations**: Derived from professional rhythm game charts
47
+
48
+ ---
49
+
50
+ ## Evaluation Metrics
51
+
52
+ The evaluation considers both **timing accuracy** and **metrical correctness**. Models are evaluated on both beat and downbeat detection tasks.
53
+
54
+ ### Primary Metrics
55
+
56
+ #### 1. Weighted F1-Score (Main Ranking Metric)
57
+
58
+ F1-scores are calculated at multiple timing thresholds (3ms to 30ms), then combined with inverse-threshold weighting:
59
+
60
+ | Threshold | Weight | Rationale |
61
+ |-----------|--------|-----------|
62
+ | 3ms | 1.000 | Full weight for highest precision |
63
+ | 6ms | 0.500 | Half weight |
64
+ | 9ms | 0.333 | One-third weight |
65
+ | 12ms | 0.250 | ... |
66
+ | 15ms | 0.200 | |
67
+ | 18ms | 0.167 | |
68
+ | 21ms | 0.143 | |
69
+ | 24ms | 0.125 | |
70
+ | 27ms | 0.111 | |
71
+ | 30ms | 0.100 | Minimum weight for coarsest threshold |
72
+
73
+ **Formula:**
74
+ ```
75
+ Weighted F1 = Σ(w_t × F1_t) / Σ(w_t)
76
+ where w_t = 3ms / t (inverse threshold weighting)
77
+ ```
78
+
79
+ This weighting scheme rewards models that achieve high precision at tight tolerances while still considering coarser thresholds.
80
+
81
+ #### 2. Continuity Metrics (CMLt, AMLt)
82
+
83
+ Based on the MIREX beat tracking evaluation protocol using `mir_eval`:
84
+
85
+ | Metric | Full Name | Description |
86
+ |--------|-----------|-------------|
87
+ | **CMLt** | Correct Metrical Level Total | Percentage of beats correctly tracked at the exact metrical level (±17.5% of beat interval) |
88
+ | **AMLt** | Any Metrical Level Total | Same as CMLt, but allows for acceptable metrical variations (double/half tempo, off-beat) |
89
+ | **CMLc** | Correct Metrical Level Continuous | Longest continuous correctly-tracked segment at exact metrical level |
90
+ | **AMLc** | Any Metrical Level Continuous | Longest continuous segment at any acceptable metrical level |
91
+
92
+ **Note:** Continuity metrics use a default `min_beat_time=5.0s` (skipping the first 5 seconds) to avoid evaluating potentially unstable tempo at the beginning of tracks.
93
+
94
+ ### Metric Interpretation
95
+
96
+ | Metric | What it measures | Good Score |
97
+ |--------|------------------|------------|
98
+ | Weighted F1 | Precise timing accuracy | > 0.7 |
99
+ | CMLt | Correct tempo tracking | > 0.8 |
100
+ | AMLt | Tempo tracking (flexible) | > 0.9 |
101
+ | CMLc | Longest stable segment | > 0.5 |
102
+
103
+ ### Evaluation Summary
104
+
105
+ For each model, we report:
106
+
107
+ ```
108
+ Beat Detection:
109
+ Weighted F1: X.XXXX
110
+ CMLt: X.XXXX AMLt: X.XXXX
111
+ CMLc: X.XXXX AMLc: X.XXXX
112
+
113
+ Downbeat Detection:
114
+ Weighted F1: X.XXXX
115
+ CMLt: X.XXXX AMLt: X.XXXX
116
+ CMLc: X.XXXX AMLc: X.XXXX
117
+
118
+ Combined Weighted F1: X.XXXX (average of beat and downbeat)
119
+ ```
120
+
121
+ ---
122
+
123
+ ## Quick Start
124
+
125
+ ### Setup
126
+
127
+ ```bash
128
+ uv sync
129
+ ```
130
+
131
+ ### Train Baseline Model
132
+
133
+ ```bash
134
+ # Train both beat and downbeat models
135
+ uv run -m exp.baseline.train
136
+
137
+ # Train specific model only
138
+ uv run -m exp.baseline.train --target beats
139
+ uv run -m exp.baseline.train --target downbeats
140
+ ```
141
+
142
+ ### Run Evaluation
143
+
144
+ ```bash
145
+ # Basic evaluation
146
+ uv run -m exp.baseline.eval
147
+
148
+ # Full evaluation with visualization and audio
149
+ uv run -m exp.baseline.eval --visualize --synthesize --summary-plot
150
+
151
+ # Evaluate on more samples with custom output directory
152
+ uv run -m exp.baseline.eval --num-samples 50 --output-dir outputs/my_eval
153
+ ```
154
+
155
+ ### Evaluation Options
156
+
157
+ | Option | Description |
158
+ |--------|-------------|
159
+ | `--model-dir DIR` | Model directory (default: `outputs/baseline`) |
160
+ | `--num-samples N` | Number of samples to evaluate (default: 20) |
161
+ | `--output-dir DIR` | Output directory (default: `outputs/eval`) |
162
+ | `--visualize` | Generate visualization plots for each track |
163
+ | `--synthesize` | Generate audio files with click tracks |
164
+ | `--viz-tracks N` | Number of tracks to visualize/synthesize (default: 5) |
165
+ | `--time-range START END` | Limit visualization time range (seconds) |
166
+ | `--click-volume FLOAT` | Click sound volume (0.0 to 1.0, default: 0.5) |
167
+ | `--summary-plot` | Generate summary evaluation bar charts |
168
+
169
+ ---
170
+
171
+ ## Visualization & Audio Tools
172
+
173
+ ### Beat Visualization
174
+
175
+ Generate plots comparing predicted vs ground truth beats:
176
+
177
+ ```bash
178
+ uv run -m exp.baseline.eval --visualize --viz-tracks 10
179
+ ```
180
+
181
+ Output: `outputs/eval/plots/track_XXX.png`
182
+
183
+ ### Click Track Audio
184
+
185
+ Generate audio files with click sounds overlaid on the original music:
186
+
187
+ ```bash
188
+ uv run -m exp.baseline.eval --synthesize
189
+ ```
190
+
191
+ Output files in `outputs/eval/audio/`:
192
+ - `track_XXX_pred.wav` - Original audio + predicted beat clicks (1000Hz beat, 1500Hz downbeat)
193
+ - `track_XXX_gt.wav` - Original audio + ground truth clicks (800Hz beat, 1200Hz downbeat)
194
+ - `track_XXX_both.wav` - Original audio + both prediction and ground truth clicks
195
+
196
+ ### Summary Plot
197
+
198
+ Generate bar charts summarizing F1 scores and continuity metrics:
199
+
200
+ ```bash
201
+ uv run -m exp.baseline.eval --summary-plot
202
+ ```
203
+
204
+ Output: `outputs/eval/evaluation_summary.png`
205
+
206
+ ---
207
+
208
+ ## Baseline Model
209
+
210
+ The provided baseline implements the **Onset Detection CNN (ODCNN)** architecture:
211
+
212
+ ### Architecture
213
+
214
+ - **Input**: Multi-view mel spectrogram (3 window sizes: 23ms, 46ms, 93ms)
215
+ - **CNN Backbone**: 3 convolutional blocks with max pooling
216
+ - **Output**: Frame-level beat/downbeat probability
217
+
218
+ ### Training Details
219
+
220
+ - **Optimizer**: SGD with momentum (0.9)
221
+ - **Learning Rate**: 0.05 with cosine annealing
222
+ - **Loss**: Binary Cross-Entropy
223
+ - **Epochs**: 50
224
+ - **Batch Size**: 512
225
+
226
+ ### Inference Pipeline
227
+
228
+ 1. Compute multi-view mel spectrogram on GPU
229
+ 2. Sliding window inference (±7 frames context = ±70ms)
230
+ 3. Hamming window smoothing
231
+ 4. Peak picking with threshold (0.5) and minimum distance (5 frames)
232
+
233
+ ---
234
+
235
+ ## Project Structure
236
+
237
+ ```
238
+ exp-onset/
239
+ ├── exp/
240
+ │ ├── baseline/ # Baseline model implementation
241
+ │ │ ├── model.py # ODCNN architecture
242
+ │ │ ├── train.py # Training script
243
+ │ │ ├── eval.py # Evaluation with viz/audio
244
+ │ │ ├── data.py # Dataset wrapper
245
+ │ │ └── utils.py # Spectrogram processing
246
+ │ └── data/
247
+ │ ├── load.py # Dataset loading & preprocessing
248
+ │ ├── eval.py # Evaluation metrics (F1, CML, AML)
249
+ │ ├── audio.py # Click track synthesis
250
+ │ └── viz.py # Visualization utilities
251
+ ├── outputs/
252
+ │ ├── baseline/ # Trained models
253
+ │ │ ├── beats/ # Beat detection model
254
+ │ │ └── downbeats/ # Downbeat detection model
255
+ │ └── eval/ # Evaluation outputs
256
+ │ ├── plots/ # Visualization images
257
+ │ ├── audio/ # Click track audio files
258
+ │ └── evaluation_summary.png
259
+ ├── README.md
260
+ ├── DATASET.md # Raw dataset specification
261
+ └── pyproject.toml
262
+ ```
263
+
264
+ ---
265
+
266
+ ## License
267
+
268
+ This project is for research and educational purposes. The dataset is derived from publicly available rhythm game charts.
exp/__init__.py ADDED
File without changes
exp/baseline/__init__.py ADDED
File without changes
exp/baseline/data.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from .utils import extract_context
6
+
7
+
8
+ class BeatTrackingDataset(Dataset):
9
+ def __init__(
10
+ self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160
11
+ ):
12
+ """
13
+ Args:
14
+ hf_dataset: HuggingFace dataset object
15
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
16
+ """
17
+ self.sr = sample_rate
18
+ self.hop_length = hop_length
19
+ self.target_type = target_type
20
+
21
+ # Context window size in samples (7 frames = 70ms at 100fps)
22
+ self.context_frames = 7
23
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + max(
24
+ [368, 736, 1488]
25
+ ) # extra for FFT window
26
+
27
+ # Cache audio arrays in memory for fast access
28
+ self.audio_cache = []
29
+ self.indices = []
30
+ self._prepare_indices(hf_dataset)
31
+
32
+ def _prepare_indices(self, hf_dataset):
33
+ """
34
+ Prepares balanced indices and caches audio.
35
+ Paper Section 4.5: Uses "Fuzzier" training examples (neighbors weighted less).
36
+ """
37
+ print(f"Preparing dataset indices for target: {self.target_type}...")
38
+
39
+ for i, item in tqdm(
40
+ enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
41
+ ):
42
+ # Cache audio array (convert to numpy if tensor)
43
+ audio = item["audio"]["array"]
44
+ if hasattr(audio, "numpy"):
45
+ audio = audio.numpy()
46
+ self.audio_cache.append(audio)
47
+
48
+ # Calculate total frames available in audio
49
+ audio_len = len(audio)
50
+ n_frames = int(audio_len / self.hop_length)
51
+
52
+ # Select ground truth based on target_type
53
+ if self.target_type == "downbeats":
54
+ # Only downbeats are positives
55
+ gt_times = item["downbeats"]
56
+ else:
57
+ # All beats are positives (downbeats are also beats)
58
+ gt_times = item["beats"]
59
+
60
+ # Convert to list if tensor
61
+ if hasattr(gt_times, "tolist"):
62
+ gt_times = gt_times.tolist()
63
+
64
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
65
+
66
+ # --- Positive Examples (with Fuzziness) ---
67
+ # "define a single frame before and after each annotated onset to be additional positive examples"
68
+ pos_frames = set()
69
+ for bf in gt_frames:
70
+ if 0 <= bf < n_frames:
71
+ self.indices.append((i, bf, 1.0)) # Center frame (Sharp onset)
72
+ pos_frames.add(bf)
73
+
74
+ # Neighbors weighted at 0.25
75
+ if 0 <= bf - 1 < n_frames:
76
+ self.indices.append((i, bf - 1, 0.25))
77
+ pos_frames.add(bf - 1)
78
+ if 0 <= bf + 1 < n_frames:
79
+ self.indices.append((i, bf + 1, 0.25))
80
+ pos_frames.add(bf + 1)
81
+
82
+ # --- Negative Examples ---
83
+ # Paper uses "all others as negative", but we balance 2:1 for stable SGD.
84
+ num_pos = len(pos_frames)
85
+ num_neg = num_pos * 2
86
+
87
+ count = 0
88
+ attempts = 0
89
+ while count < num_neg and attempts < num_neg * 5:
90
+ f = np.random.randint(0, n_frames)
91
+ if f not in pos_frames:
92
+ self.indices.append((i, f, 0.0))
93
+ count += 1
94
+ attempts += 1
95
+
96
+ print(
97
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
98
+ )
99
+
100
+ def __len__(self):
101
+ return len(self.indices)
102
+
103
+ def __getitem__(self, idx):
104
+ track_idx, frame_idx, label = self.indices[idx]
105
+
106
+ # Fast lookup from cache
107
+ audio = self.audio_cache[track_idx]
108
+ audio_len = len(audio)
109
+
110
+ # Calculate sample range for context window
111
+ center_sample = frame_idx * self.hop_length
112
+ half_context = self.context_samples // 2
113
+ start = center_sample - half_context
114
+ end = center_sample + half_context
115
+
116
+ # Handle padding if needed
117
+ pad_left = max(0, -start)
118
+ pad_right = max(0, end - audio_len)
119
+ start = max(0, start)
120
+ end = min(audio_len, end)
121
+
122
+ # Extract audio chunk
123
+ chunk = audio[start:end]
124
+ if pad_left > 0 or pad_right > 0:
125
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
126
+
127
+ waveform = torch.tensor(chunk, dtype=torch.float32)
128
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline/eval.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ODCNN
9
+ from .utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ spec = torch.nn.functional.pad(spec, (7, 7)) # Pad time
31
+ windows = spec.unfold(3, 15, 1) # (1, 3, 80, Time, 15)
32
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 15)
33
+
34
+ # Inference
35
+ activations = []
36
+ batch_size = 512
37
+ for i in range(0, len(windows), batch_size):
38
+ batch = windows[i : i + batch_size]
39
+ out = model(batch)
40
+ activations.append(out.cpu().numpy())
41
+
42
+ return np.concatenate(activations).flatten()
43
+
44
+
45
+ def pick_peaks(activations, hop_length=160, sr=16000):
46
+ """
47
+ Smooth with Hamming window and report local maxima.
48
+ """
49
+ # Smoothing
50
+ window = np.hamming(5)
51
+ window /= window.sum()
52
+ smoothed = np.convolve(activations, window, mode="same")
53
+
54
+ # Peak Picking
55
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
56
+
57
+ timestamps = peaks * hop_length / sr
58
+ return timestamps.tolist()
59
+
60
+
61
+ def visualize_track(
62
+ audio: np.ndarray,
63
+ sr: int,
64
+ pred_beats: list[float],
65
+ pred_downbeats: list[float],
66
+ gt_beats: list[float],
67
+ gt_downbeats: list[float],
68
+ output_dir: str,
69
+ track_idx: int,
70
+ time_range: tuple[float, float] | None = None,
71
+ ):
72
+ """
73
+ Create and save visualizations for a single track.
74
+ """
75
+ from ..data.viz import plot_waveform_with_beats, save_figure
76
+
77
+ os.makedirs(output_dir, exist_ok=True)
78
+
79
+ # Full waveform plot
80
+ fig = plot_waveform_with_beats(
81
+ audio,
82
+ sr,
83
+ pred_beats,
84
+ gt_beats,
85
+ pred_downbeats,
86
+ gt_downbeats,
87
+ title=f"Track {track_idx}: Beat Comparison",
88
+ time_range=time_range,
89
+ )
90
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
91
+
92
+
93
+ def synthesize_audio(
94
+ audio: np.ndarray,
95
+ sr: int,
96
+ pred_beats: list[float],
97
+ pred_downbeats: list[float],
98
+ gt_beats: list[float],
99
+ gt_downbeats: list[float],
100
+ output_dir: str,
101
+ track_idx: int,
102
+ click_volume: float = 0.5,
103
+ ):
104
+ """
105
+ Create and save audio files with click tracks for a single track.
106
+ """
107
+ from ..data.audio import create_comparison_audio, save_audio
108
+
109
+ os.makedirs(output_dir, exist_ok=True)
110
+
111
+ # Create comparison audio
112
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
113
+ audio,
114
+ pred_beats,
115
+ pred_downbeats,
116
+ gt_beats,
117
+ gt_downbeats,
118
+ sr=sr,
119
+ click_volume=click_volume,
120
+ )
121
+
122
+ # Save audio files
123
+ save_audio(
124
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
125
+ )
126
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
127
+ save_audio(
128
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
129
+ )
130
+
131
+
132
+ def main():
133
+ parser = argparse.ArgumentParser(
134
+ description="Evaluate beat tracking models with visualization and audio synthesis"
135
+ )
136
+ parser.add_argument(
137
+ "--model-dir",
138
+ type=str,
139
+ default="outputs/baseline",
140
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
141
+ )
142
+ parser.add_argument(
143
+ "--num-samples",
144
+ type=int,
145
+ default=20,
146
+ help="Number of samples to evaluate",
147
+ )
148
+ parser.add_argument(
149
+ "--output-dir",
150
+ type=str,
151
+ default="outputs/eval",
152
+ help="Directory to save visualizations and audio",
153
+ )
154
+ parser.add_argument(
155
+ "--visualize",
156
+ action="store_true",
157
+ help="Generate visualization plots for each track",
158
+ )
159
+ parser.add_argument(
160
+ "--synthesize",
161
+ action="store_true",
162
+ help="Generate audio files with click tracks",
163
+ )
164
+ parser.add_argument(
165
+ "--viz-tracks",
166
+ type=int,
167
+ default=5,
168
+ help="Number of tracks to visualize/synthesize (default: 5)",
169
+ )
170
+ parser.add_argument(
171
+ "--time-range",
172
+ type=float,
173
+ nargs=2,
174
+ default=None,
175
+ metavar=("START", "END"),
176
+ help="Time range for visualization in seconds (default: full track)",
177
+ )
178
+ parser.add_argument(
179
+ "--click-volume",
180
+ type=float,
181
+ default=0.5,
182
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
183
+ )
184
+ parser.add_argument(
185
+ "--summary-plot",
186
+ action="store_true",
187
+ help="Generate summary evaluation plot",
188
+ )
189
+ args = parser.parse_args()
190
+
191
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
192
+
193
+ # Load BOTH models using from_pretrained
194
+ beat_model = None
195
+ downbeat_model = None
196
+
197
+ has_beats = False
198
+ has_downbeats = False
199
+
200
+ beats_dir = os.path.join(args.model_dir, "beats")
201
+ downbeats_dir = os.path.join(args.model_dir, "downbeats")
202
+
203
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")) or os.path.exists(
204
+ os.path.join(beats_dir, "pytorch_model.bin")
205
+ ):
206
+ beat_model = ODCNN.from_pretrained(beats_dir).to(DEVICE)
207
+ beat_model.eval()
208
+ has_beats = True
209
+ print(f"Loaded Beat Model from {beats_dir}")
210
+ else:
211
+ print(f"Warning: No beat model found in {beats_dir}")
212
+
213
+ if os.path.exists(
214
+ os.path.join(downbeats_dir, "model.safetensors")
215
+ ) or os.path.exists(os.path.join(downbeats_dir, "pytorch_model.bin")):
216
+ downbeat_model = ODCNN.from_pretrained(downbeats_dir).to(DEVICE)
217
+ downbeat_model.eval()
218
+ has_downbeats = True
219
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
220
+ else:
221
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
222
+
223
+ if not has_beats and not has_downbeats:
224
+ print("No models found. Please run training first.")
225
+ return
226
+
227
+ predictions = []
228
+ ground_truths = []
229
+ audio_data = [] # Store audio for visualization/synthesis
230
+
231
+ # Eval on specified number of tracks
232
+ test_set = ds["train"].select(range(args.num_samples))
233
+
234
+ print("Running evaluation...")
235
+ for i, item in enumerate(tqdm(test_set)):
236
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
237
+ waveform_device = waveform.to(DEVICE)
238
+
239
+ pred_entry = {"beats": [], "downbeats": []}
240
+
241
+ # 1. Predict Beats
242
+ if has_beats:
243
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
244
+ pred_entry["beats"] = pick_peaks(act_b)
245
+
246
+ # 2. Predict Downbeats
247
+ if has_downbeats:
248
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
249
+ pred_entry["downbeats"] = pick_peaks(act_d)
250
+
251
+ predictions.append(pred_entry)
252
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
253
+
254
+ # Store audio for later visualization/synthesis
255
+ if args.visualize or args.synthesize:
256
+ if i < args.viz_tracks:
257
+ audio_data.append(
258
+ {
259
+ "audio": waveform.numpy(),
260
+ "sr": item["audio"]["sampling_rate"],
261
+ "pred": pred_entry,
262
+ "gt": ground_truths[-1],
263
+ }
264
+ )
265
+
266
+ # Run evaluation
267
+ results = evaluate_all(predictions, ground_truths)
268
+ print(format_results(results))
269
+
270
+ # Create output directory
271
+ if args.visualize or args.synthesize or args.summary_plot:
272
+ os.makedirs(args.output_dir, exist_ok=True)
273
+
274
+ # Generate visualizations
275
+ if args.visualize:
276
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
277
+ viz_dir = os.path.join(args.output_dir, "plots")
278
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
279
+ time_range = tuple(args.time_range) if args.time_range else None
280
+ visualize_track(
281
+ data["audio"],
282
+ data["sr"],
283
+ data["pred"]["beats"],
284
+ data["pred"]["downbeats"],
285
+ data["gt"]["beats"],
286
+ data["gt"]["downbeats"],
287
+ viz_dir,
288
+ i,
289
+ time_range=time_range,
290
+ )
291
+ print(f"Saved visualizations to {viz_dir}")
292
+
293
+ # Generate audio with clicks
294
+ if args.synthesize:
295
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
296
+ audio_dir = os.path.join(args.output_dir, "audio")
297
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
298
+ synthesize_audio(
299
+ data["audio"],
300
+ data["sr"],
301
+ data["pred"]["beats"],
302
+ data["pred"]["downbeats"],
303
+ data["gt"]["beats"],
304
+ data["gt"]["downbeats"],
305
+ audio_dir,
306
+ i,
307
+ click_volume=args.click_volume,
308
+ )
309
+ print(f"Saved audio files to {audio_dir}")
310
+ print(" *_pred.wav - Original audio with predicted beat clicks")
311
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
312
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
313
+
314
+ # Generate summary plot
315
+ if args.summary_plot:
316
+ from ..data.viz import plot_evaluation_summary, save_figure
317
+
318
+ print("\nGenerating summary plot...")
319
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
320
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
321
+ save_figure(fig, summary_path)
322
+ print(f"Saved summary plot to {summary_path}")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ main()
exp/baseline/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class ODCNN(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, dropout_rate=0.5):
8
+ super().__init__()
9
+
10
+ # Input 3 channels, 80 bands
11
+ # Conv 1: 7x3 filters -> 10 maps
12
+ self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
13
+ self.relu1 = nn.ReLU() # ReLU improvement
14
+ self.pool1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
15
+
16
+ # Conv 2: 3x3 filters -> 20 maps
17
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
18
+ self.relu2 = nn.ReLU()
19
+ self.pool2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
20
+
21
+ # Flatten size calculation based on architecture
22
+ # (20 feature maps * 8 freq bands * 7 time frames)
23
+ self.flatten_size = 20 * 8 * 7
24
+
25
+ # Dropout on FC inputs
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ # 256 Hidden Units
29
+ self.fc1 = nn.Linear(self.flatten_size, 256)
30
+ self.relu_fc = nn.ReLU()
31
+
32
+ # Output Unit
33
+ self.fc2 = nn.Linear(256, 1)
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+ def forward(self, x):
37
+ x = self.conv1(x)
38
+ x = self.relu1(x)
39
+ x = self.pool1(x)
40
+
41
+ x = self.conv2(x)
42
+ x = self.relu2(x)
43
+ x = self.pool2(x)
44
+
45
+ x = x.view(x.size(0), -1)
46
+
47
+ x = self.dropout(x)
48
+ x = self.fc1(x)
49
+ x = self.relu_fc(x)
50
+
51
+ x = self.dropout(x)
52
+ x = self.fc2(x)
53
+ x = self.sigmoid(x)
54
+
55
+ return x
56
+
57
+
58
+ if __name__ == "__main__":
59
+ from torchinfo import summary
60
+
61
+ model = ODCNN()
62
+ summary(model, (1, 3, 80, 15))
exp/baseline/train.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ODCNN
11
+ from .data import BeatTrackingDataset
12
+ from .utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def train(target_type: str, output_dir: str):
17
+ # Note: Paper uses SGD with Momentum, Dropout, and ReLU
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ BATCH_SIZE = 512
20
+ EPOCHS = 50
21
+ LR = 0.05
22
+ MOMENTUM = 0.9
23
+ NUM_WORKERS = 4
24
+
25
+ print(f"--- Training Model for target: {target_type} ---")
26
+ print(f"Output directory: {output_dir}")
27
+
28
+ # Create output directory
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # TensorBoard writer
32
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
+
34
+ # Data - use existing train/test splits
35
+ train_dataset = BeatTrackingDataset(ds["train"], target_type=target_type)
36
+ val_dataset = BeatTrackingDataset(ds["test"], target_type=target_type)
37
+
38
+ train_loader = DataLoader(
39
+ train_dataset,
40
+ batch_size=BATCH_SIZE,
41
+ shuffle=True,
42
+ num_workers=NUM_WORKERS,
43
+ pin_memory=True,
44
+ prefetch_factor=4,
45
+ persistent_workers=True,
46
+ )
47
+ val_loader = DataLoader(
48
+ val_dataset,
49
+ batch_size=BATCH_SIZE,
50
+ shuffle=False,
51
+ num_workers=NUM_WORKERS,
52
+ pin_memory=True,
53
+ prefetch_factor=4,
54
+ persistent_workers=True,
55
+ )
56
+
57
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
58
+
59
+ # Model
60
+ model = ODCNN(dropout_rate=0.5).to(DEVICE)
61
+
62
+ # GPU Spectrogram Preprocessor
63
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
64
+
65
+ # Optimizer
66
+ optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
67
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
68
+ criterion = nn.BCELoss() # Binary Cross Entropy
69
+
70
+ best_val_loss = float("inf")
71
+ global_step = 0
72
+
73
+ for epoch in range(EPOCHS):
74
+ # Training
75
+ model.train()
76
+ total_train_loss = 0
77
+ for waveform, y in tqdm(
78
+ train_loader,
79
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
80
+ leave=False,
81
+ ):
82
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
83
+
84
+ # Compute spectrogram on GPU
85
+ with torch.no_grad():
86
+ spec = preprocessor(waveform) # (B, 3, 80, T)
87
+ # Normalize
88
+ mean = spec.mean(dim=(2, 3), keepdim=True)
89
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
90
+ spec = (spec - mean) / std
91
+ # Extract center context (T should be ~15 frames)
92
+ x = spec[:, :, :, 7:22] # center 15 frames
93
+
94
+ optimizer.zero_grad()
95
+ output = model(x)
96
+ loss = criterion(output, y)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ total_train_loss += loss.item()
101
+ global_step += 1
102
+
103
+ # Log batch loss
104
+ writer.add_scalar("train/batch_loss", loss.item(), global_step)
105
+
106
+ avg_train_loss = total_train_loss / len(train_loader)
107
+
108
+ # Validation
109
+ model.eval()
110
+ total_val_loss = 0
111
+ with torch.no_grad():
112
+ for waveform, y in tqdm(
113
+ val_loader,
114
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
115
+ leave=False,
116
+ ):
117
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
118
+
119
+ # Compute spectrogram on GPU
120
+ spec = preprocessor(waveform) # (B, 3, 80, T)
121
+ # Normalize
122
+ mean = spec.mean(dim=(2, 3), keepdim=True)
123
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
124
+ spec = (spec - mean) / std
125
+ # Extract center context
126
+ x = spec[:, :, :, 7:22]
127
+
128
+ output = model(x)
129
+ loss = criterion(output, y)
130
+ total_val_loss += loss.item()
131
+
132
+ avg_val_loss = total_val_loss / len(val_loader)
133
+
134
+ # Log epoch metrics
135
+ writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
136
+ writer.add_scalar("val/loss", avg_val_loss, epoch)
137
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
138
+
139
+ # Step the scheduler
140
+ scheduler.step()
141
+
142
+ print(
143
+ f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
144
+ f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
145
+ )
146
+
147
+ # Save best model
148
+ if avg_val_loss < best_val_loss:
149
+ best_val_loss = avg_val_loss
150
+ model.save_pretrained(output_dir)
151
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
152
+
153
+ writer.close()
154
+
155
+ # Save final model
156
+ final_dir = os.path.join(output_dir, "final")
157
+ model.save_pretrained(final_dir)
158
+ print(f"Saved final model to {final_dir}")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument(
164
+ "--target",
165
+ type=str,
166
+ choices=["beats", "downbeats"],
167
+ default=None,
168
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
169
+ )
170
+ parser.add_argument(
171
+ "--output-dir",
172
+ type=str,
173
+ default="outputs/baseline",
174
+ help="Directory to save model and logs",
175
+ )
176
+ args = parser.parse_args()
177
+
178
+ # Determine which targets to train
179
+ targets = [args.target] if args.target else ["beats", "downbeats"]
180
+
181
+ for target in targets:
182
+ output_dir = os.path.join(args.output_dir, target)
183
+ train(target, output_dir)
exp/baseline/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio.transforms as T
4
+ import numpy as np
5
+
6
+
7
+ class MultiViewSpectrogram(nn.Module):
8
+ def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
9
+ super().__init__()
10
+ # Window sizes: 23ms, 46ms, 93ms
11
+ self.win_lengths = [368, 736, 1488]
12
+ self.transforms = nn.ModuleList()
13
+
14
+ for win_len in self.win_lengths:
15
+ n_fft = 2 ** int(np.ceil(np.log2(win_len)))
16
+ mel = T.MelSpectrogram(
17
+ sample_rate=sample_rate,
18
+ n_fft=n_fft,
19
+ win_length=win_len,
20
+ hop_length=hop_length,
21
+ f_min=27.5,
22
+ f_max=16000.0,
23
+ n_mels=n_mels,
24
+ power=1.0,
25
+ center=True,
26
+ )
27
+ self.transforms.append(mel)
28
+
29
+ def forward(self, waveform):
30
+ specs = []
31
+ for transform in self.transforms:
32
+ # Scale magnitudes logarithmically
33
+ s = transform(waveform)
34
+ s = torch.log(s + 1e-9)
35
+ specs.append(s)
36
+ return torch.stack(specs, dim=1)
37
+
38
+
39
+ def extract_context(spec, center_frame, context=7):
40
+ # Context of +/- 70ms (7 frames)
41
+ channels, n_mels, total_time = spec.shape
42
+ start = center_frame - context
43
+ end = center_frame + context + 1
44
+
45
+ pad_left = max(0, -start)
46
+ pad_right = max(0, end - total_time)
47
+
48
+ if pad_left > 0 or pad_right > 0:
49
+ spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
50
+ start += pad_left
51
+ end += pad_left
52
+
53
+ return spec[:, :, start:end]
exp/data/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading and evaluation utilities for beat tracking.
3
+
4
+ Modules:
5
+ - load: Dataset loading and preprocessing
6
+ - eval: Evaluation metrics and utilities
7
+ """
8
+
9
+ from exp.data.eval import (
10
+ DEFAULT_THRESHOLDS_MS,
11
+ evaluate_beats,
12
+ evaluate_track,
13
+ evaluate_all,
14
+ compute_weighted_f1,
15
+ format_results,
16
+ )
17
+
18
+ __all__ = [
19
+ "DEFAULT_THRESHOLDS_MS",
20
+ "evaluate_beats",
21
+ "evaluate_track",
22
+ "evaluate_all",
23
+ "compute_weighted_f1",
24
+ "format_results",
25
+ ]
exp/data/audio.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio synthesis utilities for beat tracking evaluation.
3
+
4
+ This module provides functions to:
5
+ - Generate click sounds for beats and downbeats
6
+ - Mix click tracks with original audio
7
+ - Save audio files with beat annotations
8
+
9
+ Example usage:
10
+ from exp.data.audio import create_click_track, mix_audio, save_audio
11
+
12
+ # Create click track
13
+ clicks = create_click_track(
14
+ beat_times=pred_beats,
15
+ downbeat_times=pred_downbeats,
16
+ duration=30.0,
17
+ sr=16000
18
+ )
19
+
20
+ # Mix with original audio
21
+ mixed = mix_audio(original_audio, clicks, click_volume=0.5)
22
+
23
+ # Save to file
24
+ save_audio(mixed, "output.wav", sr=16000)
25
+ """
26
+
27
+ import numpy as np
28
+ from pathlib import Path
29
+
30
+
31
+ def generate_click(
32
+ frequency: float = 1000.0,
33
+ duration: float = 0.02,
34
+ sr: int = 16000,
35
+ attack: float = 0.002,
36
+ decay: float = 0.018,
37
+ ) -> np.ndarray:
38
+ """
39
+ Generate a single click sound.
40
+
41
+ Args:
42
+ frequency: Frequency of the click tone in Hz
43
+ duration: Duration of the click in seconds
44
+ sr: Sample rate
45
+ attack: Attack time in seconds
46
+ decay: Decay time in seconds
47
+
48
+ Returns:
49
+ Click waveform as numpy array
50
+ """
51
+ t = np.arange(int(duration * sr)) / sr
52
+
53
+ # Generate sine wave
54
+ wave = np.sin(2 * np.pi * frequency * t)
55
+
56
+ # Apply envelope (attack-decay)
57
+ envelope = np.ones_like(t)
58
+ attack_samples = int(attack * sr)
59
+ decay_samples = int(decay * sr)
60
+
61
+ if attack_samples > 0:
62
+ envelope[:attack_samples] = np.linspace(0, 1, attack_samples)
63
+ if decay_samples > 0:
64
+ decay_start = len(t) - decay_samples
65
+ if decay_start > 0:
66
+ envelope[decay_start:] = np.linspace(1, 0, decay_samples)
67
+
68
+ return wave * envelope
69
+
70
+
71
+ def create_click_track(
72
+ beat_times: list[float] | np.ndarray,
73
+ downbeat_times: list[float] | np.ndarray | None = None,
74
+ duration: float | None = None,
75
+ sr: int = 16000,
76
+ beat_freq: float = 1000.0,
77
+ downbeat_freq: float = 1500.0,
78
+ click_duration: float = 0.03,
79
+ ) -> np.ndarray:
80
+ """
81
+ Create a click track from beat and downbeat times.
82
+
83
+ Args:
84
+ beat_times: List of beat times in seconds
85
+ downbeat_times: List of downbeat times in seconds (optional)
86
+ duration: Total duration in seconds (auto-detected if None)
87
+ sr: Sample rate
88
+ beat_freq: Frequency for beat clicks (Hz)
89
+ downbeat_freq: Frequency for downbeat clicks (Hz)
90
+ click_duration: Duration of each click in seconds
91
+
92
+ Returns:
93
+ Click track as numpy array
94
+ """
95
+ beat_times = np.array(beat_times) if len(beat_times) > 0 else np.array([])
96
+ if downbeat_times is not None:
97
+ downbeat_times = (
98
+ np.array(downbeat_times) if len(downbeat_times) > 0 else np.array([])
99
+ )
100
+ else:
101
+ downbeat_times = np.array([])
102
+
103
+ # Determine duration
104
+ if duration is None:
105
+ all_times = np.concatenate([beat_times, downbeat_times])
106
+ if len(all_times) == 0:
107
+ return np.array([])
108
+ duration = float(np.max(all_times)) + 1.0
109
+
110
+ # Create output array
111
+ total_samples = int(duration * sr)
112
+ output = np.zeros(total_samples, dtype=np.float32)
113
+
114
+ # Generate click templates
115
+ beat_click = generate_click(frequency=beat_freq, duration=click_duration, sr=sr)
116
+ downbeat_click = generate_click(
117
+ frequency=downbeat_freq, duration=click_duration, sr=sr
118
+ )
119
+
120
+ # Convert downbeat times to set for fast lookup
121
+ downbeat_set = set(np.round(downbeat_times, 3))
122
+
123
+ # Add beat clicks
124
+ for t in beat_times:
125
+ sample_idx = int(t * sr)
126
+ if sample_idx < 0 or sample_idx >= total_samples:
127
+ continue
128
+
129
+ # Use downbeat click if this is also a downbeat
130
+ is_downbeat = np.round(t, 3) in downbeat_set
131
+ click = downbeat_click if is_downbeat else beat_click
132
+
133
+ # Add click to output
134
+ end_idx = min(sample_idx + len(click), total_samples)
135
+ click_len = end_idx - sample_idx
136
+ output[sample_idx:end_idx] += click[:click_len]
137
+
138
+ # Add downbeat clicks (for downbeats not already in beats)
139
+ beat_set = set(np.round(beat_times, 3))
140
+ for t in downbeat_times:
141
+ if np.round(t, 3) in beat_set:
142
+ continue # Already added as beat
143
+
144
+ sample_idx = int(t * sr)
145
+ if sample_idx < 0 or sample_idx >= total_samples:
146
+ continue
147
+
148
+ end_idx = min(sample_idx + len(downbeat_click), total_samples)
149
+ click_len = end_idx - sample_idx
150
+ output[sample_idx:end_idx] += downbeat_click[:click_len]
151
+
152
+ return output
153
+
154
+
155
+ def mix_audio(
156
+ audio: np.ndarray,
157
+ click_track: np.ndarray,
158
+ click_volume: float = 0.5,
159
+ ) -> np.ndarray:
160
+ """
161
+ Mix original audio with a click track.
162
+
163
+ Args:
164
+ audio: Original audio waveform
165
+ click_track: Click track to overlay
166
+ click_volume: Volume of clicks relative to audio (0.0 to 1.0)
167
+
168
+ Returns:
169
+ Mixed audio
170
+ """
171
+ # Ensure same length
172
+ max_len = max(len(audio), len(click_track))
173
+ audio_padded = np.zeros(max_len, dtype=np.float32)
174
+ click_padded = np.zeros(max_len, dtype=np.float32)
175
+
176
+ audio_padded[: len(audio)] = audio
177
+ click_padded[: len(click_track)] = click_track
178
+
179
+ # Normalize audio
180
+ audio_max = np.abs(audio_padded).max()
181
+ if audio_max > 0:
182
+ audio_padded = audio_padded / audio_max * 0.8
183
+
184
+ # Normalize clicks
185
+ click_max = np.abs(click_padded).max()
186
+ if click_max > 0:
187
+ click_padded = click_padded / click_max * click_volume * 0.8
188
+
189
+ # Mix
190
+ mixed = audio_padded + click_padded
191
+
192
+ # Prevent clipping
193
+ max_val = np.abs(mixed).max()
194
+ if max_val > 1.0:
195
+ mixed = mixed / max_val * 0.95
196
+
197
+ return mixed.astype(np.float32)
198
+
199
+
200
+ def create_comparison_audio(
201
+ audio: np.ndarray,
202
+ pred_beats: list[float],
203
+ pred_downbeats: list[float],
204
+ gt_beats: list[float],
205
+ gt_downbeats: list[float],
206
+ sr: int = 16000,
207
+ click_volume: float = 0.5,
208
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
209
+ """
210
+ Create audio files for comparison: prediction clicks, ground truth clicks, and combined.
211
+
212
+ Args:
213
+ audio: Original audio waveform
214
+ pred_beats: Predicted beat times
215
+ pred_downbeats: Predicted downbeat times
216
+ gt_beats: Ground truth beat times
217
+ gt_downbeats: Ground truth downbeat times
218
+ sr: Sample rate
219
+ click_volume: Volume of clicks
220
+
221
+ Returns:
222
+ Tuple of (audio_with_pred_clicks, audio_with_gt_clicks, audio_with_both)
223
+ """
224
+ duration = len(audio) / sr
225
+
226
+ # Create click tracks
227
+ pred_clicks = create_click_track(
228
+ pred_beats,
229
+ pred_downbeats,
230
+ duration=duration,
231
+ sr=sr,
232
+ beat_freq=1000.0,
233
+ downbeat_freq=1500.0,
234
+ )
235
+
236
+ gt_clicks = create_click_track(
237
+ gt_beats,
238
+ gt_downbeats,
239
+ duration=duration,
240
+ sr=sr,
241
+ beat_freq=800.0, # Different frequency for GT
242
+ downbeat_freq=1200.0,
243
+ )
244
+
245
+ # Mix
246
+ audio_pred = mix_audio(audio, pred_clicks, click_volume)
247
+ audio_gt = mix_audio(audio, gt_clicks, click_volume)
248
+ audio_both = mix_audio(audio, pred_clicks + gt_clicks, click_volume)
249
+
250
+ return audio_pred, audio_gt, audio_both
251
+
252
+
253
+ def save_audio(
254
+ audio: np.ndarray,
255
+ path: str | Path,
256
+ sr: int = 16000,
257
+ ) -> None:
258
+ """
259
+ Save audio to a WAV file.
260
+
261
+ Args:
262
+ audio: Audio waveform
263
+ path: Output file path
264
+ sr: Sample rate
265
+ """
266
+ import scipy.io.wavfile as wavfile
267
+
268
+ path = Path(path)
269
+ path.parent.mkdir(parents=True, exist_ok=True)
270
+
271
+ # Convert to int16
272
+ audio_int16 = (audio * 32767).astype(np.int16)
273
+ wavfile.write(str(path), sr, audio_int16)
274
+
275
+
276
+ if __name__ == "__main__":
277
+ # Demo
278
+ print("Audio synthesis demo...")
279
+
280
+ # Create a simple sine wave as "music"
281
+ sr = 16000
282
+ duration = 10.0
283
+ t = np.arange(int(duration * sr)) / sr
284
+ music = np.sin(2 * np.pi * 220 * t) * 0.3 # 220 Hz tone
285
+
286
+ # Beats every 0.5s, downbeats every 2s
287
+ beats = np.arange(0, duration, 0.5).tolist()
288
+ downbeats = np.arange(0, duration, 2.0).tolist()
289
+
290
+ # Create click track
291
+ clicks = create_click_track(beats, downbeats, duration=duration, sr=sr)
292
+
293
+ # Mix
294
+ mixed = mix_audio(music, clicks, click_volume=0.6)
295
+
296
+ print(f"Created mixed audio: {len(mixed)} samples ({len(mixed) / sr:.2f}s)")
297
+ print(f"Beats: {len(beats)}, Downbeats: {len(downbeats)}")
298
+
299
+ # Save demo
300
+ save_audio(mixed, "/tmp/beat_click_demo.wav", sr=sr)
301
+ print("Saved demo to /tmp/beat_click_demo.wav")
exp/data/eval.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation utilities for beat and downbeat detection.
3
+
4
+ This module provides functions to evaluate beat/downbeat predictions against
5
+ ground truth annotations using F1-scores at various timing thresholds and
6
+ continuity-based metrics (CMLt, AMLt).
7
+
8
+ The evaluation metrics include:
9
+ - **F1-scores**: Calculated for timing thresholds from 3ms to 30ms
10
+ - **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2)
11
+ - **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level
12
+ - **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations
13
+ (double/half tempo, off-beat, etc.)
14
+ - **CMLc/AMLc**: Continuous versions (longest correct segment)
15
+
16
+ Example usage:
17
+ from ..data.eval import (
18
+ evaluate_beats, evaluate_all, compute_weighted_f1,
19
+ compute_continuity_metrics, format_results
20
+ )
21
+
22
+ # Evaluate single track
23
+ results = evaluate_beats(pred_beats, gt_beats)
24
+ print(f"Weighted F1: {results['weighted_f1']:.4f}")
25
+ print(f"CMLt: {results['continuity']['CMLt']:.4f}")
26
+ print(f"AMLt: {results['continuity']['AMLt']:.4f}")
27
+
28
+ # Evaluate with custom thresholds
29
+ results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20])
30
+
31
+ # Evaluate all tracks in dataset
32
+ summary = evaluate_all(predictions, ground_truths)
33
+ print(format_results(summary))
34
+ """
35
+
36
+ from typing import Sequence
37
+ import numpy as np
38
+ import mir_eval
39
+
40
+
41
+ # Default timing thresholds in milliseconds (3ms to 30ms, step 3ms)
42
+ DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30]
43
+
44
+ # Default minimum beat time for mir_eval metrics (can be set to 0 to use all beats)
45
+ DEFAULT_MIN_BEAT_TIME = 5.0
46
+
47
+
48
+ def match_events(
49
+ pred: np.ndarray,
50
+ gt: np.ndarray,
51
+ tolerance_sec: float,
52
+ ) -> tuple[int, int, int]:
53
+ """
54
+ Match predicted events to ground truth events within a tolerance.
55
+
56
+ Uses greedy matching: each ground truth event is matched to the closest
57
+ unmatched prediction within the tolerance window.
58
+
59
+ Args:
60
+ pred: Predicted event times in seconds, shape (N,)
61
+ gt: Ground truth event times in seconds, shape (M,)
62
+ tolerance_sec: Maximum time difference for a match (in seconds)
63
+
64
+ Returns:
65
+ Tuple of (true_positives, false_positives, false_negatives)
66
+ """
67
+ if len(gt) == 0:
68
+ return 0, len(pred), 0
69
+ if len(pred) == 0:
70
+ return 0, 0, len(gt)
71
+
72
+ pred = np.sort(pred)
73
+ gt = np.sort(gt)
74
+
75
+ matched_pred = np.zeros(len(pred), dtype=bool)
76
+ matched_gt = np.zeros(len(gt), dtype=bool)
77
+
78
+ # For each ground truth, find the closest unmatched prediction
79
+ for i, gt_time in enumerate(gt):
80
+ # Find predictions within tolerance
81
+ diffs = np.abs(pred - gt_time)
82
+ candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0]
83
+
84
+ if len(candidates) > 0:
85
+ # Match to closest candidate
86
+ best_idx = candidates[np.argmin(diffs[candidates])]
87
+ matched_pred[best_idx] = True
88
+ matched_gt[i] = True
89
+
90
+ tp = int(matched_gt.sum())
91
+ fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp)
92
+ fn = int(len(gt) - tp)
93
+
94
+ # Recalculate fp correctly
95
+ fp = len(pred) - tp
96
+
97
+ return tp, fp, fn
98
+
99
+
100
+ def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
101
+ """
102
+ Compute precision, recall, and F1-score from TP, FP, FN counts.
103
+
104
+ Args:
105
+ tp: True positives
106
+ fp: False positives
107
+ fn: False negatives
108
+
109
+ Returns:
110
+ Tuple of (precision, recall, f1_score)
111
+ """
112
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
113
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
114
+ f1 = (
115
+ 2 * precision * recall / (precision + recall)
116
+ if (precision + recall) > 0
117
+ else 0.0
118
+ )
119
+ return precision, recall, f1
120
+
121
+
122
+ def compute_weighted_f1(
123
+ f1_scores: dict[int, float],
124
+ thresholds_ms: Sequence[int] | None = None,
125
+ ) -> float:
126
+ """
127
+ Compute weighted F1-score where weights are inversely proportional to threshold.
128
+
129
+ The weight for threshold T ms is 1 / (T / min_threshold).
130
+ For example, with thresholds [3, 6, 9, ...]:
131
+ - 3ms: weight = 1
132
+ - 6ms: weight = 0.5
133
+ - 9ms: weight = 0.333...
134
+
135
+ Args:
136
+ f1_scores: Dict mapping threshold (ms) to F1-score
137
+ thresholds_ms: List of thresholds used (for weight calculation)
138
+
139
+ Returns:
140
+ Weighted F1-score
141
+ """
142
+ if not f1_scores:
143
+ return 0.0
144
+
145
+ if thresholds_ms is None:
146
+ thresholds_ms = sorted(f1_scores.keys())
147
+
148
+ min_threshold = min(thresholds_ms)
149
+ total_weight = 0.0
150
+ weighted_sum = 0.0
151
+
152
+ for t in thresholds_ms:
153
+ if t in f1_scores:
154
+ weight = min_threshold / t # 3ms -> 1, 6ms -> 0.5, etc.
155
+ weighted_sum += weight * f1_scores[t]
156
+ total_weight += weight
157
+
158
+ return weighted_sum / total_weight if total_weight > 0 else 0.0
159
+
160
+
161
+ def compute_continuity_metrics(
162
+ pred_times: Sequence[float],
163
+ gt_times: Sequence[float],
164
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
165
+ phase_threshold: float = 0.175,
166
+ period_threshold: float = 0.175,
167
+ ) -> dict:
168
+ """
169
+ Compute continuity-based beat tracking metrics using mir_eval.
170
+
171
+ These metrics evaluate beat tracking accuracy accounting for metrical level:
172
+ - CMLt (Correct Metric Level Total): Accuracy at the correct metrical level
173
+ - AMLt (Any Metric Level Total): Accuracy allowing for metrical variations
174
+ (double/half tempo, off-beat, etc.)
175
+ - CMLc/AMLc: Continuous versions (longest correct segment)
176
+
177
+ Args:
178
+ pred_times: Predicted beat times in seconds
179
+ gt_times: Ground truth beat times in seconds
180
+ min_beat_time: Minimum time to start evaluation (default: 5.0s)
181
+ Set to 0.0 to use all beats, but note that early beats
182
+ may not have stable inter-beat intervals.
183
+ phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175)
184
+ period_threshold: Maximum period error as ratio of beat interval (default: 0.175)
185
+
186
+ Returns:
187
+ Dict containing:
188
+ - 'CMLc': Correct Metric Level Continuous
189
+ - 'CMLt': Correct Metric Level Total
190
+ - 'AMLc': Any Metric Level Continuous
191
+ - 'AMLt': Any Metric Level Total
192
+ """
193
+ pred_arr = np.sort(np.array(pred_times, dtype=np.float64))
194
+ gt_arr = np.sort(np.array(gt_times, dtype=np.float64))
195
+
196
+ # Trim beats before min_beat_time (standard preprocessing)
197
+ pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time)
198
+ gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time)
199
+
200
+ # Handle edge cases where trimming results in too few beats
201
+ if len(gt_trimmed) < 2 or len(pred_trimmed) < 2:
202
+ return {
203
+ "CMLc": 0.0,
204
+ "CMLt": 0.0,
205
+ "AMLc": 0.0,
206
+ "AMLt": 0.0,
207
+ }
208
+
209
+ # Compute continuity metrics
210
+ CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(
211
+ gt_trimmed,
212
+ pred_trimmed,
213
+ continuity_phase_threshold=phase_threshold,
214
+ continuity_period_threshold=period_threshold,
215
+ )
216
+
217
+ return {
218
+ "CMLc": float(CMLc),
219
+ "CMLt": float(CMLt),
220
+ "AMLc": float(AMLc),
221
+ "AMLt": float(AMLt),
222
+ }
223
+
224
+
225
+ def evaluate_beats(
226
+ pred_times: Sequence[float],
227
+ gt_times: Sequence[float],
228
+ thresholds_ms: Sequence[int] | None = None,
229
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
230
+ ) -> dict:
231
+ """
232
+ Evaluate beat predictions against ground truth at multiple thresholds.
233
+
234
+ Args:
235
+ pred_times: Predicted beat times in seconds
236
+ gt_times: Ground truth beat times in seconds
237
+ thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms)
238
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
239
+
240
+ Returns:
241
+ Dict containing:
242
+ - 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}]
243
+ - 'f1_scores': Dict[threshold_ms, f1_score] (convenience access)
244
+ - 'weighted_f1': Weighted F1-score across all thresholds
245
+ - 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics
246
+ - 'num_predictions': Number of predictions
247
+ - 'num_ground_truth': Number of ground truth events
248
+ """
249
+ if thresholds_ms is None:
250
+ thresholds_ms = DEFAULT_THRESHOLDS_MS
251
+
252
+ pred_arr = np.array(pred_times, dtype=np.float64)
253
+ gt_arr = np.array(gt_times, dtype=np.float64)
254
+
255
+ per_threshold = {}
256
+ f1_scores = {}
257
+
258
+ for threshold_ms in thresholds_ms:
259
+ tolerance_sec = threshold_ms / 1000.0
260
+ tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec)
261
+ precision, recall, f1 = compute_f1(tp, fp, fn)
262
+
263
+ per_threshold[threshold_ms] = {
264
+ "precision": precision,
265
+ "recall": recall,
266
+ "f1": f1,
267
+ "tp": tp,
268
+ "fp": fp,
269
+ "fn": fn,
270
+ }
271
+ f1_scores[threshold_ms] = f1
272
+
273
+ weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms)
274
+ continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time)
275
+
276
+ return {
277
+ "per_threshold": per_threshold,
278
+ "f1_scores": f1_scores,
279
+ "weighted_f1": weighted_f1,
280
+ "continuity": continuity,
281
+ "num_predictions": len(pred_arr),
282
+ "num_ground_truth": len(gt_arr),
283
+ }
284
+
285
+
286
+ def evaluate_track(
287
+ pred_beats: Sequence[float],
288
+ pred_downbeats: Sequence[float],
289
+ gt_beats: Sequence[float],
290
+ gt_downbeats: Sequence[float],
291
+ thresholds_ms: Sequence[int] | None = None,
292
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
293
+ ) -> dict:
294
+ """
295
+ Evaluate both beat and downbeat predictions for a single track.
296
+
297
+ Args:
298
+ pred_beats: Predicted beat times in seconds
299
+ pred_downbeats: Predicted downbeat times in seconds
300
+ gt_beats: Ground truth beat times in seconds
301
+ gt_downbeats: Ground truth downbeat times in seconds
302
+ thresholds_ms: Timing thresholds in milliseconds
303
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
304
+
305
+ Returns:
306
+ Dict containing:
307
+ - 'beats': Results from evaluate_beats for beats
308
+ - 'downbeats': Results from evaluate_beats for downbeats
309
+ - 'combined_weighted_f1': Average of beat and downbeat weighted F1
310
+ """
311
+ beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time)
312
+ downbeat_results = evaluate_beats(
313
+ pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time
314
+ )
315
+
316
+ combined_weighted_f1 = (
317
+ beat_results["weighted_f1"] + downbeat_results["weighted_f1"]
318
+ ) / 2
319
+
320
+ return {
321
+ "beats": beat_results,
322
+ "downbeats": downbeat_results,
323
+ "combined_weighted_f1": combined_weighted_f1,
324
+ }
325
+
326
+
327
+ def evaluate_all(
328
+ predictions: Sequence[dict],
329
+ ground_truths: Sequence[dict],
330
+ thresholds_ms: Sequence[int] | None = None,
331
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
332
+ verbose: bool = False,
333
+ ) -> dict:
334
+ """
335
+ Evaluate predictions for multiple tracks.
336
+
337
+ Args:
338
+ predictions: List of dicts with 'beats' and 'downbeats' keys
339
+ ground_truths: List of dicts with 'beats' and 'downbeats' keys
340
+ thresholds_ms: Timing thresholds in milliseconds
341
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
342
+ verbose: If True, print per-track results
343
+
344
+ Returns:
345
+ Dict containing:
346
+ - 'per_track': List of per-track results
347
+ - 'mean_beat_weighted_f1': Mean weighted F1 for beats
348
+ - 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats
349
+ - 'mean_combined_weighted_f1': Mean combined weighted F1
350
+ - 'beat_f1_by_threshold': Mean F1 per threshold for beats
351
+ - 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats
352
+ - 'beat_continuity': Mean continuity metrics for beats
353
+ - 'downbeat_continuity': Mean continuity metrics for downbeats
354
+ """
355
+ if len(predictions) != len(ground_truths):
356
+ raise ValueError(
357
+ f"Number of predictions ({len(predictions)}) must match "
358
+ f"number of ground truths ({len(ground_truths)})"
359
+ )
360
+
361
+ if thresholds_ms is None:
362
+ thresholds_ms = DEFAULT_THRESHOLDS_MS
363
+
364
+ per_track = []
365
+ beat_weighted_f1s = []
366
+ downbeat_weighted_f1s = []
367
+ combined_weighted_f1s = []
368
+
369
+ beat_f1_by_threshold = {t: [] for t in thresholds_ms}
370
+ downbeat_f1_by_threshold = {t: [] for t in thresholds_ms}
371
+
372
+ # Continuity metrics tracking
373
+ beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
374
+ downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
375
+
376
+ for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
377
+ result = evaluate_track(
378
+ pred_beats=pred["beats"],
379
+ pred_downbeats=pred["downbeats"],
380
+ gt_beats=gt["beats"],
381
+ gt_downbeats=gt["downbeats"],
382
+ thresholds_ms=thresholds_ms,
383
+ min_beat_time=min_beat_time,
384
+ )
385
+
386
+ per_track.append(result)
387
+ beat_weighted_f1s.append(result["beats"]["weighted_f1"])
388
+ downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"])
389
+ combined_weighted_f1s.append(result["combined_weighted_f1"])
390
+
391
+ for t in thresholds_ms:
392
+ beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t])
393
+ downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t])
394
+
395
+ # Track continuity metrics
396
+ for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]:
397
+ beat_continuity[metric].append(result["beats"]["continuity"][metric])
398
+ downbeat_continuity[metric].append(
399
+ result["downbeats"]["continuity"][metric]
400
+ )
401
+
402
+ if verbose:
403
+ beat_cont = result["beats"]["continuity"]
404
+ print(
405
+ f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, "
406
+ f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, "
407
+ f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, "
408
+ f"Combined={result['combined_weighted_f1']:.4f}"
409
+ )
410
+
411
+ return {
412
+ "per_track": per_track,
413
+ "mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)),
414
+ "mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)),
415
+ "mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)),
416
+ "beat_f1_by_threshold": {
417
+ t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items()
418
+ },
419
+ "downbeat_f1_by_threshold": {
420
+ t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items()
421
+ },
422
+ "beat_continuity": {
423
+ metric: float(np.mean(values)) for metric, values in beat_continuity.items()
424
+ },
425
+ "downbeat_continuity": {
426
+ metric: float(np.mean(values))
427
+ for metric, values in downbeat_continuity.items()
428
+ },
429
+ "num_tracks": len(predictions),
430
+ }
431
+
432
+
433
+ def format_results(results: dict, title: str = "Evaluation Results") -> str:
434
+ """
435
+ Format evaluation results as a human-readable string.
436
+
437
+ Args:
438
+ results: Results dict from evaluate_all or evaluate_track
439
+ title: Title for the report
440
+
441
+ Returns:
442
+ Formatted string report
443
+ """
444
+ lines = [title, "=" * len(title), ""]
445
+
446
+ # Check if this is aggregate results (from evaluate_all)
447
+ if "num_tracks" in results:
448
+ lines.append(f"Number of tracks: {results['num_tracks']}")
449
+ lines.append("")
450
+ lines.append("Overall Metrics:")
451
+ lines.append(
452
+ f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}"
453
+ )
454
+ lines.append(
455
+ f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}"
456
+ )
457
+ lines.append(
458
+ f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}"
459
+ )
460
+ lines.append("")
461
+
462
+ lines.append("Beat F1 by Threshold:")
463
+ for t, f1 in sorted(results["beat_f1_by_threshold"].items()):
464
+ lines.append(f" {t:2d}ms: {f1:.4f}")
465
+ lines.append("")
466
+
467
+ lines.append("Downbeat F1 by Threshold:")
468
+ for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()):
469
+ lines.append(f" {t:2d}ms: {f1:.4f}")
470
+ lines.append("")
471
+
472
+ # Continuity metrics
473
+ if "beat_continuity" in results:
474
+ lines.append("Beat Continuity Metrics:")
475
+ bc = results["beat_continuity"]
476
+ lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)")
477
+ lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)")
478
+ lines.append(
479
+ f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)"
480
+ )
481
+ lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)")
482
+ lines.append("")
483
+
484
+ if "downbeat_continuity" in results:
485
+ lines.append("Downbeat Continuity Metrics:")
486
+ dc = results["downbeat_continuity"]
487
+ lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)")
488
+ lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)")
489
+ lines.append(
490
+ f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)"
491
+ )
492
+ lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)")
493
+
494
+ # Single track results (from evaluate_track)
495
+ elif "beats" in results and "downbeats" in results:
496
+ lines.append("Beat Detection:")
497
+ lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}")
498
+ lines.append(f" Predictions: {results['beats']['num_predictions']}")
499
+ lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}")
500
+
501
+ # Beat continuity metrics
502
+ if "continuity" in results["beats"]:
503
+ bc = results["beats"]["continuity"]
504
+ lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}")
505
+ lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}")
506
+ lines.append("")
507
+
508
+ lines.append("Downbeat Detection:")
509
+ lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}")
510
+ lines.append(f" Predictions: {results['downbeats']['num_predictions']}")
511
+ lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}")
512
+
513
+ # Downbeat continuity metrics
514
+ if "continuity" in results["downbeats"]:
515
+ dc = results["downbeats"]["continuity"]
516
+ lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}")
517
+ lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}")
518
+ lines.append("")
519
+
520
+ lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}")
521
+
522
+ return "\n".join(lines)
523
+
524
+
525
+ if __name__ == "__main__":
526
+ # Demo with synthetic data
527
+ print("Running evaluation demo...\n")
528
+
529
+ # Simulate ground truth beats at regular intervals (30s to have beats after 5s)
530
+ gt_beats = np.arange(0, 30, 0.5).tolist() # Beat every 0.5s for 30s
531
+ gt_downbeats = np.arange(0, 30, 2.0).tolist() # Downbeat every 2s
532
+
533
+ # Simulate predictions with some noise and missed/extra detections
534
+ np.random.seed(42)
535
+ pred_beats = (
536
+ np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats))
537
+ ).tolist()
538
+ pred_beats = pred_beats[:-2] # Miss last 2 beats
539
+ pred_beats.append(15.25) # Add false positive
540
+
541
+ pred_downbeats = (
542
+ np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats))
543
+ ).tolist()
544
+
545
+ # Evaluate single track
546
+ results = evaluate_track(
547
+ pred_beats=pred_beats,
548
+ pred_downbeats=pred_downbeats,
549
+ gt_beats=gt_beats,
550
+ gt_downbeats=gt_downbeats,
551
+ )
552
+
553
+ print(format_results(results, "Single Track Demo"))
554
+ print("\n" + "=" * 50 + "\n")
555
+
556
+ # Multi-track demo
557
+ predictions = [
558
+ {"beats": pred_beats, "downbeats": pred_downbeats},
559
+ {"beats": pred_beats, "downbeats": pred_downbeats},
560
+ ]
561
+ ground_truths = [
562
+ {"beats": gt_beats, "downbeats": gt_downbeats},
563
+ {"beats": gt_beats, "downbeats": gt_downbeats},
564
+ ]
565
+
566
+ all_results = evaluate_all(predictions, ground_truths, verbose=True)
567
+ print()
568
+ print(format_results(all_results, "Multi-Track Demo"))
exp/data/load.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, Audio
2
+
3
+ N_PROC = None
4
+
5
+ ds = load_dataset("JacobLinCool/taiko-1000-parsed")
6
+ ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"])
7
+
8
+
9
+ def filter_out_broken(example):
10
+ try:
11
+ example["audio"]["array"]
12
+ return True
13
+ except:
14
+ return False
15
+
16
+
17
+ ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32)
18
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
19
+
20
+
21
+ def build_beat_and_downbeat_labels(example):
22
+ """
23
+ Extract beat and downbeat times from the chart segments.
24
+
25
+ - Downbeats: First beat of each measure (segment timestamp)
26
+ - Beats: All beats within each measure based on time signature
27
+
28
+ Returns lists of times in seconds.
29
+ """
30
+ title = example["metadata"]["TITLE"]
31
+ segments = example["oni"]["segments"]
32
+
33
+ beats = []
34
+ downbeats = []
35
+
36
+ for i, segment in enumerate(segments):
37
+ seg_timestamp = segment["timestamp"]
38
+ measure_num = segment["measure_num"] # numerator (e.g., 4 in 4/4)
39
+ measure_den = segment["measure_den"] # denominator (e.g., 4 in 4/4)
40
+ notes = segment["notes"]
41
+
42
+ # Downbeat is the start of each measure
43
+ downbeats.append(seg_timestamp)
44
+
45
+ # Get BPM from the first note in segment, or fallback to next segment's first note
46
+ bpm = None
47
+ if notes:
48
+ bpm = notes[0]["bpm"]
49
+ else:
50
+ # Look ahead for BPM if current segment has no notes
51
+ for j in range(i + 1, len(segments)):
52
+ if segments[j]["notes"]:
53
+ bpm = segments[j]["notes"][0]["bpm"]
54
+ break
55
+
56
+ if bpm is None or bpm <= 0:
57
+ bpm = 120.0 # fallback default BPM
58
+
59
+ # Calculate beat duration: one beat = 60/BPM seconds (for quarter note)
60
+ # Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.)
61
+ beat_duration = (60.0 / bpm) * (4.0 / measure_den)
62
+
63
+ # Calculate beat positions within this measure
64
+ for beat_idx in range(measure_num):
65
+ beat_time = seg_timestamp + beat_idx * beat_duration
66
+ beats.append(beat_time)
67
+
68
+ # Sort and deduplicate (in case of overlapping segments)
69
+ beats = sorted(set(beats))
70
+ downbeats = sorted(set(downbeats))
71
+
72
+ return {
73
+ "title": title,
74
+ "beats": beats,
75
+ "downbeats": downbeats,
76
+ }
77
+
78
+
79
+ ds = ds.map(
80
+ build_beat_and_downbeat_labels,
81
+ num_proc=N_PROC,
82
+ batch_size=32,
83
+ writer_batch_size=32,
84
+ remove_columns=["oni", "metadata"],
85
+ )
86
+
87
+ ds = ds.with_format("torch")
88
+
89
+ if __name__ == "__main__":
90
+ print(ds)
91
+ print(ds["train"].features)
exp/data/viz.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for beat tracking evaluation.
3
+
4
+ This module provides functions to:
5
+ - Plot beat and downbeat predictions vs ground truth
6
+ - Create waveform visualizations with beat annotations
7
+ - Generate comparison plots for evaluation
8
+
9
+ Example usage:
10
+ from exp.data.viz import plot_beats, plot_waveform_with_beats, save_figure
11
+
12
+ # Plot beat comparison
13
+ fig = plot_beats(pred_beats, gt_beats, pred_downbeats, gt_downbeats)
14
+ save_figure(fig, "beat_comparison.png")
15
+
16
+ # Plot waveform with beats
17
+ fig = plot_waveform_with_beats(audio, sr, pred_beats, gt_beats)
18
+ save_figure(fig, "waveform.png")
19
+ """
20
+
21
+ import numpy as np
22
+ from pathlib import Path
23
+
24
+ # Try to import matplotlib, but make it optional
25
+ try:
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib.patches as mpatches
28
+
29
+ HAS_MATPLOTLIB = True
30
+ except ImportError:
31
+ HAS_MATPLOTLIB = False
32
+
33
+
34
+ def _check_matplotlib():
35
+ if not HAS_MATPLOTLIB:
36
+ raise ImportError(
37
+ "matplotlib is required for visualization. "
38
+ "Install with: pip install matplotlib"
39
+ )
40
+
41
+
42
+ def plot_beats(
43
+ pred_beats: list[float] | np.ndarray,
44
+ gt_beats: list[float] | np.ndarray,
45
+ pred_downbeats: list[float] | np.ndarray | None = None,
46
+ gt_downbeats: list[float] | np.ndarray | None = None,
47
+ title: str = "Beat Tracking Comparison",
48
+ figsize: tuple[int, int] = (14, 4),
49
+ time_range: tuple[float, float] | None = None,
50
+ ) -> "plt.Figure":
51
+ """
52
+ Create a visualization comparing predicted and ground truth beats.
53
+
54
+ Args:
55
+ pred_beats: Predicted beat times in seconds
56
+ gt_beats: Ground truth beat times in seconds
57
+ pred_downbeats: Predicted downbeat times (optional)
58
+ gt_downbeats: Ground truth downbeat times (optional)
59
+ title: Plot title
60
+ figsize: Figure size (width, height)
61
+ time_range: Optional tuple (start, end) to limit time range
62
+
63
+ Returns:
64
+ matplotlib Figure object
65
+ """
66
+ _check_matplotlib()
67
+
68
+ fig, ax = plt.subplots(figsize=figsize)
69
+
70
+ pred_beats = np.array(pred_beats)
71
+ gt_beats = np.array(gt_beats)
72
+
73
+ # Apply time range filter
74
+ if time_range is not None:
75
+ start, end = time_range
76
+ pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
77
+ gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
78
+
79
+ if pred_downbeats is not None:
80
+ pred_downbeats = np.array(pred_downbeats)
81
+ pred_downbeats = pred_downbeats[
82
+ (pred_downbeats >= start) & (pred_downbeats <= end)
83
+ ]
84
+ if gt_downbeats is not None:
85
+ gt_downbeats = np.array(gt_downbeats)
86
+ gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
87
+
88
+ # Plot ground truth beats
89
+ ax.vlines(
90
+ gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5, label="GT Beats"
91
+ )
92
+
93
+ # Plot predicted beats
94
+ ax.vlines(
95
+ pred_beats,
96
+ 0.6,
97
+ 1.0,
98
+ colors="blue",
99
+ alpha=0.7,
100
+ linewidth=1.5,
101
+ label="Pred Beats",
102
+ )
103
+
104
+ # Plot downbeats if provided
105
+ if gt_downbeats is not None and len(gt_downbeats) > 0:
106
+ gt_downbeats = np.array(gt_downbeats)
107
+ ax.vlines(
108
+ gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3, label="GT Downbeats"
109
+ )
110
+
111
+ if pred_downbeats is not None and len(pred_downbeats) > 0:
112
+ pred_downbeats = np.array(pred_downbeats)
113
+ ax.vlines(
114
+ pred_downbeats,
115
+ 0.6,
116
+ 1.0,
117
+ colors="darkblue",
118
+ linewidth=3,
119
+ label="Pred Downbeats",
120
+ )
121
+
122
+ # Styling
123
+ ax.set_ylim(-0.1, 1.1)
124
+ ax.set_yticks([0.2, 0.8])
125
+ ax.set_yticklabels(["Ground Truth", "Prediction"])
126
+ ax.set_xlabel("Time (seconds)")
127
+ ax.set_title(title)
128
+ ax.legend(loc="upper right", ncol=4)
129
+ ax.grid(True, alpha=0.3)
130
+
131
+ # Set x-axis range
132
+ if time_range is not None:
133
+ ax.set_xlim(time_range)
134
+ else:
135
+ all_times = np.concatenate([pred_beats, gt_beats])
136
+ if len(all_times) > 0:
137
+ ax.set_xlim(0, np.max(all_times) + 0.5)
138
+
139
+ plt.tight_layout()
140
+ return fig
141
+
142
+
143
+ def plot_waveform_with_beats(
144
+ audio: np.ndarray,
145
+ sr: int,
146
+ pred_beats: list[float] | np.ndarray,
147
+ gt_beats: list[float] | np.ndarray,
148
+ pred_downbeats: list[float] | np.ndarray | None = None,
149
+ gt_downbeats: list[float] | np.ndarray | None = None,
150
+ title: str = "Waveform with Beat Annotations",
151
+ figsize: tuple[int, int] = (14, 6),
152
+ time_range: tuple[float, float] | None = None,
153
+ ) -> "plt.Figure":
154
+ """
155
+ Create a waveform visualization with beat annotations.
156
+
157
+ Args:
158
+ audio: Audio waveform
159
+ sr: Sample rate
160
+ pred_beats: Predicted beat times
161
+ gt_beats: Ground truth beat times
162
+ pred_downbeats: Predicted downbeat times (optional)
163
+ gt_downbeats: Ground truth downbeat times (optional)
164
+ title: Plot title
165
+ figsize: Figure size
166
+ time_range: Optional tuple (start, end) to limit time range
167
+
168
+ Returns:
169
+ matplotlib Figure object
170
+ """
171
+ _check_matplotlib()
172
+
173
+ fig, (ax1, ax2) = plt.subplots(
174
+ 2, 1, figsize=figsize, sharex=True, height_ratios=[3, 1]
175
+ )
176
+
177
+ # Time axis
178
+ duration = len(audio) / sr
179
+ t = np.linspace(0, duration, len(audio))
180
+
181
+ # Apply time range
182
+ if time_range is not None:
183
+ start, end = time_range
184
+ start_idx = int(start * sr)
185
+ end_idx = int(end * sr)
186
+ t = t[start_idx:end_idx]
187
+ audio_plot = audio[start_idx:end_idx]
188
+ else:
189
+ audio_plot = audio
190
+ start, end = 0, duration
191
+
192
+ # Plot waveform
193
+ ax1.plot(t, audio_plot, color="gray", alpha=0.7, linewidth=0.5)
194
+ ax1.set_ylabel("Amplitude")
195
+ ax1.set_title(title)
196
+
197
+ # Filter beats to time range
198
+ pred_beats = np.array(pred_beats)
199
+ gt_beats = np.array(gt_beats)
200
+ pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
201
+ gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
202
+
203
+ # Plot beat markers on waveform
204
+ audio_max = np.abs(audio_plot).max() if len(audio_plot) > 0 else 1.0
205
+
206
+ for beat in gt_beats:
207
+ ax1.axvline(beat, color="green", alpha=0.5, linewidth=1)
208
+ for beat in pred_beats:
209
+ ax1.axvline(beat, color="blue", alpha=0.3, linewidth=1, linestyle="--")
210
+
211
+ # Add downbeat markers (thicker lines)
212
+ if gt_downbeats is not None:
213
+ gt_downbeats = np.array(gt_downbeats)
214
+ gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
215
+ for db in gt_downbeats:
216
+ ax1.axvline(db, color="darkgreen", alpha=0.8, linewidth=2)
217
+
218
+ if pred_downbeats is not None:
219
+ pred_downbeats = np.array(pred_downbeats)
220
+ pred_downbeats = pred_downbeats[
221
+ (pred_downbeats >= start) & (pred_downbeats <= end)
222
+ ]
223
+ for db in pred_downbeats:
224
+ ax1.axvline(db, color="darkblue", alpha=0.5, linewidth=2, linestyle="--")
225
+
226
+ ax1.set_ylim(-audio_max * 1.1, audio_max * 1.1)
227
+
228
+ # Beat comparison subplot
229
+ ax2.vlines(gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5)
230
+ ax2.vlines(pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5)
231
+
232
+ if gt_downbeats is not None and len(gt_downbeats) > 0:
233
+ ax2.vlines(gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3)
234
+ if pred_downbeats is not None and len(pred_downbeats) > 0:
235
+ ax2.vlines(pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3)
236
+
237
+ ax2.set_ylim(-0.1, 1.1)
238
+ ax2.set_yticks([0.2, 0.8])
239
+ ax2.set_yticklabels(["GT", "Pred"])
240
+ ax2.set_xlabel("Time (seconds)")
241
+
242
+ # Legend
243
+ legend_elements = [
244
+ mpatches.Patch(color="green", alpha=0.7, label="GT Beats"),
245
+ mpatches.Patch(color="blue", alpha=0.7, label="Pred Beats"),
246
+ mpatches.Patch(color="darkgreen", label="GT Downbeats"),
247
+ mpatches.Patch(color="darkblue", label="Pred Downbeats"),
248
+ ]
249
+ ax1.legend(handles=legend_elements, loc="upper right", ncol=4)
250
+
251
+ ax1.grid(True, alpha=0.3)
252
+ ax2.grid(True, alpha=0.3)
253
+
254
+ plt.tight_layout()
255
+ return fig
256
+
257
+
258
+ def plot_evaluation_summary(
259
+ results: dict,
260
+ title: str = "Evaluation Summary",
261
+ figsize: tuple[int, int] = (12, 8),
262
+ ) -> "plt.Figure":
263
+ """
264
+ Create a summary visualization of evaluation results.
265
+
266
+ Args:
267
+ results: Results dict from evaluate_all
268
+ title: Plot title
269
+ figsize: Figure size
270
+
271
+ Returns:
272
+ matplotlib Figure object
273
+ """
274
+ _check_matplotlib()
275
+
276
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
277
+
278
+ # F1 by threshold for beats
279
+ ax1 = axes[0, 0]
280
+ if "beat_f1_by_threshold" in results:
281
+ thresholds = sorted(results["beat_f1_by_threshold"].keys())
282
+ f1_scores = [results["beat_f1_by_threshold"][t] for t in thresholds]
283
+ ax1.bar(range(len(thresholds)), f1_scores, color="steelblue", alpha=0.8)
284
+ ax1.set_xticks(range(len(thresholds)))
285
+ ax1.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
286
+ ax1.set_ylabel("F1 Score")
287
+ ax1.set_title("Beat F1 by Threshold")
288
+ ax1.set_ylim(0, 1)
289
+ ax1.grid(True, alpha=0.3)
290
+
291
+ # F1 by threshold for downbeats
292
+ ax2 = axes[0, 1]
293
+ if "downbeat_f1_by_threshold" in results:
294
+ thresholds = sorted(results["downbeat_f1_by_threshold"].keys())
295
+ f1_scores = [results["downbeat_f1_by_threshold"][t] for t in thresholds]
296
+ ax2.bar(range(len(thresholds)), f1_scores, color="coral", alpha=0.8)
297
+ ax2.set_xticks(range(len(thresholds)))
298
+ ax2.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
299
+ ax2.set_ylabel("F1 Score")
300
+ ax2.set_title("Downbeat F1 by Threshold")
301
+ ax2.set_ylim(0, 1)
302
+ ax2.grid(True, alpha=0.3)
303
+
304
+ # Continuity metrics for beats
305
+ ax3 = axes[1, 0]
306
+ if "beat_continuity" in results:
307
+ metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
308
+ values = [results["beat_continuity"][m] for m in metrics]
309
+ colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
310
+ bars = ax3.bar(metrics, values, color=colors, alpha=0.8)
311
+ ax3.set_ylabel("Score")
312
+ ax3.set_title("Beat Continuity Metrics")
313
+ ax3.set_ylim(0, 1)
314
+ ax3.grid(True, alpha=0.3)
315
+ # Add value labels
316
+ for bar, val in zip(bars, values):
317
+ ax3.text(
318
+ bar.get_x() + bar.get_width() / 2,
319
+ bar.get_height() + 0.02,
320
+ f"{val:.3f}",
321
+ ha="center",
322
+ fontsize=9,
323
+ )
324
+
325
+ # Continuity metrics for downbeats
326
+ ax4 = axes[1, 1]
327
+ if "downbeat_continuity" in results:
328
+ metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
329
+ values = [results["downbeat_continuity"][m] for m in metrics]
330
+ colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
331
+ bars = ax4.bar(metrics, values, color=colors, alpha=0.8)
332
+ ax4.set_ylabel("Score")
333
+ ax4.set_title("Downbeat Continuity Metrics")
334
+ ax4.set_ylim(0, 1)
335
+ ax4.grid(True, alpha=0.3)
336
+ # Add value labels
337
+ for bar, val in zip(bars, values):
338
+ ax4.text(
339
+ bar.get_x() + bar.get_width() / 2,
340
+ bar.get_height() + 0.02,
341
+ f"{val:.3f}",
342
+ ha="center",
343
+ fontsize=9,
344
+ )
345
+
346
+ fig.suptitle(title, fontsize=14, fontweight="bold")
347
+ plt.tight_layout()
348
+ return fig
349
+
350
+
351
+ def save_figure(
352
+ fig: "plt.Figure",
353
+ path: str | Path,
354
+ dpi: int = 150,
355
+ ) -> None:
356
+ """
357
+ Save a matplotlib figure to file.
358
+
359
+ Args:
360
+ fig: Figure to save
361
+ path: Output file path
362
+ dpi: Resolution (dots per inch)
363
+ """
364
+ _check_matplotlib()
365
+
366
+ path = Path(path)
367
+ path.parent.mkdir(parents=True, exist_ok=True)
368
+ fig.savefig(str(path), dpi=dpi, bbox_inches="tight")
369
+ plt.close(fig)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ # Demo
374
+ _check_matplotlib()
375
+ print("Visualization demo...")
376
+
377
+ # Generate synthetic data
378
+ np.random.seed(42)
379
+ gt_beats = np.arange(0, 10, 0.5)
380
+ gt_downbeats = np.arange(0, 10, 2.0)
381
+ pred_beats = gt_beats + np.random.normal(0, 0.02, len(gt_beats))
382
+ pred_downbeats = gt_downbeats + np.random.normal(0, 0.01, len(gt_downbeats))
383
+
384
+ # Generate fake audio
385
+ sr = 16000
386
+ duration = 10.0
387
+ t = np.arange(int(duration * sr)) / sr
388
+ audio = np.sin(2 * np.pi * 220 * t) * 0.3
389
+
390
+ # Create plots
391
+ fig1 = plot_beats(
392
+ pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Beat Comparison Demo"
393
+ )
394
+ save_figure(fig1, "/tmp/beat_comparison_demo.png")
395
+ print("Saved /tmp/beat_comparison_demo.png")
396
+
397
+ fig2 = plot_waveform_with_beats(
398
+ audio,
399
+ sr,
400
+ pred_beats,
401
+ gt_beats,
402
+ pred_downbeats,
403
+ gt_downbeats,
404
+ title="Waveform Demo",
405
+ time_range=(2, 8),
406
+ )
407
+ save_figure(fig2, "/tmp/waveform_demo.png")
408
+ print("Saved /tmp/waveform_demo.png")
409
+
410
+ # Fake evaluation results
411
+ results = {
412
+ "beat_f1_by_threshold": {
413
+ 3: 0.5,
414
+ 6: 0.7,
415
+ 9: 0.85,
416
+ 12: 0.9,
417
+ 15: 0.95,
418
+ 18: 0.96,
419
+ 21: 0.97,
420
+ 24: 0.97,
421
+ 27: 0.98,
422
+ 30: 0.98,
423
+ },
424
+ "downbeat_f1_by_threshold": {
425
+ 3: 0.6,
426
+ 6: 0.8,
427
+ 9: 0.9,
428
+ 12: 0.95,
429
+ 15: 0.97,
430
+ 18: 0.98,
431
+ 21: 0.98,
432
+ 24: 0.99,
433
+ 27: 0.99,
434
+ 30: 0.99,
435
+ },
436
+ "beat_continuity": {"CMLc": 0.75, "CMLt": 0.92, "AMLc": 0.80, "AMLt": 0.95},
437
+ "downbeat_continuity": {"CMLc": 0.85, "CMLt": 0.95, "AMLc": 0.88, "AMLt": 0.97},
438
+ }
439
+ fig3 = plot_evaluation_summary(results, title="Evaluation Summary Demo")
440
+ save_figure(fig3, "/tmp/eval_summary_demo.png")
441
+ print("Saved /tmp/eval_summary_demo.png")
exp/sota/__init__.py ADDED
File without changes
outputs/baseline/beats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline/beats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline/beats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline/beats/final/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline/beats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ee01ee41360f0b486e16d6022f896a19f9ead901be0180bdbd9cad2a3b8597
3
+ size 1159372
outputs/baseline/beats/logs/events.out.tfevents.1766351314.msiit232.1284330.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b2d91a22ba01091bf072f5a5e8f12fc7d49801d6538914c973ccb2700978934
3
+ size 17749022
outputs/baseline/beats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7a0d5178bc5dfeee6da26345e7956aeb6bf64a21be7e541db4bcc37b290249
3
+ size 1159372
outputs/baseline/downbeats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline/downbeats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline/downbeats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline/downbeats/final/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline/downbeats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:870e3425ffd366be9a0e8fafcda62fa28b2c25917c8354570edc53a67e132d38
3
+ size 1159372
outputs/baseline/downbeats/logs/events.out.tfevents.1766353075.msiit232.1284330.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8744916b2c1a8255cd5379e6956a4ad2acbf8bcc1fcfaed21ca11285a771550c
3
+ size 4272622
outputs/baseline/downbeats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8895be0bff1c3210f46b04c490596490fe03081728e17fffb33c80369b472134
3
+ size 1159372
pyproject.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "exp-beat-tracking"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "datasets>=4.4.2",
9
+ "einops>=0.8.1",
10
+ "matplotlib>=3.10.0",
11
+ "mir-eval>=0.8.2",
12
+ "safetensors>=0.7.0",
13
+ "scipy>=1.16.3",
14
+ "tensorboard>=2.20.0",
15
+ "torch>=2.9.1",
16
+ "torchaudio>=2.9.1",
17
+ "torchcodec>=0.9.1",
18
+ "torchinfo>=1.8.0",
19
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff