jorisvaneyghen commited on
Commit
b8279df
·
1 Parent(s): 76a31ca

improved controls

Browse files
Files changed (2) hide show
  1. app.py +116 -7
  2. index.html +75 -101
app.py CHANGED
@@ -1,12 +1,14 @@
1
- from madmom.features.downbeats import DBNDownBeatTrackingProcessor
 
 
 
 
2
  from beat_this.model.postprocessor import Postprocessor
3
  from flask import Flask, request, jsonify, send_from_directory
4
  from flask_cors import CORS
5
- import numpy as np
6
- import torch
7
- import logging
8
- import os
9
- import mimetypes
10
 
11
  # Add MIME types for JavaScript and WebAssembly
12
  mimetypes.add_type('application/javascript', '.js')
@@ -109,10 +111,17 @@ def postprocess_beats():
109
  downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats
110
  beats = beats.tolist() if isinstance(beats, np.ndarray) else beats
111
 
 
 
 
 
 
112
  bars = {i: beat for i, beat in enumerate(downbeats)}
113
 
114
  return jsonify({
115
- "bars": bars
 
 
116
  })
117
 
118
  except Exception as e:
@@ -168,5 +177,105 @@ def process_logits(beat_logits, downbeat_logits, type='minimal',
168
  return beats, downbeats
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  if __name__ == '__main__':
172
  app.run(debug=True)
 
1
+ import logging
2
+ import mimetypes
3
+ import numpy as np
4
+ import os
5
+ import torch
6
  from beat_this.model.postprocessor import Postprocessor
7
  from flask import Flask, request, jsonify, send_from_directory
8
  from flask_cors import CORS
9
+ from madmom.features.downbeats import DBNDownBeatTrackingProcessor
10
+ from statistics import median, mode
11
+ from typing import List, Tuple
 
 
12
 
13
  # Add MIME types for JavaScript and WebAssembly
14
  mimetypes.add_type('application/javascript', '.js')
 
111
  downbeats = downbeats.tolist() if isinstance(downbeats, np.ndarray) else downbeats
112
  beats = beats.tolist() if isinstance(beats, np.ndarray) else beats
113
 
114
+ estimated_bpm, detected_beats_per_bar, final_indices = analyze_beats(beats, downbeats)
115
+ print(f"estimated bpm: {estimated_bpm}, detected beats_per_bar: {detected_beats_per_bar}")
116
+ print(final_indices)
117
+
118
+
119
  bars = {i: beat for i, beat in enumerate(downbeats)}
120
 
121
  return jsonify({
122
+ "bars": bars,
123
+ "estimated_bpm": estimated_bpm,
124
+ "detected_beats_per_bar": detected_beats_per_bar
125
  })
126
 
127
  except Exception as e:
 
177
  return beats, downbeats
178
 
179
 
180
+
181
+ def analyze_beats(beats: List[float], downbeats: List[float]) -> Tuple[float, float, List[int]]:
182
+ """
183
+ Analyze beats and downbeats to calculate BPM and clean outliers.
184
+
185
+ Args:
186
+ beats: List of beat positions in seconds
187
+ downbeats: List of downbeat positions in seconds (first beat of each bar)
188
+
189
+ Returns:
190
+ Tuple containing:
191
+ - estimated_bpm: Calculated BPM after removing outliers
192
+ - beats_per_bar: Median bar duration in seconds
193
+ - valid_bar_indices: Indices of bars that passed all filters
194
+ """
195
+
196
+ # Step 1: Calculate beats per bar and bar durations
197
+ bar_beats_count = []
198
+ bar_durations = []
199
+
200
+ for i in range(len(downbeats) - 1):
201
+ # Find beats between current downbeat and next downbeat
202
+ start_time = downbeats[i]
203
+ end_time = downbeats[i + 1]
204
+
205
+ # Count beats in this bar
206
+ beats_in_bar = len([beat for beat in beats if start_time <= beat < end_time])
207
+ bar_beats_count.append(beats_in_bar)
208
+
209
+ # Calculate bar duration
210
+ bar_duration = end_time - start_time
211
+ bar_durations.append(bar_duration)
212
+
213
+ # Handle the last bar (if we have at least one downbeat)
214
+ if len(downbeats) > 0:
215
+ last_start = downbeats[-1]
216
+ # For the last bar, count beats from last downbeat to end of beats list
217
+ last_beats = len([beat for beat in beats if beat >= last_start])
218
+ bar_beats_count.append(last_beats)
219
+
220
+ # Estimate last bar duration using average beat duration
221
+ if len(beats) > 1:
222
+ avg_beat_duration = (beats[-1] - beats[0]) / (len(beats) - 1)
223
+ last_duration = last_beats * avg_beat_duration
224
+ else:
225
+ last_duration = 0
226
+ bar_durations.append(last_duration)
227
+
228
+ # Step 2: Remove bars with outlier beats per bar
229
+ if bar_beats_count:
230
+ try:
231
+ # Find the most common beats per bar value
232
+ common_beats_per_bar = mode(bar_beats_count)
233
+
234
+ # Keep only bars with the common beats per bar
235
+ valid_bars_bp = []
236
+ valid_durations_bp = []
237
+ valid_indices_bp = []
238
+
239
+ for i, (beat_count, duration) in enumerate(zip(bar_beats_count, bar_durations)):
240
+ if beat_count == common_beats_per_bar:
241
+ valid_bars_bp.append(beat_count)
242
+ valid_durations_bp.append(duration)
243
+ valid_indices_bp.append(i)
244
+ except StatisticsError:
245
+ # If no clear mode, use median
246
+ median_beats = median(bar_beats_count)
247
+ valid_bars_bp = [bc for bc in bar_beats_count if bc == median_beats]
248
+ valid_durations_bp = [bar_durations[i] for i, bc in enumerate(bar_beats_count) if bc == median_beats]
249
+ valid_indices_bp = [i for i, bc in enumerate(bar_beats_count) if bc == median_beats]
250
+ else:
251
+ return 0, 0, []
252
+
253
+ # Step 3: Remove bars with outlier durations
254
+ if valid_durations_bp:
255
+ median_duration = median(valid_durations_bp)
256
+
257
+ # Calculate reasonable bounds (e.g., ±25% of median)
258
+ lower_bound = median_duration * 0.75
259
+ upper_bound = median_duration * 1.25
260
+
261
+ final_durations = []
262
+ final_indices = []
263
+
264
+ for i, duration in zip(valid_indices_bp, valid_durations_bp):
265
+ if lower_bound <= duration <= upper_bound:
266
+ final_durations.append(duration)
267
+ final_indices.append(i)
268
+
269
+ # Step 4: Calculate average bar duration and convert to BPM
270
+ if final_durations:
271
+ avg_bar_duration = sum(final_durations) / len(final_durations)
272
+ estimated_bpm = 60.0 / (avg_bar_duration / common_beats_per_bar) # BPM = 60 / (beat duration in seconds)
273
+
274
+ return estimated_bpm, common_beats_per_bar, final_indices
275
+
276
+ return 0, 0, []
277
+
278
+
279
+
280
  if __name__ == '__main__':
281
  app.run(debug=True)
index.html CHANGED
@@ -3,7 +3,7 @@
3
  <head>
4
  <meta charset="UTF-8">
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
- <title>Music Beat Detection</title>
7
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
8
  <style>
9
  * {
@@ -166,34 +166,6 @@
166
  gap: 10px;
167
  margin-top: 15px;
168
  }
169
-
170
- .play-button, .stop-button {
171
- padding: 10px 20px;
172
- border: none;
173
- border-radius: 5px;
174
- cursor: pointer;
175
- font-weight: bold;
176
- transition: all 0.3s;
177
- }
178
-
179
- .play-button {
180
- background: #4CAF50;
181
- color: white;
182
- }
183
-
184
- .play-button:hover {
185
- background: #45a049;
186
- }
187
-
188
- .stop-button {
189
- background: #f44336;
190
- color: white;
191
- }
192
-
193
- .stop-button:hover {
194
- background: #d32f2f;
195
- }
196
-
197
  .form-group {
198
  margin-bottom: 15px;
199
  }
@@ -274,15 +246,6 @@
274
  border-bottom: none;
275
  }
