loop_maestro / js /BeatDetector.js
jorisvaneyghen's picture
cache Results
a77e892
export default class BeatDetector {
constructor() {
this.beatSession = null;
this.melSession = null;
this.sampleRate = 22050;
this.chunkSize = 1500;
this.borderSize = 6;
this.serverUrl = ''; // Update this to your server URL
this.isCancelled = false;
this.pyodide = null;
}
async init(pyodide, progressCallback = null) {
this.pyodide = pyodide;
try {
if (progressCallback) await progressCallback(10, "Loading beat detection model...");
await this.initializePyodide()
if (progressCallback) await progressCallback(25, "Loading beat detection model...");
// Load beat and mel spectrogram models (no postprocessor needed)
this.beatSession = await ort.InferenceSession.create(
'./beat_model.onnx',
{executionProviders: ['wasm']}
);
if (progressCallback) await progressCallback(50, "Loading spectrogram model...");
this.melSession = await ort.InferenceSession.create(
'./log_mel_spec.onnx',
{executionProviders: ['wasm']}
);
if (progressCallback) await progressCallback(100, "Models loaded successfully!");
console.log("ONNX models loaded successfully");
return true;
} catch (error) {
console.error("Failed to load models:", error);
if (progressCallback) await progressCallback(0, "Failed to load models");
return false;
}
}
async initializePyodide() {
try {
// Load required Python packages
await this.pyodide.loadPackage(["numpy", "micropip"]);
// Load the custom Python code
await this.loadPythonCode();
console.log("Pyodide initialized successfully");
} catch (error) {
console.error("Error initializing Pyodide:", error);
}
}
async loadPythonCode() {
try {
// Fetch the Python code from the URL
const response = await fetch('/logits_to_bars.py');
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
const pythonCode = await response.text();
// Load the Python code into Pyodide
await this.pyodide.runPythonAsync(pythonCode);
console.log('Python code loaded successfully from /logits_to_bars.py');
} catch (error) {
console.error('Error loading Python code:', error);
}
}
cancel() {
this.isCancelled = true;
}
resetCancellation() {
this.isCancelled = false;
}
// Audio preprocessing using the log_mel_spec.onnx model
async preprocessAudio(audioBuffer) {
const originalSampleRate = audioBuffer.sampleRate;
console.info('originalSampleRate :', originalSampleRate);
// Get data from both channels
const channel0 = audioBuffer.getChannelData(0);
const channel1 = audioBuffer.getChannelData(1);
// Calculate mean of both channels
let audioData = new Float32Array(channel0.length);
for (let i = 0; i < channel0.length; i++) {
audioData[i] = (channel0[i] + channel1[i]) / 2;
}
// Use the ONNX model to compute log mel spectrogram
return await this.computeLogMelSpectrogramONNX(audioData);
}
async computeLogMelSpectrogramONNX(audioData) {
if (!this.melSession) {
throw new Error("Log Mel Spectrogram model not initialized");
}
// Prepare input tensor without batch dimension
const inputTensor = new ort.Tensor('float32', audioData, [audioData.length]);
try {
// Run inference
const results = await this.melSession.run({
'input': inputTensor
});
// Extract the log mel spectrogram from the output
const outputName = Object.keys(results)[0];
const logMelOutput = results[outputName];
// Convert 2D output to array format
const spectrogram = [];
const numFrames = logMelOutput.dims[0];
const numMels = logMelOutput.dims[1];
for (let i = 0; i < numFrames; i++) {
const frame = [];
for (let j = 0; j < numMels; j++) {
frame.push(logMelOutput.data[i * numMels + j]);
}
spectrogram.push(frame);
}
console.log(`Log mel spectrogram computed: ${spectrogram.length} frames, ${spectrogram[0].length} mel bands`);
return spectrogram;
} catch (error) {
console.error("Error computing log mel spectrogram:", error);
throw error;
}
}
splitIntoChunks(spectrogram, chunkSize, borderSize, avoidShortEnd = true) {
const chunks = [];
const starts = [];
// Generate start positions similar to Python's np.arange
let startPositions = [];
for (let i = -borderSize; i < spectrogram.length - borderSize; i += chunkSize - 2 * borderSize) {
startPositions.push(i);
}
// Adjust last start position if avoidShortEnd is true and piece is long enough
if (avoidShortEnd && spectrogram.length > chunkSize - 2 * borderSize && startPositions.length > 0) {
startPositions[startPositions.length - 1] = spectrogram.length - (chunkSize - borderSize);
}
// Process each start position
for (const start of startPositions) {
const chunkStart = Math.max(0, start);
const chunkEnd = Math.min(spectrogram.length, start + chunkSize);
// Extract the chunk
let chunk = spectrogram.slice(chunkStart, chunkEnd);
// Calculate padding needed (similar to Python's zeropad)
const leftPad = Math.max(0, -start);
const rightPad = Math.max(0, Math.min(borderSize, start + chunkSize - spectrogram.length));
// Apply padding if needed
if (leftPad > 0 || rightPad > 0) {
const paddedChunk = [];
// Add left padding
for (let i = 0; i < leftPad; i++) {
paddedChunk.push(new Array(128).fill(0)); // Assuming 128 bins like in your Python code
}
// Add the actual chunk data
paddedChunk.push(...chunk);
// Add right padding
for (let i = 0; i < rightPad; i++) {
paddedChunk.push(new Array(128).fill(0));
}
chunks.push(paddedChunk);
} else {
chunks.push(chunk);
}
starts.push(start);
}
return {chunks, starts};
}
async processAudio(audioBuffer, progressCallback = null) {
if (!this.beatSession) {
throw new Error("Beat model not initialized");
}
// Reset cancellation flag
this.resetCancellation();
// Preprocess audio using ONNX model
if (progressCallback) await progressCallback(0, "Detecting beats...");
const spectrogram = await this.preprocessAudio(audioBuffer);
// Check for cancellation
if (this.isCancelled) throw new Error("Processing cancelled");
const {chunks, starts} = this.splitIntoChunks(spectrogram, this.chunkSize, this.borderSize);
// Store predictions for each chunk
const predChunks = [];
if (progressCallback) await progressCallback(5, "Detecting beats...");
// Track progress more accurately
const totalChunks = chunks.length;
// Process each chunk with progress updates
for (let i = 0; i < totalChunks; i++) {
// Check for cancellation
if (this.isCancelled) throw new Error("Processing cancelled");
const chunk = chunks[i];
const start = starts[i];
// Convert to tensor format
const inputTensor = new ort.Tensor('float32',
this.flattenArray(chunk),
[1, chunk.length, 128]
);
// Run inference
const results = await this.beatSession.run({
'input': inputTensor
});
// Extract predictions
const beatPred = Array.from(results.beat.data);
const downbeatPred = Array.from(results.downbeat.data);
// Store chunk predictions
predChunks.push({
beat: beatPred,
downbeat: downbeatPred
});
// Calculate progress more smoothly
const currentProgress = 5 + ((i + 1) / totalChunks) * 90;
if (progressCallback) {
await progressCallback(
Math.floor(currentProgress),
`Detecting beats... ${i + 1}/${totalChunks}`
);
}
}
if (this.isCancelled) throw new Error("Processing cancelled");
if (progressCallback) await progressCallback(95, "Post-processing beats...");
// Aggregate predictions
const aggregated = this.aggregatePrediction(
predChunks,
starts,
spectrogram.length,
this.chunkSize,
this.borderSize,
'keep_first'
);
if (this.isCancelled) throw new Error("Processing cancelled");
if (progressCallback) await progressCallback(100, "Complete!");
return {
prediction_beat: aggregated.beat,
prediction_downbeat: aggregated.downbeat,
};
}
async logits_to_bars(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) {
// Call the Python function
const result = await this.pyodide.runPythonAsync(`
import json
beat_logits = ${JSON.stringify(beatLogits)}
downbeat_logits = ${JSON.stringify(downbeatLogits)}
beats_per_bar = ${beats_per_bar}
min_bpm = ${min_bpm}
max_bpm = ${max_bpm}
result = logits_to_bars(beat_logits, downbeat_logits, beats_per_bar, min_bpm, max_bpm)
json.dumps(result)
`);
console.log(result);
// Parse and display the result
const resultObj = JSON.parse(result);
return {
bars: resultObj.bars || {},
estimated_bpm: resultObj.estimated_bpm || null,
detected_beats_per_bar: resultObj.detected_beats_per_bar || null
};
}
// Use Python server for postprocessing
async logits_to_bars_online(beatLogits, downbeatLogits, min_bpm, max_bpm, beats_per_bar) {
try {
const response = await fetch(`${this.serverUrl}/logits_to_bars`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
beat_logits: beatLogits,
downbeat_logits: downbeatLogits,
min_bpm: min_bpm,
max_bpm: max_bpm,
beats_per_bar: beats_per_bar,
})
});
if (!response.ok) {
throw new Error(`Server returned ${response.status}: ${response.statusText}`);
}
const result = await response.json();
if (result.error) {
throw new Error(`Server error: ${result.error}`);
}
console.log(`Server postprocessing results: ${result.bars ? Object.keys(result.bars).length : 0} bars`);
// Return bars along with estimated_bpm and detected_beats_per_bar
return {
bars: result.bars || {},
estimated_bpm: result.estimated_bpm || null,
detected_beats_per_bar: result.detected_beats_per_bar || null
};
} catch (error) {
console.error("Error in server postprocessing:", error);
// Return empty object as fallback
return {
bars: {},
estimated_bpm: null,
detected_beats_per_bar: null
};
}
}
// Check server status
async checkServerStatus() {
try {
const response = await fetch(`${this.serverUrl}/health`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
}
});
return response.ok;
} catch (error) {
console.warn("Server health check failed:", error);
return false;
}
}
aggregatePrediction(predChunks, starts, fullSize, chunkSize, borderSize, overlapMode) {
let processedChunks = predChunks;
// Remove borders if borderSize > 0
if (borderSize > 0) {
processedChunks = predChunks.map(pchunk => ({
beat: pchunk.beat.slice(borderSize, -borderSize),
downbeat: pchunk.downbeat.slice(borderSize, -borderSize)
}));
}
// Initialize arrays with very low values (equivalent to -1000.0 in Python)
const piecePredictionBeat = new Array(fullSize).fill(-1000.0);
const piecePredictionDownbeat = new Array(fullSize).fill(-1000.0);
// Prepare iteration based on overlap mode
let chunksToProcess = processedChunks;
let startsToProcess = starts;
if (overlapMode === "keep_first") {
// Process in reverse order so earlier predictions overwrite later ones
chunksToProcess = [...processedChunks].reverse();
startsToProcess = [...starts].reverse();
}
// Aggregate predictions
for (let i = 0; i < chunksToProcess.length; i++) {
const start = startsToProcess[i];
const pchunk = chunksToProcess[i];
const effectiveStart = start + borderSize;
const effectiveEnd = start + chunkSize - borderSize;
// Copy predictions to the appropriate positions
for (let j = 0; j < pchunk.beat.length; j++) {
const pos = effectiveStart + j;
if (pos < fullSize) {
piecePredictionBeat[pos] = pchunk.beat[j];
piecePredictionDownbeat[pos] = pchunk.downbeat[j];
}
}
}
return {
beat: piecePredictionBeat,
downbeat: piecePredictionDownbeat
};
}
flattenArray(arr) {
const flat = [];
for (let i = 0; i < arr.length; i++) {
for (let j = 0; j < arr[i].length; j++) {
flat.push(arr[i][j]);
}
}
return flat;
}
}