JacobLinCool commited on
Commit
64bf319
·
verified ·
1 Parent(s): 6f01cc1

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +13 -0
  2. .gitignore +16 -0
  3. .python-version +1 -0
  4. BASELINE3_IMPROVEMENTS.md +163 -0
  5. README.md +299 -0
  6. SE/Squeeze-and-Excitation Networks 1.jpg +3 -0
  7. SE/Squeeze-and-Excitation Networks 10.jpg +3 -0
  8. SE/Squeeze-and-Excitation Networks 11.jpg +3 -0
  9. SE/Squeeze-and-Excitation Networks 12.jpg +3 -0
  10. SE/Squeeze-and-Excitation Networks 13.jpg +3 -0
  11. SE/Squeeze-and-Excitation Networks 2.jpg +3 -0
  12. SE/Squeeze-and-Excitation Networks 3.jpg +3 -0
  13. SE/Squeeze-and-Excitation Networks 4.jpg +3 -0
  14. SE/Squeeze-and-Excitation Networks 5.jpg +3 -0
  15. SE/Squeeze-and-Excitation Networks 6.jpg +3 -0
  16. SE/Squeeze-and-Excitation Networks 7.jpg +3 -0
  17. SE/Squeeze-and-Excitation Networks 8.jpg +3 -0
  18. SE/Squeeze-and-Excitation Networks 9.jpg +3 -0
  19. exp/__init__.py +0 -0
  20. exp/baseline1/__init__.py +0 -0
  21. exp/baseline1/data.py +128 -0
  22. exp/baseline1/eval.py +322 -0
  23. exp/baseline1/model.py +62 -0
  24. exp/baseline1/train.py +183 -0
  25. exp/baseline1/utils.py +53 -0
  26. exp/baseline2/__init__.py +0 -0
  27. exp/baseline2/data.py +137 -0
  28. exp/baseline2/eval.py +324 -0
  29. exp/baseline2/model.py +139 -0
  30. exp/baseline2/train.py +215 -0
  31. exp/baseline3/__init__.py +0 -0
  32. exp/baseline3/data.py +173 -0
  33. exp/baseline3/eval.py +336 -0
  34. exp/baseline3/model.py +173 -0
  35. exp/baseline3/train.py +433 -0
  36. exp/data/__init__.py +25 -0
  37. exp/data/audio.py +301 -0
  38. exp/data/eval.py +568 -0
  39. exp/data/load.py +91 -0
  40. exp/data/viz.py +441 -0
  41. outputs/baseline1/beats/README.md +10 -0
  42. outputs/baseline1/beats/config.json +3 -0
  43. outputs/baseline1/beats/final/README.md +10 -0
  44. outputs/baseline1/beats/final/config.json +3 -0
  45. outputs/baseline1/beats/final/model.safetensors +3 -0
  46. outputs/baseline1/beats/logs/events.out.tfevents.1766351314.msiit232.1284330.0 +3 -0
  47. outputs/baseline1/beats/model.safetensors +3 -0
  48. outputs/baseline1/downbeats/README.md +10 -0
  49. outputs/baseline1/downbeats/config.json +3 -0
  50. outputs/baseline1/downbeats/final/README.md +10 -0