276
 
277
- .visualizer {
278
- height: 100px;
279
- background: rgba(0, 0, 0, 0.2);
280
- border-radius: 10px;
281
- margin-top: 20px;
282
- position: relative;
283
- overflow: hidden;
284
- }
285
-
286
  .beat-marker {
287
  position: absolute;
288
  bottom: 0;
@@ -368,7 +331,7 @@
368
  <body>
369
  <div class="container">
370
  <header>
371
- <h1>Music Beat Detection</h1>
372
  <p class="subtitle">Upload a song, detect beats and bars, then play specific sections</p>
373
  </header>
374
 
@@ -417,15 +380,6 @@
417
  <button id="cancelButton" class="cancel-button">Cancel Processing</button>
418
  </div>
419
 
420
- <div class="controls">
421
- <button id="playButton" class="play-button" disabled>Play</button>
422
- <button id="stopButton" class="stop-button" disabled>Stop</button>
423
- </div>
424
-
425
- <div class="visualizer" id="visualizer">
426
- <div class="current-time" id="currentTime"></div>
427
- </div>
428
-
429
  <audio id="audioPlayer" class="audio-player" controls></audio>
430
  </div>
431
 
@@ -434,22 +388,14 @@
434
 
435
  <div id="results" class="results">
436
  <div class="results-grid">
437
- <div class="result-item">
438
- <div class="result-label">Song Duration</div>
439
- <div id="songDuration" class="result-value">--:--</div>
440
- </div>
441
- <div class="result-item">
442
- <div class="result-label">Detected Beats</div>
443
- <div id="beatCount" class="result-value">0</div>
444
- </div>
445
- <div class="result-item">
446
- <div class="result-label">Detected Downbeats</div>
447
- <div id="downbeatCount" class="result-value">0</div>
448
- </div>
449
  <div class="result-item">
450
  <div class="result-label">Estimated BPM</div>
451
  <div id="estimatedBPM" class="result-value">--</div>
452
  </div>
 
 
 
 
453
  </div>
454
 
455
  <h3 style="margin-top: 20px;">Bar Detection Parameters</h3>
@@ -563,12 +509,6 @@
563
  audioData[i] = (channel0[i] + channel1[i]) / 2;
564
  }
