jorisvaneyghen commited on
Commit
ee88cac
·
1 Parent(s): b9af2f7

run python code (logits_to_bars) in browser

Browse files
Files changed (4) hide show
  1. app.py +10 -10
  2. index.html +4 -1
  3. js/BeatDetector.js +66 -2
  4. sw.js +5 -4
app.py CHANGED
@@ -7,7 +7,7 @@ from beat_this.model.postprocessor import Postprocessor
7
  from flask import Flask, request, jsonify, send_from_directory
8
  from flask_cors import CORS
9
  from madmom.features.downbeats import DBNDownBeatTrackingProcessor
10
- from statistics import median, mode
11
  from typing import List, Tuple
12
 
13
  # Add MIME types for JavaScript and WebAssembly
@@ -101,7 +101,7 @@ def postprocess_beats():
101
  return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400
102
 
103
  # Process the logits to extract beat and downbeat timings
104
- beats, downbeats = process_logits(beat_logits, downbeat_logits, type='dbn',
105
  beats_per_bar=beats_per_bar,
106
  min_bpm=min_bpm,
107
  max_bpm=max_bpm)
@@ -156,14 +156,14 @@ def process_logits(beat_logits, downbeat_logits, type='minimal',
156
  Tuple of (beats, downbeats) where each is an array of timings in seconds
157
  """
158
  frames2beats = Postprocessor(type=type)
159
- # if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0):
160
- # frames2beats.dbb = DBNDownBeatTrackingProcessor(
161
- # beats_per_bar=beats_per_bar,
162
- # min_bpm=min_bpm,
163
- # max_bpm=max_bpm,
164
- # fps=50,
165
- # transition_lambda=100,
166
- # )
167
 
168
 
169
  # Convert numpy arrays to PyTorch tensors
 
7
  from flask import Flask, request, jsonify, send_from_directory
8
  from flask_cors import CORS
9
  from madmom.features.downbeats import DBNDownBeatTrackingProcessor
10
+ from statistics import median, mode, StatisticsError
11
  from typing import List, Tuple
12
 
13
  # Add MIME types for JavaScript and WebAssembly
 
101
  return jsonify({"error": "Beat and downbeat logits must have the same length"}), 400
102
 
103
  # Process the logits to extract beat and downbeat timings
104
+ beats, downbeats = process_logits(beat_logits, downbeat_logits, type='minimal',
105
  beats_per_bar=beats_per_bar,
106
  min_bpm=min_bpm,
107
  max_bpm=max_bpm)
 
156
  Tuple of (beats, downbeats) where each is an array of timings in seconds
157
  """
158
  frames2beats = Postprocessor(type=type)
159
+ if type == 'dbn' and (beats_per_bar != [3, 4] or min_bpm != 55.0 or max_bpm != 215.0):
160
+ frames2beats.dbb = DBNDownBeatTrackingProcessor(
161
+ beats_per_bar=beats_per_bar,
162
+ min_bpm=min_bpm,
163
+ max_bpm=max_bpm,
164
+ fps=50,
165
+ transition_lambda=100,
166
+ )
167
 
168
 
169
  # Convert numpy arrays to PyTorch tensors
index.html CHANGED
@@ -11,6 +11,7 @@
11
  <link rel="manifest" href="/site.webmanifest"/>
12
  <link rel="stylesheet" href="css/styles.css"/>
13
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
 
14
  </head>
15
  <body>
16
  <div class="container">
@@ -181,8 +182,10 @@
181
  // Show initialization progress
182
  this.showInitProgress();
183
 
 
 
184
  // Initialize the detector with progress updates
185
- const success = await this.detector.init(this.updateInitProgress.bind(this));
186
 
187
  if (success) {
188
  this.hideInitProgress();
 
11
  <link rel="manifest" href="/site.webmanifest"/>
12
  <link rel="stylesheet" href="css/styles.css"/>
13
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
14
+ <script src="https://cdn.jsdelivr.net/pyodide/v0.29.0/full/pyodide.js"></script>
15
  </head>
16
  <body>
17
  <div class="container">
 
182
  // Show initialization progress
183
  this.showInitProgress();
184
 
185
+ let pyodide = await loadPyodide();
186
+
187
  // Initialize the detector with progress updates
188
+ const success = await this.detector.init(pyodide, this.updateInitProgress.bind(this));
189
 
190
  if (success) {
191
  this.hideInitProgress();
js/BeatDetector.js CHANGED
@@ -7,12 +7,18 @@ export default class BeatDetector {
7
  this.borderSize = 6;
8
  this.serverUrl = ''; // Update this to your server URL
9
  this.isCancelled = false;
 
10
  }
11
 
12
- async init(progressCallback = null) {
 
13
  try {
14
  if (progressCallback) await progressCallback(10, "Loading beat detection model...");
15
 
 
 
 
 
16
 
17
  // Load beat and mel spectrogram models (no postprocessor needed)
18
  this.beatSession = await ort.InferenceSession.create(
@@ -38,6 +44,39 @@ export default class BeatDetector {
38
  }
39
  }
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  cancel() {
42
  this.isCancelled = true;
43
  }
@@ -200,7 +239,7 @@ export default class BeatDetector {
200
  await new Promise(resolve => setTimeout(resolve, 0));
201
 
202
  // Update progress
203
- const progress = Math.floor( ((i+1) / chunks.length) * 95);
204
  if (progressCallback) await progressCallback(progress, `Detecting beats ... ${i + 1}/${chunks.length}...`);
205
 
206
  // Extract predictions
@@ -239,6 +278,31 @@ export default class BeatDetector {
239
 
240
  // Use Python server for postprocessing
241
  async logits_to_bars(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  try {
243
  const response = await fetch(`${this.serverUrl}/logits_to_bars`, {
244
  method: 'POST',
 
7
  this.borderSize = 6;
8
  this.serverUrl = ''; // Update this to your server URL
9
  this.isCancelled = false;
10
+ this.pyodide = null;
11
  }
12
 
13
+ async init(pyodide, progressCallback = null) {
14
+ this.pyodide = pyodide;
15
  try {
16
  if (progressCallback) await progressCallback(10, "Loading beat detection model...");
17
 
18
+ await this.initializePyodide()
19
+
20
+ if (progressCallback) await progressCallback(25, "Loading beat detection model...");
21
+
22
 
23
  // Load beat and mel spectrogram models (no postprocessor needed)
24
  this.beatSession = await ort.InferenceSession.create(
 
44
  }
45
  }
46
 
47
+ async initializePyodide() {
48
+ try {
49
+ // Load required Python packages
50
+ await this.pyodide.loadPackage(["numpy", "micropip"]);
51
+ // Load the custom Python code
52
+ await this.loadPythonCode();
53
+
54
+ console.log("Pyodide initialized successfully");
55
+ } catch (error) {
56
+ console.error("Error initializing Pyodide:", error);
57
+ }
58
+ }
59
+
60
+ async loadPythonCode() {
61
+ try {
62
+ // Fetch the Python code from the URL
63
+ const response = await fetch('/logits_to_bars.py');
64
+
65
+ if (!response.ok) {
66
+ throw new Error(`HTTP error! status: ${response.status}`);
67
+ }
68
+
69
+ const pythonCode = await response.text();
70
+
71
+ // Load the Python code into Pyodide
72
+ await this.pyodide.runPythonAsync(pythonCode);
73
+
74
+ console.log('Python code loaded successfully from /logits_to_bars.py');
75
+ } catch (error) {
76
+ console.error('Error loading Python code:', error);
77
+ }
78
+ }
79
+
80
  cancel() {
81
  this.isCancelled = true;
82
  }
 
239
  await new Promise(resolve => setTimeout(resolve, 0));
240
 
241
  // Update progress
242
+ const progress = Math.floor(((i + 1) / chunks.length) * 95);
243
  if (progressCallback) await progressCallback(progress, `Detecting beats ... ${i + 1}/${chunks.length}...`);
244
 
245
  // Extract predictions
 
278
 
279
  // Use Python server for postprocessing
280
  async logits_to_bars(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) {
281
+ // Call the Python function
282
+ const result = await this.pyodide.runPythonAsync(`
283
+ import json
284
+ beat_logits = ${JSON.stringify(beatLogits)}
285
+ downbeat_logits = ${JSON.stringify(downbeatLogits)}
286
+ beats_per_bar = ${beats_per_bar}
287
+ min_bpm = ${min_bpm}
288
+ max_bpm = ${max_bpm}
289
+
290
+ result = logits_to_bars(beat_logits, downbeat_logits, beats_per_bar, min_bpm, max_bpm)
291
+ json.dumps(result)
292
+ `);
293
+
294
+ console.log(result);
295
+
296
+ // Parse and display the result
297
+ const resultObj = JSON.parse(result);
298
+ return {
299
+ bars: resultObj.bars || {},
300
+ estimated_bpm: resultObj.estimated_bpm || null,
301
+ detected_beats_per_bar: resultObj.detected_beats_per_bar || null
302
+ };
303
+ }
304
+
305
+ async logits_to_bars_online(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) {
306
  try {
307
  const response = await fetch(`${this.serverUrl}/logits_to_bars`, {
308
  method: 'POST',
sw.js CHANGED
@@ -1,7 +1,8 @@
1
- // A version for your cache. Change this to force an update.
2
- const CACHE_NAME = 'my-pwa-cache-v1';
 
3
 
4
- // List of files to cache immediately during the install step.
5
  const urlsToCache = [
6
  '/',
7
  '/css/styles.css',
@@ -17,7 +18,7 @@ const urlsToCache = [
17
  '/web-app-manifest-512x512.png',
18
  '/log_mel_spec.onnx',
19
  '/beat_model.onnx',
20
- 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js'
21
  ];
22
 
23
 
 
1
+ // Use a version that you can update with each release
2
+ const APP_VERSION = '0.1.1';
3
+ const CACHE_NAME = `my-pwa-cache-${APP_VERSION}`;
4
 
5
+ // List of files to cache
6
  const urlsToCache = [
7
  '/',
8
  '/css/styles.css',
 
18
  '/web-app-manifest-512x512.png',
19
  '/log_mel_spec.onnx',
20
  '/beat_model.onnx',
21
+ '/logits_to_bars.py',
22
  ];
23
 
24