export default class BeatDetector { constructor() { this.beatSession = null; this.melSession = null; this.sampleRate = 22050; this.chunkSize = 1500; this.borderSize = 6; this.serverUrl = ''; // Update this to your server URL this.isCancelled = false; this.pyodide = null; } async init(pyodide, progressCallback = null) { this.pyodide = pyodide; try { if (progressCallback) await progressCallback(10, "Loading beat detection model..."); await this.initializePyodide() if (progressCallback) await progressCallback(25, "Loading beat detection model..."); // Load beat and mel spectrogram models (no postprocessor needed) this.beatSession = await ort.InferenceSession.create( './beat_model.onnx', {executionProviders: ['wasm']} ); if (progressCallback) await progressCallback(50, "Loading spectrogram model..."); this.melSession = await ort.InferenceSession.create( './log_mel_spec.onnx', {executionProviders: ['wasm']} ); if (progressCallback) await progressCallback(100, "Models loaded successfully!"); console.log("ONNX models loaded successfully"); return true; } catch (error) { console.error("Failed to load models:", error); if (progressCallback) await progressCallback(0, "Failed to load models"); return false; } } async initializePyodide() { try { // Load required Python packages await this.pyodide.loadPackage(["numpy", "micropip"]); // Load the custom Python code await this.loadPythonCode(); console.log("Pyodide initialized successfully"); } catch (error) { console.error("Error initializing Pyodide:", error); } } async loadPythonCode() { try { // Fetch the Python code from the URL const response = await fetch('/logits_to_bars.py'); if (!response.ok) { throw new Error(`HTTP error! status: ${response.status}`); } const pythonCode = await response.text(); // Load the Python code into Pyodide await this.pyodide.runPythonAsync(pythonCode); console.log('Python code loaded successfully from /logits_to_bars.py'); } catch (error) { console.error('Error loading Python code:', error); } } cancel() { this.isCancelled = true; } resetCancellation() { this.isCancelled = false; } // Audio preprocessing using the log_mel_spec.onnx model async preprocessAudio(audioBuffer) { const originalSampleRate = audioBuffer.sampleRate; console.info('originalSampleRate :', originalSampleRate); // Get data from both channels const channel0 = audioBuffer.getChannelData(0); const channel1 = audioBuffer.getChannelData(1); // Calculate mean of both channels let audioData = new Float32Array(channel0.length); for (let i = 0; i < channel0.length; i++) { audioData[i] = (channel0[i] + channel1[i]) / 2; } // Use the ONNX model to compute log mel spectrogram return await this.computeLogMelSpectrogramONNX(audioData); } async computeLogMelSpectrogramONNX(audioData) { if (!this.melSession) { throw new Error("Log Mel Spectrogram model not initialized"); } // Prepare input tensor without batch dimension const inputTensor = new ort.Tensor('float32', audioData, [audioData.length]); try { // Run inference const results = await this.melSession.run({ 'input': inputTensor }); // Extract the log mel spectrogram from the output const outputName = Object.keys(results)[0]; const logMelOutput = results[outputName]; // Convert 2D output to array format const spectrogram = []; const numFrames = logMelOutput.dims[0]; const numMels = logMelOutput.dims[1]; for (let i = 0; i < numFrames; i++) { const frame = []; for (let j = 0; j < numMels; j++) { frame.push(logMelOutput.data[i * numMels + j]); } spectrogram.push(frame); } console.log(`Log mel spectrogram computed: ${spectrogram.length} frames, ${spectrogram[0].length} mel bands`); return spectrogram; } catch (error) { console.error("Error computing log mel spectrogram:", error); throw error; } } splitIntoChunks(spectrogram, chunkSize, borderSize, avoidShortEnd = true) { const chunks = []; const starts = []; // Generate start positions similar to Python's np.arange let startPositions = []; for (let i = -borderSize; i < spectrogram.length - borderSize; i += chunkSize - 2 * borderSize) { startPositions.push(i); } // Adjust last start position if avoidShortEnd is true and piece is long enough if (avoidShortEnd && spectrogram.length > chunkSize - 2 * borderSize && startPositions.length > 0) { startPositions[startPositions.length - 1] = spectrogram.length - (chunkSize - borderSize); } // Process each start position for (const start of startPositions) { const chunkStart = Math.max(0, start); const chunkEnd = Math.min(spectrogram.length, start + chunkSize); // Extract the chunk let chunk = spectrogram.slice(chunkStart, chunkEnd); // Calculate padding needed (similar to Python's zeropad) const leftPad = Math.max(0, -start); const rightPad = Math.max(0, Math.min(borderSize, start + chunkSize - spectrogram.length)); // Apply padding if needed if (leftPad > 0 || rightPad > 0) { const paddedChunk = []; // Add left padding for (let i = 0; i < leftPad; i++) { paddedChunk.push(new Array(128).fill(0)); // Assuming 128 bins like in your Python code } // Add the actual chunk data paddedChunk.push(...chunk); // Add right padding for (let i = 0; i < rightPad; i++) { paddedChunk.push(new Array(128).fill(0)); } chunks.push(paddedChunk); } else { chunks.push(chunk); } starts.push(start); } return {chunks, starts}; } async processAudio(audioBuffer, progressCallback = null) { if (!this.beatSession) { throw new Error("Beat model not initialized"); } // Reset cancellation flag this.resetCancellation(); // Preprocess audio using ONNX model if (progressCallback) await progressCallback(0, "Detecting beats..."); const spectrogram = await this.preprocessAudio(audioBuffer); // Check for cancellation if (this.isCancelled) throw new Error("Processing cancelled"); const {chunks, starts} = this.splitIntoChunks(spectrogram, this.chunkSize, this.borderSize); // Store predictions for each chunk const predChunks = []; if (progressCallback) await progressCallback(5, "Detecting beats..."); // Track progress more accurately const totalChunks = chunks.length; // Process each chunk with progress updates for (let i = 0; i < totalChunks; i++) { // Check for cancellation if (this.isCancelled) throw new Error("Processing cancelled"); const chunk = chunks[i]; const start = starts[i]; // Convert to tensor format const inputTensor = new ort.Tensor('float32', this.flattenArray(chunk), [1, chunk.length, 128] ); // Run inference const results = await this.beatSession.run({ 'input': inputTensor }); // Extract predictions const beatPred = Array.from(results.beat.data); const downbeatPred = Array.from(results.downbeat.data); // Store chunk predictions predChunks.push({ beat: beatPred, downbeat: downbeatPred }); // Calculate progress more smoothly const currentProgress = 5 + ((i + 1) / totalChunks) * 90; if (progressCallback) { await progressCallback( Math.floor(currentProgress), `Detecting beats... ${i + 1}/${totalChunks}` ); } } if (this.isCancelled) throw new Error("Processing cancelled"); if (progressCallback) await progressCallback(95, "Post-processing beats..."); // Aggregate predictions const aggregated = this.aggregatePrediction( predChunks, starts, spectrogram.length, this.chunkSize, this.borderSize, 'keep_first' ); if (this.isCancelled) throw new Error("Processing cancelled"); if (progressCallback) await progressCallback(100, "Complete!"); return { prediction_beat: aggregated.beat, prediction_downbeat: aggregated.downbeat, }; } async logits_to_bars(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) { // Call the Python function const result = await this.pyodide.runPythonAsync(` import json beat_logits = ${JSON.stringify(beatLogits)} downbeat_logits = ${JSON.stringify(downbeatLogits)} beats_per_bar = ${beats_per_bar} min_bpm = ${min_bpm} max_bpm = ${max_bpm} result = logits_to_bars(beat_logits, downbeat_logits, beats_per_bar, min_bpm, max_bpm) json.dumps(result) `); console.log(result); // Parse and display the result const resultObj = JSON.parse(result); return { bars: resultObj.bars || {}, estimated_bpm: resultObj.estimated_bpm || null, detected_beats_per_bar: resultObj.detected_beats_per_bar || null }; } // Use Python server for postprocessing async logits_to_bars_online(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) { try { const response = await fetch(`${this.serverUrl}/logits_to_bars`, { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ beat_logits: beatLogits, downbeat_logits: downbeatLogits, min_bpm: min_bpm, max_bpm: max_bpm, beats_per_bar: beats_per_bar, }) }); if (!response.ok) { throw new Error(`Server returned ${response.status}: ${response.statusText}`); } const result = await response.json(); if (result.error) { throw new Error(`Server error: ${result.error}`); } console.log(`Server postprocessing results: ${result.bars ? Object.keys(result.bars).length : 0} bars`); // Return bars along with estimated_bpm and detected_beats_per_bar return { bars: result.bars || {}, estimated_bpm: result.estimated_bpm || null, detected_beats_per_bar: result.detected_beats_per_bar || null }; } catch (error) { console.error("Error in server postprocessing:", error); // Return empty object as fallback return { bars: {}, estimated_bpm: null, detected_beats_per_bar: null }; } } // Check server status async checkServerStatus() { try { const response = await fetch(`${this.serverUrl}/health`, { method: 'GET', headers: { 'Content-Type': 'application/json', } }); return response.ok; } catch (error) { console.warn("Server health check failed:", error); return false; } } aggregatePrediction(predChunks, starts, fullSize, chunkSize, borderSize, overlapMode) { let processedChunks = predChunks; // Remove borders if borderSize > 0 if (borderSize > 0) { processedChunks = predChunks.map(pchunk => ({ beat: pchunk.beat.slice(borderSize, -borderSize), downbeat: pchunk.downbeat.slice(borderSize, -borderSize) })); } // Initialize arrays with very low values (equivalent to -1000.0 in Python) const piecePredictionBeat = new Array(fullSize).fill(-1000.0); const piecePredictionDownbeat = new Array(fullSize).fill(-1000.0); // Prepare iteration based on overlap mode let chunksToProcess = processedChunks; let startsToProcess = starts; if (overlapMode === "keep_first") { // Process in reverse order so earlier predictions overwrite later ones chunksToProcess = [...processedChunks].reverse(); startsToProcess = [...starts].reverse(); } // Aggregate predictions for (let i = 0; i < chunksToProcess.length; i++) { const start = startsToProcess[i]; const pchunk = chunksToProcess[i]; const effectiveStart = start + borderSize; const effectiveEnd = start + chunkSize - borderSize; // Copy predictions to the appropriate positions for (let j = 0; j < pchunk.beat.length; j++) { const pos = effectiveStart + j; if (pos < fullSize) { piecePredictionBeat[pos] = pchunk.beat[j]; piecePredictionDownbeat[pos] = pchunk.downbeat[j]; } } } return { beat: piecePredictionBeat, downbeat: piecePredictionDownbeat }; } flattenArray(arr) { const flat = []; for (let i = 0; i < arr.length; i++) { for (let j = 0; j < arr[i].length; j++) { flat.push(arr[i][j]); } } return flat; } }