565
 
566
- // Output first 10 values
567
- console.log('First 10 values of audioData:');
568
- for (let i = 0; i < 10 && i < audioData.length; i++) {
569
- console.log(`[${i}]: ${audioData[i]}`);
570
- }
571
-
572
  // Use the ONNX model to compute log mel spectrogram
573
  return await this.computeLogMelSpectrogramONNX(audioData);
574
  }
@@ -780,11 +720,20 @@
780
 
781
  console.log(`Server postprocessing results: ${result.bars ? Object.keys(result.bars).length : 0} bars`);
782
 
783
- return result.bars || {};
 
 
 
 
 
784
  } catch (error) {
785
  console.error("Error in server postprocessing:", error);
786
  // Return empty object as fallback
787
- return {};
 
 
 
 
788
  }
789
  }
790
 
@@ -875,6 +824,8 @@
875
  this.audioContext = null;
876
  this.logits = null;
877
  this.bars = null;
 
 
878
  this.isLooping = false;
879
  this.audioSource = null;
880
  this.isPlayingSection = false;
@@ -964,8 +915,6 @@
964
  const uploadArea = document.getElementById('uploadArea');
965
  const fileInput = document.getElementById('audioFile');
966
  const cancelButton = document.getElementById('cancelButton');
967
- const playButton = document.getElementById('playButton');
968
- const stopButton = document.getElementById('stopButton');
969
  const calculateBarsButton = document.getElementById('calculateBars');
970
  const playSectionButton = document.getElementById('playSection');
971
  const audioPlayer = document.getElementById('audioPlayer');