.gitattributes CHANGED
@@ -33,3 +33,16 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]1.jpg filter=lfs diff=lfs merge=lfs -text
37
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]10.jpg filter=lfs diff=lfs merge=lfs -text
38
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]11.jpg filter=lfs diff=lfs merge=lfs -text
39
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]12.jpg filter=lfs diff=lfs merge=lfs -text
40
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]13.jpg filter=lfs diff=lfs merge=lfs -text
41
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]2.jpg filter=lfs diff=lfs merge=lfs -text
42
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]3.jpg filter=lfs diff=lfs merge=lfs -text
43
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]4.jpg filter=lfs diff=lfs merge=lfs -text
44
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]5.jpg filter=lfs diff=lfs merge=lfs -text
45
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]6.jpg filter=lfs diff=lfs merge=lfs -text
46
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]7.jpg filter=lfs diff=lfs merge=lfs -text
47
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]8.jpg filter=lfs diff=lfs merge=lfs -text
48
+ SE/Squeeze-and-Excitation[[:space:]]Networks[[:space:]]9.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ outputs/*
13
+ !outputs/baseline1/
14
+ !outputs/baseline2/
15
+
16
+ .ruff_cache/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
BASELINE3_IMPROVEMENTS.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Baseline3 improvements (beats + downbeats)
2
+
3
+ This document summarizes the changes that were made in `exp/baseline3` relative to `exp/baseline2` during this session, with an emphasis on improvements intended to increase beat/downbeat F1 and continuity while keeping the training/eval workflow consistent with baseline2.
4
+
5
+ ## Scope / goals
6
+
7
+ - Keep the same overall pipeline as baseline2 (same dataset, same context window, same mel multi-view preprocessing, same peak-picking evaluation).
8
+ - Add SE-inspired improvements to the **model** (baseline3) while preserving the baseline2 ResNet backbone structure.
9
+ - Make training and TensorBoard curves **comparable** to baseline2.
10
+ - Support faster iteration when needed (optional), but allow returning to baseline2-style “full” training defaults.
11
+
12
+ ---
13
+
14
+ ## Model improvements (affects both beats + downbeats)
15
+
16
+ ### 1) Extra SE-inspired gating (temporal excitation)
17
+
18
+ - File: `exp/baseline3/model.py`
19
+ - Added an additional SE-style gating mechanism that is **time-dependent** (a “temporal excitation” in addition to channel excitation).
20
+ - The intent is to help the network emphasize temporally-salient patterns that correspond to rhythmic events, improving peak sharpness and reducing spurious activations.
21
+
22
+ ### 2) SE block robustness
23
+
24
+ - File: `exp/baseline3/model.py`
25
+ - Made the SE hidden dimension robust for small channel counts (ensuring the intermediate dimension is never zero).
26
+
27
+ ---
28
+
29
+ ## Data / sampling improvements (optional; applies to both beats + downbeats)
30
+
31
+ ### 3) Track capping support (optional)
32
+
33
+ - File: `exp/baseline3/data.py`
34
+ - Added support for limiting the number of tracks used when building indices.
35
+ - This was introduced for **fast iteration** runs (debugging / quick experiments). When not used, training uses the full dataset like baseline2.
36
+
37
+ ### 4) Hard-negative sampling near events (optional)
38
+
39
+ - File: `exp/baseline3/data.py`
40
+ - Added optional “hard negatives” close to ground-truth frames:
41
+ - For each beat/downbeat frame, add negative frames at offsets ±d for d=2..R.
42
+ - Controlled by `hard_neg_radius` and `hard_neg_fraction`.
43
+ - Rationale: random negatives are often too easy; near-event negatives help reduce double-peaks/jitter and can improve continuity.
44
+ - Status: kept **off by default** when running in baseline2-style mode.
45
+
46
+ ---
47
+
48
+ ## Training-loop improvements
49
+
50
+ ### 5) Output directories fixed to avoid overwriting baseline2
51
+
52
+ - File: `exp/baseline3/train.py` (and earlier in the session also baseline3 eval defaults)
53
+ - Baseline3 outputs were adjusted to use baseline3-specific output directories so baseline2 artifacts aren’t overwritten.
54
+
55
+ ### 6) Loss logging parity with baseline2
56
+
57
+ - File: `exp/baseline3/train.py`
58
+ - Baseline2 uses unweighted BCE (`nn.BCELoss`). Baseline3 introduced an optional weighted BCE objective for imbalance experiments.
59
+ - A key issue was discovered: TensorBoard curves looked “worse” in baseline3 because it was logging weighted BCE as the main loss.
60
+ - Fix:
61
+ - `train/batch_loss` and `train/epoch_loss` are now **unweighted BCE** (baseline2-comparable).
62
+ - If weighting is enabled, the optimized objective is logged separately as `*_weighted`.
63
+
64
+ ### 7) Optional imbalance-aware objective (pos weighting)
65
+
66
+ - File: `exp/baseline3/train.py`
67
+ - Added an optional weighted BCE objective, controlled by `--pos-weight`.
68
+ - Default is `--pos-weight 0.0`, which matches baseline2 behavior.
69
+
70
+ ### 8) Optional gradient clipping
71
+
72
+ - File: `exp/baseline3/train.py`
73
+ - Added `--grad-clip` support to stabilize training when experimenting.
74
+ - For baseline2-style mode, default was set back to **disabled** (`--grad-clip 0.0`).
75
+
76
+ ### 9) Fast-iteration controls (optional)
77
+
78
+ - File: `exp/baseline3/train.py`
79
+ - Added optional caps for quicker experiments:
80
+ - `--max-train-tracks`, `--max-val-tracks`
81
+ - `--max-train-steps`, `--max-val-steps`, `--max-steps-total`
82
+ - These are intended only for debugging/iteration. Baseline2-style training leaves them unset (0/unlimited).
83
+
84
+ ### 10) Back to baseline2-style default training mode
85
+
86
+ - File: `exp/baseline3/train.py`
87
+ - Returned baseline3 defaults to match baseline2 training mode:
88
+ - `--epochs 3`
89
+ - `--patience 5`
90
+ - objective defaults to unweighted BCE when `--pos-weight 0.0`
91
+ - no grad clipping by default
92
+
93
+ ---
94
+
95
+ ## Evaluation improvements
96
+
97
+ ### 11) Mix-and-match beats and downbeats checkpoints
98
+
99
+ - File: `exp/baseline3/eval.py`
100
+ - Added support to evaluate using different model directories for beats vs downbeats:
101
+ - `--beats-model-dir`
102
+ - `--downbeats-model-dir`
103
+ - This enables workflows like “new beats run + keep downbeats fixed”.
104
+
105
+ ---
106
+
107
+ ## Beats-specific notes
108
+
109
+ - All model/training/eval improvements above apply to beats.
110
+ - A key gotcha found during quick experiments: some runs only saved the checkpoint under a `final/` subfolder. When evaluating, using the correct folder matters.
111
+
112
+ ### Latest mixed eval result (beats improved)
113
+
114
+ Eval command used:
115
+
116
+ - Beats: `outputs/baseline3_b2mode_full3/beats`
117
+ - Downbeats: `outputs/baseline3_smoketest/downbeats`
118
+ - Output: `outputs/eval_mix_b3_b2modebeats_smoketestdownbeats`
119
+
120
+ Key metrics (116 tracks):
121
+
122
+ - Mean Beat Weighted F1: **0.3531**
123
+ - Beat continuity: CMLt **0.3567**, AMLt **0.3607**, CMLc **0.0603**, AMLc **0.0624**
124
+
125
+ Summary plot:
126
+
127
+ - `outputs/eval_mix_b3_b2modebeats_smoketestdownbeats/evaluation_summary.png`
128
+
129
+ ---
130
+
131
+ ## Downbeats-specific notes
132
+
133
+ - Downbeats training uses the same dataset/indexing logic, model architecture, and preprocessing as beats.
134
+ - The improvements (temporal excitation, loss logging parity, optional hard negatives, optional fast-iteration, mixed-checkpoint evaluation) all apply identically.
135
+ - In the mixed eval above, downbeats were held fixed using the baseline3 smoketest checkpoint.
136
+
137
+ ---
138
+
139
+ ## Repro commands
140
+
141
+ ### Full baseline2-style training (beats only)
142
+
143
+ ```bash
144
+ uv run -m exp.baseline3.train --target beats --output-dir outputs/baseline3_b2mode_full3
145
+ ```
146
+
147
+ ### Mixed evaluation (beats from a new run + downbeats from baseline3 smoketest)
148
+
149
+ ```bash
150
+ uv run -m exp.baseline3.eval \
151
+ --beats-model-dir outputs/baseline3_b2mode_full3/beats \
152
+ --downbeats-model-dir outputs/baseline3_smoketest/downbeats \
153
+ --output-dir outputs/eval_mix_b3_b2modebeats_smoketestdownbeats \
154
+ --summary-plot
155
+ ```
156
+
157
+ ---
158
+
159
+ ## Known warnings
160
+
161
+ - You may see repeated torchaudio warnings like:
162
+ - “At least one mel filterbank has all zero values…”
163
+ - This is produced by `torchaudio` mel filterbank construction for some parameter combinations and is not specific to baseline3.
README.md ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Beat Tracking Challenge
2
+
3
+ A challenge for detecting beats and downbeats in music audio, with a focus on handling dynamic tempo changes common in rhythm game charts.
4
+
5
+ ## Goal
6
+
7
+ The goal is to **detect and identify beats and downbeats** in audio to assist composers by providing a flexible timing grid when working with samples that have dynamic BPM changes.
8
+
9
+ - **Beat**: A regular pulse in music (e.g., quarter notes in 4/4 time)
10
+ - **Downbeat**: The first beat of each measure (the "1" in counting "1-2-3-4")
11
+
12
+ This is particularly useful for:
13
+ - Music production with samples of varying tempos
14
+ - Rhythm game chart creation and verification
15
+ - Audio analysis and music information retrieval (MIR)
16
+
17
+ ---
18
+
19
+ ## Dataset
20
+
21
+ The dataset is derived from Taiko no Tatsujin rhythm game charts, providing high-quality human-annotated beat and downbeat ground truth.
22
+
23
+ **Source**: [`JacobLinCool/taiko-1000-parsed`](https://huggingface.co/datasets/JacobLinCool/taiko-1000-parsed)
24
+
25
+ | Split | Tracks | Duration | Description |
26
+ |-------|--------|----------|-------------|
27
+ | `train` | ~1000 | 1-3 min each | Training data with beat/downbeat annotations |
28
+ | `test` | ~100 | 1-3 min each | Held-out test set for final evaluation |
29
+
30
+ ### Data Features
31
+
32
+ Each example contains:
33
+
34
+ | Field | Type | Description |
35
+ |-------|------|-------------|
36
+ | `audio` | `Audio` | Audio waveform at 16kHz sample rate |
37
+ | `title` | `str` | Track title |
38
+ | `beats` | `list[float]` | Beat timestamps in seconds |
39
+ | `downbeats` | `list[float]` | Downbeat timestamps in seconds |
40
+
41
+ ### Dataset Characteristics
42
+
43
+ - **Dynamic BPM**: Many tracks feature tempo changes mid-song
44
+ - **Variable Time Signatures**: Common patterns include 4/4, 3/4, 6/8, and more exotic meters
45
+ - **Diverse Genres**: Japanese pop, anime themes, classical arrangements, electronic music
46
+ - **High-Quality Annotations**: Derived from professional rhythm game charts
47
+
48
+ ---
49
+
50
+ ## Evaluation Metrics
51
+
52
+ The evaluation considers both **timing accuracy** and **metrical correctness**. Models are evaluated on both beat and downbeat detection tasks.
53
+
54
+ ### Primary Metrics
55
+
56
+ #### 1. Weighted F1-Score (Main Ranking Metric)
57
+
58
+ F1-scores are calculated at multiple timing thresholds (3ms to 30ms), then combined with inverse-threshold weighting:
59
+
60
+ | Threshold | Weight | Rationale |
61
+ |-----------|--------|-----------|
62
+ | 3ms | 1.000 | Full weight for highest precision |
63
+ | 6ms | 0.500 | Half weight |
64
+ | 9ms | 0.333 | One-third weight |
65
+ | 12ms | 0.250 | ... |
66
+ | 15ms | 0.200 | |
67
+ | 18ms | 0.167 | |
68
+ | 21ms | 0.143 | |
69
+ | 24ms | 0.125 | |
70
+ | 27ms | 0.111 | |
71
+ | 30ms | 0.100 | Minimum weight for coarsest threshold |
72
+
73
+ **Formula:**
74
+ ```
75
+ Weighted F1 = Σ(w_t × F1_t) / Σ(w_t)
76
+ where w_t = 3ms / t (inverse threshold weighting)
77
+ ```
78
+
79
+ This weighting scheme rewards models that achieve high precision at tight tolerances while still considering coarser thresholds.
80
+
81
+ #### 2. Continuity Metrics (CMLt, AMLt)
82
+
83
+ Based on the MIREX beat tracking evaluation protocol using `mir_eval`:
84
+
85
+ | Metric | Full Name | Description |
86
+ |--------|-----------|-------------|
87
+ | **CMLt** | Correct Metrical Level Total | Percentage of beats correctly tracked at the exact metrical level (±17.5% of beat interval) |
88
+ | **AMLt** | Any Metrical Level Total | Same as CMLt, but allows for acceptable metrical variations (double/half tempo, off-beat) |
89
+ | **CMLc** | Correct Metrical Level Continuous | Longest continuous correctly-tracked segment at exact metrical level |
90
+ | **AMLc** | Any Metrical Level Continuous | Longest continuous segment at any acceptable metrical level |
91
+
92
+ **Note:** Continuity metrics use a default `min_beat_time=5.0s` (skipping the first 5 seconds) to avoid evaluating potentially unstable tempo at the beginning of tracks.
93
+
94
+ ### Metric Interpretation
95
+
96
+ | Metric | What it measures | Good Score |
97
+ |--------|------------------|------------|
98
+ | Weighted F1 | Precise timing accuracy | > 0.7 |
99
+ | CMLt | Correct tempo tracking | > 0.8 |
100
+ | AMLt | Tempo tracking (flexible) | > 0.9 |
101
+ | CMLc | Longest stable segment | > 0.5 |
102
+
103
+ ### Evaluation Summary
104
+
105
+ For each model, we report:
106
+
107
+ ```
108
+ Beat Detection:
109
+ Weighted F1: X.XXXX
110
+ CMLt: X.XXXX AMLt: X.XXXX
111
+ CMLc: X.XXXX AMLc: X.XXXX
112
+
113
+ Downbeat Detection:
114
+ Weighted F1: X.XXXX
115
+ CMLt: X.XXXX AMLt: X.XXXX
116
+ CMLc: X.XXXX AMLc: X.XXXX
117
+
118
+ Combined Weighted F1: X.XXXX (average of beat and downbeat)
119
+ ```
120
+
121
+ ### Benchmark Results
122
+
123
+ Results evaluated on 100 tracks from the test set:
124
+
125
+ | Model | Combined F1 | Beat F1 | Downbeat F1 | CMLt (Beat) | CMLt (Downbeat) |
126
+ |-------|-------------|---------|-------------|-------------|-----------------|
127
+ | **Baseline 1 (ODCNN)** | 0.0765 | 0.0861 | 0.0669 | 0.0731 | 0.0321 |
128
+ | **Baseline 2 (ResNet-SE)** | **0.2775** | **0.3292** | **0.2258** | **0.3287** | **0.1146** |
129
+
130
+ *Note: Baseline 2 (ResNet-SE) demonstrates significantly better performance due to larger context window and deeper architecture.*
131
+
132
+ ---
133
+
134
+ ## Quick Start
135
+
136
+ ### Setup
137
+
138
+ ```bash
139
+ uv sync
140
+ ```
141
+
142
+ ### Train Models
143
+
144
+ ```bash
145
+ # Train Baseline 1 (ODCNN)
146
+ uv run -m exp.baseline1.train
147
+
148
+ # Train Baseline 2 (ResNet-SE)
149
+ uv run -m exp.baseline2.train
150
+
151
+ # Train specific target only (e.g. for Baseline 2)
152
+ uv run -m exp.baseline2.train --target beats
153
+ uv run -m exp.baseline2.train --target downbeats
154
+ ```
155
+
156
+ ### Run Evaluation
157
+
158
+ ```bash
159
+ # Evaluation (replace baseline1 with baseline2 to evaluate the new model)
160
+ uv run -m exp.baseline1.eval
161
+
162
+ # Full evaluation with visualization and audio
163
+ uv run -m exp.baseline1.eval --visualize --synthesize --summary-plot
164
+
165
+ # Evaluate on more samples with custom output directory
166
+ uv run -m exp.baseline1.eval --num-samples 50 --output-dir outputs/eval_baseline1
167
+ ```
168
+
169
+ ### Evaluation Options
170
+
171
+ | Option | Description |
172
+ |--------|-------------|
173
+ | Option | Description |
174
+ |--------|-------------|
175
+ | `--model-dir DIR` | Model directory (default: `outputs/baseline1`) |
176
+ | `--num-samples N` | Number of samples to evaluate (default: 20) |
177
+ | `--output-dir DIR` | Output directory (default: `outputs/eval`) |
178
+ | `--visualize` | Generate visualization plots for each track |
179
+ | `--synthesize` | Generate audio files with click tracks |
180
+ | `--viz-tracks N` | Number of tracks to visualize/synthesize (default: 5) |
181
+ | `--time-range START END` | Limit visualization time range (seconds) |
182
+ | `--click-volume FLOAT` | Click sound volume (0.0 to 1.0, default: 0.5) |
183
+ | `--summary-plot` | Generate summary evaluation bar charts |
184
+
185
+ ---
186
+
187
+ ## Visualization & Audio Tools
188
+
189
+ ### Beat Visualization
190
+
191
+ Generate plots comparing predicted vs ground truth beats:
192
+
193
+ ```bash
194
+ uv run -m exp.baseline1.eval --visualize --viz-tracks 10
195
+ ```
196
+
197
+ Output: `outputs/eval/plots/track_XXX.png`
198
+
199
+ ### Click Track Audio
200
+
201
+ Generate audio files with click sounds overlaid on the original music:
202
+
203
+ ```bash
204
+ uv run -m exp.baseline1.eval --synthesize
205
+ ```
206
+
207
+ Output files in `outputs/eval/audio/`:
208
+ - `track_XXX_pred.wav` - Original audio + predicted beat clicks (1000Hz beat, 1500Hz downbeat)
209
+ - `track_XXX_gt.wav` - Original audio + ground truth clicks (800Hz beat, 1200Hz downbeat)
210
+ - `track_XXX_both.wav` - Original audio + both prediction and ground truth clicks
211
+
212
+ ### Summary Plot
213
+
214
+ Generate bar charts summarizing F1 scores and continuity metrics:
215
+
216
+ ```bash
217
+ uv run -m exp.baseline1.eval --summary-plot
218
+ ```
219
+
220
+ Output: `outputs/eval/evaluation_summary.png`
221
+
222
+ ---
223
+
224
+ ## Models
225
+
226
+ ### Baseline 1: ODCNN
227
+
228
+ A 10-year-old baseline model: <https://ieeexplore.ieee.org/document/6854953>.
229
+
230
+ The original baseline implements the **Onset Detection CNN (ODCNN)** architecture:
231
+
232
+ #### Architecture
233
+ - **Input**: Multi-view mel spectrogram (3 window sizes: 23ms, 46ms, 93ms)
234
+ - **CNN Backbone**: 3 convolutional blocks with max pooling
235
+ - **Output**: Frame-level beat/downbeat probability
236
+ - **Inference**: ±7 frames context (±70ms)
237
+
238
+ ### Baseline 2: ResNet-SE
239
+
240
+ Inspired by ResNet-SE: <https://arxiv.org/abs/1709.01507>.
241
+
242
+ A modernized architecture designed to capture longer temporal context:
243
+
244
+ #### Architecture
245
+ - **Input**: Mel spectrogram with larger context
246
+ - **Backbone**: ResNet with Squeeze-and-Excitation (SE) blocks
247
+ - **Context**: **±50 frames (~1s)** window
248
+ - **Features**: Deeper network (4 stages) with effective channel attention
249
+ - **Parameters**: ~400k (Small & Efficient)
250
+
251
+ ### Training Details
252
+
253
+ Both models use similar training loops:
254
+ - **Optimizer**: SGD (Baseline 1) / AdamW (Baseline 2)
255
+ - **Learning Rate**: Cosine annealing
256
+ - **Loss**: Binary Cross-Entropy
257
+ - **Epochs**: 50 (Baseline 1) / 3 (Baseline 2)
258
+ - **Batch Size**: 512 (Baseline 1) / 128 (Baseline 2)
259
+
260
+ ---
261
+
262
+ ## Project Structure
263
+
264
+ ```
265
+ exp-onset/
266
+ ├── exp/
267
+ │ ├── baseline1/ # Baseline 1 (ODCNN)
268
+ │ │ ├── model.py # ODCNN architecture
269
+ │ │ ├── train.py
270
+ │ │ ├── eval.py
271
+ │ │ ├── data.py
272
+ │ │ └── utils.py
273
+ │ ├── baseline2/ # Baseline 2 (ResNet-SE)
274
+ │ │ ├── model.py # ResNet-SE
275
+ │ │ ├── train.py
276
+ │ │ ├── eval.py
277
+ │ │ └── data.py
278
+ │ └── data/
279
+ │ ├── load.py # Dataset loading & preprocessing
280
+ │ ├── eval.py # Evaluation metrics (F1, CML, AML)
281
+ │ ├── audio.py # Click track synthesis
282
+ │ └── viz.py # Visualization utilities
283
+ ├── outputs/
284
+ │ ├── baseline1/ # Trained models (Baseline 1)
285
+ │ ├── baseline2/ # Trained models (Baseline 2)
286
+ │ └── eval/ # Evaluation outputs
287
+ │ ├── plots/ # Visualization images
288
+ │ ├── audio/ # Click track audio files
289
+ │ └── evaluation_summary.png
290
+ ├── README.md
291
+ ├── DATASET.md # Raw dataset specification
292
+ └── pyproject.toml
293
+ ```
294
+
295
+ ---
296
+
297
+ ## License
298
+
299
+ This project is for research and educational purposes. The dataset is derived from publicly available rhythm game charts.
SE/Squeeze-and-Excitation Networks 1.jpg ADDED

Git LFS Details

  • SHA256: d0380e82ecf8f2ffc4ff8553a4d40ab50f7503d8b9ffdc58d0ca067b67060c97
  • Pointer size: 132 Bytes
  • Size of remote file: 5.94 MB
SE/Squeeze-and-Excitation Networks 10.jpg ADDED

Git LFS Details

  • SHA256: 74d9f6c4dc9e6bffccf03e2a8f233fa55eb99bf4e2762a5e7cf6ec2cdd37837e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.27 MB
SE/Squeeze-and-Excitation Networks 11.jpg ADDED

Git LFS Details

  • SHA256: ad07ce4dcb5540e4878dd0b56cfe8db4c7f13f92429587ac75aa6203a6a7700c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.06 MB
SE/Squeeze-and-Excitation Networks 12.jpg ADDED

Git LFS Details

  • SHA256: e195ff80703b9c88ee1d661fb8a0cbf9a312cdc96732ed304cbb878d6efd7777
  • Pointer size: 132 Bytes
  • Size of remote file: 6.86 MB
SE/Squeeze-and-Excitation Networks 13.jpg ADDED

Git LFS Details

  • SHA256: 25065cec902bbed6e52ae28cf5e8c52613ec1da8d36dcd027765edd70cc27e1c
  • Pointer size: 132 Bytes
  • Size of remote file: 5.26 MB
SE/Squeeze-and-Excitation Networks 2.jpg ADDED

Git LFS Details

  • SHA256: 63dc30e35ecae244ffcb197c491dcbe237f8724ec418c20ddbd91b89e5512135
  • Pointer size: 132 Bytes
  • Size of remote file: 5.96 MB
SE/Squeeze-and-Excitation Networks 3.jpg ADDED

Git LFS Details

  • SHA256: 2653f086da04f5d990ea817531663115396a77fc52962524623b8a49508f9412
  • Pointer size: 132 Bytes
  • Size of remote file: 6.24 MB
SE/Squeeze-and-Excitation Networks 4.jpg ADDED

Git LFS Details

  • SHA256: b80814bb2e06269784482d396949c835263084d05944d580a4981cc6ebad70d4
  • Pointer size: 132 Bytes
  • Size of remote file: 5.35 MB
SE/Squeeze-and-Excitation Networks 5.jpg ADDED

Git LFS Details

  • SHA256: af5e33be7e6cdb093a1e12df355406919d19920343c475eb8e31f65c859d9f76
  • Pointer size: 132 Bytes
  • Size of remote file: 5.07 MB
SE/Squeeze-and-Excitation Networks 6.jpg ADDED

Git LFS Details

  • SHA256: 4f5e9d98eaa72167d038b5469e158a418eb535780d24d7784124027e7b08e571
  • Pointer size: 132 Bytes
  • Size of remote file: 5.95 MB
SE/Squeeze-and-Excitation Networks 7.jpg ADDED

Git LFS Details

  • SHA256: b719f3619270a097f05d6909b8446c351a0e38eeb60aac2f6c900b0fe4d5275b
  • Pointer size: 132 Bytes
  • Size of remote file: 5.91 MB
SE/Squeeze-and-Excitation Networks 8.jpg ADDED

Git LFS Details

  • SHA256: 662aab9fdddcf2f68dc72c2d8499480198b53c96e101a7f502c803a7c5388a05
  • Pointer size: 132 Bytes
  • Size of remote file: 5.64 MB
SE/Squeeze-and-Excitation Networks 9.jpg ADDED

Git LFS Details

  • SHA256: 9f2c401006c2d645d3c043f7ca79b61e54c36c5a136df8e1bdbc1a3b5eb74bf3
  • Pointer size: 132 Bytes
  • Size of remote file: 5.47 MB
exp/__init__.py ADDED
File without changes
exp/baseline1/__init__.py ADDED
File without changes
exp/baseline1/data.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from .utils import extract_context
6
+
7
+
8
+ class BeatTrackingDataset(Dataset):
9
+ def __init__(
10
+ self, hf_dataset, target_type="beats", sample_rate=16000, hop_length=160
11
+ ):
12
+ """
13
+ Args:
14
+ hf_dataset: HuggingFace dataset object
15
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
16
+ """
17
+ self.sr = sample_rate
18
+ self.hop_length = hop_length
19
+ self.target_type = target_type
20
+
21
+ # Context window size in samples (7 frames = 70ms at 100fps)
22
+ self.context_frames = 7
23
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + max(
24
+ [368, 736, 1488]
25
+ ) # extra for FFT window
26
+
27
+ # Cache audio arrays in memory for fast access
28
+ self.audio_cache = []
29
+ self.indices = []
30
+ self._prepare_indices(hf_dataset)
31
+
32
+ def _prepare_indices(self, hf_dataset):
33
+ """
34
+ Prepares balanced indices and caches audio.
35
+ Paper Section 4.5: Uses "Fuzzier" training examples (neighbors weighted less).
36
+ """
37
+ print(f"Preparing dataset indices for target: {self.target_type}...")
38
+
39
+ for i, item in tqdm(
40
+ enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
41
+ ):
42
+ # Cache audio array (convert to numpy if tensor)
43
+ audio = item["audio"]["array"]
44
+ if hasattr(audio, "numpy"):
45
+ audio = audio.numpy()
46
+ self.audio_cache.append(audio)
47
+
48
+ # Calculate total frames available in audio
49
+ audio_len = len(audio)
50
+ n_frames = int(audio_len / self.hop_length)
51
+
52
+ # Select ground truth based on target_type
53
+ if self.target_type == "downbeats":
54
+ # Only downbeats are positives
55
+ gt_times = item["downbeats"]
56
+ else:
57
+ # All beats are positives (downbeats are also beats)
58
+ gt_times = item["beats"]
59
+
60
+ # Convert to list if tensor
61
+ if hasattr(gt_times, "tolist"):
62
+ gt_times = gt_times.tolist()
63
+
64
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
65
+
66
+ # --- Positive Examples (with Fuzziness) ---
67
+ # "define a single frame before and after each annotated onset to be additional positive examples"
68
+ pos_frames = set()
69
+ for bf in gt_frames:
70
+ if 0 <= bf < n_frames:
71
+ self.indices.append((i, bf, 1.0)) # Center frame (Sharp onset)
72
+ pos_frames.add(bf)
73
+
74
+ # Neighbors weighted at 0.25
75
+ if 0 <= bf - 1 < n_frames:
76
+ self.indices.append((i, bf - 1, 0.25))
77
+ pos_frames.add(bf - 1)
78
+ if 0 <= bf + 1 < n_frames:
79
+ self.indices.append((i, bf + 1, 0.25))
80
+ pos_frames.add(bf + 1)
81
+
82
+ # --- Negative Examples ---
83
+ # Paper uses "all others as negative", but we balance 2:1 for stable SGD.
84
+ num_pos = len(pos_frames)
85
+ num_neg = num_pos * 2
86
+
87
+ count = 0
88
+ attempts = 0
89
+ while count < num_neg and attempts < num_neg * 5:
90
+ f = np.random.randint(0, n_frames)
91
+ if f not in pos_frames:
92
+ self.indices.append((i, f, 0.0))
93
+ count += 1
94
+ attempts += 1
95
+
96
+ print(
97
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
98
+ )
99
+
100
+ def __len__(self):
101
+ return len(self.indices)
102
+
103
+ def __getitem__(self, idx):
104
+ track_idx, frame_idx, label = self.indices[idx]
105
+
106
+ # Fast lookup from cache
107
+ audio = self.audio_cache[track_idx]
108
+ audio_len = len(audio)
109
+
110
+ # Calculate sample range for context window
111
+ center_sample = frame_idx * self.hop_length
112
+ half_context = self.context_samples // 2
113
+ start = center_sample - half_context
114
+ end = center_sample + half_context
115
+
116
+ # Handle padding if needed
117
+ pad_left = max(0, -start)
118
+ pad_right = max(0, end - audio_len)
119
+ start = max(0, start)
120
+ end = min(audio_len, end)
121
+
122
+ # Extract audio chunk
123
+ chunk = audio[start:end]
124
+ if pad_left > 0 or pad_right > 0:
125
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
126
+
127
+ waveform = torch.tensor(chunk, dtype=torch.float32)
128
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline1/eval.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ODCNN
9
+ from .utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ spec = torch.nn.functional.pad(spec, (7, 7)) # Pad time
31
+ windows = spec.unfold(3, 15, 1) # (1, 3, 80, Time, 15)
32
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 15)
33
+
34
+ # Inference
35
+ activations = []
36
+ batch_size = 512
37
+ for i in range(0, len(windows), batch_size):
38
+ batch = windows[i : i + batch_size]
39
+ out = model(batch)
40
+ activations.append(out.cpu().numpy())
41
+
42
+ return np.concatenate(activations).flatten()
43
+
44
+
45
+ def pick_peaks(activations, hop_length=160, sr=16000):
46
+ """
47
+ Smooth with Hamming window and report local maxima.
48
+ """
49
+ # Smoothing
50
+ window = np.hamming(5)
51
+ window /= window.sum()
52
+ smoothed = np.convolve(activations, window, mode="same")
53
+
54
+ # Peak Picking
55
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
56
+
57
+ timestamps = peaks * hop_length / sr
58
+ return timestamps.tolist()
59
+
60
+
61
+ def visualize_track(
62
+ audio: np.ndarray,
63
+ sr: int,
64
+ pred_beats: list[float],
65
+ pred_downbeats: list[float],
66
+ gt_beats: list[float],
67
+ gt_downbeats: list[float],
68
+ output_dir: str,
69
+ track_idx: int,
70
+ time_range: tuple[float, float] | None = None,
71
+ ):
72
+ """
73
+ Create and save visualizations for a single track.
74
+ """
75
+ from ..data.viz import plot_waveform_with_beats, save_figure
76
+
77
+ os.makedirs(output_dir, exist_ok=True)
78
+
79
+ # Full waveform plot
80
+ fig = plot_waveform_with_beats(
81
+ audio,
82
+ sr,
83
+ pred_beats,
84
+ gt_beats,
85
+ pred_downbeats,
86
+ gt_downbeats,
87
+ title=f"Track {track_idx}: Beat Comparison",
88
+ time_range=time_range,
89
+ )
90
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
91
+
92
+
93
+ def synthesize_audio(
94
+ audio: np.ndarray,
95
+ sr: int,
96
+ pred_beats: list[float],
97
+ pred_downbeats: list[float],
98
+ gt_beats: list[float],
99
+ gt_downbeats: list[float],
100
+ output_dir: str,
101
+ track_idx: int,
102
+ click_volume: float = 0.5,
103
+ ):
104
+ """
105
+ Create and save audio files with click tracks for a single track.
106
+ """
107
+ from ..data.audio import create_comparison_audio, save_audio
108
+
109
+ os.makedirs(output_dir, exist_ok=True)
110
+
111
+ # Create comparison audio
112
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
113
+ audio,
114
+ pred_beats,
115
+ pred_downbeats,
116
+ gt_beats,
117
+ gt_downbeats,
118
+ sr=sr,
119
+ click_volume=click_volume,
120
+ )
121
+
122
+ # Save audio files
123
+ save_audio(
124
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
125
+ )
126
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
127
+ save_audio(
128
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
129
+ )
130
+
131
+
132
+ def main():
133
+ parser = argparse.ArgumentParser(
134
+ description="Evaluate beat tracking models with visualization and audio synthesis"
135
+ )
136
+ parser.add_argument(
137
+ "--model-dir",
138
+ type=str,
139
+ default="outputs/baseline1",
140
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
141
+ )
142
+ parser.add_argument(
143
+ "--num-samples",
144
+ type=int,
145
+ default=116,
146
+ help="Number of samples to evaluate",
147
+ )
148
+ parser.add_argument(
149
+ "--output-dir",
150
+ type=str,
151
+ default="outputs/eval_baseline1",
152
+ help="Directory to save visualizations and audio",
153
+ )
154
+ parser.add_argument(
155
+ "--visualize",
156
+ action="store_true",
157
+ help="Generate visualization plots for each track",
158
+ )
159
+ parser.add_argument(
160
+ "--synthesize",
161
+ action="store_true",
162
+ help="Generate audio files with click tracks",
163
+ )
164
+ parser.add_argument(
165
+ "--viz-tracks",
166
+ type=int,
167
+ default=5,
168
+ help="Number of tracks to visualize/synthesize (default: 5)",
169
+ )
170
+ parser.add_argument(
171
+ "--time-range",
172
+ type=float,
173
+ nargs=2,
174
+ default=None,
175
+ metavar=("START", "END"),
176
+ help="Time range for visualization in seconds (default: full track)",
177
+ )
178
+ parser.add_argument(
179
+ "--click-volume",
180
+ type=float,
181
+ default=0.5,
182
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
183
+ )
184
+ parser.add_argument(
185
+ "--summary-plot",
186
+ action="store_true",
187
+ help="Generate summary evaluation plot",
188
+ )
189
+ args = parser.parse_args()
190
+
191
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
192
+
193
+ # Load BOTH models using from_pretrained
194
+ beat_model = None
195
+ downbeat_model = None
196
+
197
+ has_beats = False
198
+ has_downbeats = False
199
+
200
+ beats_dir = os.path.join(args.model_dir, "beats")
201
+ downbeats_dir = os.path.join(args.model_dir, "downbeats")
202
+
203
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
204
+ beat_model = ODCNN.from_pretrained(beats_dir).to(DEVICE)
205
+ beat_model.eval()
206
+ has_beats = True
207
+ print(f"Loaded Beat Model from {beats_dir}")
208
+ else:
209
+ print(f"Warning: No beat model found in {beats_dir}")
210
+
211
+ if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
212
+ downbeat_model = ODCNN.from_pretrained(downbeats_dir).to(DEVICE)
213
+ downbeat_model.eval()
214
+ has_downbeats = True
215
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
216
+ else:
217
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
218
+
219
+ if not has_beats and not has_downbeats:
220
+ print("No models found. Please run training first.")
221
+ return
222
+
223
+ predictions = []
224
+ ground_truths = []
225
+ audio_data = [] # Store audio for visualization/synthesis
226
+
227
+ # Eval on specified number of tracks
228
+ test_set = ds["train"].select(range(args.num_samples))
229
+
230
+ print("Running evaluation...")
231
+ for i, item in enumerate(tqdm(test_set)):
232
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
233
+ waveform_device = waveform.to(DEVICE)
234
+
235
+ pred_entry = {"beats": [], "downbeats": []}
236
+
237
+ # 1. Predict Beats
238
+ if has_beats:
239
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
240
+ pred_entry["beats"] = pick_peaks(act_b)
241
+
242
+ # 2. Predict Downbeats
243
+ if has_downbeats:
244
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
245
+ pred_entry["downbeats"] = pick_peaks(act_d)
246
+
247
+ predictions.append(pred_entry)
248
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
249
+
250
+ # Store audio for later visualization/synthesis
251
+ if args.visualize or args.synthesize:
252
+ if i < args.viz_tracks:
253
+ audio_data.append(
254
+ {
255
+ "audio": waveform.numpy(),
256
+ "sr": item["audio"]["sampling_rate"],
257
+ "pred": pred_entry,
258
+ "gt": ground_truths[-1],
259
+ }
260
+ )
261
+
262
+ # Run evaluation
263
+ results = evaluate_all(predictions, ground_truths)
264
+ print(format_results(results))
265
+
266
+ # Create output directory
267
+ if args.visualize or args.synthesize or args.summary_plot:
268
+ os.makedirs(args.output_dir, exist_ok=True)
269
+
270
+ # Generate visualizations
271
+ if args.visualize:
272
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
273
+ viz_dir = os.path.join(args.output_dir, "plots")
274
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
275
+ time_range = tuple(args.time_range) if args.time_range else None
276
+ visualize_track(
277
+ data["audio"],
278
+ data["sr"],
279
+ data["pred"]["beats"],
280
+ data["pred"]["downbeats"],
281
+ data["gt"]["beats"],
282
+ data["gt"]["downbeats"],
283
+ viz_dir,
284
+ i,
285
+ time_range=time_range,
286
+ )
287
+ print(f"Saved visualizations to {viz_dir}")
288
+
289
+ # Generate audio with clicks
290
+ if args.synthesize:
291
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
292
+ audio_dir = os.path.join(args.output_dir, "audio")
293
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
294
+ synthesize_audio(
295
+ data["audio"],
296
+ data["sr"],
297
+ data["pred"]["beats"],
298
+ data["pred"]["downbeats"],
299
+ data["gt"]["beats"],
300
+ data["gt"]["downbeats"],
301
+ audio_dir,
302
+ i,
303
+ click_volume=args.click_volume,
304
+ )
305
+ print(f"Saved audio files to {audio_dir}")
306
+ print(" *_pred.wav - Original audio with predicted beat clicks")
307
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
308
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
309
+
310
+ # Generate summary plot
311
+ if args.summary_plot:
312
+ from ..data.viz import plot_evaluation_summary, save_figure
313
+
314
+ print("\nGenerating summary plot...")
315
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
316
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
317
+ save_figure(fig, summary_path)
318
+ print(f"Saved summary plot to {summary_path}")
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
exp/baseline1/model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class ODCNN(nn.Module, PyTorchModelHubMixin):
7
+ def __init__(self, dropout_rate=0.5):
8
+ super().__init__()
9
+
10
+ # Input 3 channels, 80 bands
11
+ # Conv 1: 7x3 filters -> 10 maps
12
+ self.conv1 = nn.Conv2d(3, 10, kernel_size=(3, 7))
13
+ self.relu1 = nn.ReLU() # ReLU improvement
14
+ self.pool1 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
15
+
16
+ # Conv 2: 3x3 filters -> 20 maps
17
+ self.conv2 = nn.Conv2d(10, 20, kernel_size=(3, 3))
18
+ self.relu2 = nn.ReLU()
19
+ self.pool2 = nn.MaxPool2d(kernel_size=(3, 1), stride=(3, 1))
20
+
21
+ # Flatten size calculation based on architecture
22
+ # (20 feature maps * 8 freq bands * 7 time frames)
23
+ self.flatten_size = 20 * 8 * 7
24
+
25
+ # Dropout on FC inputs
26
+ self.dropout = nn.Dropout(p=dropout_rate)
27
+
28
+ # 256 Hidden Units
29
+ self.fc1 = nn.Linear(self.flatten_size, 256)
30
+ self.relu_fc = nn.ReLU()
31
+
32
+ # Output Unit
33
+ self.fc2 = nn.Linear(256, 1)
34
+ self.sigmoid = nn.Sigmoid()
35
+
36
+ def forward(self, x):
37
+ x = self.conv1(x)
38
+ x = self.relu1(x)
39
+ x = self.pool1(x)
40
+
41
+ x = self.conv2(x)
42
+ x = self.relu2(x)
43
+ x = self.pool2(x)
44
+
45
+ x = x.view(x.size(0), -1)
46
+
47
+ x = self.dropout(x)
48
+ x = self.fc1(x)
49
+ x = self.relu_fc(x)
50
+
51
+ x = self.dropout(x)
52
+ x = self.fc2(x)
53
+ x = self.sigmoid(x)
54
+
55
+ return x
56
+
57
+
58
+ if __name__ == "__main__":
59
+ from torchinfo import summary
60
+
61
+ model = ODCNN()
62
+ summary(model, (1, 3, 80, 15))
exp/baseline1/train.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ODCNN
11
+ from .data import BeatTrackingDataset
12
+ from .utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def train(target_type: str, output_dir: str):
17
+ # Note: Paper uses SGD with Momentum, Dropout, and ReLU
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ BATCH_SIZE = 512
20
+ EPOCHS = 50
21
+ LR = 0.05
22
+ MOMENTUM = 0.9
23
+ NUM_WORKERS = 4
24
+
25
+ print(f"--- Training Model for target: {target_type} ---")
26
+ print(f"Output directory: {output_dir}")
27
+
28
+ # Create output directory
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # TensorBoard writer
32
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
+
34
+ # Data - use existing train/test splits
35
+ train_dataset = BeatTrackingDataset(ds["train"], target_type=target_type)
36
+ val_dataset = BeatTrackingDataset(ds["test"], target_type=target_type)
37
+
38
+ train_loader = DataLoader(
39
+ train_dataset,
40
+ batch_size=BATCH_SIZE,
41
+ shuffle=True,
42
+ num_workers=NUM_WORKERS,
43
+ pin_memory=True,
44
+ prefetch_factor=4,
45
+ persistent_workers=True,
46
+ )
47
+ val_loader = DataLoader(
48
+ val_dataset,
49
+ batch_size=BATCH_SIZE,
50
+ shuffle=False,
51
+ num_workers=NUM_WORKERS,
52
+ pin_memory=True,
53
+ prefetch_factor=4,
54
+ persistent_workers=True,
55
+ )
56
+
57
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
58
+
59
+ # Model
60
+ model = ODCNN(dropout_rate=0.5).to(DEVICE)
61
+
62
+ # GPU Spectrogram Preprocessor
63
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
64
+
65
+ # Optimizer
66
+ optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
67
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
68
+ criterion = nn.BCELoss() # Binary Cross Entropy
69
+
70
+ best_val_loss = float("inf")
71
+ global_step = 0
72
+
73
+ for epoch in range(EPOCHS):
74
+ # Training
75
+ model.train()
76
+ total_train_loss = 0
77
+ for waveform, y in tqdm(
78
+ train_loader,
79
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
80
+ leave=False,
81
+ ):
82
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
83
+
84
+ # Compute spectrogram on GPU
85
+ with torch.no_grad():
86
+ spec = preprocessor(waveform) # (B, 3, 80, T)
87
+ # Normalize
88
+ mean = spec.mean(dim=(2, 3), keepdim=True)
89
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
90
+ spec = (spec - mean) / std
91
+ # Extract center context (T should be ~15 frames)
92
+ x = spec[:, :, :, 7:22] # center 15 frames
93
+
94
+ optimizer.zero_grad()
95
+ output = model(x)
96
+ loss = criterion(output, y)
97
+ loss.backward()
98
+ optimizer.step()
99
+
100
+ total_train_loss += loss.item()
101
+ global_step += 1
102
+
103
+ # Log batch loss
104
+ writer.add_scalar("train/batch_loss", loss.item(), global_step)
105
+
106
+ avg_train_loss = total_train_loss / len(train_loader)
107
+
108
+ # Validation
109
+ model.eval()
110
+ total_val_loss = 0
111
+ with torch.no_grad():
112
+ for waveform, y in tqdm(
113
+ val_loader,
114
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
115
+ leave=False,
116
+ ):
117
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
118
+
119
+ # Compute spectrogram on GPU
120
+ spec = preprocessor(waveform) # (B, 3, 80, T)
121
+ # Normalize
122
+ mean = spec.mean(dim=(2, 3), keepdim=True)
123
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
124
+ spec = (spec - mean) / std
125
+ # Extract center context
126
+ x = spec[:, :, :, 7:22]
127
+
128
+ output = model(x)
129
+ loss = criterion(output, y)
130
+ total_val_loss += loss.item()
131
+
132
+ avg_val_loss = total_val_loss / len(val_loader)
133
+
134
+ # Log epoch metrics
135
+ writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
136
+ writer.add_scalar("val/loss", avg_val_loss, epoch)
137
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
138
+
139
+ # Step the scheduler
140
+ scheduler.step()
141
+
142
+ print(
143
+ f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
144
+ f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
145
+ )
146
+
147
+ # Save best model
148
+ if avg_val_loss < best_val_loss:
149
+ best_val_loss = avg_val_loss
150
+ model.save_pretrained(output_dir)
151
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
152
+
153
+ writer.close()
154
+
155
+ # Save final model
156
+ final_dir = os.path.join(output_dir, "final")
157
+ model.save_pretrained(final_dir)
158
+ print(f"Saved final model to {final_dir}")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser()
163
+ parser.add_argument(
164
+ "--target",
165
+ type=str,
166
+ choices=["beats", "downbeats"],
167
+ default=None,
168
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
169
+ )
170
+ parser.add_argument(
171
+ "--output-dir",
172
+ type=str,
173
+ default="outputs/baseline1",
174
+ help="Directory to save model and logs",
175
+ )
176
+ args = parser.parse_args()
177
+
178
+ # Determine which targets to train
179
+ targets = [args.target] if args.target else ["beats", "downbeats"]
180
+
181
+ for target in targets:
182
+ output_dir = os.path.join(args.output_dir, target)
183
+ train(target, output_dir)
exp/baseline1/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio.transforms as T
4
+ import numpy as np
5
+
6
+
7
+ class MultiViewSpectrogram(nn.Module):
8
+ def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
9
+ super().__init__()
10
+ # Window sizes: 23ms, 46ms, 93ms
11
+ self.win_lengths = [368, 736, 1488]
12
+ self.transforms = nn.ModuleList()
13
+
14
+ for win_len in self.win_lengths:
15
+ n_fft = 2 ** int(np.ceil(np.log2(win_len)))
16
+ mel = T.MelSpectrogram(
17
+ sample_rate=sample_rate,
18
+ n_fft=n_fft,
19
+ win_length=win_len,
20
+ hop_length=hop_length,
21
+ f_min=27.5,
22
+ f_max=16000.0,
23
+ n_mels=n_mels,
24
+ power=1.0,
25
+ center=True,
26
+ )
27
+ self.transforms.append(mel)
28
+
29
+ def forward(self, waveform):
30
+ specs = []
31
+ for transform in self.transforms:
32
+ # Scale magnitudes logarithmically
33
+ s = transform(waveform)
34
+ s = torch.log(s + 1e-9)
35
+ specs.append(s)
36
+ return torch.stack(specs, dim=1)
37
+
38
+
39
+ def extract_context(spec, center_frame, context=7):
40
+ # Context of +/- 70ms (7 frames)
41
+ channels, n_mels, total_time = spec.shape
42
+ start = center_frame - context
43
+ end = center_frame + context + 1
44
+
45
+ pad_left = max(0, -start)
46
+ pad_right = max(0, end - total_time)
47
+
48
+ if pad_left > 0 or pad_right > 0:
49
+ spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
50
+ start += pad_left
51
+ end += pad_left
52
+
53
+ return spec[:, :, start:end]
exp/baseline2/__init__.py ADDED
File without changes
exp/baseline2/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ class BeatTrackingDataset(Dataset):
8
+ def __init__(
9
+ self,
10
+ hf_dataset,
11
+ target_type="beats",
12
+ sample_rate=16000,
13
+ hop_length=160,
14
+ context_frames=50,
15
+ ):
16
+ """
17
+ Args:
18
+ hf_dataset: HuggingFace dataset object
19
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
20
+ context_frames (int): Number of frames before and after the center frame.
21
+ Total frames = 2 * context_frames + 1.
22
+ Default 50 means 101 frames (~1s).
23
+ """
24
+ self.sr = sample_rate
25
+ self.hop_length = hop_length
26
+ self.target_type = target_type
27
+
28
+ self.context_frames = context_frames
29
+ # Context window size in samples
30
+ # We need enough samples for the center frame +/- context frames
31
+ # PLUS the window size of the largest FFT to compute the edges correctly.
32
+ # Largest window in MultiViewSpectrogram is 1488.
33
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488
34
+
35
+ # Cache audio arrays in memory for fast access
36
+ self.audio_cache = []
37
+ self.indices = []
38
+ self._prepare_indices(hf_dataset)
39
+
40
+ def _prepare_indices(self, hf_dataset):
41
+ """
42
+ Prepares balanced indices and caches audio.
43
+ Uses the same "Fuzzier" training examples strategy as the baseline.
44
+ """
45
+ print(f"Preparing dataset indices for target: {self.target_type}...")
46
+
47
+ for i, item in tqdm(
48
+ enumerate(hf_dataset), total=len(hf_dataset), desc="Building indices"
49
+ ):
50
+ # Cache audio array (convert to numpy if tensor)
51
+ audio = item["audio"]["array"]
52
+ if hasattr(audio, "numpy"):
53
+ audio = audio.numpy()
54
+ self.audio_cache.append(audio)
55
+
56
+ # Calculate total frames available in audio
57
+ audio_len = len(audio)
58
+ n_frames = int(audio_len / self.hop_length)
59
+
60
+ # Select ground truth based on target_type
61
+ if self.target_type == "downbeats":
62
+ gt_times = item["downbeats"]
63
+ else:
64
+ gt_times = item["beats"]
65
+
66
+ # Convert to list if tensor
67
+ if hasattr(gt_times, "tolist"):
68
+ gt_times = gt_times.tolist()
69
+
70
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
71
+
72
+ # --- Positive Examples (with Fuzziness) ---
73
+ pos_frames = set()
74
+ for bf in gt_frames:
75
+ if 0 <= bf < n_frames:
76
+ self.indices.append((i, bf, 1.0)) # Center frame
77
+ pos_frames.add(bf)
78
+
79
+ # Neighbors weighted at 0.25
80
+ if 0 <= bf - 1 < n_frames:
81
+ self.indices.append((i, bf - 1, 0.25))
82
+ pos_frames.add(bf - 1)
83
+ if 0 <= bf + 1 < n_frames:
84
+ self.indices.append((i, bf + 1, 0.25))
85
+ pos_frames.add(bf + 1)
86
+
87
+ # --- Negative Examples ---
88
+ # Balance 2:1
89
+ num_pos = len(pos_frames)
90
+ num_neg = num_pos * 2
91
+
92
+ count = 0
93
+ attempts = 0
94
+ while count < num_neg and attempts < num_neg * 5:
95
+ f = np.random.randint(0, n_frames)
96
+ if f not in pos_frames:
97
+ self.indices.append((i, f, 0.0))
98
+ count += 1
99
+ attempts += 1
100
+
101
+ print(
102
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
103
+ )
104
+
105
+ def __len__(self):
106
+ return len(self.indices)
107
+
108
+ def __getitem__(self, idx):
109
+ track_idx, frame_idx, label = self.indices[idx]
110
+
111
+ # Fast lookup from cache
112
+ audio = self.audio_cache[track_idx]
113
+ audio_len = len(audio)
114
+
115
+ # Calculate sample range for context window
116
+ center_sample = frame_idx * self.hop_length
117
+ half_context = self.context_samples // 2
118
+
119
+ # We want the window centered around center_sample
120
+ start = center_sample - half_context
121
+ end = center_sample + half_context
122
+
123
+ # Handle padding if needed
124
+ pad_left = max(0, -start)
125
+ pad_right = max(0, end - audio_len)
126
+
127
+ valid_start = max(0, start)
128
+ valid_end = min(audio_len, end)
129
+
130
+ # Extract audio chunk
131
+ chunk = audio[valid_start:valid_end]
132
+
133
+ if pad_left > 0 or pad_right > 0:
134
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
135
+
136
+ waveform = torch.tensor(chunk, dtype=torch.float32)
137
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline2/eval.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ResNet
9
+ from ..baseline1.utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ # Context frames = 50, so total window = 101.
31
+ # Pad time by 50 on each side.
32
+ spec = torch.nn.functional.pad(spec, (50, 50)) # Pad time
33
+ windows = spec.unfold(3, 101, 1) # (1, 3, 80, Time, 101)
34
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 101)
35
+
36
+ # Inference
37
+ activations = []
38
+ batch_size = 128 # Reduced batch size
39
+ for i in range(0, len(windows), batch_size):
40
+ batch = windows[i : i + batch_size]
41
+ out = model(batch)
42
+ activations.append(out.cpu().numpy())
43
+
44
+ return np.concatenate(activations).flatten()
45
+
46
+
47
+ def pick_peaks(activations, hop_length=160, sr=16000):
48
+ """
49
+ Smooth with Hamming window and report local maxima.
50
+ """
51
+ # Smoothing
52
+ window = np.hamming(5)
53
+ window /= window.sum()
54
+ smoothed = np.convolve(activations, window, mode="same")
55
+
56
+ # Peak Picking
57
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
58
+
59
+ timestamps = peaks * hop_length / sr
60
+ return timestamps.tolist()
61
+
62
+
63
+ def visualize_track(
64
+ audio: np.ndarray,
65
+ sr: int,
66
+ pred_beats: list[float],
67
+ pred_downbeats: list[float],
68
+ gt_beats: list[float],
69
+ gt_downbeats: list[float],
70
+ output_dir: str,
71
+ track_idx: int,
72
+ time_range: tuple[float, float] | None = None,
73
+ ):
74
+ """
75
+ Create and save visualizations for a single track.
76
+ """
77
+ from ..data.viz import plot_waveform_with_beats, save_figure
78
+
79
+ os.makedirs(output_dir, exist_ok=True)
80
+
81
+ # Full waveform plot
82
+ fig = plot_waveform_with_beats(
83
+ audio,
84
+ sr,
85
+ pred_beats,
86
+ gt_beats,
87
+ pred_downbeats,
88
+ gt_downbeats,
89
+ title=f"Track {track_idx}: Beat Comparison",
90
+ time_range=time_range,
91
+ )
92
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
93
+
94
+
95
+ def synthesize_audio(
96
+ audio: np.ndarray,
97
+ sr: int,
98
+ pred_beats: list[float],
99
+ pred_downbeats: list[float],
100
+ gt_beats: list[float],
101
+ gt_downbeats: list[float],
102
+ output_dir: str,
103
+ track_idx: int,
104
+ click_volume: float = 0.5,
105
+ ):
106
+ """
107
+ Create and save audio files with click tracks for a single track.
108
+ """
109
+ from ..data.audio import create_comparison_audio, save_audio
110
+
111
+ os.makedirs(output_dir, exist_ok=True)
112
+
113
+ # Create comparison audio
114
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
115
+ audio,
116
+ pred_beats,
117
+ pred_downbeats,
118
+ gt_beats,
119
+ gt_downbeats,
120
+ sr=sr,
121
+ click_volume=click_volume,
122
+ )
123
+
124
+ # Save audio files
125
+ save_audio(
126
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
127
+ )
128
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
129
+ save_audio(
130
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
131
+ )
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(
136
+ description="Evaluate beat tracking models with visualization and audio synthesis"
137
+ )
138
+ parser.add_argument(
139
+ "--model-dir",
140
+ type=str,
141
+ default="outputs/baseline2",
142
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
143
+ )
144
+ parser.add_argument(
145
+ "--num-samples",
146
+ type=int,
147
+ default=116,
148
+ help="Number of samples to evaluate",
149
+ )
150
+ parser.add_argument(
151
+ "--output-dir",
152
+ type=str,
153
+ default="outputs/eval_baseline2",
154
+ help="Directory to save visualizations and audio",
155
+ )
156
+ parser.add_argument(
157
+ "--visualize",
158
+ action="store_true",
159
+ help="Generate visualization plots for each track",
160
+ )
161
+ parser.add_argument(
162
+ "--synthesize",
163
+ action="store_true",
164
+ help="Generate audio files with click tracks",
165
+ )
166
+ parser.add_argument(
167
+ "--viz-tracks",
168
+ type=int,
169
+ default=5,
170
+ help="Number of tracks to visualize/synthesize (default: 5)",
171
+ )
172
+ parser.add_argument(
173
+ "--time-range",
174
+ type=float,
175
+ nargs=2,
176
+ default=None,
177
+ metavar=("START", "END"),
178
+ help="Time range for visualization in seconds (default: full track)",
179
+ )
180
+ parser.add_argument(
181
+ "--click-volume",
182
+ type=float,
183
+ default=0.5,
184
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
185
+ )
186
+ parser.add_argument(
187
+ "--summary-plot",
188
+ action="store_true",
189
+ help="Generate summary evaluation plot",
190
+ )
191
+ args = parser.parse_args()
192
+
193
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
194
+
195
+ # Load BOTH models using from_pretrained
196
+ beat_model = None
197
+ downbeat_model = None
198
+
199
+ has_beats = False
200
+ has_downbeats = False
201
+
202
+ beats_dir = os.path.join(args.model_dir, "beats")
203
+ downbeats_dir = os.path.join(args.model_dir, "downbeats")
204
+
205
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
206
+ beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE)
207
+ beat_model.eval()
208
+ has_beats = True
209
+ print(f"Loaded Beat Model from {beats_dir}")
210
+ else:
211
+ print(f"Warning: No beat model found in {beats_dir}")
212
+
213
+ if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
214
+ downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE)
215
+ downbeat_model.eval()
216
+ has_downbeats = True
217
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
218
+ else:
219
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
220
+
221
+ if not has_beats and not has_downbeats:
222
+ print("No models found. Please run training first.")
223
+ return
224
+
225
+ predictions = []
226
+ ground_truths = []
227
+ audio_data = [] # Store audio for visualization/synthesis
228
+
229
+ # Eval on specified number of tracks
230
+ test_set = ds["train"].select(range(args.num_samples))
231
+
232
+ print("Running evaluation...")
233
+ for i, item in enumerate(tqdm(test_set)):
234
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
235
+ waveform_device = waveform.to(DEVICE)
236
+
237
+ pred_entry = {"beats": [], "downbeats": []}
238
+
239
+ # 1. Predict Beats
240
+ if has_beats:
241
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
242
+ pred_entry["beats"] = pick_peaks(act_b)
243
+
244
+ # 2. Predict Downbeats
245
+ if has_downbeats:
246
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
247
+ pred_entry["downbeats"] = pick_peaks(act_d)
248
+
249
+ predictions.append(pred_entry)
250
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
251
+
252
+ # Store audio for later visualization/synthesis
253
+ if args.visualize or args.synthesize:
254
+ if i < args.viz_tracks:
255
+ audio_data.append(
256
+ {
257
+ "audio": waveform.numpy(),
258
+ "sr": item["audio"]["sampling_rate"],
259
+ "pred": pred_entry,
260
+ "gt": ground_truths[-1],
261
+ }
262
+ )
263
+
264
+ # Run evaluation
265
+ results = evaluate_all(predictions, ground_truths)
266
+ print(format_results(results))
267
+
268
+ # Create output directory
269
+ if args.visualize or args.synthesize or args.summary_plot:
270
+ os.makedirs(args.output_dir, exist_ok=True)
271
+
272
+ # Generate visualizations
273
+ if args.visualize:
274
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
275
+ viz_dir = os.path.join(args.output_dir, "plots")
276
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
277
+ time_range = tuple(args.time_range) if args.time_range else None
278
+ visualize_track(
279
+ data["audio"],
280
+ data["sr"],
281
+ data["pred"]["beats"],
282
+ data["pred"]["downbeats"],
283
+ data["gt"]["beats"],
284
+ data["gt"]["downbeats"],
285
+ viz_dir,
286
+ i,
287
+ time_range=time_range,
288
+ )
289
+ print(f"Saved visualizations to {viz_dir}")
290
+
291
+ # Generate audio with clicks
292
+ if args.synthesize:
293
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
294
+ audio_dir = os.path.join(args.output_dir, "audio")
295
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
296
+ synthesize_audio(
297
+ data["audio"],
298
+ data["sr"],
299
+ data["pred"]["beats"],
300
+ data["pred"]["downbeats"],
301
+ data["gt"]["beats"],
302
+ data["gt"]["downbeats"],
303
+ audio_dir,
304
+ i,
305
+ click_volume=args.click_volume,
306
+ )
307
+ print(f"Saved audio files to {audio_dir}")
308
+ print(" *_pred.wav - Original audio with predicted beat clicks")
309
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
310
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
311
+
312
+ # Generate summary plot
313
+ if args.summary_plot:
314
+ from ..data.viz import plot_evaluation_summary, save_figure
315
+
316
+ print("\nGenerating summary plot...")
317
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
318
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
319
+ save_figure(fig, summary_path)
320
+ print(f"Saved summary plot to {summary_path}")
321
+
322
+
323
+ if __name__ == "__main__":
324
+ main()
exp/baseline2/model.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class SEBlock(nn.Module):
7
+ def __init__(self, channels, reduction=16):
8
+ super().__init__()
9
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
10
+ self.fc = nn.Sequential(
11
+ nn.Linear(channels, channels // reduction, bias=False),
12
+ nn.ReLU(inplace=True),
13
+ nn.Linear(channels // reduction, channels, bias=False),
14
+ nn.Sigmoid(),
15
+ )
16
+
17
+ def forward(self, x):
18
+ b, c, _, _ = x.size()
19
+ y = self.avg_pool(x).view(b, c)
20
+ y = self.fc(y).view(b, c, 1, 1)
21
+ return x * y.expand_as(x)
22
+
23
+
24
+ class ResBlock(nn.Module):
25
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
26
+ super().__init__()
27
+ self.conv1 = nn.Conv2d(
28
+ in_channels,
29
+ out_channels,
30
+ kernel_size=3,
31
+ stride=stride,
32
+ padding=1,
33
+ bias=False,
34
+ )
35
+ self.bn1 = nn.BatchNorm2d(out_channels)
36
+ self.relu = nn.ReLU(inplace=True)
37
+ self.conv2 = nn.Conv2d(
38
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
39
+ )
40
+ self.bn2 = nn.BatchNorm2d(out_channels)
41
+ self.se = SEBlock(out_channels)
42
+ self.downsample = downsample
43
+
44
+ def forward(self, x):
45
+ identity = x
46
+ if self.downsample is not None:
47
+ identity = self.downsample(x)
48
+
49
+ out = self.conv1(x)
50
+ out = self.bn1(out)
51
+ out = self.relu(out)
52
+
53
+ out = self.conv2(out)
54
+ out = self.bn2(out)
55
+ out = self.se(out)
56
+
57
+ out += identity
58
+ out = self.relu(out)
59
+ return out
60
+
61
+
62
+ class ResNet(nn.Module, PyTorchModelHubMixin):
63
+ def __init__(
64
+ self, layers=[2, 2, 2, 2], channels=[16, 24, 48, 96], dropout_rate=0.5
65
+ ):
66
+ super().__init__()
67
+ self.in_channels = 16
68
+
69
+ # Stem
70
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
71
+ self.bn1 = nn.BatchNorm2d(16)
72
+ self.relu = nn.ReLU(inplace=True)
73
+
74
+ # Stages
75
+ self.layer1 = self._make_layer(channels[0], layers[0], stride=1)
76
+ self.layer2 = self._make_layer(channels[1], layers[1], stride=2)
77
+ self.layer3 = self._make_layer(channels[2], layers[2], stride=2)
78
+ self.layer4 = self._make_layer(channels[3], layers[3], stride=2)
79
+
80
+ self.dropout = nn.Dropout(p=dropout_rate)
81
+
82
+ # Final classification head
83
+ # H, W will reduce. Assuming input is (3, 80, 101)
84
+ # L1: (16, 80, 101) (stride 1)
85
+ # L2: (32, 40, 51) (stride 2)
86
+ # L3: (64, 20, 26) (stride 2)
87
+ # L4: (128, 10, 13) (stride 2)
88
+
89
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
90
+ self.fc = nn.Linear(channels[3], 1)
91
+ self.sigmoid = nn.Sigmoid()
92
+
93
+ def _make_layer(self, out_channels, blocks, stride=1):
94
+ downsample = None
95
+ if stride != 1 or self.in_channels != out_channels:
96
+ downsample = nn.Sequential(
97
+ nn.Conv2d(
98
+ self.in_channels,
99
+ out_channels,
100
+ kernel_size=1,
101
+ stride=stride,
102
+ bias=False,
103
+ ),
104
+ nn.BatchNorm2d(out_channels),
105
+ )
106
+
107
+ layers = []
108
+ layers.append(ResBlock(self.in_channels, out_channels, stride, downsample))
109
+ self.in_channels = out_channels
110
+ for _ in range(1, blocks):
111
+ layers.append(ResBlock(self.in_channels, out_channels))
112
+
113
+ return nn.Sequential(*layers)
114
+
115
+ def forward(self, x):
116
+ # x: (B, 3, 80, 101)
117
+ x = self.conv1(x)
118
+ x = self.bn1(x)
119
+ x = self.relu(x)
120
+
121
+ x = self.layer1(x)
122
+ x = self.layer2(x)
123
+ x = self.layer3(x)
124
+ x = self.layer4(x)
125
+
126
+ x = self.avgpool(x) # (B, 128, 1, 1)
127
+ x = torch.flatten(x, 1) # (B, 128)
128
+ x = self.dropout(x)
129
+ x = self.fc(x)
130
+ x = self.sigmoid(x)
131
+
132
+ return x
133
+
134
+
135
+ if __name__ == "__main__":
136
+ from torchinfo import summary
137
+
138
+ model = ResNet()
139
+ summary(model, (1, 3, 80, 101))
exp/baseline2/train.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ResNet
11
+ from .data import BeatTrackingDataset
12
+ from ..baseline1.utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def train(target_type: str, output_dir: str):
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+ BATCH_SIZE = 128 # Reduced batch size due to larger context
19
+ EPOCHS = 3
20
+ LR = 0.001 # Adjusted LR for Adam (ResNet usually prefers Adam/AdamW)
21
+ NUM_WORKERS = 4
22
+ CONTEXT_FRAMES = 50 # +/- 50 frames -> 101 frames total
23
+ PATIENCE = 5 # Early stopping patience
24
+
25
+ print(f"--- Training Model for target: {target_type} ---")
26
+ print(f"Output directory: {output_dir}")
27
+
28
+ # Create output directory
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # TensorBoard writer
32
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
33
+
34
+ # Data
35
+ train_dataset = BeatTrackingDataset(
36
+ ds["train"], target_type=target_type, context_frames=CONTEXT_FRAMES
37
+ )
38
+ val_dataset = BeatTrackingDataset(
39
+ ds["test"], target_type=target_type, context_frames=CONTEXT_FRAMES
40
+ )
41
+
42
+ train_loader = DataLoader(
43
+ train_dataset,
44
+ batch_size=BATCH_SIZE,
45
+ shuffle=True,
46
+ num_workers=NUM_WORKERS,
47
+ pin_memory=True,
48
+ prefetch_factor=4,
49
+ persistent_workers=True,
50
+ )
51
+ val_loader = DataLoader(
52
+ val_dataset,
53
+ batch_size=BATCH_SIZE,
54
+ shuffle=False,
55
+ num_workers=NUM_WORKERS,
56
+ pin_memory=True,
57
+ prefetch_factor=4,
58
+ persistent_workers=True,
59
+ )
60
+
61
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
62
+
63
+ # Model
64
+ model = ResNet(dropout_rate=0.5).to(DEVICE)
65
+
66
+ # GPU Spectrogram Preprocessor
67
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
68
+
69
+ # Optimizer - Using AdamW for ResNet
70
+ optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
71
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
72
+ criterion = nn.BCELoss() # Binary Cross Entropy
73
+
74
+ best_val_loss = float("inf")
75
+ patience_counter = 0
76
+ global_step = 0
77
+
78
+ for epoch in range(EPOCHS):
79
+ # Training
80
+ model.train()
81
+ total_train_loss = 0
82
+ for waveform, y in tqdm(
83
+ train_loader,
84
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Train",
85
+ leave=False,
86
+ ):
87
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
88
+
89
+ # Compute spectrogram on GPU
90
+ with torch.no_grad():
91
+ spec = preprocessor(waveform) # (B, 3, 80, T_raw)
92
+ # Normalize
93
+ mean = spec.mean(dim=(2, 3), keepdim=True)
94
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
95
+ spec = (spec - mean) / std
96
+
97
+ T_curr = spec.shape[-1]
98
+ target_T = CONTEXT_FRAMES * 2 + 1
99
+
100
+ if T_curr > target_T:
101
+ start = (T_curr - target_T) // 2
102
+ x = spec[:, :, :, start : start + target_T]
103
+ elif T_curr < target_T:
104
+ # This shouldn't happen if dataset is correct, but just in case pad
105
+ pad = target_T - T_curr
106
+ x = torch.nn.functional.pad(spec, (0, pad))
107
+ else:
108
+ x = spec
109
+
110
+ optimizer.zero_grad()
111
+ output = model(x)
112
+ loss = criterion(output, y)
113
+ loss.backward()
114
+ optimizer.step()
115
+
116
+ total_train_loss += loss.item()
117
+ global_step += 1
118
+
119
+ # Log batch loss
120
+ writer.add_scalar("train/batch_loss", loss.item(), global_step)
121
+
122
+ avg_train_loss = total_train_loss / len(train_loader)
123
+
124
+ # Validation
125
+ model.eval()
126
+ total_val_loss = 0
127
+ with torch.no_grad():
128
+ for waveform, y in tqdm(
129
+ val_loader,
130
+ desc=f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} Val",
131
+ leave=False,
132
+ ):
133
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
134
+
135
+ # Compute spectrogram on GPU
136
+ spec = preprocessor(waveform) # (B, 3, 80, T)
137
+ # Normalize
138
+ mean = spec.mean(dim=(2, 3), keepdim=True)
139
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
140
+ spec = (spec - mean) / std
141
+
142
+ T_curr = spec.shape[-1]
143
+ target_T = CONTEXT_FRAMES * 2 + 1
144
+
145
+ if T_curr > target_T:
146
+ start = (T_curr - target_T) // 2
147
+ x = spec[:, :, :, start : start + target_T]
148
+ else:
149
+ pad = target_T - T_curr
150
+ x = torch.nn.functional.pad(spec, (0, pad))
151
+
152
+ output = model(x)
153
+ loss = criterion(output, y)
154
+ total_val_loss += loss.item()
155
+
156
+ avg_val_loss = total_val_loss / len(val_loader)
157
+
158
+ # Log epoch metrics
159
+ writer.add_scalar("train/epoch_loss", avg_train_loss, epoch)
160
+ writer.add_scalar("val/loss", avg_val_loss, epoch)
161
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
162
+
163
+ # Step the scheduler
164
+ scheduler.step()
165
+
166
+ print(
167
+ f"[{target_type}] Epoch {epoch + 1}/{EPOCHS} - "
168
+ f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}"
169
+ )
170
+
171
+ # Save best model
172
+ if avg_val_loss < best_val_loss:
173
+ best_val_loss = avg_val_loss
174
+ patience_counter = 0
175
+ model.save_pretrained(output_dir)
176
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
177
+ else:
178
+ patience_counter += 1
179
+ print(f" -> No improvement (patience: {patience_counter}/{PATIENCE})")
180
+
181
+ if patience_counter >= PATIENCE:
182
+ print("Early stopping triggered.")
183
+ break
184
+
185
+ writer.close()
186
+
187
+ # Save final model
188
+ final_dir = os.path.join(output_dir, "final")
189
+ model.save_pretrained(final_dir)
190
+ print(f"Saved final model to {final_dir}")
191
+
192
+
193
+ if __name__ == "__main__":
194
+ parser = argparse.ArgumentParser()
195
+ parser.add_argument(
196
+ "--target",
197
+ type=str,
198
+ choices=["beats", "downbeats"],
199
+ default=None,
200
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
201
+ )
202
+ parser.add_argument(
203
+ "--output-dir",
204
+ type=str,
205
+ default="outputs/baseline2",
206
+ help="Directory to save model and logs",
207
+ )
208
+ args = parser.parse_args()
209
+
210
+ # Determine which targets to train
211
+ targets = [args.target] if args.target else ["beats", "downbeats"]
212
+
213
+ for target in targets:
214
+ output_dir = os.path.join(args.output_dir, target)
215
+ train(target, output_dir)
exp/baseline3/__init__.py ADDED
File without changes
exp/baseline3/data.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ class BeatTrackingDataset(Dataset):
8
+ def __init__(
9
+ self,
10
+ hf_dataset,
11
+ target_type="beats",
12
+ sample_rate=16000,
13
+ hop_length=160,
14
+ context_frames=50,
15
+ max_tracks: int | None = None,
16
+ hard_neg_radius: int = 0,
17
+ hard_neg_fraction: float = 0.5,
18
+ ):
19
+ """
20
+ Args:
21
+ hf_dataset: HuggingFace dataset object
22
+ target_type (str): "beats" or "downbeats". Determines which labels are treated as positive.
23
+ context_frames (int): Number of frames before and after the center frame.
24
+ Total frames = 2 * context_frames + 1.
25
+ Default 50 means 101 frames (~1s).
26
+ """
27
+ self.sr = sample_rate
28
+ self.hop_length = hop_length
29
+ self.target_type = target_type
30
+ self.hard_neg_radius = int(hard_neg_radius)
31
+ self.hard_neg_fraction = float(hard_neg_fraction)
32
+
33
+ self.context_frames = context_frames
34
+ # Context window size in samples
35
+ # We need enough samples for the center frame +/- context frames
36
+ # PLUS the window size of the largest FFT to compute the edges correctly.
37
+ # Largest window in MultiViewSpectrogram is 1488.
38
+ self.context_samples = (self.context_frames * 2 + 1) * hop_length + 1488
39
+
40
+ # Cache audio arrays in memory for fast access
41
+ self.audio_cache = []
42
+ self.indices = []
43
+ self._prepare_indices(hf_dataset, max_tracks=max_tracks)
44
+
45
+ def _prepare_indices(self, hf_dataset, *, max_tracks: int | None):
46
+ """
47
+ Prepares balanced indices and caches audio.
48
+ Uses the same "Fuzzier" training examples strategy as the baseline.
49
+ """
50
+ print(f"Preparing dataset indices for target: {self.target_type}...")
51
+
52
+ total = len(hf_dataset)
53
+ if max_tracks is not None:
54
+ total = min(total, max_tracks)
55
+
56
+ for i, item in tqdm(
57
+ enumerate(hf_dataset), total=total, desc="Building indices"
58
+ ):
59
+ if max_tracks is not None and i >= max_tracks:
60
+ break
61
+ # Cache audio array (convert to numpy if tensor)
62
+ audio = item["audio"]["array"]
63
+ if hasattr(audio, "numpy"):
64
+ audio = audio.numpy()
65
+ self.audio_cache.append(audio)
66
+
67
+ # Calculate total frames available in audio
68
+ audio_len = len(audio)
69
+ n_frames = int(audio_len / self.hop_length)
70
+
71
+ # Select ground truth based on target_type
72
+ if self.target_type == "downbeats":
73
+ gt_times = item["downbeats"]
74
+ else:
75
+ gt_times = item["beats"]
76
+
77
+ # Convert to list if tensor
78
+ if hasattr(gt_times, "tolist"):
79
+ gt_times = gt_times.tolist()
80
+
81
+ gt_frames = set([int(t * self.sr / self.hop_length) for t in gt_times])
82
+
83
+ # --- Positive Examples (with Fuzziness) ---
84
+ pos_frames = set()
85
+ for bf in gt_frames:
86
+ if 0 <= bf < n_frames:
87
+ self.indices.append((i, bf, 1.0)) # Center frame
88
+ pos_frames.add(bf)
89
+
90
+ # Neighbors weighted at 0.25
91
+ if 0 <= bf - 1 < n_frames:
92
+ self.indices.append((i, bf - 1, 0.25))
93
+ pos_frames.add(bf - 1)
94
+ if 0 <= bf + 1 < n_frames:
95
+ self.indices.append((i, bf + 1, 0.25))
96
+ pos_frames.add(bf + 1)
97
+
98
+ # --- Negative Examples ---
99
+ # Balance 2:1
100
+ num_pos = len(pos_frames)
101
+ num_neg = num_pos * 2
102
+
103
+ # (Optional) hard negatives close to beats.
104
+ # Rationale: random negatives are often "easy" (silence/long gaps),
105
+ # while the model struggles most on near-beat confusions that cause
106
+ # double peaks / jitter.
107
+ hard_neg_target = 0
108
+ if self.hard_neg_radius > 1 and num_neg > 0:
109
+ hard_neg_target = int(num_neg * self.hard_neg_fraction)
110
+ hard_neg_target = max(0, min(num_neg, hard_neg_target))
111
+
112
+ hard_added = 0
113
+ if hard_neg_target > 0:
114
+ for bf in gt_frames:
115
+ for d in range(2, self.hard_neg_radius + 1):
116
+ for f in (bf - d, bf + d):
117
+ if hard_added >= hard_neg_target:
118
+ break
119
+ if 0 <= f < n_frames and f not in pos_frames:
120
+ self.indices.append((i, f, 0.0))
121
+ hard_added += 1
122
+ if hard_added >= hard_neg_target:
123
+ break
124
+ if hard_added >= hard_neg_target:
125
+ break
126
+
127
+ count = 0
128
+ attempts = 0
129
+ remaining_neg = num_neg - hard_added
130
+ while count < remaining_neg and attempts < remaining_neg * 5:
131
+ f = np.random.randint(0, n_frames)
132
+ if f not in pos_frames:
133
+ self.indices.append((i, f, 0.0))
134
+ count += 1
135
+ attempts += 1
136
+
137
+ print(
138
+ f"Dataset ready. {len(self.indices)} samples, {len(self.audio_cache)} tracks cached."
139
+ )
140
+
141
+ def __len__(self):
142
+ return len(self.indices)
143
+
144
+ def __getitem__(self, idx):
145
+ track_idx, frame_idx, label = self.indices[idx]
146
+
147
+ # Fast lookup from cache
148
+ audio = self.audio_cache[track_idx]
149
+ audio_len = len(audio)
150
+
151
+ # Calculate sample range for context window
152
+ center_sample = frame_idx * self.hop_length
153
+ half_context = self.context_samples // 2
154
+
155
+ # We want the window centered around center_sample
156
+ start = center_sample - half_context
157
+ end = center_sample + half_context
158
+
159
+ # Handle padding if needed
160
+ pad_left = max(0, -start)
161
+ pad_right = max(0, end - audio_len)
162
+
163
+ valid_start = max(0, start)
164
+ valid_end = min(audio_len, end)
165
+
166
+ # Extract audio chunk
167
+ chunk = audio[valid_start:valid_end]
168
+
169
+ if pad_left > 0 or pad_right > 0:
170
+ chunk = np.pad(chunk, (pad_left, pad_right), mode="constant")
171
+
172
+ waveform = torch.tensor(chunk, dtype=torch.float32)
173
+ return waveform, torch.tensor([label], dtype=torch.float32)
exp/baseline3/eval.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ from scipy.signal import find_peaks
5
+ import argparse
6
+ import os
7
+
8
+ from .model import ResNet
9
+ from ..baseline1.utils import MultiViewSpectrogram
10
+ from ..data.load import ds
11
+ from ..data.eval import evaluate_all, format_results
12
+
13
+
14
+ def get_activation_function(model, waveform, device):
15
+ """
16
+ Computes probability curve over time.
17
+ """
18
+ processor = MultiViewSpectrogram().to(device)
19
+ waveform = waveform.unsqueeze(0).to(device)
20
+
21
+ with torch.no_grad():
22
+ spec = processor(waveform)
23
+
24
+ # Normalize
25
+ mean = spec.mean(dim=(2, 3), keepdim=True)
26
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
27
+ spec = (spec - mean) / std
28
+
29
+ # Batchify with sliding window
30
+ # Context frames = 50, so total window = 101.
31
+ # Pad time by 50 on each side.
32
+ spec = torch.nn.functional.pad(spec, (50, 50)) # Pad time
33
+ windows = spec.unfold(3, 101, 1) # (1, 3, 80, Time, 101)
34
+ windows = windows.permute(3, 0, 1, 2, 4).squeeze(1) # (Time, 3, 80, 101)
35
+
36
+ # Inference
37
+ activations = []
38
+ batch_size = 128 # Reduced batch size
39
+ for i in range(0, len(windows), batch_size):
40
+ batch = windows[i : i + batch_size]
41
+ out = model(batch)
42
+ activations.append(out.cpu().numpy())
43
+
44
+ return np.concatenate(activations).flatten()
45
+
46
+
47
+ def pick_peaks(activations, hop_length=160, sr=16000):
48
+ """
49
+ Smooth with Hamming window and report local maxima.
50
+ """
51
+ # Smoothing
52
+ window = np.hamming(5)
53
+ window /= window.sum()
54
+ smoothed = np.convolve(activations, window, mode="same")
55
+
56
+ # Peak Picking
57
+ peaks, _ = find_peaks(smoothed, height=0.5, distance=5)
58
+
59
+ timestamps = peaks * hop_length / sr
60
+ return timestamps.tolist()
61
+
62
+
63
+ def visualize_track(
64
+ audio: np.ndarray,
65
+ sr: int,
66
+ pred_beats: list[float],
67
+ pred_downbeats: list[float],
68
+ gt_beats: list[float],
69
+ gt_downbeats: list[float],
70
+ output_dir: str,
71
+ track_idx: int,
72
+ time_range: tuple[float, float] | None = None,
73
+ ):
74
+ """
75
+ Create and save visualizations for a single track.
76
+ """
77
+ from ..data.viz import plot_waveform_with_beats, save_figure
78
+
79
+ os.makedirs(output_dir, exist_ok=True)
80
+
81
+ # Full waveform plot
82
+ fig = plot_waveform_with_beats(
83
+ audio,
84
+ sr,
85
+ pred_beats,
86
+ gt_beats,
87
+ pred_downbeats,
88
+ gt_downbeats,
89
+ title=f"Track {track_idx}: Beat Comparison",
90
+ time_range=time_range,
91
+ )
92
+ save_figure(fig, os.path.join(output_dir, f"track_{track_idx:03d}.png"))
93
+
94
+
95
+ def synthesize_audio(
96
+ audio: np.ndarray,
97
+ sr: int,
98
+ pred_beats: list[float],
99
+ pred_downbeats: list[float],
100
+ gt_beats: list[float],
101
+ gt_downbeats: list[float],
102
+ output_dir: str,
103
+ track_idx: int,
104
+ click_volume: float = 0.5,
105
+ ):
106
+ """
107
+ Create and save audio files with click tracks for a single track.
108
+ """
109
+ from ..data.audio import create_comparison_audio, save_audio
110
+
111
+ os.makedirs(output_dir, exist_ok=True)
112
+
113
+ # Create comparison audio
114
+ audio_pred, audio_gt, audio_both = create_comparison_audio(
115
+ audio,
116
+ pred_beats,
117
+ pred_downbeats,
118
+ gt_beats,
119
+ gt_downbeats,
120
+ sr=sr,
121
+ click_volume=click_volume,
122
+ )
123
+
124
+ # Save audio files
125
+ save_audio(
126
+ audio_pred, os.path.join(output_dir, f"track_{track_idx:03d}_pred.wav"), sr
127
+ )
128
+ save_audio(audio_gt, os.path.join(output_dir, f"track_{track_idx:03d}_gt.wav"), sr)
129
+ save_audio(
130
+ audio_both, os.path.join(output_dir, f"track_{track_idx:03d}_both.wav"), sr
131
+ )
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(
136
+ description="Evaluate beat tracking models with visualization and audio synthesis"
137
+ )
138
+ parser.add_argument(
139
+ "--model-dir",
140
+ type=str,
141
+ default="outputs/baseline3",
142
+ help="Base directory containing trained models (with 'beats' and 'downbeats' subdirs)",
143
+ )
144
+ parser.add_argument(
145
+ "--beats-model-dir",
146
+ type=str,
147
+ default=None,
148
+ help="Directory containing the trained beats model (overrides --model-dir/beats)",
149
+ )
150
+ parser.add_argument(
151
+ "--downbeats-model-dir",
152
+ type=str,
153
+ default=None,
154
+ help="Directory containing the trained downbeats model (overrides --model-dir/downbeats)",
155
+ )
156
+ parser.add_argument(
157
+ "--num-samples",
158
+ type=int,
159
+ default=116,
160
+ help="Number of samples to evaluate",
161
+ )
162
+ parser.add_argument(
163
+ "--output-dir",
164
+ type=str,
165
+ default="outputs/eval_baseline3",
166
+ help="Directory to save visualizations and audio",
167
+ )
168
+ parser.add_argument(
169
+ "--visualize",
170
+ action="store_true",
171
+ help="Generate visualization plots for each track",
172
+ )
173
+ parser.add_argument(
174
+ "--synthesize",
175
+ action="store_true",
176
+ help="Generate audio files with click tracks",
177
+ )
178
+ parser.add_argument(
179
+ "--viz-tracks",
180
+ type=int,
181
+ default=5,
182
+ help="Number of tracks to visualize/synthesize (default: 5)",
183
+ )
184
+ parser.add_argument(
185
+ "--time-range",
186
+ type=float,
187
+ nargs=2,
188
+ default=None,
189
+ metavar=("START", "END"),
190
+ help="Time range for visualization in seconds (default: full track)",
191
+ )
192
+ parser.add_argument(
193
+ "--click-volume",
194
+ type=float,
195
+ default=0.5,
196
+ help="Volume of click sounds relative to audio (0.0 to 1.0)",
197
+ )
198
+ parser.add_argument(
199
+ "--summary-plot",
200
+ action="store_true",
201
+ help="Generate summary evaluation plot",
202
+ )
203
+ args = parser.parse_args()
204
+
205
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
206
+
207
+ # Load BOTH models using from_pretrained
208
+ beat_model = None
209
+ downbeat_model = None
210
+
211
+ has_beats = False
212
+ has_downbeats = False
213
+
214
+ beats_dir = args.beats_model_dir or os.path.join(args.model_dir, "beats")
215
+ downbeats_dir = args.downbeats_model_dir or os.path.join(args.model_dir, "downbeats")
216
+
217
+ if os.path.exists(os.path.join(beats_dir, "model.safetensors")):
218
+ beat_model = ResNet.from_pretrained(beats_dir).to(DEVICE)
219
+ beat_model.eval()
220
+ has_beats = True
221
+ print(f"Loaded Beat Model from {beats_dir}")
222
+ else:
223
+ print(f"Warning: No beat model found in {beats_dir}")
224
+
225
+ if os.path.exists(os.path.join(downbeats_dir, "model.safetensors")):
226
+ downbeat_model = ResNet.from_pretrained(downbeats_dir).to(DEVICE)
227
+ downbeat_model.eval()
228
+ has_downbeats = True
229
+ print(f"Loaded Downbeat Model from {downbeats_dir}")
230
+ else:
231
+ print(f"Warning: No downbeat model found in {downbeats_dir}")
232
+
233
+ if not has_beats and not has_downbeats:
234
+ print("No models found. Please run training first.")
235
+ return
236
+
237
+ predictions = []
238
+ ground_truths = []
239
+ audio_data = [] # Store audio for visualization/synthesis
240
+
241
+ # Eval on specified number of tracks
242
+ test_set = ds["train"].select(range(args.num_samples))
243
+
244
+ print("Running evaluation...")
245
+ for i, item in enumerate(tqdm(test_set)):
246
+ waveform = torch.tensor(item["audio"]["array"], dtype=torch.float32)
247
+ waveform_device = waveform.to(DEVICE)
248
+
249
+ pred_entry = {"beats": [], "downbeats": []}
250
+
251
+ # 1. Predict Beats
252
+ if has_beats:
253
+ act_b = get_activation_function(beat_model, waveform_device, DEVICE)
254
+ pred_entry["beats"] = pick_peaks(act_b)
255
+
256
+ # 2. Predict Downbeats
257
+ if has_downbeats:
258
+ act_d = get_activation_function(downbeat_model, waveform_device, DEVICE)
259
+ pred_entry["downbeats"] = pick_peaks(act_d)
260
+
261
+ predictions.append(pred_entry)
262
+ ground_truths.append({"beats": item["beats"], "downbeats": item["downbeats"]})
263
+
264
+ # Store audio for later visualization/synthesis
265
+ if args.visualize or args.synthesize:
266
+ if i < args.viz_tracks:
267
+ audio_data.append(
268
+ {
269
+ "audio": waveform.numpy(),
270
+ "sr": item["audio"]["sampling_rate"],
271
+ "pred": pred_entry,
272
+ "gt": ground_truths[-1],
273
+ }
274
+ )
275
+
276
+ # Run evaluation
277
+ results = evaluate_all(predictions, ground_truths)
278
+ print(format_results(results))
279
+
280
+ # Create output directory
281
+ if args.visualize or args.synthesize or args.summary_plot:
282
+ os.makedirs(args.output_dir, exist_ok=True)
283
+
284
+ # Generate visualizations
285
+ if args.visualize:
286
+ print(f"\nGenerating visualizations for {len(audio_data)} tracks...")
287
+ viz_dir = os.path.join(args.output_dir, "plots")
288
+ for i, data in enumerate(tqdm(audio_data, desc="Visualizing")):
289
+ time_range = tuple(args.time_range) if args.time_range else None
290
+ visualize_track(
291
+ data["audio"],
292
+ data["sr"],
293
+ data["pred"]["beats"],
294
+ data["pred"]["downbeats"],
295
+ data["gt"]["beats"],
296
+ data["gt"]["downbeats"],
297
+ viz_dir,
298
+ i,
299
+ time_range=time_range,
300
+ )
301
+ print(f"Saved visualizations to {viz_dir}")
302
+
303
+ # Generate audio with clicks
304
+ if args.synthesize:
305
+ print(f"\nSynthesizing audio for {len(audio_data)} tracks...")
306
+ audio_dir = os.path.join(args.output_dir, "audio")
307
+ for i, data in enumerate(tqdm(audio_data, desc="Synthesizing")):
308
+ synthesize_audio(
309
+ data["audio"],
310
+ data["sr"],
311
+ data["pred"]["beats"],
312
+ data["pred"]["downbeats"],
313
+ data["gt"]["beats"],
314
+ data["gt"]["downbeats"],
315
+ audio_dir,
316
+ i,
317
+ click_volume=args.click_volume,
318
+ )
319
+ print(f"Saved audio files to {audio_dir}")
320
+ print(" *_pred.wav - Original audio with predicted beat clicks")
321
+ print(" *_gt.wav - Original audio with ground truth beat clicks")
322
+ print(" *_both.wav - Original audio with both predicted and GT clicks")
323
+
324
+ # Generate summary plot
325
+ if args.summary_plot:
326
+ from ..data.viz import plot_evaluation_summary, save_figure
327
+
328
+ print("\nGenerating summary plot...")
329
+ fig = plot_evaluation_summary(results, title="Beat Tracking Evaluation Summary")
330
+ summary_path = os.path.join(args.output_dir, "evaluation_summary.png")
331
+ save_figure(fig, summary_path)
332
+ print(f"Saved summary plot to {summary_path}")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()
exp/baseline3/model.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from huggingface_hub import PyTorchModelHubMixin
4
+
5
+
6
+ class SEBlock(nn.Module):
7
+ def __init__(self, channels: int, reduction: int = 16, use_max_pool: bool = True):
8
+ super().__init__()
9
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
10
+ self.max_pool = nn.AdaptiveMaxPool2d(1) if use_max_pool else None
11
+
12
+ hidden = max(1, channels // reduction)
13
+ self.fc = nn.Sequential(
14
+ nn.Linear(channels, hidden, bias=False),
15
+ nn.ReLU(inplace=True),
16
+ nn.Linear(hidden, channels, bias=False),
17
+ nn.Sigmoid(),
18
+ )
19
+
20
+ def forward(self, x):
21
+ b, c, _, _ = x.size()
22
+ y = self.avg_pool(x).view(b, c)
23
+ if self.max_pool is not None:
24
+ y = y + self.max_pool(x).view(b, c)
25
+ y = self.fc(y).view(b, c, 1, 1)
26
+ return x * y.expand_as(x)
27
+
28
+
29
+ class TemporalSEBlock(nn.Module):
30
+ """Temporal squeeze/excitation for (B, C, F, T) feature maps.
31
+
32
+ Squeezes across frequency (mean over F) to get a per-channel temporal descriptor
33
+ (B, C, T), then excites with a lightweight 1D bottleneck MLP implemented with
34
+ pointwise Conv1d.
35
+ """
36
+
37
+ def __init__(self, channels: int, reduction: int = 16):
38
+ super().__init__()
39
+ hidden = max(1, channels // reduction)
40
+ self.net = nn.Sequential(
41
+ nn.Conv1d(channels, hidden, kernel_size=1, bias=False),
42
+ nn.ReLU(inplace=True),
43
+ nn.Conv1d(hidden, channels, kernel_size=1, bias=False),
44
+ nn.Sigmoid(),
45
+ )
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ # x: (B, C, F, T)
49
+ # squeeze over frequency -> (B, C, T)
50
+ t = x.mean(dim=2)
51
+ gate = self.net(t) # (B, C, T)
52
+ return x * gate.unsqueeze(2)
53
+
54
+
55
+ class ResBlock(nn.Module):
56
+ def __init__(self, in_channels, out_channels, stride=1, downsample=None):
57
+ super().__init__()
58
+ self.conv1 = nn.Conv2d(
59
+ in_channels,
60
+ out_channels,
61
+ kernel_size=3,
62
+ stride=stride,
63
+ padding=1,
64
+ bias=False,
65
+ )
66
+ self.bn1 = nn.BatchNorm2d(out_channels)
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.conv2 = nn.Conv2d(
69
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
70
+ )
71
+ self.bn2 = nn.BatchNorm2d(out_channels)
72
+ # Baseline3: combine channel SE with a lightweight temporal SE gate.
73
+ self.cse = SEBlock(out_channels)
74
+ self.tse = TemporalSEBlock(out_channels)
75
+ self.downsample = downsample
76
+
77
+ def forward(self, x):
78
+ identity = x
79
+ if self.downsample is not None:
80
+ identity = self.downsample(x)
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.cse(out)
89
+ out = self.tse(out)
90
+
91
+ out += identity
92
+ out = self.relu(out)
93
+ return out
94
+
95
+
96
+ class ResNet(nn.Module, PyTorchModelHubMixin):
97
+ def __init__(
98
+ self, layers=[2, 2, 2, 2], channels=[16, 24, 48, 96], dropout_rate=0.5
99
+ ):
100
+ super().__init__()
101
+ self.in_channels = 16
102
+
103
+ # Stem
104
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
105
+ self.bn1 = nn.BatchNorm2d(16)
106
+ self.relu = nn.ReLU(inplace=True)
107
+
108
+ # Stages
109
+ self.layer1 = self._make_layer(channels[0], layers[0], stride=1)
110
+ self.layer2 = self._make_layer(channels[1], layers[1], stride=2)
111
+ self.layer3 = self._make_layer(channels[2], layers[2], stride=2)
112
+ self.layer4 = self._make_layer(channels[3], layers[3], stride=2)
113
+
114
+ self.dropout = nn.Dropout(p=dropout_rate)
115
+
116
+ # Final classification head
117
+ # H, W will reduce. Assuming input is (3, 80, 101)
118
+ # L1: (16, 80, 101) (stride 1)
119
+ # L2: (32, 40, 51) (stride 2)
120
+ # L3: (64, 20, 26) (stride 2)
121
+ # L4: (128, 10, 13) (stride 2)
122
+
123
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
124
+ self.fc = nn.Linear(channels[3], 1)
125
+ self.sigmoid = nn.Sigmoid()
126
+
127
+ def _make_layer(self, out_channels, blocks, stride=1):
128
+ downsample = None
129
+ if stride != 1 or self.in_channels != out_channels:
130
+ downsample = nn.Sequential(
131
+ nn.Conv2d(
132
+ self.in_channels,
133
+ out_channels,
134
+ kernel_size=1,
135
+ stride=stride,
136
+ bias=False,
137
+ ),
138
+ nn.BatchNorm2d(out_channels),
139
+ )
140
+
141
+ layers = []
142
+ layers.append(ResBlock(self.in_channels, out_channels, stride, downsample))
143
+ self.in_channels = out_channels
144
+ for _ in range(1, blocks):
145
+ layers.append(ResBlock(self.in_channels, out_channels))
146
+
147
+ return nn.Sequential(*layers)
148
+
149
+ def forward(self, x):
150
+ # x: (B, 3, 80, 101)
151
+ x = self.conv1(x)
152
+ x = self.bn1(x)
153
+ x = self.relu(x)
154
+
155
+ x = self.layer1(x)
156
+ x = self.layer2(x)
157
+ x = self.layer3(x)
158
+ x = self.layer4(x)
159
+
160
+ x = self.avgpool(x) # (B, 128, 1, 1)
161
+ x = torch.flatten(x, 1) # (B, 128)
162
+ x = self.dropout(x)
163
+ x = self.fc(x)
164
+ x = self.sigmoid(x)
165
+
166
+ return x
167
+
168
+
169
+ if __name__ == "__main__":
170
+ from torchinfo import summary
171
+
172
+ model = ResNet()
173
+ summary(model, (1, 3, 80, 101))
exp/baseline3/train.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader
5
+ from torch.utils.tensorboard import SummaryWriter
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import os
9
+
10
+ from .model import ResNet
11
+ from .data import BeatTrackingDataset
12
+ from ..baseline1.utils import MultiViewSpectrogram
13
+ from ..data.load import ds
14
+
15
+
16
+ def weighted_bce_loss(
17
+ y_pred: torch.Tensor, y_true: torch.Tensor, pos_weight: float
18
+ ) -> torch.Tensor:
19
+ """Weighted BCE on probabilities.
20
+
21
+ This training setup outputs probabilities (sigmoid in model). To better handle
22
+ the heavy class imbalance, we upweight positive-ish labels.
23
+
24
+ Labels are in {0.0, 0.25, 1.0}. We use a smooth weighting: w = 1 + pos_weight * y.
25
+ """
26
+
27
+ bce = torch.nn.functional.binary_cross_entropy(y_pred, y_true, reduction="none")
28
+ weights = 1.0 + (pos_weight * y_true)
29
+ return (bce * weights).mean()
30
+
31
+
32
+ def unweighted_bce_loss(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
33
+ """Unweighted BCE on probabilities.
34
+
35
+ This matches baseline2's loss definition and is useful for apples-to-apples
36
+ TensorBoard comparisons and for early stopping / best checkpoint selection.
37
+ """
38
+
39
+ return torch.nn.functional.binary_cross_entropy(y_pred, y_true)
40
+
41
+
42
+ def train(
43
+ target_type: str,
44
+ output_dir: str,
45
+ *,
46
+ batch_size: int,
47
+ epochs: int,
48
+ lr: float,
49
+ weight_decay: float,
50
+ num_workers: int,
51
+ context_frames: int,
52
+ patience: int,
53
+ pos_weight: float,
54
+ grad_clip: float,
55
+ max_train_tracks: int | None,
56
+ max_val_tracks: int | None,
57
+ max_train_steps: int,
58
+ max_val_steps: int,
59
+ max_steps_total: int,
60
+ hard_neg_radius: int,
61
+ hard_neg_fraction: float,
62
+ ):
63
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
64
+
65
+ print(f"--- Training Model for target: {target_type} ---")
66
+ print(f"Output directory: {output_dir}")
67
+
68
+ # Create output directory
69
+ os.makedirs(output_dir, exist_ok=True)
70
+
71
+ # TensorBoard writer
72
+ writer = SummaryWriter(log_dir=os.path.join(output_dir, "logs"))
73
+
74
+ # Data
75
+ train_dataset = BeatTrackingDataset(
76
+ ds["train"],
77
+ target_type=target_type,
78
+ context_frames=context_frames,
79
+ max_tracks=max_train_tracks,
80
+ hard_neg_radius=hard_neg_radius,
81
+ hard_neg_fraction=hard_neg_fraction,
82
+ )
83
+ val_dataset = BeatTrackingDataset(
84
+ ds["test"],
85
+ target_type=target_type,
86
+ context_frames=context_frames,
87
+ max_tracks=max_val_tracks,
88
+ hard_neg_radius=hard_neg_radius,
89
+ hard_neg_fraction=hard_neg_fraction,
90
+ )
91
+
92
+ train_loader = DataLoader(
93
+ train_dataset,
94
+ batch_size=batch_size,
95
+ shuffle=True,
96
+ num_workers=num_workers,
97
+ pin_memory=True,
98
+ prefetch_factor=4,
99
+ persistent_workers=True,
100
+ )
101
+ val_loader = DataLoader(
102
+ val_dataset,
103
+ batch_size=batch_size,
104
+ shuffle=False,
105
+ num_workers=num_workers,
106
+ pin_memory=True,
107
+ prefetch_factor=4,
108
+ persistent_workers=True,
109
+ )
110
+
111
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
112
+
113
+ # Model
114
+ model = ResNet(dropout_rate=0.5).to(DEVICE)
115
+
116
+ # GPU Spectrogram Preprocessor
117
+ preprocessor = MultiViewSpectrogram(sample_rate=16000, hop_length=160).to(DEVICE)
118
+
119
+ # Optimizer - Using AdamW for ResNet
120
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
121
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
122
+
123
+ # Match baseline2's objective by default (unweighted BCE).
124
+ criterion = nn.BCELoss()
125
+
126
+ best_val_loss = float("inf")
127
+ patience_counter = 0
128
+ global_step = 0
129
+
130
+ for epoch in range(epochs):
131
+ # Training
132
+ model.train()
133
+ total_train_loss = 0
134
+ total_train_loss_unweighted = 0
135
+ steps_this_epoch = 0
136
+ for waveform, y in tqdm(
137
+ train_loader,
138
+ desc=f"[{target_type}] Epoch {epoch + 1}/{epochs} Train",
139
+ leave=False,
140
+ ):
141
+ if max_steps_total > 0 and global_step >= max_steps_total:
142
+ break
143
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
144
+
145
+ # Compute spectrogram on GPU
146
+ with torch.no_grad():
147
+ spec = preprocessor(waveform) # (B, 3, 80, T_raw)
148
+ # Normalize
149
+ mean = spec.mean(dim=(2, 3), keepdim=True)
150
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
151
+ spec = (spec - mean) / std
152
+
153
+ T_curr = spec.shape[-1]
154
+ target_T = context_frames * 2 + 1
155
+
156
+ if T_curr > target_T:
157
+ start = (T_curr - target_T) // 2
158
+ x = spec[:, :, :, start : start + target_T]
159
+ elif T_curr < target_T:
160
+ # This shouldn't happen if dataset is correct, but just in case pad
161
+ pad = target_T - T_curr
162
+ x = torch.nn.functional.pad(spec, (0, pad))
163
+ else:
164
+ x = spec
165
+
166
+ optimizer.zero_grad()
167
+ output = model(x)
168
+
169
+ loss_unweighted = criterion(output, y)
170
+ loss = (
171
+ weighted_bce_loss(output, y, pos_weight=pos_weight)
172
+ if pos_weight > 0
173
+ else loss_unweighted
174
+ )
175
+ loss.backward()
176
+
177
+ if grad_clip > 0:
178
+ torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
179
+ optimizer.step()
180
+
181
+ total_train_loss += loss.item()
182
+ total_train_loss_unweighted += loss_unweighted.item()
183
+ global_step += 1
184
+ steps_this_epoch += 1
185
+
186
+ # Log losses
187
+ # - train/batch_loss matches baseline2 (unweighted)
188
+ # - train/batch_loss_weighted is the optimized objective (if pos_weight>0)
189
+ writer.add_scalar("train/batch_loss", loss_unweighted.item(), global_step)
190
+ writer.add_scalar("train/batch_loss_weighted", loss.item(), global_step)
191
+
192
+ if max_train_steps > 0 and steps_this_epoch >= max_train_steps:
193
+ break
194
+
195
+ if steps_this_epoch == 0:
196
+ print("No training steps executed (max_steps reached).")
197
+ break
198
+
199
+ avg_train_loss = total_train_loss / steps_this_epoch
200
+ avg_train_loss_unweighted = total_train_loss_unweighted / steps_this_epoch
201
+
202
+ # Validation
203
+ model.eval()
204
+ total_val_loss = 0
205
+ total_val_loss_unweighted = 0
206
+ val_steps = 0
207
+ with torch.no_grad():
208
+ for waveform, y in tqdm(
209
+ val_loader,
210
+ desc=f"[{target_type}] Epoch {epoch + 1}/{epochs} Val",
211
+ leave=False,
212
+ ):
213
+ if max_steps_total > 0 and global_step >= max_steps_total:
214
+ break
215
+ waveform, y = waveform.to(DEVICE), y.to(DEVICE)
216
+
217
+ # Compute spectrogram on GPU
218
+ spec = preprocessor(waveform) # (B, 3, 80, T)
219
+ # Normalize
220
+ mean = spec.mean(dim=(2, 3), keepdim=True)
221
+ std = spec.std(dim=(2, 3), keepdim=True) + 1e-6
222
+ spec = (spec - mean) / std
223
+
224
+ T_curr = spec.shape[-1]
225
+ target_T = context_frames * 2 + 1
226
+
227
+ if T_curr > target_T:
228
+ start = (T_curr - target_T) // 2
229
+ x = spec[:, :, :, start : start + target_T]
230
+ else:
231
+ pad = target_T - T_curr
232
+ x = torch.nn.functional.pad(spec, (0, pad))
233
+
234
+ output = model(x)
235
+ loss_unweighted = criterion(output, y)
236
+ loss = (
237
+ weighted_bce_loss(output, y, pos_weight=pos_weight)
238
+ if pos_weight > 0
239
+ else loss_unweighted
240
+ )
241
+ total_val_loss_unweighted += loss_unweighted.item()
242
+ total_val_loss += loss.item()
243
+ val_steps += 1
244
+
245
+ if max_val_steps > 0 and val_steps >= max_val_steps:
246
+ break
247
+
248
+ if val_steps == 0:
249
+ print("No validation steps executed (max_steps reached).")
250
+ break
251
+
252
+ avg_val_loss = total_val_loss / val_steps
253
+ avg_val_loss_unweighted = total_val_loss_unweighted / val_steps
254
+
255
+ # Log epoch metrics
256
+ writer.add_scalar("train/epoch_loss", avg_train_loss_unweighted, epoch)
257
+ writer.add_scalar("train/epoch_loss_weighted", avg_train_loss, epoch)
258
+ writer.add_scalar("val/loss", avg_val_loss_unweighted, epoch)
259
+ writer.add_scalar("val/loss_weighted", avg_val_loss, epoch)
260
+ writer.add_scalar("train/learning_rate", scheduler.get_last_lr()[0], epoch)
261
+
262
+ # Step the scheduler
263
+ scheduler.step()
264
+
265
+ print(
266
+ f"[{target_type}] Epoch {epoch + 1}/{epochs} - "
267
+ f"Train Loss: {avg_train_loss_unweighted:.4f}, Val Loss: {avg_val_loss_unweighted:.4f}"
268
+ )
269
+
270
+ # Save best model
271
+ if avg_val_loss_unweighted < best_val_loss:
272
+ best_val_loss = avg_val_loss_unweighted
273
+ patience_counter = 0
274
+ model.save_pretrained(output_dir)
275
+ print(f" -> Saved best model (val_loss: {best_val_loss:.4f})")
276
+ else:
277
+ patience_counter += 1
278
+ print(f" -> No improvement (patience: {patience_counter}/{patience})")
279
+
280
+ if patience_counter >= patience:
281
+ print("Early stopping triggered.")
282
+ break
283
+
284
+ if max_steps_total > 0 and global_step >= max_steps_total:
285
+ print("Reached max_steps_total; stopping training.")
286
+ break
287
+
288
+ writer.close()
289
+
290
+ # Save final model
291
+ final_dir = os.path.join(output_dir, "final")
292
+ model.save_pretrained(final_dir)
293
+ print(f"Saved final model to {final_dir}")
294
+
295
+
296
+ if __name__ == "__main__":
297
+ parser = argparse.ArgumentParser()
298
+ parser.add_argument(
299
+ "--target",
300
+ type=str,
301
+ choices=["beats", "downbeats"],
302
+ default=None,
303
+ help="Train a model for 'beats' or 'downbeats'. If not specified, trains both.",
304
+ )
305
+ parser.add_argument(
306
+ "--output-dir",
307
+ type=str,
308
+ default="outputs/baseline3",
309
+ help="Directory to save model and logs",
310
+ )
311
+ parser.add_argument(
312
+ "--batch-size",
313
+ type=int,
314
+ default=128,
315
+ help="Batch size (default: 128)",
316
+ )
317
+ parser.add_argument(
318
+ "--epochs",
319
+ type=int,
320
+ default=3,
321
+ help="Max epochs (default: 3; early stopping may stop sooner)",
322
+ )
323
+ parser.add_argument(
324
+ "--lr",
325
+ type=float,
326
+ default=0.001,
327
+ help="AdamW learning rate (default: 0.001)",
328
+ )
329
+ parser.add_argument(
330
+ "--weight-decay",
331
+ type=float,
332
+ default=1e-4,
333
+ help="AdamW weight decay (default: 1e-4)",
334
+ )
335
+ parser.add_argument(
336
+ "--num-workers",
337
+ type=int,
338
+ default=4,
339
+ help="DataLoader workers (default: 4)",
340
+ )
341
+ parser.add_argument(
342
+ "--max-train-tracks",
343
+ type=int,
344
+ default=0,
345
+ help="Limit train split to first N tracks (default: 0 = all)",
346
+ )
347
+ parser.add_argument(
348
+ "--max-val-tracks",
349
+ type=int,
350
+ default=0,
351
+ help="Limit val split to first N tracks (default: 0 = all)",
352
+ )
353
+ parser.add_argument(
354
+ "--max-train-steps",
355
+ type=int,
356
+ default=0,
357
+ help="Max train batches per epoch (default: 0 = all)",
358
+ )
359
+ parser.add_argument(
360
+ "--max-val-steps",
361
+ type=int,
362
+ default=0,
363
+ help="Max val batches per epoch (default: 0 = all)",
364
+ )
365
+ parser.add_argument(
366
+ "--max-steps-total",
367
+ type=int,
368
+ default=0,
369
+ help="Stop training after N total train batches (default: 0 = unlimited)",
370
+ )
371
+ parser.add_argument(
372
+ "--hard-neg-radius",
373
+ type=int,
374
+ default=0,
375
+ help="Add negatives at +/-d frames from each beat for d>=2..R (default: 0 = off)",
376
+ )
377
+ parser.add_argument(
378
+ "--hard-neg-fraction",
379
+ type=float,
380
+ default=0.5,
381
+ help="Fraction of negatives reserved for hard negatives (default: 0.5)",
382
+ )
383
+ parser.add_argument(
384
+ "--context-frames",
385
+ type=int,
386
+ default=50,
387
+ help="Context frames on each side (default: 50 -> 101 total frames)",
388
+ )
389
+ parser.add_argument(
390
+ "--patience",
391
+ type=int,
392
+ default=5,
393
+ help="Early stopping patience (default: 5)",
394
+ )
395
+ parser.add_argument(
396
+ "--pos-weight",
397
+ type=float,
398
+ default=0.0,
399
+ help="Positive label upweight factor (default: 0.0; 0 matches baseline2)",
400
+ )
401
+ parser.add_argument(
402
+ "--grad-clip",
403
+ type=float,
404
+ default=0.0,
405
+ help="Clip gradient norm; set 0 to disable (default: 0.0)",
406
+ )
407
+ args = parser.parse_args()
408
+
409
+ # Determine which targets to train
410
+ targets = [args.target] if args.target else ["beats", "downbeats"]
411
+
412
+ for target in targets:
413
+ output_dir = os.path.join(args.output_dir, target)
414
+ train(
415
+ target,
416
+ output_dir,
417
+ batch_size=args.batch_size,
418
+ epochs=args.epochs,
419
+ lr=args.lr,
420
+ weight_decay=args.weight_decay,
421
+ num_workers=args.num_workers,
422
+ context_frames=args.context_frames,
423
+ patience=args.patience,
424
+ pos_weight=args.pos_weight,
425
+ grad_clip=args.grad_clip,
426
+ max_train_tracks=(args.max_train_tracks or None),
427
+ max_val_tracks=(args.max_val_tracks or None),
428
+ max_train_steps=args.max_train_steps,
429
+ max_val_steps=args.max_val_steps,
430
+ max_steps_total=args.max_steps_total,
431
+ hard_neg_radius=args.hard_neg_radius,
432
+ hard_neg_fraction=args.hard_neg_fraction,
433
+ )
exp/data/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loading and evaluation utilities for beat tracking.
3
+
4
+ Modules:
5
+ - load: Dataset loading and preprocessing
6
+ - eval: Evaluation metrics and utilities
7
+ """
8
+
9
+ from exp.data.eval import (
10
+ DEFAULT_THRESHOLDS_MS,
11
+ evaluate_beats,
12
+ evaluate_track,
13
+ evaluate_all,
14
+ compute_weighted_f1,
15
+ format_results,
16
+ )
17
+
18
+ __all__ = [
19
+ "DEFAULT_THRESHOLDS_MS",
20
+ "evaluate_beats",
21
+ "evaluate_track",
22
+ "evaluate_all",
23
+ "compute_weighted_f1",
24
+ "format_results",
25
+ ]
exp/data/audio.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio synthesis utilities for beat tracking evaluation.
3
+
4
+ This module provides functions to:
5
+ - Generate click sounds for beats and downbeats
6
+ - Mix click tracks with original audio
7
+ - Save audio files with beat annotations
8
+
9
+ Example usage:
10
+ from exp.data.audio import create_click_track, mix_audio, save_audio
11
+
12
+ # Create click track
13
+ clicks = create_click_track(
14
+ beat_times=pred_beats,
15
+ downbeat_times=pred_downbeats,
16
+ duration=30.0,
17
+ sr=16000
18
+ )
19
+
20
+ # Mix with original audio
21
+ mixed = mix_audio(original_audio, clicks, click_volume=0.5)
22
+
23
+ # Save to file
24
+ save_audio(mixed, "output.wav", sr=16000)
25
+ """
26
+
27
+ import numpy as np
28
+ from pathlib import Path
29
+
30
+
31
+ def generate_click(
32
+ frequency: float = 1000.0,
33
+ duration: float = 0.02,
34
+ sr: int = 16000,
35
+ attack: float = 0.002,
36
+ decay: float = 0.018,
37
+ ) -> np.ndarray:
38
+ """
39
+ Generate a single click sound.
40
+
41
+ Args:
42
+ frequency: Frequency of the click tone in Hz
43
+ duration: Duration of the click in seconds
44
+ sr: Sample rate
45
+ attack: Attack time in seconds
46
+ decay: Decay time in seconds
47
+
48
+ Returns:
49
+ Click waveform as numpy array
50
+ """
51
+ t = np.arange(int(duration * sr)) / sr
52
+
53
+ # Generate sine wave
54
+ wave = np.sin(2 * np.pi * frequency * t)
55
+
56
+ # Apply envelope (attack-decay)
57
+ envelope = np.ones_like(t)
58
+ attack_samples = int(attack * sr)
59
+ decay_samples = int(decay * sr)
60
+
61
+ if attack_samples > 0:
62
+ envelope[:attack_samples] = np.linspace(0, 1, attack_samples)
63
+ if decay_samples > 0:
64
+ decay_start = len(t) - decay_samples
65
+ if decay_start > 0:
66
+ envelope[decay_start:] = np.linspace(1, 0, decay_samples)
67
+
68
+ return wave * envelope
69
+
70
+
71
+ def create_click_track(
72
+ beat_times: list[float] | np.ndarray,
73
+ downbeat_times: list[float] | np.ndarray | None = None,
74
+ duration: float | None = None,
75
+ sr: int = 16000,
76
+ beat_freq: float = 1000.0,
77
+ downbeat_freq: float = 1500.0,
78
+ click_duration: float = 0.03,
79
+ ) -> np.ndarray:
80
+ """
81
+ Create a click track from beat and downbeat times.
82
+
83
+ Args:
84
+ beat_times: List of beat times in seconds
85
+ downbeat_times: List of downbeat times in seconds (optional)
86
+ duration: Total duration in seconds (auto-detected if None)
87
+ sr: Sample rate
88
+ beat_freq: Frequency for beat clicks (Hz)
89
+ downbeat_freq: Frequency for downbeat clicks (Hz)
90
+ click_duration: Duration of each click in seconds
91
+
92
+ Returns:
93
+ Click track as numpy array
94
+ """
95
+ beat_times = np.array(beat_times) if len(beat_times) > 0 else np.array([])
96
+ if downbeat_times is not None:
97
+ downbeat_times = (
98
+ np.array(downbeat_times) if len(downbeat_times) > 0 else np.array([])
99
+ )
100
+ else:
101
+ downbeat_times = np.array([])
102
+
103
+ # Determine duration
104
+ if duration is None:
105
+ all_times = np.concatenate([beat_times, downbeat_times])
106
+ if len(all_times) == 0:
107
+ return np.array([])
108
+ duration = float(np.max(all_times)) + 1.0
109
+
110
+ # Create output array
111
+ total_samples = int(duration * sr)
112
+ output = np.zeros(total_samples, dtype=np.float32)
113
+
114
+ # Generate click templates
115
+ beat_click = generate_click(frequency=beat_freq, duration=click_duration, sr=sr)
116
+ downbeat_click = generate_click(
117
+ frequency=downbeat_freq, duration=click_duration, sr=sr
118
+ )
119
+
120
+ # Convert downbeat times to set for fast lookup
121
+ downbeat_set = set(np.round(downbeat_times, 3))
122
+
123
+ # Add beat clicks
124
+ for t in beat_times:
125
+ sample_idx = int(t * sr)
126
+ if sample_idx < 0 or sample_idx >= total_samples:
127
+ continue
128
+
129
+ # Use downbeat click if this is also a downbeat
130
+ is_downbeat = np.round(t, 3) in downbeat_set
131
+ click = downbeat_click if is_downbeat else beat_click
132
+
133
+ # Add click to output
134
+ end_idx = min(sample_idx + len(click), total_samples)
135
+ click_len = end_idx - sample_idx
136
+ output[sample_idx:end_idx] += click[:click_len]
137
+
138
+ # Add downbeat clicks (for downbeats not already in beats)
139
+ beat_set = set(np.round(beat_times, 3))
140
+ for t in downbeat_times:
141
+ if np.round(t, 3) in beat_set:
142
+ continue # Already added as beat
143
+
144
+ sample_idx = int(t * sr)
145
+ if sample_idx < 0 or sample_idx >= total_samples:
146
+ continue
147
+
148
+ end_idx = min(sample_idx + len(downbeat_click), total_samples)
149
+ click_len = end_idx - sample_idx
150
+ output[sample_idx:end_idx] += downbeat_click[:click_len]
151
+
152
+ return output
153
+
154
+
155
+ def mix_audio(
156
+ audio: np.ndarray,
157
+ click_track: np.ndarray,
158
+ click_volume: float = 0.5,
159
+ ) -> np.ndarray:
160
+ """
161
+ Mix original audio with a click track.
162
+
163
+ Args:
164
+ audio: Original audio waveform
165
+ click_track: Click track to overlay
166
+ click_volume: Volume of clicks relative to audio (0.0 to 1.0)
167
+
168
+ Returns:
169
+ Mixed audio
170
+ """
171
+ # Ensure same length
172
+ max_len = max(len(audio), len(click_track))
173
+ audio_padded = np.zeros(max_len, dtype=np.float32)
174
+ click_padded = np.zeros(max_len, dtype=np.float32)
175
+
176
+ audio_padded[: len(audio)] = audio
177
+ click_padded[: len(click_track)] = click_track
178
+
179
+ # Normalize audio
180
+ audio_max = np.abs(audio_padded).max()
181
+ if audio_max > 0:
182
+ audio_padded = audio_padded / audio_max * 0.8
183
+
184
+ # Normalize clicks
185
+ click_max = np.abs(click_padded).max()
186
+ if click_max > 0:
187
+ click_padded = click_padded / click_max * click_volume * 0.8
188
+
189
+ # Mix
190
+ mixed = audio_padded + click_padded
191
+
192
+ # Prevent clipping
193
+ max_val = np.abs(mixed).max()
194
+ if max_val > 1.0:
195
+ mixed = mixed / max_val * 0.95
196
+
197
+ return mixed.astype(np.float32)
198
+
199
+
200
+ def create_comparison_audio(
201
+ audio: np.ndarray,
202
+ pred_beats: list[float],
203
+ pred_downbeats: list[float],
204
+ gt_beats: list[float],
205
+ gt_downbeats: list[float],
206
+ sr: int = 16000,
207
+ click_volume: float = 0.5,
208
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
209
+ """
210
+ Create audio files for comparison: prediction clicks, ground truth clicks, and combined.
211
+
212
+ Args:
213
+ audio: Original audio waveform
214
+ pred_beats: Predicted beat times
215
+ pred_downbeats: Predicted downbeat times
216
+ gt_beats: Ground truth beat times
217
+ gt_downbeats: Ground truth downbeat times
218
+ sr: Sample rate
219
+ click_volume: Volume of clicks
220
+
221
+ Returns:
222
+ Tuple of (audio_with_pred_clicks, audio_with_gt_clicks, audio_with_both)
223
+ """
224
+ duration = len(audio) / sr
225
+
226
+ # Create click tracks
227
+ pred_clicks = create_click_track(
228
+ pred_beats,
229
+ pred_downbeats,
230
+ duration=duration,
231
+ sr=sr,
232
+ beat_freq=1000.0,
233
+ downbeat_freq=1500.0,
234
+ )
235
+
236
+ gt_clicks = create_click_track(
237
+ gt_beats,
238
+ gt_downbeats,
239
+ duration=duration,
240
+ sr=sr,
241
+ beat_freq=800.0, # Different frequency for GT
242
+ downbeat_freq=1200.0,
243
+ )
244
+
245
+ # Mix
246
+ audio_pred = mix_audio(audio, pred_clicks, click_volume)
247
+ audio_gt = mix_audio(audio, gt_clicks, click_volume)
248
+ audio_both = mix_audio(audio, pred_clicks + gt_clicks, click_volume)
249
+
250
+ return audio_pred, audio_gt, audio_both
251
+
252
+
253
+ def save_audio(
254
+ audio: np.ndarray,
255
+ path: str | Path,
256
+ sr: int = 16000,
257
+ ) -> None:
258
+ """
259
+ Save audio to a WAV file.
260
+
261
+ Args:
262
+ audio: Audio waveform
263
+ path: Output file path
264
+ sr: Sample rate
265
+ """
266
+ import scipy.io.wavfile as wavfile
267
+
268
+ path = Path(path)
269
+ path.parent.mkdir(parents=True, exist_ok=True)
270
+
271
+ # Convert to int16
272
+ audio_int16 = (audio * 32767).astype(np.int16)
273
+ wavfile.write(str(path), sr, audio_int16)
274
+
275
+
276
+ if __name__ == "__main__":
277
+ # Demo
278
+ print("Audio synthesis demo...")
279
+
280
+ # Create a simple sine wave as "music"
281
+ sr = 16000
282
+ duration = 10.0
283
+ t = np.arange(int(duration * sr)) / sr
284
+ music = np.sin(2 * np.pi * 220 * t) * 0.3 # 220 Hz tone
285
+
286
+ # Beats every 0.5s, downbeats every 2s
287
+ beats = np.arange(0, duration, 0.5).tolist()
288
+ downbeats = np.arange(0, duration, 2.0).tolist()
289
+
290
+ # Create click track
291
+ clicks = create_click_track(beats, downbeats, duration=duration, sr=sr)
292
+
293
+ # Mix
294
+ mixed = mix_audio(music, clicks, click_volume=0.6)
295
+
296
+ print(f"Created mixed audio: {len(mixed)} samples ({len(mixed) / sr:.2f}s)")
297
+ print(f"Beats: {len(beats)}, Downbeats: {len(downbeats)}")
298
+
299
+ # Save demo
300
+ save_audio(mixed, "/tmp/beat_click_demo.wav", sr=sr)
301
+ print("Saved demo to /tmp/beat_click_demo.wav")
exp/data/eval.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation utilities for beat and downbeat detection.
3
+
4
+ This module provides functions to evaluate beat/downbeat predictions against
5
+ ground truth annotations using F1-scores at various timing thresholds and
6
+ continuity-based metrics (CMLt, AMLt).
7
+
8
+ The evaluation metrics include:
9
+ - **F1-scores**: Calculated for timing thresholds from 3ms to 30ms
10
+ - **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2)
11
+ - **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level
12
+ - **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations
13
+ (double/half tempo, off-beat, etc.)
14
+ - **CMLc/AMLc**: Continuous versions (longest correct segment)
15
+
16
+ Example usage:
17
+ from ..data.eval import (
18
+ evaluate_beats, evaluate_all, compute_weighted_f1,
19
+ compute_continuity_metrics, format_results
20
+ )
21
+
22
+ # Evaluate single track
23
+ results = evaluate_beats(pred_beats, gt_beats)
24
+ print(f"Weighted F1: {results['weighted_f1']:.4f}")
25
+ print(f"CMLt: {results['continuity']['CMLt']:.4f}")
26
+ print(f"AMLt: {results['continuity']['AMLt']:.4f}")
27
+
28
+ # Evaluate with custom thresholds
29
+ results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20])
30
+
31
+ # Evaluate all tracks in dataset
32
+ summary = evaluate_all(predictions, ground_truths)
33
+ print(format_results(summary))
34
+ """
35
+
36
+ from typing import Sequence
37
+ import numpy as np
38
+ import mir_eval
39
+
40
+
41
+ # Default timing thresholds in milliseconds (3ms to 30ms, step 3ms)
42
+ DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30]
43
+
44
+ # Default minimum beat time for mir_eval metrics (can be set to 0 to use all beats)
45
+ DEFAULT_MIN_BEAT_TIME = 5.0
46
+
47
+
48
+ def match_events(
49
+ pred: np.ndarray,
50
+ gt: np.ndarray,
51
+ tolerance_sec: float,
52
+ ) -> tuple[int, int, int]:
53
+ """
54
+ Match predicted events to ground truth events within a tolerance.
55
+
56
+ Uses greedy matching: each ground truth event is matched to the closest
57
+ unmatched prediction within the tolerance window.
58
+
59
+ Args:
60
+ pred: Predicted event times in seconds, shape (N,)
61
+ gt: Ground truth event times in seconds, shape (M,)
62
+ tolerance_sec: Maximum time difference for a match (in seconds)
63
+
64
+ Returns:
65
+ Tuple of (true_positives, false_positives, false_negatives)
66
+ """
67
+ if len(gt) == 0:
68
+ return 0, len(pred), 0
69
+ if len(pred) == 0:
70
+ return 0, 0, len(gt)
71
+
72
+ pred = np.sort(pred)
73
+ gt = np.sort(gt)
74
+
75
+ matched_pred = np.zeros(len(pred), dtype=bool)
76
+ matched_gt = np.zeros(len(gt), dtype=bool)
77
+
78
+ # For each ground truth, find the closest unmatched prediction
79
+ for i, gt_time in enumerate(gt):
80
+ # Find predictions within tolerance
81
+ diffs = np.abs(pred - gt_time)
82
+ candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0]
83
+
84
+ if len(candidates) > 0:
85
+ # Match to closest candidate
86
+ best_idx = candidates[np.argmin(diffs[candidates])]
87
+ matched_pred[best_idx] = True
88
+ matched_gt[i] = True
89
+
90
+ tp = int(matched_gt.sum())
91
+ fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp)
92
+ fn = int(len(gt) - tp)
93
+
94
+ # Recalculate fp correctly
95
+ fp = len(pred) - tp
96
+
97
+ return tp, fp, fn
98
+
99
+
100
+ def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]:
101
+ """
102
+ Compute precision, recall, and F1-score from TP, FP, FN counts.
103
+
104
+ Args:
105
+ tp: True positives
106
+ fp: False positives
107
+ fn: False negatives
108
+
109
+ Returns:
110
+ Tuple of (precision, recall, f1_score)
111
+ """
112
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
113
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
114
+ f1 = (
115
+ 2 * precision * recall / (precision + recall)
116
+ if (precision + recall) > 0
117
+ else 0.0
118
+ )
119
+ return precision, recall, f1
120
+
121
+
122
+ def compute_weighted_f1(
123
+ f1_scores: dict[int, float],
124
+ thresholds_ms: Sequence[int] | None = None,
125
+ ) -> float:
126
+ """
127
+ Compute weighted F1-score where weights are inversely proportional to threshold.
128
+
129
+ The weight for threshold T ms is 1 / (T / min_threshold).
130
+ For example, with thresholds [3, 6, 9, ...]:
131
+ - 3ms: weight = 1
132
+ - 6ms: weight = 0.5
133
+ - 9ms: weight = 0.333...
134
+
135
+ Args:
136
+ f1_scores: Dict mapping threshold (ms) to F1-score
137
+ thresholds_ms: List of thresholds used (for weight calculation)
138
+
139
+ Returns:
140
+ Weighted F1-score
141
+ """
142
+ if not f1_scores:
143
+ return 0.0
144
+
145
+ if thresholds_ms is None:
146
+ thresholds_ms = sorted(f1_scores.keys())
147
+
148
+ min_threshold = min(thresholds_ms)
149
+ total_weight = 0.0
150
+ weighted_sum = 0.0
151
+
152
+ for t in thresholds_ms:
153
+ if t in f1_scores:
154
+ weight = min_threshold / t # 3ms -> 1, 6ms -> 0.5, etc.
155
+ weighted_sum += weight * f1_scores[t]
156
+ total_weight += weight
157
+
158
+ return weighted_sum / total_weight if total_weight > 0 else 0.0
159
+
160
+
161
+ def compute_continuity_metrics(
162
+ pred_times: Sequence[float],
163
+ gt_times: Sequence[float],
164
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
165
+ phase_threshold: float = 0.175,
166
+ period_threshold: float = 0.175,
167
+ ) -> dict:
168
+ """
169
+ Compute continuity-based beat tracking metrics using mir_eval.
170
+
171
+ These metrics evaluate beat tracking accuracy accounting for metrical level:
172
+ - CMLt (Correct Metric Level Total): Accuracy at the correct metrical level
173
+ - AMLt (Any Metric Level Total): Accuracy allowing for metrical variations
174
+ (double/half tempo, off-beat, etc.)
175
+ - CMLc/AMLc: Continuous versions (longest correct segment)
176
+
177
+ Args:
178
+ pred_times: Predicted beat times in seconds
179
+ gt_times: Ground truth beat times in seconds
180
+ min_beat_time: Minimum time to start evaluation (default: 5.0s)
181
+ Set to 0.0 to use all beats, but note that early beats
182
+ may not have stable inter-beat intervals.
183
+ phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175)
184
+ period_threshold: Maximum period error as ratio of beat interval (default: 0.175)
185
+
186
+ Returns:
187
+ Dict containing:
188
+ - 'CMLc': Correct Metric Level Continuous
189
+ - 'CMLt': Correct Metric Level Total
190
+ - 'AMLc': Any Metric Level Continuous
191
+ - 'AMLt': Any Metric Level Total
192
+ """
193
+ pred_arr = np.sort(np.array(pred_times, dtype=np.float64))
194
+ gt_arr = np.sort(np.array(gt_times, dtype=np.float64))
195
+
196
+ # Trim beats before min_beat_time (standard preprocessing)
197
+ pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time)
198
+ gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time)
199
+
200
+ # Handle edge cases where trimming results in too few beats
201
+ if len(gt_trimmed) < 2 or len(pred_trimmed) < 2:
202
+ return {
203
+ "CMLc": 0.0,
204
+ "CMLt": 0.0,
205
+ "AMLc": 0.0,
206
+ "AMLt": 0.0,
207
+ }
208
+
209
+ # Compute continuity metrics
210
+ CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity(
211
+ gt_trimmed,
212
+ pred_trimmed,
213
+ continuity_phase_threshold=phase_threshold,
214
+ continuity_period_threshold=period_threshold,
215
+ )
216
+
217
+ return {
218
+ "CMLc": float(CMLc),
219
+ "CMLt": float(CMLt),
220
+ "AMLc": float(AMLc),
221
+ "AMLt": float(AMLt),
222
+ }
223
+
224
+
225
+ def evaluate_beats(
226
+ pred_times: Sequence[float],
227
+ gt_times: Sequence[float],
228
+ thresholds_ms: Sequence[int] | None = None,
229
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
230
+ ) -> dict:
231
+ """
232
+ Evaluate beat predictions against ground truth at multiple thresholds.
233
+
234
+ Args:
235
+ pred_times: Predicted beat times in seconds
236
+ gt_times: Ground truth beat times in seconds
237
+ thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms)
238
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
239
+
240
+ Returns:
241
+ Dict containing:
242
+ - 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}]
243
+ - 'f1_scores': Dict[threshold_ms, f1_score] (convenience access)
244
+ - 'weighted_f1': Weighted F1-score across all thresholds
245
+ - 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics
246
+ - 'num_predictions': Number of predictions
247
+ - 'num_ground_truth': Number of ground truth events
248
+ """
249
+ if thresholds_ms is None:
250
+ thresholds_ms = DEFAULT_THRESHOLDS_MS
251
+
252
+ pred_arr = np.array(pred_times, dtype=np.float64)
253
+ gt_arr = np.array(gt_times, dtype=np.float64)
254
+
255
+ per_threshold = {}
256
+ f1_scores = {}
257
+
258
+ for threshold_ms in thresholds_ms:
259
+ tolerance_sec = threshold_ms / 1000.0
260
+ tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec)
261
+ precision, recall, f1 = compute_f1(tp, fp, fn)
262
+
263
+ per_threshold[threshold_ms] = {
264
+ "precision": precision,
265
+ "recall": recall,
266
+ "f1": f1,
267
+ "tp": tp,
268
+ "fp": fp,
269
+ "fn": fn,
270
+ }
271
+ f1_scores[threshold_ms] = f1
272
+
273
+ weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms)
274
+ continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time)
275
+
276
+ return {
277
+ "per_threshold": per_threshold,
278
+ "f1_scores": f1_scores,
279
+ "weighted_f1": weighted_f1,
280
+ "continuity": continuity,
281
+ "num_predictions": len(pred_arr),
282
+ "num_ground_truth": len(gt_arr),
283
+ }
284
+
285
+
286
+ def evaluate_track(
287
+ pred_beats: Sequence[float],
288
+ pred_downbeats: Sequence[float],
289
+ gt_beats: Sequence[float],
290
+ gt_downbeats: Sequence[float],
291
+ thresholds_ms: Sequence[int] | None = None,
292
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
293
+ ) -> dict:
294
+ """
295
+ Evaluate both beat and downbeat predictions for a single track.
296
+
297
+ Args:
298
+ pred_beats: Predicted beat times in seconds
299
+ pred_downbeats: Predicted downbeat times in seconds
300
+ gt_beats: Ground truth beat times in seconds
301
+ gt_downbeats: Ground truth downbeat times in seconds
302
+ thresholds_ms: Timing thresholds in milliseconds
303
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
304
+
305
+ Returns:
306
+ Dict containing:
307
+ - 'beats': Results from evaluate_beats for beats
308
+ - 'downbeats': Results from evaluate_beats for downbeats
309
+ - 'combined_weighted_f1': Average of beat and downbeat weighted F1
310
+ """
311
+ beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time)
312
+ downbeat_results = evaluate_beats(
313
+ pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time
314
+ )
315
+
316
+ combined_weighted_f1 = (
317
+ beat_results["weighted_f1"] + downbeat_results["weighted_f1"]
318
+ ) / 2
319
+
320
+ return {
321
+ "beats": beat_results,
322
+ "downbeats": downbeat_results,
323
+ "combined_weighted_f1": combined_weighted_f1,
324
+ }
325
+
326
+
327
+ def evaluate_all(
328
+ predictions: Sequence[dict],
329
+ ground_truths: Sequence[dict],
330
+ thresholds_ms: Sequence[int] | None = None,
331
+ min_beat_time: float = DEFAULT_MIN_BEAT_TIME,
332
+ verbose: bool = False,
333
+ ) -> dict:
334
+ """
335
+ Evaluate predictions for multiple tracks.
336
+
337
+ Args:
338
+ predictions: List of dicts with 'beats' and 'downbeats' keys
339
+ ground_truths: List of dicts with 'beats' and 'downbeats' keys
340
+ thresholds_ms: Timing thresholds in milliseconds
341
+ min_beat_time: Minimum time for continuity metrics (default: 5.0s)
342
+ verbose: If True, print per-track results
343
+
344
+ Returns:
345
+ Dict containing:
346
+ - 'per_track': List of per-track results
347
+ - 'mean_beat_weighted_f1': Mean weighted F1 for beats
348
+ - 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats
349
+ - 'mean_combined_weighted_f1': Mean combined weighted F1
350
+ - 'beat_f1_by_threshold': Mean F1 per threshold for beats
351
+ - 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats
352
+ - 'beat_continuity': Mean continuity metrics for beats
353
+ - 'downbeat_continuity': Mean continuity metrics for downbeats
354
+ """
355
+ if len(predictions) != len(ground_truths):
356
+ raise ValueError(
357
+ f"Number of predictions ({len(predictions)}) must match "
358
+ f"number of ground truths ({len(ground_truths)})"
359
+ )
360
+
361
+ if thresholds_ms is None:
362
+ thresholds_ms = DEFAULT_THRESHOLDS_MS
363
+
364
+ per_track = []
365
+ beat_weighted_f1s = []
366
+ downbeat_weighted_f1s = []
367
+ combined_weighted_f1s = []
368
+
369
+ beat_f1_by_threshold = {t: [] for t in thresholds_ms}
370
+ downbeat_f1_by_threshold = {t: [] for t in thresholds_ms}
371
+
372
+ # Continuity metrics tracking
373
+ beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
374
+ downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []}
375
+
376
+ for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
377
+ result = evaluate_track(
378
+ pred_beats=pred["beats"],
379
+ pred_downbeats=pred["downbeats"],
380
+ gt_beats=gt["beats"],
381
+ gt_downbeats=gt["downbeats"],
382
+ thresholds_ms=thresholds_ms,
383
+ min_beat_time=min_beat_time,
384
+ )
385
+
386
+ per_track.append(result)
387
+ beat_weighted_f1s.append(result["beats"]["weighted_f1"])
388
+ downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"])
389
+ combined_weighted_f1s.append(result["combined_weighted_f1"])
390
+
391
+ for t in thresholds_ms:
392
+ beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t])
393
+ downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t])
394
+
395
+ # Track continuity metrics
396
+ for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]:
397
+ beat_continuity[metric].append(result["beats"]["continuity"][metric])
398
+ downbeat_continuity[metric].append(
399
+ result["downbeats"]["continuity"][metric]
400
+ )
401
+
402
+ if verbose:
403
+ beat_cont = result["beats"]["continuity"]
404
+ print(
405
+ f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, "
406
+ f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, "
407
+ f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, "
408
+ f"Combined={result['combined_weighted_f1']:.4f}"
409
+ )
410
+
411
+ return {
412
+ "per_track": per_track,
413
+ "mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)),
414
+ "mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)),
415
+ "mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)),
416
+ "beat_f1_by_threshold": {
417
+ t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items()
418
+ },
419
+ "downbeat_f1_by_threshold": {
420
+ t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items()
421
+ },
422
+ "beat_continuity": {
423
+ metric: float(np.mean(values)) for metric, values in beat_continuity.items()
424
+ },
425
+ "downbeat_continuity": {
426
+ metric: float(np.mean(values))
427
+ for metric, values in downbeat_continuity.items()
428
+ },
429
+ "num_tracks": len(predictions),
430
+ }
431
+
432
+
433
+ def format_results(results: dict, title: str = "Evaluation Results") -> str:
434
+ """
435
+ Format evaluation results as a human-readable string.
436
+
437
+ Args:
438
+ results: Results dict from evaluate_all or evaluate_track
439
+ title: Title for the report
440
+
441
+ Returns:
442
+ Formatted string report
443
+ """
444
+ lines = [title, "=" * len(title), ""]
445
+
446
+ # Check if this is aggregate results (from evaluate_all)
447
+ if "num_tracks" in results:
448
+ lines.append(f"Number of tracks: {results['num_tracks']}")
449
+ lines.append("")
450
+ lines.append("Overall Metrics:")
451
+ lines.append(
452
+ f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}"
453
+ )
454
+ lines.append(
455
+ f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}"
456
+ )
457
+ lines.append(
458
+ f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}"
459
+ )
460
+ lines.append("")
461
+
462
+ lines.append("Beat F1 by Threshold:")
463
+ for t, f1 in sorted(results["beat_f1_by_threshold"].items()):
464
+ lines.append(f" {t:2d}ms: {f1:.4f}")
465
+ lines.append("")
466
+
467
+ lines.append("Downbeat F1 by Threshold:")
468
+ for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()):
469
+ lines.append(f" {t:2d}ms: {f1:.4f}")
470
+ lines.append("")
471
+
472
+ # Continuity metrics
473
+ if "beat_continuity" in results:
474
+ lines.append("Beat Continuity Metrics:")
475
+ bc = results["beat_continuity"]
476
+ lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)")
477
+ lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)")
478
+ lines.append(
479
+ f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)"
480
+ )
481
+ lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)")
482
+ lines.append("")
483
+
484
+ if "downbeat_continuity" in results:
485
+ lines.append("Downbeat Continuity Metrics:")
486
+ dc = results["downbeat_continuity"]
487
+ lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)")
488
+ lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)")
489
+ lines.append(
490
+ f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)"
491
+ )
492
+ lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)")
493
+
494
+ # Single track results (from evaluate_track)
495
+ elif "beats" in results and "downbeats" in results:
496
+ lines.append("Beat Detection:")
497
+ lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}")
498
+ lines.append(f" Predictions: {results['beats']['num_predictions']}")
499
+ lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}")
500
+
501
+ # Beat continuity metrics
502
+ if "continuity" in results["beats"]:
503
+ bc = results["beats"]["continuity"]
504
+ lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}")
505
+ lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}")
506
+ lines.append("")
507
+
508
+ lines.append("Downbeat Detection:")
509
+ lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}")
510
+ lines.append(f" Predictions: {results['downbeats']['num_predictions']}")
511
+ lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}")
512
+
513
+ # Downbeat continuity metrics
514
+ if "continuity" in results["downbeats"]:
515
+ dc = results["downbeats"]["continuity"]
516
+ lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}")
517
+ lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}")
518
+ lines.append("")
519
+
520
+ lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}")
521
+
522
+ return "\n".join(lines)
523
+
524
+
525
+ if __name__ == "__main__":
526
+ # Demo with synthetic data
527
+ print("Running evaluation demo...\n")
528
+
529
+ # Simulate ground truth beats at regular intervals (30s to have beats after 5s)
530
+ gt_beats = np.arange(0, 30, 0.5).tolist() # Beat every 0.5s for 30s
531
+ gt_downbeats = np.arange(0, 30, 2.0).tolist() # Downbeat every 2s
532
+
533
+ # Simulate predictions with some noise and missed/extra detections
534
+ np.random.seed(42)
535
+ pred_beats = (
536
+ np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats))
537
+ ).tolist()
538
+ pred_beats = pred_beats[:-2] # Miss last 2 beats
539
+ pred_beats.append(15.25) # Add false positive
540
+
541
+ pred_downbeats = (
542
+ np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats))
543
+ ).tolist()
544
+
545
+ # Evaluate single track
546
+ results = evaluate_track(
547
+ pred_beats=pred_beats,
548
+ pred_downbeats=pred_downbeats,
549
+ gt_beats=gt_beats,
550
+ gt_downbeats=gt_downbeats,
551
+ )
552
+
553
+ print(format_results(results, "Single Track Demo"))
554
+ print("\n" + "=" * 50 + "\n")
555
+
556
+ # Multi-track demo
557
+ predictions = [
558
+ {"beats": pred_beats, "downbeats": pred_downbeats},
559
+ {"beats": pred_beats, "downbeats": pred_downbeats},
560
+ ]
561
+ ground_truths = [
562
+ {"beats": gt_beats, "downbeats": gt_downbeats},
563
+ {"beats": gt_beats, "downbeats": gt_downbeats},
564
+ ]
565
+
566
+ all_results = evaluate_all(predictions, ground_truths, verbose=True)
567
+ print()
568
+ print(format_results(all_results, "Multi-Track Demo"))
exp/data/load.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, Audio
2
+
3
+ N_PROC = None
4
+
5
+ ds = load_dataset("JacobLinCool/taiko-1000-parsed")
6
+ ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"])
7
+
8
+
9
+ def filter_out_broken(example):
10
+ try:
11
+ example["audio"]["array"]
12
+ return True
13
+ except:
14
+ return False
15
+
16
+
17
+ ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32)
18
+ ds = ds.cast_column("audio", Audio(sampling_rate=16000))
19
+
20
+
21
+ def build_beat_and_downbeat_labels(example):
22
+ """
23
+ Extract beat and downbeat times from the chart segments.
24
+
25
+ - Downbeats: First beat of each measure (segment timestamp)
26
+ - Beats: All beats within each measure based on time signature
27
+
28
+ Returns lists of times in seconds.
29
+ """
30
+ title = example["metadata"]["TITLE"]
31
+ segments = example["oni"]["segments"]
32
+
33
+ beats = []
34
+ downbeats = []
35
+
36
+ for i, segment in enumerate(segments):
37
+ seg_timestamp = segment["timestamp"]
38
+ measure_num = segment["measure_num"] # numerator (e.g., 4 in 4/4)
39
+ measure_den = segment["measure_den"] # denominator (e.g., 4 in 4/4)
40
+ notes = segment["notes"]
41
+
42
+ # Downbeat is the start of each measure
43
+ downbeats.append(seg_timestamp)
44
+
45
+ # Get BPM from the first note in segment, or fallback to next segment's first note
46
+ bpm = None
47
+ if notes:
48
+ bpm = notes[0]["bpm"]
49
+ else:
50
+ # Look ahead for BPM if current segment has no notes
51
+ for j in range(i + 1, len(segments)):
52
+ if segments[j]["notes"]:
53
+ bpm = segments[j]["notes"][0]["bpm"]
54
+ break
55
+
56
+ if bpm is None or bpm <= 0:
57
+ bpm = 120.0 # fallback default BPM
58
+
59
+ # Calculate beat duration: one beat = 60/BPM seconds (for quarter note)
60
+ # Adjust for time signature denominator (4 = quarter, 8 = eighth, etc.)
61
+ beat_duration = (60.0 / bpm) * (4.0 / measure_den)
62
+
63
+ # Calculate beat positions within this measure
64
+ for beat_idx in range(measure_num):
65
+ beat_time = seg_timestamp + beat_idx * beat_duration
66
+ beats.append(beat_time)
67
+
68
+ # Sort and deduplicate (in case of overlapping segments)
69
+ beats = sorted(set(beats))
70
+ downbeats = sorted(set(downbeats))
71
+
72
+ return {
73
+ "title": title,
74
+ "beats": beats,
75
+ "downbeats": downbeats,
76
+ }
77
+
78
+
79
+ ds = ds.map(
80
+ build_beat_and_downbeat_labels,
81
+ num_proc=N_PROC,
82
+ batch_size=32,
83
+ writer_batch_size=32,
84
+ remove_columns=["oni", "metadata"],
85
+ )
86
+
87
+ ds = ds.with_format("torch")
88
+
89
+ if __name__ == "__main__":
90
+ print(ds)
91
+ print(ds["train"].features)
exp/data/viz.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Visualization utilities for beat tracking evaluation.
3
+
4
+ This module provides functions to:
5
+ - Plot beat and downbeat predictions vs ground truth
6
+ - Create waveform visualizations with beat annotations
7
+ - Generate comparison plots for evaluation
8
+
9
+ Example usage:
10
+ from exp.data.viz import plot_beats, plot_waveform_with_beats, save_figure
11
+
12
+ # Plot beat comparison
13
+ fig = plot_beats(pred_beats, gt_beats, pred_downbeats, gt_downbeats)
14
+ save_figure(fig, "beat_comparison.png")
15
+
16
+ # Plot waveform with beats
17
+ fig = plot_waveform_with_beats(audio, sr, pred_beats, gt_beats)
18
+ save_figure(fig, "waveform.png")
19
+ """
20
+
21
+ import numpy as np
22
+ from pathlib import Path
23
+
24
+ # Try to import matplotlib, but make it optional
25
+ try:
26
+ import matplotlib.pyplot as plt
27
+ import matplotlib.patches as mpatches
28
+
29
+ HAS_MATPLOTLIB = True
30
+ except ImportError:
31
+ HAS_MATPLOTLIB = False
32
+
33
+
34
+ def _check_matplotlib():
35
+ if not HAS_MATPLOTLIB:
36
+ raise ImportError(
37
+ "matplotlib is required for visualization. "
38
+ "Install with: pip install matplotlib"
39
+ )
40
+
41
+
42
+ def plot_beats(
43
+ pred_beats: list[float] | np.ndarray,
44
+ gt_beats: list[float] | np.ndarray,
45
+ pred_downbeats: list[float] | np.ndarray | None = None,
46
+ gt_downbeats: list[float] | np.ndarray | None = None,
47
+ title: str = "Beat Tracking Comparison",
48
+ figsize: tuple[int, int] = (14, 4),
49
+ time_range: tuple[float, float] | None = None,
50
+ ) -> "plt.Figure":
51
+ """
52
+ Create a visualization comparing predicted and ground truth beats.
53
+
54
+ Args:
55
+ pred_beats: Predicted beat times in seconds
56
+ gt_beats: Ground truth beat times in seconds
57
+ pred_downbeats: Predicted downbeat times (optional)
58
+ gt_downbeats: Ground truth downbeat times (optional)
59
+ title: Plot title
60
+ figsize: Figure size (width, height)
61
+ time_range: Optional tuple (start, end) to limit time range
62
+
63
+ Returns:
64
+ matplotlib Figure object
65
+ """
66
+ _check_matplotlib()
67
+
68
+ fig, ax = plt.subplots(figsize=figsize)
69
+
70
+ pred_beats = np.array(pred_beats)
71
+ gt_beats = np.array(gt_beats)
72
+
73
+ # Apply time range filter
74
+ if time_range is not None:
75
+ start, end = time_range
76
+ pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
77
+ gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
78
+
79
+ if pred_downbeats is not None:
80
+ pred_downbeats = np.array(pred_downbeats)
81
+ pred_downbeats = pred_downbeats[
82
+ (pred_downbeats >= start) & (pred_downbeats <= end)
83
+ ]
84
+ if gt_downbeats is not None:
85
+ gt_downbeats = np.array(gt_downbeats)
86
+ gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
87
+
88
+ # Plot ground truth beats
89
+ ax.vlines(
90
+ gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5, label="GT Beats"
91
+ )
92
+
93
+ # Plot predicted beats
94
+ ax.vlines(
95
+ pred_beats,
96
+ 0.6,
97
+ 1.0,
98
+ colors="blue",
99
+ alpha=0.7,
100
+ linewidth=1.5,
101
+ label="Pred Beats",
102
+ )
103
+
104
+ # Plot downbeats if provided
105
+ if gt_downbeats is not None and len(gt_downbeats) > 0:
106
+ gt_downbeats = np.array(gt_downbeats)
107
+ ax.vlines(
108
+ gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3, label="GT Downbeats"
109
+ )
110
+
111
+ if pred_downbeats is not None and len(pred_downbeats) > 0:
112
+ pred_downbeats = np.array(pred_downbeats)
113
+ ax.vlines(
114
+ pred_downbeats,
115
+ 0.6,
116
+ 1.0,
117
+ colors="darkblue",
118
+ linewidth=3,
119
+ label="Pred Downbeats",
120
+ )
121
+
122
+ # Styling
123
+ ax.set_ylim(-0.1, 1.1)
124
+ ax.set_yticks([0.2, 0.8])
125
+ ax.set_yticklabels(["Ground Truth", "Prediction"])
126
+ ax.set_xlabel("Time (seconds)")
127
+ ax.set_title(title)
128
+ ax.legend(loc="upper right", ncol=4)
129
+ ax.grid(True, alpha=0.3)
130
+
131
+ # Set x-axis range
132
+ if time_range is not None:
133
+ ax.set_xlim(time_range)
134
+ else:
135
+ all_times = np.concatenate([pred_beats, gt_beats])
136
+ if len(all_times) > 0:
137
+ ax.set_xlim(0, np.max(all_times) + 0.5)
138
+
139
+ plt.tight_layout()
140
+ return fig
141
+
142
+
143
+ def plot_waveform_with_beats(
144
+ audio: np.ndarray,
145
+ sr: int,
146
+ pred_beats: list[float] | np.ndarray,
147
+ gt_beats: list[float] | np.ndarray,
148
+ pred_downbeats: list[float] | np.ndarray | None = None,
149
+ gt_downbeats: list[float] | np.ndarray | None = None,
150
+ title: str = "Waveform with Beat Annotations",
151
+ figsize: tuple[int, int] = (14, 6),
152
+ time_range: tuple[float, float] | None = None,
153
+ ) -> "plt.Figure":
154
+ """
155
+ Create a waveform visualization with beat annotations.
156
+
157
+ Args:
158
+ audio: Audio waveform
159
+ sr: Sample rate
160
+ pred_beats: Predicted beat times
161
+ gt_beats: Ground truth beat times
162
+ pred_downbeats: Predicted downbeat times (optional)
163
+ gt_downbeats: Ground truth downbeat times (optional)
164
+ title: Plot title
165
+ figsize: Figure size
166
+ time_range: Optional tuple (start, end) to limit time range
167
+
168
+ Returns:
169
+ matplotlib Figure object
170
+ """
171
+ _check_matplotlib()
172
+
173
+ fig, (ax1, ax2) = plt.subplots(
174
+ 2, 1, figsize=figsize, sharex=True, height_ratios=[3, 1]
175
+ )
176
+
177
+ # Time axis
178
+ duration = len(audio) / sr
179
+ t = np.linspace(0, duration, len(audio))
180
+
181
+ # Apply time range
182
+ if time_range is not None:
183
+ start, end = time_range
184
+ start_idx = int(start * sr)
185
+ end_idx = int(end * sr)
186
+ t = t[start_idx:end_idx]
187
+ audio_plot = audio[start_idx:end_idx]
188
+ else:
189
+ audio_plot = audio
190
+ start, end = 0, duration
191
+
192
+ # Plot waveform
193
+ ax1.plot(t, audio_plot, color="gray", alpha=0.7, linewidth=0.5)
194
+ ax1.set_ylabel("Amplitude")
195
+ ax1.set_title(title)
196
+
197
+ # Filter beats to time range
198
+ pred_beats = np.array(pred_beats)
199
+ gt_beats = np.array(gt_beats)
200
+ pred_beats = pred_beats[(pred_beats >= start) & (pred_beats <= end)]
201
+ gt_beats = gt_beats[(gt_beats >= start) & (gt_beats <= end)]
202
+
203
+ # Plot beat markers on waveform
204
+ audio_max = np.abs(audio_plot).max() if len(audio_plot) > 0 else 1.0
205
+
206
+ for beat in gt_beats:
207
+ ax1.axvline(beat, color="green", alpha=0.5, linewidth=1)
208
+ for beat in pred_beats:
209
+ ax1.axvline(beat, color="blue", alpha=0.3, linewidth=1, linestyle="--")
210
+
211
+ # Add downbeat markers (thicker lines)
212
+ if gt_downbeats is not None:
213
+ gt_downbeats = np.array(gt_downbeats)
214
+ gt_downbeats = gt_downbeats[(gt_downbeats >= start) & (gt_downbeats <= end)]
215
+ for db in gt_downbeats:
216
+ ax1.axvline(db, color="darkgreen", alpha=0.8, linewidth=2)
217
+
218
+ if pred_downbeats is not None:
219
+ pred_downbeats = np.array(pred_downbeats)
220
+ pred_downbeats = pred_downbeats[
221
+ (pred_downbeats >= start) & (pred_downbeats <= end)
222
+ ]
223
+ for db in pred_downbeats:
224
+ ax1.axvline(db, color="darkblue", alpha=0.5, linewidth=2, linestyle="--")
225
+
226
+ ax1.set_ylim(-audio_max * 1.1, audio_max * 1.1)
227
+
228
+ # Beat comparison subplot
229
+ ax2.vlines(gt_beats, 0, 0.4, colors="green", alpha=0.7, linewidth=1.5)
230
+ ax2.vlines(pred_beats, 0.6, 1.0, colors="blue", alpha=0.7, linewidth=1.5)
231
+
232
+ if gt_downbeats is not None and len(gt_downbeats) > 0:
233
+ ax2.vlines(gt_downbeats, 0, 0.4, colors="darkgreen", linewidth=3)
234
+ if pred_downbeats is not None and len(pred_downbeats) > 0:
235
+ ax2.vlines(pred_downbeats, 0.6, 1.0, colors="darkblue", linewidth=3)
236
+
237
+ ax2.set_ylim(-0.1, 1.1)
238
+ ax2.set_yticks([0.2, 0.8])
239
+ ax2.set_yticklabels(["GT", "Pred"])
240
+ ax2.set_xlabel("Time (seconds)")
241
+
242
+ # Legend
243
+ legend_elements = [
244
+ mpatches.Patch(color="green", alpha=0.7, label="GT Beats"),
245
+ mpatches.Patch(color="blue", alpha=0.7, label="Pred Beats"),
246
+ mpatches.Patch(color="darkgreen", label="GT Downbeats"),
247
+ mpatches.Patch(color="darkblue", label="Pred Downbeats"),
248
+ ]
249
+ ax1.legend(handles=legend_elements, loc="upper right", ncol=4)
250
+
251
+ ax1.grid(True, alpha=0.3)
252
+ ax2.grid(True, alpha=0.3)
253
+
254
+ plt.tight_layout()
255
+ return fig
256
+
257
+
258
+ def plot_evaluation_summary(
259
+ results: dict,
260
+ title: str = "Evaluation Summary",
261
+ figsize: tuple[int, int] = (12, 8),
262
+ ) -> "plt.Figure":
263
+ """
264
+ Create a summary visualization of evaluation results.
265
+
266
+ Args:
267
+ results: Results dict from evaluate_all
268
+ title: Plot title
269
+ figsize: Figure size
270
+
271
+ Returns:
272
+ matplotlib Figure object
273
+ """
274
+ _check_matplotlib()
275
+
276
+ fig, axes = plt.subplots(2, 2, figsize=figsize)
277
+
278
+ # F1 by threshold for beats
279
+ ax1 = axes[0, 0]
280
+ if "beat_f1_by_threshold" in results:
281
+ thresholds = sorted(results["beat_f1_by_threshold"].keys())
282
+ f1_scores = [results["beat_f1_by_threshold"][t] for t in thresholds]
283
+ ax1.bar(range(len(thresholds)), f1_scores, color="steelblue", alpha=0.8)
284
+ ax1.set_xticks(range(len(thresholds)))
285
+ ax1.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
286
+ ax1.set_ylabel("F1 Score")
287
+ ax1.set_title("Beat F1 by Threshold")
288
+ ax1.set_ylim(0, 1)
289
+ ax1.grid(True, alpha=0.3)
290
+
291
+ # F1 by threshold for downbeats
292
+ ax2 = axes[0, 1]
293
+ if "downbeat_f1_by_threshold" in results:
294
+ thresholds = sorted(results["downbeat_f1_by_threshold"].keys())
295
+ f1_scores = [results["downbeat_f1_by_threshold"][t] for t in thresholds]
296
+ ax2.bar(range(len(thresholds)), f1_scores, color="coral", alpha=0.8)
297
+ ax2.set_xticks(range(len(thresholds)))
298
+ ax2.set_xticklabels([f"{t}ms" for t in thresholds], rotation=45)
299
+ ax2.set_ylabel("F1 Score")
300
+ ax2.set_title("Downbeat F1 by Threshold")
301
+ ax2.set_ylim(0, 1)
302
+ ax2.grid(True, alpha=0.3)
303
+
304
+ # Continuity metrics for beats
305
+ ax3 = axes[1, 0]
306
+ if "beat_continuity" in results:
307
+ metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
308
+ values = [results["beat_continuity"][m] for m in metrics]
309
+ colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
310
+ bars = ax3.bar(metrics, values, color=colors, alpha=0.8)
311
+ ax3.set_ylabel("Score")
312
+ ax3.set_title("Beat Continuity Metrics")
313
+ ax3.set_ylim(0, 1)
314
+ ax3.grid(True, alpha=0.3)
315
+ # Add value labels
316
+ for bar, val in zip(bars, values):
317
+ ax3.text(
318
+ bar.get_x() + bar.get_width() / 2,
319
+ bar.get_height() + 0.02,
320
+ f"{val:.3f}",
321
+ ha="center",
322
+ fontsize=9,
323
+ )
324
+
325
+ # Continuity metrics for downbeats
326
+ ax4 = axes[1, 1]
327
+ if "downbeat_continuity" in results:
328
+ metrics = ["CMLc", "CMLt", "AMLc", "AMLt"]
329
+ values = [results["downbeat_continuity"][m] for m in metrics]
330
+ colors = ["#2E86AB", "#A23B72", "#F18F01", "#C73E1D"]
331
+ bars = ax4.bar(metrics, values, color=colors, alpha=0.8)
332
+ ax4.set_ylabel("Score")
333
+ ax4.set_title("Downbeat Continuity Metrics")
334
+ ax4.set_ylim(0, 1)
335
+ ax4.grid(True, alpha=0.3)
336
+ # Add value labels
337
+ for bar, val in zip(bars, values):
338
+ ax4.text(
339
+ bar.get_x() + bar.get_width() / 2,
340
+ bar.get_height() + 0.02,
341
+ f"{val:.3f}",
342
+ ha="center",
343
+ fontsize=9,
344
+ )
345
+
346
+ fig.suptitle(title, fontsize=14, fontweight="bold")
347
+ plt.tight_layout()
348
+ return fig
349
+
350
+
351
+ def save_figure(
352
+ fig: "plt.Figure",
353
+ path: str | Path,
354
+ dpi: int = 150,
355
+ ) -> None:
356
+ """
357
+ Save a matplotlib figure to file.
358
+
359
+ Args:
360
+ fig: Figure to save
361
+ path: Output file path
362
+ dpi: Resolution (dots per inch)
363
+ """
364
+ _check_matplotlib()
365
+
366
+ path = Path(path)
367
+ path.parent.mkdir(parents=True, exist_ok=True)
368
+ fig.savefig(str(path), dpi=dpi, bbox_inches="tight")
369
+ plt.close(fig)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ # Demo
374
+ _check_matplotlib()
375
+ print("Visualization demo...")
376
+
377
+ # Generate synthetic data
378
+ np.random.seed(42)
379
+ gt_beats = np.arange(0, 10, 0.5)
380
+ gt_downbeats = np.arange(0, 10, 2.0)
381
+ pred_beats = gt_beats + np.random.normal(0, 0.02, len(gt_beats))
382
+ pred_downbeats = gt_downbeats + np.random.normal(0, 0.01, len(gt_downbeats))
383
+
384
+ # Generate fake audio
385
+ sr = 16000
386
+ duration = 10.0
387
+ t = np.arange(int(duration * sr)) / sr
388
+ audio = np.sin(2 * np.pi * 220 * t) * 0.3
389
+
390
+ # Create plots
391
+ fig1 = plot_beats(
392
+ pred_beats, gt_beats, pred_downbeats, gt_downbeats, title="Beat Comparison Demo"
393
+ )
394
+ save_figure(fig1, "/tmp/beat_comparison_demo.png")
395
+ print("Saved /tmp/beat_comparison_demo.png")
396
+
397
+ fig2 = plot_waveform_with_beats(
398
+ audio,
399
+ sr,
400
+ pred_beats,
401
+ gt_beats,
402
+ pred_downbeats,
403
+ gt_downbeats,
404
+ title="Waveform Demo",
405
+ time_range=(2, 8),
406
+ )
407
+ save_figure(fig2, "/tmp/waveform_demo.png")
408
+ print("Saved /tmp/waveform_demo.png")
409
+
410
+ # Fake evaluation results
411
+ results = {
412
+ "beat_f1_by_threshold": {
413
+ 3: 0.5,
414
+ 6: 0.7,
415
+ 9: 0.85,
416
+ 12: 0.9,
417
+ 15: 0.95,
418
+ 18: 0.96,
419
+ 21: 0.97,
420
+ 24: 0.97,
421
+ 27: 0.98,
422
+ 30: 0.98,
423
+ },
424
+ "downbeat_f1_by_threshold": {
425
+ 3: 0.6,
426
+ 6: 0.8,
427
+ 9: 0.9,
428
+ 12: 0.95,
429
+ 15: 0.97,
430
+ 18: 0.98,
431
+ 21: 0.98,
432
+ 24: 0.99,
433
+ 27: 0.99,
434
+ 30: 0.99,
435
+ },
436
+ "beat_continuity": {"CMLc": 0.75, "CMLt": 0.92, "AMLc": 0.80, "AMLt": 0.95},
437
+ "downbeat_continuity": {"CMLc": 0.85, "CMLt": 0.95, "AMLc": 0.88, "AMLt": 0.97},
438
+ }
439
+ fig3 = plot_evaluation_summary(results, title="Evaluation Summary Demo")
440
+ save_figure(fig3, "/tmp/eval_summary_demo.png")
441
+ print("Saved /tmp/eval_summary_demo.png")
outputs/baseline1/beats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/beats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/beats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/beats/final/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/beats/final/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0ee01ee41360f0b486e16d6022f896a19f9ead901be0180bdbd9cad2a3b8597
3
+ size 1159372
outputs/baseline1/beats/logs/events.out.tfevents.1766351314.msiit232.1284330.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b2d91a22ba01091bf072f5a5e8f12fc7d49801d6538914c973ccb2700978934
3
+ size 17749022
outputs/baseline1/beats/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7a0d5178bc5dfeee6da26345e7956aeb6bf64a21be7e541db4bcc37b290249
3
+ size 1159372
outputs/baseline1/downbeats/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
outputs/baseline1/downbeats/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "dropout_rate": 0.5
3
+ }
outputs/baseline1/downbeats/final/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]