Spaces:
Sleeping
Sleeping
| export default class BeatDetector { | |
| constructor() { | |
| this.beatSession = null; | |
| this.melSession = null; | |
| this.sampleRate = 22050; | |
| this.chunkSize = 1500; | |
| this.borderSize = 6; | |
| this.serverUrl = ''; // Update this to your server URL | |
| this.isCancelled = false; | |
| this.pyodide = null; | |
| } | |
| async init(pyodide, progressCallback = null) { | |
| this.pyodide = pyodide; | |
| try { | |
| if (progressCallback) await progressCallback(10, "Loading beat detection model..."); | |
| await this.initializePyodide() | |
| if (progressCallback) await progressCallback(25, "Loading beat detection model..."); | |
| // Load beat and mel spectrogram models (no postprocessor needed) | |
| this.beatSession = await ort.InferenceSession.create( | |
| './beat_model.onnx', | |
| {executionProviders: ['wasm']} | |
| ); | |
| if (progressCallback) await progressCallback(50, "Loading spectrogram model..."); | |
| this.melSession = await ort.InferenceSession.create( | |
| './log_mel_spec.onnx', | |
| {executionProviders: ['wasm']} | |
| ); | |
| if (progressCallback) await progressCallback(100, "Models loaded successfully!"); | |
| console.log("ONNX models loaded successfully"); | |
| return true; | |
| } catch (error) { | |
| console.error("Failed to load models:", error); | |
| if (progressCallback) await progressCallback(0, "Failed to load models"); | |
| return false; | |
| } | |
| } | |
| async initializePyodide() { | |
| try { | |
| // Load required Python packages | |
| await this.pyodide.loadPackage(["numpy", "micropip"]); | |
| // Load the custom Python code | |
| await this.loadPythonCode(); | |
| console.log("Pyodide initialized successfully"); | |
| } catch (error) { | |
| console.error("Error initializing Pyodide:", error); | |
| } | |
| } | |
| async loadPythonCode() { | |
| try { | |
| // Fetch the Python code from the URL | |
| const response = await fetch('/logits_to_bars.py'); | |
| if (!response.ok) { | |
| throw new Error(`HTTP error! status: ${response.status}`); | |
| } | |
| const pythonCode = await response.text(); | |
| // Load the Python code into Pyodide | |
| await this.pyodide.runPythonAsync(pythonCode); | |
| console.log('Python code loaded successfully from /logits_to_bars.py'); | |
| } catch (error) { | |
| console.error('Error loading Python code:', error); | |
| } | |
| } | |
| cancel() { | |
| this.isCancelled = true; | |
| } | |
| resetCancellation() { | |
| this.isCancelled = false; | |
| } | |
| // Audio preprocessing using the log_mel_spec.onnx model | |
| async preprocessAudio(audioBuffer) { | |
| const originalSampleRate = audioBuffer.sampleRate; | |
| console.info('originalSampleRate :', originalSampleRate); | |
| // Get data from both channels | |
| const channel0 = audioBuffer.getChannelData(0); | |
| const channel1 = audioBuffer.getChannelData(1); | |
| // Calculate mean of both channels | |
| let audioData = new Float32Array(channel0.length); | |
| for (let i = 0; i < channel0.length; i++) { | |
| audioData[i] = (channel0[i] + channel1[i]) / 2; | |
| } | |
| // Use the ONNX model to compute log mel spectrogram | |
| return await this.computeLogMelSpectrogramONNX(audioData); | |
| } | |
| async computeLogMelSpectrogramONNX(audioData) { | |
| if (!this.melSession) { | |
| throw new Error("Log Mel Spectrogram model not initialized"); | |
| } | |
| // Prepare input tensor without batch dimension | |
| const inputTensor = new ort.Tensor('float32', audioData, [audioData.length]); | |
| try { | |
| // Run inference | |
| const results = await this.melSession.run({ | |
| 'input': inputTensor | |
| }); | |
| // Extract the log mel spectrogram from the output | |
| const outputName = Object.keys(results)[0]; | |
| const logMelOutput = results[outputName]; | |
| // Convert 2D output to array format | |
| const spectrogram = []; | |
| const numFrames = logMelOutput.dims[0]; | |
| const numMels = logMelOutput.dims[1]; | |
| for (let i = 0; i < numFrames; i++) { | |
| const frame = []; | |
| for (let j = 0; j < numMels; j++) { | |
| frame.push(logMelOutput.data[i * numMels + j]); | |
| } | |
| spectrogram.push(frame); | |
| } | |
| console.log(`Log mel spectrogram computed: ${spectrogram.length} frames, ${spectrogram[0].length} mel bands`); | |
| return spectrogram; | |
| } catch (error) { | |
| console.error("Error computing log mel spectrogram:", error); | |
| throw error; | |
| } | |
| } | |
| splitIntoChunks(spectrogram, chunkSize, borderSize, avoidShortEnd = true) { | |
| const chunks = []; | |
| const starts = []; | |
| // Generate start positions similar to Python's np.arange | |
| let startPositions = []; | |
| for (let i = -borderSize; i < spectrogram.length - borderSize; i += chunkSize - 2 * borderSize) { | |
| startPositions.push(i); | |
| } | |
| // Adjust last start position if avoidShortEnd is true and piece is long enough | |
| if (avoidShortEnd && spectrogram.length > chunkSize - 2 * borderSize && startPositions.length > 0) { | |
| startPositions[startPositions.length - 1] = spectrogram.length - (chunkSize - borderSize); | |
| } | |
| // Process each start position | |
| for (const start of startPositions) { | |
| const chunkStart = Math.max(0, start); | |
| const chunkEnd = Math.min(spectrogram.length, start + chunkSize); | |
| // Extract the chunk | |
| let chunk = spectrogram.slice(chunkStart, chunkEnd); | |
| // Calculate padding needed (similar to Python's zeropad) | |
| const leftPad = Math.max(0, -start); | |
| const rightPad = Math.max(0, Math.min(borderSize, start + chunkSize - spectrogram.length)); | |
| // Apply padding if needed | |
| if (leftPad > 0 || rightPad > 0) { | |
| const paddedChunk = []; | |
| // Add left padding | |
| for (let i = 0; i < leftPad; i++) { | |
| paddedChunk.push(new Array(128).fill(0)); // Assuming 128 bins like in your Python code | |
| } | |
| // Add the actual chunk data | |
| paddedChunk.push(...chunk); | |
| // Add right padding | |
| for (let i = 0; i < rightPad; i++) { | |
| paddedChunk.push(new Array(128).fill(0)); | |
| } | |
| chunks.push(paddedChunk); | |
| } else { | |
| chunks.push(chunk); | |
| } | |
| starts.push(start); | |
| } | |
| return {chunks, starts}; | |
| } | |
| async processAudio(audioBuffer, progressCallback = null) { | |
| if (!this.beatSession) { | |
| throw new Error("Beat model not initialized"); | |
| } | |
| // Reset cancellation flag | |
| this.resetCancellation(); | |
| // Preprocess audio using ONNX model | |
| if (progressCallback) await progressCallback(0, "Detecting beats..."); | |
| const spectrogram = await this.preprocessAudio(audioBuffer); | |
| // Check for cancellation | |
| if (this.isCancelled) throw new Error("Processing cancelled"); | |
| const {chunks, starts} = this.splitIntoChunks(spectrogram, this.chunkSize, this.borderSize); | |
| // Store predictions for each chunk | |
| const predChunks = []; | |
| if (progressCallback) await progressCallback(5, "Detecting beats..."); | |
| // Track progress more accurately | |
| const totalChunks = chunks.length; | |
| // Process each chunk with progress updates | |
| for (let i = 0; i < totalChunks; i++) { | |
| // Check for cancellation | |
| if (this.isCancelled) throw new Error("Processing cancelled"); | |
| const chunk = chunks[i]; | |
| const start = starts[i]; | |
| // Convert to tensor format | |
| const inputTensor = new ort.Tensor('float32', | |
| this.flattenArray(chunk), | |
| [1, chunk.length, 128] | |
| ); | |
| // Run inference | |
| const results = await this.beatSession.run({ | |
| 'input': inputTensor | |
| }); | |
| // Extract predictions | |
| const beatPred = Array.from(results.beat.data); | |
| const downbeatPred = Array.from(results.downbeat.data); | |
| // Store chunk predictions | |
| predChunks.push({ | |
| beat: beatPred, | |
| downbeat: downbeatPred | |
| }); | |
| // Calculate progress more smoothly | |
| const currentProgress = 5 + ((i + 1) / totalChunks) * 90; | |
| if (progressCallback) { | |
| await progressCallback( | |
| Math.floor(currentProgress), | |
| `Detecting beats... ${i + 1}/${totalChunks}` | |
| ); | |
| } | |
| } | |
| if (this.isCancelled) throw new Error("Processing cancelled"); | |
| if (progressCallback) await progressCallback(95, "Post-processing beats..."); | |
| // Aggregate predictions | |
| const aggregated = this.aggregatePrediction( | |
| predChunks, | |
| starts, | |
| spectrogram.length, | |
| this.chunkSize, | |
| this.borderSize, | |
| 'keep_first' | |
| ); | |
| if (this.isCancelled) throw new Error("Processing cancelled"); | |
| if (progressCallback) await progressCallback(100, "Complete!"); | |
| return { | |
| prediction_beat: aggregated.beat, | |
| prediction_downbeat: aggregated.downbeat, | |
| }; | |
| } | |
| async logits_to_bars(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) { | |
| // Call the Python function | |
| const result = await this.pyodide.runPythonAsync(` | |
| import json | |
| beat_logits = ${JSON.stringify(beatLogits)} | |
| downbeat_logits = ${JSON.stringify(downbeatLogits)} | |
| beats_per_bar = ${beats_per_bar} | |
| min_bpm = ${min_bpm} | |
| max_bpm = ${max_bpm} | |
| result = logits_to_bars(beat_logits, downbeat_logits, beats_per_bar, min_bpm, max_bpm) | |
| json.dumps(result) | |
| `); | |
| console.log(result); | |
| // Parse and display the result | |
| const resultObj = JSON.parse(result); | |
| return { | |
| bars: resultObj.bars || {}, | |
| estimated_bpm: resultObj.estimated_bpm || null, | |
| detected_beats_per_bar: resultObj.detected_beats_per_bar || null | |
| }; | |
| } | |
| // Use Python server for postprocessing | |
| async logits_to_bars_online(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) { | |
| try { | |
| const response = await fetch(`${this.serverUrl}/logits_to_bars`, { | |
| method: 'POST', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| body: JSON.stringify({ | |
| beat_logits: beatLogits, | |
| downbeat_logits: downbeatLogits, | |
| min_bpm: min_bpm, | |
| max_bpm: max_bpm, | |
| beats_per_bar: beats_per_bar, | |
| }) | |
| }); | |
| if (!response.ok) { | |
| throw new Error(`Server returned ${response.status}: ${response.statusText}`); | |
| } | |
| const result = await response.json(); | |
| if (result.error) { | |
| throw new Error(`Server error: ${result.error}`); | |
| } | |
| console.log(`Server postprocessing results: ${result.bars ? Object.keys(result.bars).length : 0} bars`); | |
| // Return bars along with estimated_bpm and detected_beats_per_bar | |
| return { | |
| bars: result.bars || {}, | |
| estimated_bpm: result.estimated_bpm || null, | |
| detected_beats_per_bar: result.detected_beats_per_bar || null | |
| }; | |
| } catch (error) { | |
| console.error("Error in server postprocessing:", error); | |
| // Return empty object as fallback | |
| return { | |
| bars: {}, | |
| estimated_bpm: null, | |
| detected_beats_per_bar: null | |
| }; | |
| } | |
| } | |
| // Check server status | |
| async checkServerStatus() { | |
| try { | |
| const response = await fetch(`${this.serverUrl}/health`, { | |
| method: 'GET', | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| } | |
| }); | |
| return response.ok; | |
| } catch (error) { | |
| console.warn("Server health check failed:", error); | |
| return false; | |
| } | |
| } | |
| aggregatePrediction(predChunks, starts, fullSize, chunkSize, borderSize, overlapMode) { | |
| let processedChunks = predChunks; | |
| // Remove borders if borderSize > 0 | |
| if (borderSize > 0) { | |
| processedChunks = predChunks.map(pchunk => ({ | |
| beat: pchunk.beat.slice(borderSize, -borderSize), | |
| downbeat: pchunk.downbeat.slice(borderSize, -borderSize) | |
| })); | |
| } | |
| // Initialize arrays with very low values (equivalent to -1000.0 in Python) | |
| const piecePredictionBeat = new Array(fullSize).fill(-1000.0); | |
| const piecePredictionDownbeat = new Array(fullSize).fill(-1000.0); | |
| // Prepare iteration based on overlap mode | |
| let chunksToProcess = processedChunks; | |
| let startsToProcess = starts; | |
| if (overlapMode === "keep_first") { | |
| // Process in reverse order so earlier predictions overwrite later ones | |
| chunksToProcess = [...processedChunks].reverse(); | |
| startsToProcess = [...starts].reverse(); | |
| } | |
| // Aggregate predictions | |
| for (let i = 0; i < chunksToProcess.length; i++) { | |
| const start = startsToProcess[i]; | |
| const pchunk = chunksToProcess[i]; | |
| const effectiveStart = start + borderSize; | |
| const effectiveEnd = start + chunkSize - borderSize; | |
| // Copy predictions to the appropriate positions | |
| for (let j = 0; j < pchunk.beat.length; j++) { | |
| const pos = effectiveStart + j; | |
| if (pos < fullSize) { | |
| piecePredictionBeat[pos] = pchunk.beat[j]; | |
| piecePredictionDownbeat[pos] = pchunk.downbeat[j]; | |
| } | |
| } | |
| } | |
| return { | |
| beat: piecePredictionBeat, | |
| downbeat: piecePredictionDownbeat | |
| }; | |
| } | |
| flattenArray(arr) { | |
| const flat = []; | |
| for (let i = 0; i < arr.length; i++) { | |
| for (let j = 0; j < arr[i].length; j++) { | |
| flat.push(arr[i][j]); | |
| } | |
| } | |
| return flat; | |
| } | |
| } |