@@ -996,13 +945,10 @@
996
 
997
  cancelButton.addEventListener('click', () => this.cancelProcessing());
998
 
999
- playButton.addEventListener('click', () => this.playAudio());
1000
- stopButton.addEventListener('click', () => this.stopAudio());
1001
-
1002
  calculateBarsButton.addEventListener('click', () => this.calculateBars());
1003
  playSectionButton.addEventListener('click', () => this.playSection());
1004
 
1005
- audioPlayer.addEventListener('timeupdate', () => this.updateVisualizer());
1006
  audioPlayer.addEventListener('ended', () => this.handleAudioEnded());
1007
 
1008
  loopSectionCheckbox.addEventListener('change', (e) => {
@@ -1089,10 +1035,14 @@
1089
  updateProgress
1090
  );
1091
 
 
 
 
 
1092
  // Update UI with results
1093
  this.updateResultsUI();
1094
 
1095
- // Enable the calculate bars button
1096
  calculateBarsButton.disabled = false;
1097
 
1098
  // Show results section
@@ -1125,25 +1075,41 @@
1125
  }
1126
  }
1127
 
1128
- updateResultsUI() {
1129
- // Update song duration
1130
- const songDuration = document.getElementById('songDuration');
1131
- songDuration.textContent = this.formatTime(this.audioBuffer.duration);
1132
 
1133
- // Update beat and downbeat counts
1134
- const beatCount = document.getElementById('beatCount');
1135
- const downbeatCount = document.getElementById('downbeatCount');
 
 
 
 
 
1136
 
1137
- // Count beats and downbeats (you might need to adjust this based on your logits)
1138
- const beats = this.logits.prediction_beat.filter(val => val > 1).length;
1139
- const downbeats = this.logits.prediction_downbeat.filter(val => val > 1).length;
1140
 
1141
- beatCount.textContent = beats;
1142
- downbeatCount.textContent = downbeats;
 
 
 
 
 
 
 
 
 
 
 
1143
 
1144
- // Enable play button
1145
- document.getElementById('playButton').disabled = false;
1146
- document.getElementById('stopButton').disabled = false;
1147
 
1148
  // Set up audio player
1149
  const audioPlayer = document.getElementById('audioPlayer');
@@ -1151,8 +1117,12 @@
1151
  const url = URL.createObjectURL(blob);
1152
  audioPlayer.src = url;
1153
 
1154
- // Update visualizer
1155
- this.updateVisualizer();
 
 
 
 
1156
  }
1157
 
1158
  async calculateBars() {
@@ -1170,7 +1140,7 @@
1170
  calculateBarsButton.textContent = 'Calculating...';
1171
 
1172
  try {
1173
- this.bars = await this.detector.logits_to_bars(
1174
  this.logits.prediction_beat,
1175
  this.logits.prediction_downbeat,
1176
  minBPM,
@@ -1178,8 +1148,19 @@
1178
  beatsPerBar
1179
  );
1180
 
 
 
 
 
1181
  this.displayBars();
1182
 
 
 
 
 
 
 
 
1183
  // Show bars results section
1184
  document.getElementById('barsResults').style.display = 'block';
1185
 
@@ -1276,15 +1257,8 @@
1276
  }
1277
  }
1278
 
