bs_polarformer / index.html
bgkb's picture
Remove enableGraphCapture (incompatible), keep time_frames override
9158719 verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>BS PolarFormer – Vocal Separator</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body {
font-family: system-ui, -apple-system, sans-serif;
background: #0a0e17; color: #c8d0e0;
min-height: 100vh; display: flex; flex-direction: column;
align-items: center; padding: 2rem;
}
h1 { font-size: 1.4rem; margin-bottom: .3rem; color: #e2e8f0; }
.subtitle { font-size: .85rem; color: #64748b; margin-bottom: 1.5rem; }
.card {
background: #141a2a; border: 1px solid #1e293b; border-radius: 12px;
padding: 1.5rem; width: 100%; max-width: 600px; margin-bottom: 1rem;
}
label.upload {
display: flex; flex-direction: column; align-items: center; gap: .5rem;
padding: 2rem; border: 2px dashed #334155; border-radius: 8px;
cursor: pointer; transition: border-color .2s;
}
label.upload:hover { border-color: #4f8ff7; }
label.upload svg { width: 36px; height: 36px; stroke: #4f8ff7; fill: none; }
input[type=file] { display: none; }
.file-name { font-size: .85rem; color: #94a3b8; }
.controls { display: flex; gap: .75rem; margin-top: 1rem; flex-wrap: wrap; }
button {
padding: .55rem 1.2rem; border: none; border-radius: 6px;
font-size: .85rem; cursor: pointer; transition: .15s;
}
button:disabled { opacity: .4; cursor: not-allowed; }
#btn-run { background: #4f8ff7; color: #fff; }
#btn-run:hover:not(:disabled) { background: #3b7de6; }
.btn-dl { background: #1e293b; color: #c8d0e0; }
.btn-dl:hover:not(:disabled) { background: #2a3a52; }
.progress-wrap {
margin-top: 1rem; background: #0f1520; border-radius: 6px;
overflow: hidden; height: 6px; display: none;
}
.progress-bar {
height: 100%; background: #4f8ff7; width: 0%; transition: width .3s;
}
#status {
margin-top: .75rem; font-size: .8rem; color: #94a3b8;
min-height: 1.2rem; white-space: pre-line;
}
.outputs { display: none; }
.player { margin-top: .75rem; }
.player-label { font-size: .8rem; color: #64748b; margin-bottom: .25rem; }
audio { width: 100%; height: 36px; }
#model-select { display: flex; gap: .75rem; align-items: center; margin-top: .75rem; font-size: .85rem; }
#model-select label { cursor: pointer; }
#model-select input { margin-right: .25rem; }
.warn { color: #f59e0b; font-size: .78rem; margin-top: .5rem; display: none; }
</style>
</head>
<body>
<h1>BS PolarFormer Vocal Separator</h1>
<p class="subtitle">Runs entirely in your browser via ONNX Runtime + WebGPU</p>
<div class="card">
<label class="upload" id="drop-zone">
<svg viewBox="0 0 24 24" stroke-width="1.5"><path d="M12 16V4m0 0l-4 4m4-4l4 4"/><path d="M2 17l.621 2.485A2 2 0 004.561 21h14.878a2 2 0 001.94-1.515L22 17"/></svg>
<span>Drop an audio file or click to browse</span>
<span class="file-name" id="file-name"></span>
<input type="file" id="file-input" accept="audio/*">
</label>
<div id="model-select">
<span>Model:</span>
<label><input type="radio" name="precision" value="fp32" checked> FP32 (201 MB)</label>
<label><input type="radio" name="precision" value="fp16"> FP16 (103 MB)</label>
</div>
<div class="controls">
<button id="btn-run" disabled>Separate Vocals</button>
<button class="btn-dl" id="btn-dl-vocals" disabled>Download Vocals</button>
<button class="btn-dl" id="btn-dl-other" disabled>Download Other</button>
</div>
<div class="progress-wrap" id="progress-wrap">
<div class="progress-bar" id="progress-bar"></div>
</div>
<div id="status"></div>
<div class="warn" id="webgpu-warn">WebGPU not available or GPU has insufficient storage buffer limits β€” using WASM (slower).</div>
<div class="outputs" id="outputs">
<div class="player">
<div class="player-label">Vocals</div>
<audio id="audio-vocals" controls></audio>
</div>
<div class="player">
<div class="player-label">Other (instrumental)</div>
<audio id="audio-other" controls></audio>
</div>
</div>
</div>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.21.0/dist/ort.all.min.js"></script>
<script>
// ── Config matching the model YAML ──────────────────────────────────────────
const SAMPLE_RATE = 44100;
const N_FFT = 2048;
const HOP = 512;
const WIN = 2048;
const N_FREQ = N_FFT / 2 + 1; // 1025
const AUDIO_CH = 2; // stereo
const CHUNK = 131072; // samples per chunk
const OVERLAP = 2;
// ── ONNX model paths (downloaded from Hugging Face) ────────────────────────
const HF_BASE = 'https://huggingface.co/bgkb/bs_polarformer/resolve/main';
// WebGPU models have cascaded Split/Concat ops to fit within 8-buffer limit
const MODEL_PATHS = {
fp32: { wasm: `${HF_BASE}/bs_polarformer.onnx`, webgpu: `${HF_BASE}/bs_polarformer_webgpu.onnx` },
fp16: { wasm: `${HF_BASE}/bs_polarformer_fp16.onnx`, webgpu: `${HF_BASE}/bs_polarformer_webgpu_fp16.onnx` },
};
// ── State ───────────────────────────────────────────────────────────────────
let audioBuffer = null; // decoded AudioBuffer
let vocalsBlob = null;
let otherBlob = null;
// ── DOM refs ────────────────────────────────────────────────────────────────
const $ = id => document.getElementById(id);
const fileInput = $('file-input');
const fileName = $('file-name');
const btnRun = $('btn-run');
const btnDlV = $('btn-dl-vocals');
const btnDlO = $('btn-dl-other');
const statusEl = $('status');
const progressWrap = $('progress-wrap');
const progressBar = $('progress-bar');
const outputsEl = $('outputs');
const warn = $('webgpu-warn');
// ── Helpers ─────────────────────────────────────────────────────────────────
function setStatus(msg) { statusEl.textContent = msg; }
function setProgress(frac) {
progressWrap.style.display = 'block';
progressBar.style.width = (frac * 100).toFixed(1) + '%';
}
function makeZeroInputTensor() {
// Probe tensor that matches real model input shape:
// [batch, frames, features] = [1, 253, 4100]
// 253 comes from CHUNK=131072, N_FFT=2048, HOP=512.
const probeFrames = Math.floor((CHUNK - N_FFT) / HOP) + 1;
const probeFeatures = N_FREQ * AUDIO_CH * 2;
return new ort.Tensor('float32', new Float32Array(probeFrames * probeFeatures), [1, probeFrames, probeFeatures]);
}
async function createStableSession(precision, preferredProvider) {
const candidates = preferredProvider === 'webgpu' ? ['webgpu', 'wasm'] : ['wasm'];
let lastError = null;
for (const provider of candidates) {
try {
const opts = {
executionProviders: [provider],
graphOptimizationLevel: 'all',
};
if (provider === 'webgpu') {
ort.env.webgpu.powerPreference = 'high-performance';
// Fix both dimensions so ORT can pre-plan GPU memory layouts
const probeFrames = Math.floor((CHUNK - N_FFT) / HOP) + 1;
opts.freeDimensionOverrides = { batch: 1, time_frames: probeFrames };
} else {
// WASM: use all available cores
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
}
const session = await ort.InferenceSession.create(MODEL_PATHS[precision][provider], opts);
// Important: some WebGPU failures only appear on first run() (shader generation),
// not during session creation. Run a cheap warmup to catch those early.
const probeInput = makeZeroInputTensor();
await session.run({ stft_features: probeInput });
return { session, provider };
} catch (e) {
console.warn(`Failed to initialize ${provider} backend:`, e);
lastError = e;
}
}
throw lastError || new Error('Failed to create inference session.');
}
// ── Audio decode ────────────────────────────────────────────────────────────
const audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: SAMPLE_RATE });
fileInput.addEventListener('change', async () => {
const file = fileInput.files[0];
if (!file) return;
fileName.textContent = file.name;
setStatus('Decoding audio...');
const arrayBuf = await file.arrayBuffer();
audioBuffer = await audioCtx.decodeAudioData(arrayBuf);
setStatus(`Loaded: ${audioBuffer.duration.toFixed(1)}s, ${audioBuffer.numberOfChannels}ch, ${audioBuffer.sampleRate}Hz`);
btnRun.disabled = false;
});
// ── Drag & drop ─────────────────────────────────────────────────────────────
const dropZone = $('drop-zone');
dropZone.addEventListener('dragover', e => { e.preventDefault(); dropZone.style.borderColor = '#4f8ff7'; });
dropZone.addEventListener('dragleave', () => { dropZone.style.borderColor = '#334155'; });
dropZone.addEventListener('drop', e => {
e.preventDefault(); dropZone.style.borderColor = '#334155';
if (e.dataTransfer.files.length) {
fileInput.files = e.dataTransfer.files;
fileInput.dispatchEvent(new Event('change'));
}
});
// ── DSP: Hann window, STFT, iSTFT ──────────────────────────────────────────
function hannWindow(len) {
const w = new Float32Array(len);
for (let i = 0; i < len; i++) w[i] = 0.5 * (1 - Math.cos(2 * Math.PI * i / len));
return w;
}
/** Radix-2 Cooley-Tukey FFT (in-place, complex interleaved [re,im,...]) */
function fftInPlace(re, im, N) {
// Bit-reversal permutation
for (let i = 1, j = 0; i < N; i++) {
let bit = N >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) {
let tmp = re[i]; re[i] = re[j]; re[j] = tmp;
tmp = im[i]; im[i] = im[j]; im[j] = tmp;
}
}
// FFT butterflies
for (let len = 2; len <= N; len <<= 1) {
const half = len >> 1;
const angle = -2 * Math.PI / len;
const wRe = Math.cos(angle), wIm = Math.sin(angle);
for (let i = 0; i < N; i += len) {
let curRe = 1, curIm = 0;
for (let j = 0; j < half; j++) {
const a = i + j, b = i + j + half;
const tRe = curRe * re[b] - curIm * im[b];
const tIm = curRe * im[b] + curIm * re[b];
re[b] = re[a] - tRe; im[b] = im[a] - tIm;
re[a] += tRe; im[a] += tIm;
const newCurRe = curRe * wRe - curIm * wIm;
curIm = curRe * wIm + curIm * wRe;
curRe = newCurRe;
}
}
}
}
/** Real FFT of length N (returns N/2+1 complex pairs as interleaved [re,im,...]) */
function rfft(x, N) {
const re = new Float32Array(N);
const im = new Float32Array(N);
re.set(x);
fftInPlace(re, im, N);
const out = new Float32Array((N / 2 + 1) * 2);
for (let k = 0; k <= N / 2; k++) {
out[k * 2] = re[k];
out[k * 2 + 1] = im[k];
}
return out;
}
/** Inverse real FFT: takes N/2+1 complex pairs, returns N real samples */
function irfft(spec, N) {
const re = new Float32Array(N);
const im = new Float32Array(N);
const half = N / 2;
// Fill full spectrum using conjugate symmetry
for (let k = 0; k <= half; k++) {
re[k] = spec[k * 2];
im[k] = -spec[k * 2 + 1]; // conjugate for inverse
}
for (let k = 1; k < half; k++) {
re[N - k] = spec[k * 2];
im[N - k] = spec[k * 2 + 1]; // already negated above cancels with conjugate symmetry
}
fftInPlace(re, im, N);
const out = new Float32Array(N);
for (let i = 0; i < N; i++) out[i] = re[i] / N;
return out;
}
/**
* STFT on a mono signal.
* Returns Float32Array of shape [n_freq, n_frames, 2] flattened.
*/
function stft(signal, nFft, hop, win) {
const nFrames = Math.floor((signal.length - nFft) / hop) + 1;
const out = new Float32Array(N_FREQ * nFrames * 2);
const windowed = new Float32Array(nFft);
for (let t = 0; t < nFrames; t++) {
const off = t * hop;
for (let i = 0; i < nFft; i++) windowed[i] = (signal[off + i] || 0) * win[i];
const spec = rfft(windowed, nFft);
for (let f = 0; f < N_FREQ; f++) {
out[(f * nFrames + t) * 2] = spec[f * 2];
out[(f * nFrames + t) * 2 + 1] = spec[f * 2 + 1];
}
}
return { data: out, nFrames };
}
/**
* iSTFT: reconstruct from [n_freq, n_frames, 2] complex STFT.
*/
function istft(stftData, nFrames, nFft, hop, win, length) {
const out = new Float32Array(length);
const winSum = new Float32Array(length);
const spec = new Float32Array((N_FREQ) * 2);
for (let t = 0; t < nFrames; t++) {
for (let f = 0; f < N_FREQ; f++) {
spec[f * 2] = stftData[(f * nFrames + t) * 2];
spec[f * 2 + 1] = stftData[(f * nFrames + t) * 2 + 1];
}
const frame = irfft(spec, nFft);
const off = t * hop;
for (let i = 0; i < nFft && off + i < length; i++) {
out[off + i] += frame[i] * win[i];
winSum[off + i] += win[i] * win[i];
}
}
for (let i = 0; i < length; i++) {
if (winSum[i] > 1e-8) out[i] /= winSum[i];
}
return out;
}
// ── Prepare model input from stereo chunk ───────────────────────────────────
function prepareChunkInput(left, right, win) {
const stftL = stft(left, N_FFT, HOP, win);
const stftR = stft(right, N_FFT, HOP, win);
const nFrames = stftL.nFrames;
// Interleave: (f_left, f_right) for each freq -> shape (n_freq*2, nFrames, 2)
// Then flatten to (nFrames, n_freq*2*2) = (nFrames, 4100)
const totalFreqs = N_FREQ * AUDIO_CH; // 2050
const featDim = totalFreqs * 2; // 4100
const input = new Float32Array(nFrames * featDim);
for (let t = 0; t < nFrames; t++) {
for (let f = 0; f < N_FREQ; f++) {
const lRe = stftL.data[(f * nFrames + t) * 2];
const lIm = stftL.data[(f * nFrames + t) * 2 + 1];
const rRe = stftR.data[(f * nFrames + t) * 2];
const rIm = stftR.data[(f * nFrames + t) * 2 + 1];
// band order: (f*2) = left, (f*2+1) = right
const base = t * featDim;
input[base + (f * 2) * 2] = lRe;
input[base + (f * 2) * 2 + 1] = lIm;
input[base + (f * 2 + 1) * 2] = rRe;
input[base + (f * 2 + 1) * 2 + 1] = rIm;
}
}
return { input, nFrames, stftL, stftR };
}
// ── Apply mask and iSTFT ────────────────────────────────────────────────────
function applyMaskAndReconstruct(mask, stftL, stftR, nFrames, win, length) {
// mask shape: [1, 1, 2050, nFrames, 2] flattened
// stft*: [n_freq, nFrames, 2]
const maskedL = new Float32Array(N_FREQ * nFrames * 2);
const maskedR = new Float32Array(N_FREQ * nFrames * 2);
for (let f = 0; f < N_FREQ; f++) {
for (let t = 0; t < nFrames; t++) {
// mask indices: [0, 0, band, t, re/im]
const mLIdx = ((f * 2) * nFrames + t) * 2;
const mRIdx = ((f * 2 + 1) * nFrames + t) * 2;
const mLRe = mask[mLIdx], mLIm = mask[mLIdx + 1];
const mRRe = mask[mRIdx], mRIm = mask[mRIdx + 1];
const sIdx = (f * nFrames + t) * 2;
const sLRe = stftL.data[sIdx], sLIm = stftL.data[sIdx + 1];
const sRRe = stftR.data[sIdx], sRIm = stftR.data[sIdx + 1];
// Complex multiply
maskedL[sIdx] = sLRe * mLRe - sLIm * mLIm;
maskedL[sIdx + 1] = sLRe * mLIm + sLIm * mLRe;
maskedR[sIdx] = sRRe * mRRe - sRIm * mRIm;
maskedR[sIdx + 1] = sRRe * mRIm + sRIm * mRRe;
}
// Zero DC
maskedL[(f === 0 ? 0 : -1)] // handled below
}
// Zero DC bin
for (let t = 0; t < nFrames; t++) {
maskedL[t * 2] = 0; maskedL[t * 2 + 1] = 0;
maskedR[t * 2] = 0; maskedR[t * 2 + 1] = 0;
}
const reconL = istft(maskedL, nFrames, N_FFT, HOP, win, length);
const reconR = istft(maskedR, nFrames, N_FFT, HOP, win, length);
return { left: reconL, right: reconR };
}
// ── WAV encoding ────────────────────────────────────────────────────────────
function encodeWav(left, right, sr) {
const nSamples = left.length;
const buf = new ArrayBuffer(44 + nSamples * 4); // 16-bit stereo
const view = new DataView(buf);
const writeStr = (off, s) => { for (let i = 0; i < s.length; i++) view.setUint8(off + i, s.charCodeAt(i)); };
writeStr(0, 'RIFF');
view.setUint32(4, 36 + nSamples * 4, true);
writeStr(8, 'WAVE');
writeStr(12, 'fmt ');
view.setUint32(16, 16, true);
view.setUint16(20, 1, true);
view.setUint16(22, 2, true);
view.setUint32(24, sr, true);
view.setUint32(28, sr * 4, true);
view.setUint16(32, 4, true);
view.setUint16(34, 16, true);
writeStr(36, 'data');
view.setUint32(40, nSamples * 4, true);
let off = 44;
for (let i = 0; i < nSamples; i++) {
const l = Math.max(-1, Math.min(1, left[i]));
const r = Math.max(-1, Math.min(1, right[i]));
view.setInt16(off, l * 32767, true); off += 2;
view.setInt16(off, r * 32767, true); off += 2;
}
return new Blob([buf], { type: 'audio/wav' });
}
// ── Main separation pipeline ────────────────────────────────────────────────
btnRun.addEventListener('click', async () => {
if (!audioBuffer) return;
btnRun.disabled = true;
btnDlV.disabled = true;
btnDlO.disabled = true;
outputsEl.style.display = 'none';
const totalSamples = audioBuffer.length;
const left = audioBuffer.getChannelData(0);
const right = audioBuffer.numberOfChannels > 1 ? audioBuffer.getChannelData(1) : left;
// Check WebGPU availability
let provider = 'wasm';
if (navigator.gpu) {
try {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) provider = 'webgpu';
} catch (e) {
console.warn('WebGPU probe failed:', e);
}
}
if (provider === 'wasm') warn.style.display = 'block';
else warn.style.display = 'none';
const precision = document.querySelector('input[name=precision]:checked').value;
setStatus(`Loading ${precision.toUpperCase()} model (${provider})...`);
setProgress(0);
let session;
try {
const stable = await createStableSession(precision, provider);
session = stable.session;
provider = stable.provider;
warn.style.display = provider === 'wasm' ? 'block' : 'none';
} catch (e) {
console.error('Unable to initialize any backend:', e);
setStatus(`Failed to initialize model backend: ${e?.message || e}`);
btnRun.disabled = false;
return;
}
const win = hannWindow(WIN);
const step = Math.floor(CHUNK / OVERLAP);
const starts = [];
for (let s = 0; s < totalSamples; s += step) starts.push(s);
const vocalsL = new Float32Array(totalSamples);
const vocalsR = new Float32Array(totalSamples);
const count = new Float32Array(totalSamples);
const t0 = performance.now();
for (let ci = 0; ci < starts.length; ci++) {
const start = starts[ci];
const end = Math.min(start + CHUNK, totalSamples);
const chunkLen = end - start;
// Extract & pad chunk
const cL = new Float32Array(CHUNK);
const cR = new Float32Array(CHUNK);
cL.set(left.subarray(start, end));
cR.set(right.subarray(start, end));
// STFT & prepare input
const tStft0 = performance.now();
const { input, nFrames, stftL, stftR } = prepareChunkInput(cL, cR, win);
const tStft1 = performance.now();
// Run ONNX
const tensor = new ort.Tensor('float32', input, [1, nFrames, N_FREQ * AUDIO_CH * 2]);
const tOrt0 = performance.now();
const results = await session.run({ stft_features: tensor });
const mask = results.mask.data;
const tOrt1 = performance.now();
// Reconstruct
const recon = applyMaskAndReconstruct(mask, stftL, stftR, nFrames, win, CHUNK);
const tRecon = performance.now();
if (ci === 0) console.log(`Chunk timing: STFT=${(tStft1-tStft0).toFixed(0)}ms, ORT=${(tOrt1-tOrt0).toFixed(0)}ms, iSTFT=${(tRecon-tOrt1).toFixed(0)}ms`);
// Accumulate with overlap
for (let i = 0; i < chunkLen; i++) {
vocalsL[start + i] += recon.left[i];
vocalsR[start + i] += recon.right[i];
count[start + i] += 1;
}
const frac = (ci + 1) / starts.length;
const elapsed = (performance.now() - t0) / 1000;
const eta = elapsed / frac * (1 - frac);
setProgress(frac);
setStatus(`Chunk ${ci + 1}/${starts.length} Β· ${elapsed.toFixed(1)}s elapsed Β· ~${eta.toFixed(0)}s remaining`);
// Yield to UI
await new Promise(r => setTimeout(r, 0));
}
// Average overlaps
for (let i = 0; i < totalSamples; i++) {
if (count[i] > 0) { vocalsL[i] /= count[i]; vocalsR[i] /= count[i]; }
}
const elapsed = ((performance.now() - t0) / 1000).toFixed(1);
const duration = audioBuffer.duration.toFixed(1);
const rtf = (parseFloat(elapsed) / audioBuffer.duration).toFixed(2);
// Build other = original - vocals
const otherL = new Float32Array(totalSamples);
const otherR = new Float32Array(totalSamples);
for (let i = 0; i < totalSamples; i++) {
otherL[i] = left[i] - vocalsL[i];
otherR[i] = right[i] - vocalsR[i];
}
vocalsBlob = encodeWav(vocalsL, vocalsR, SAMPLE_RATE);
otherBlob = encodeWav(otherL, otherR, SAMPLE_RATE);
$('audio-vocals').src = URL.createObjectURL(vocalsBlob);
$('audio-other').src = URL.createObjectURL(otherBlob);
outputsEl.style.display = 'block';
btnDlV.disabled = false;
btnDlO.disabled = false;
btnRun.disabled = false;
setStatus(`Done in ${elapsed}s (${duration}s audio, ${rtf}x realtime) Β· ${provider.toUpperCase()}`);
});
// ── Downloads ───────────────────────────────────────────────────────────────
function download(blob, name) {
const a = document.createElement('a');
a.href = URL.createObjectURL(blob);
a.download = name;
a.click();
}
btnDlV.addEventListener('click', () => vocalsBlob && download(vocalsBlob, 'vocals.wav'));
btnDlO.addEventListener('click', () => otherBlob && download(otherBlob, 'other.wav'));
</script>
</body>
</html>