1279
- updateVisualizer() {
1280
  const audioPlayer = document.getElementById('audioPlayer');
1281
- const currentTime = document.getElementById('currentTime');
1282
- const visualizer = document.getElementById('visualizer');
1283
-
1284
- if (audioPlayer.duration > 0) {
1285
- const progress = (audioPlayer.currentTime / audioPlayer.duration) * 100;
1286
- currentTime.style.left = `${progress}%`;
1287
- }
1288
 
1289
  // If playing a section and reached the end, loop if enabled
1290
  if (this.isPlayingSection && audioPlayer.currentTime >= this.sectionEndTime) {
 
3
  <head>
4
  <meta charset="UTF-8">
5
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Loop Maestro</title>
7
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
8
  <style>
9
  * {
 
166
  gap: 10px;
167
  margin-top: 15px;
168
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  .form-group {
170
  margin-bottom: 15px;
171
  }
 
246
  border-bottom: none;
247
  }
248
 
 
 
 
 
 
 
 
 
 
249
  .beat-marker {
250
  position: absolute;
251
  bottom: 0;
 
331
  <body>
332
  <div class="container">
333
  <header>
334
+ <h1>Loop Maestro</h1>
335
  <p class="subtitle">Upload a song, detect beats and bars, then play specific sections</p>
336
  </header>
337
 
 
380
  <button id="cancelButton" class="cancel-button">Cancel Processing</button>
381
  </div>
382
 
 
 
 
 
 
 
 
 
 
383
  <audio id="audioPlayer" class="audio-player" controls></audio>
384
  </div>
385
 
 
388
 
389
  <div id="results" class="results">
390
  <div class="results-grid">
 
 
 
 
 
 
 
 
 
 
 
 
391
  <div class="result-item">
392
  <div class="result-label">Estimated BPM</div>
393
  <div id="estimatedBPM" class="result-value">--</div>
394
  </div>
395
+ <div class="result-item">
396
+ <div class="result-label">Detected Beats Per Bar</div>
397
+ <div id="detectedBeatsPerBar" class="result-value">--</div>
398
+ </div>
399
  </div>
400
 
401
  <h3 style="margin-top: 20px;">Bar Detection Parameters</h3>
 
509
  audioData[i] = (channel0[i] + channel1[i]) / 2;
510
  }
511
 
 
 
 
 
 
 
512
  // Use the ONNX model to compute log mel spectrogram
513
  return await this.computeLogMelSpectrogramONNX(audioData);
514
  }
 
720
 
721
  console.log(`Server postprocessing results: ${result.bars ? Object.keys(result.bars).length : 0} bars`);
722
 
723
+ // Return bars along with estimated_bpm and detected_beats_per_bar
724
+ return {
725
+ bars: result.bars || {},
726
+ estimated_bpm: result.estimated_bpm || null,
727
+ detected_beats_per_bar: result.detected_beats_per_bar || null
728
+ };
729
  } catch (error) {
730
  console.error("Error in server postprocessing:", error);
731
  // Return empty object as fallback
732
+ return {
733
+ bars: {},
734
+ estimated_bpm: null,
735
+ detected_beats_per_bar: null
736
+ };
737
  }
738
  }
739
 
 
824
  this.audioContext = null;
825
  this.logits = null;
826
  this.bars = null;
827
+ this.estimatedBPM = null;
828
+ this.detectedBeatsPerBar = null;
829
  this.isLooping = false;
830
  this.audioSource = null;
831
  this.isPlayingSection = false;
 
915
  const uploadArea = document.getElementById('uploadArea');
916
  const fileInput = document.getElementById('audioFile');
917
  const cancelButton = document.getElementById('cancelButton');
 
 
918
  const calculateBarsButton = document.getElementById('calculateBars');
919
  const playSectionButton = document.getElementById('playSection');
920
  const audioPlayer = document.getElementById('audioPlayer');
 
945
 
946
  cancelButton.addEventListener('click', () => this.cancelProcessing());
947
 
 
 
 
948
  calculateBarsButton.addEventListener('click', () => this.calculateBars());
949
  playSectionButton.addEventListener('click', () => this.playSection());
950
 
951
+ audioPlayer.addEventListener('timeupdate', () => this.checkLoop());
952
  audioPlayer.addEventListener('ended', () => this.handleAudioEnded());
953
 
954
  loopSectionCheckbox.addEventListener('change', (e) => {
 
1035
  updateProgress
1036
  );
1037
 
1038
+ // Automatically run logits_to_bars after preprocessing is complete
1039
+ await updateProgress(95, 'Running bar detection...');
1040
+ await this.runAutomaticBarDetection();
1041
+
1042
  // Update UI with results
1043
  this.updateResultsUI();
1044
 
1045
+ // Enable the calculate bars button for manual recalculation
1046
  calculateBarsButton.disabled = false;
1047
 
1048
  // Show results section
 
1075
  }
1076
  }
1077
 
1078
+ async runAutomaticBarDetection() {
1079
+ const minBPM = parseFloat(document.getElementById('minBPM').value);
1080
+ const maxBPM = parseFloat(document.getElementById('maxBPM').value);
1081
+ const beatsPerBar = parseInt(document.getElementById('beatsPerBar').value);
1082
 
1083
+ try {
1084
+ const result = await this.detector.logits_to_bars(
1085
+ this.logits.prediction_beat,
1086
+ this.logits.prediction_downbeat,
1087
+ minBPM,
1088
+ maxBPM,
1089
+ beatsPerBar
1090
+ );
1091
 
1092
+ this.bars = result.bars;
1093
+ this.estimatedBPM = result.estimated_bpm;
1094
+ this.detectedBeatsPerBar = result.detected_beats_per_bar;
1095
 
1096
+ console.log(`Automatic bar detection: BPM=${this.estimatedBPM}, BeatsPerBar=${this.detectedBeatsPerBar}`);
1097
+ } catch (error) {
1098
+ console.error('Error in automatic bar detection:', error);
1099
+ // Set default values if detection fails
1100
+ this.estimatedBPM = null;
1101
+ this.detectedBeatsPerBar = null;
1102
+ }
1103
+ }
1104
+
1105
+ updateResultsUI() {
1106
+ // Update estimated BPM
1107
+ const estimatedBPM = document.getElementById('estimatedBPM');
1108
+ estimatedBPM.textContent = this.estimatedBPM !== null ? this.estimatedBPM.toFixed(1) : '--';
1109
 
1110
+ // Update detected beats per bar
1111
+ const detectedBeatsPerBar = document.getElementById('detectedBeatsPerBar');
1112
+ detectedBeatsPerBar.textContent = this.detectedBeatsPerBar !== null ? this.detectedBeatsPerBar : '--';
1113
 
1114
  // Set up audio player
1115
  const audioPlayer = document.getElementById('audioPlayer');
 
1117
  const url = URL.createObjectURL(blob);
1118
  audioPlayer.src = url;
1119
 
1120
+ // Display bars if available
1121
+ if (this.bars && Object.keys(this.bars).length > 0) {
1122
+ this.displayBars();
1123
+ document.getElementById('barsResults').style.display = 'block';
1124
+ }
1125
+
1126
  }
1127
 
1128
  async calculateBars() {
 
1140
  calculateBarsButton.textContent = 'Calculating...';
1141
 
1142
  try {
1143
+ const result = await this.detector.logits_to_bars(
1144
  this.logits.prediction_beat,
1145
  this.logits.prediction_downbeat,
1146
  minBPM,
 
1148
  beatsPerBar
1149
  );
1150
 
1151
+ this.bars = result.bars;
1152
+ this.estimatedBPM = result.estimated_bpm;
1153
+ this.detectedBeatsPerBar = result.detected_beats_per_bar;
1154
+
1155
  this.displayBars();
1156
 
1157
+ // Update the displayed BPM and beats per bar
1158
+ const estimatedBPM = document.getElementById('estimatedBPM');
1159
+ estimatedBPM.textContent = this.estimatedBPM !== null ? this.estimatedBPM.toFixed(1) : '--';
1160
+
1161
+ const detectedBeatsPerBar = document.getElementById('detectedBeatsPerBar');
1162
+ detectedBeatsPerBar.textContent = this.detectedBeatsPerBar !== null ? this.detectedBeatsPerBar : '--';
1163
+
1164
  // Show bars results section
1165
  document.getElementById('barsResults').style.display = 'block';
1166
 
 
1257
  }
1258
  }
1259
 
1260
+ checkLoop() {
1261
  const audioPlayer = document.getElementById('audioPlayer');
 
 
 
 
 
 
 
1262
 
1263
  // If playing a section and reached the end, loop if enabled
1264
  if (this.isPlayingSection && audioPlayer.currentTime >= this.sectionEndTime) {