WhisperLiveKit / app.py
Nekochu's picture
Inital commit
395c488 verified
"""
Transcription Engine Comparison Space — Free CPU.
Single-file app: Gradio + FastAPI routes + WhisperLiveKit WebSocket +
Voxtral Realtime browser-side transcription (WebGPU via transformers.js) +
inline recorder UI (HTML/CSS/JS).
FER runs entirely in browser via ONNX (no server cost).
The only external file is static/emotion_model_web.onnx (~4.8MB).
"""
import base64
import logging
import os
import gc
import sys
import traceback
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
_CLI_MODE = len(sys.argv) > 1 and not sys.argv[1].startswith("--")
if _CLI_MODE:
# -- CLI MODE: transcribe + diarize an audio file -------------------------
import numpy as np
import torch
import librosa
from pyannote.audio import Pipeline
audio_file = sys.argv[1]
print(f"Loading: {audio_file}")
audio, _ = librosa.load(audio_file, sr=16000, mono=True)
audio = audio.astype(np.float32)
print(f"Audio: {len(audio)/16000:.1f}s")
# Diarization
print("Loading diarization (pyannote)...")
pipeline = Pipeline.from_pretrained("models/speaker-diarization-3.1")
waveform = torch.tensor(audio).unsqueeze(0)
result = pipeline({"waveform": waveform, "sample_rate": 16000})
diar = result.speaker_diarization
# Post-processing: merge speakers with similar embeddings (numpy, no sklearn)
speaker_labels = sorted(diar.labels())
merge_map = {}
if hasattr(result, "speaker_embeddings") and result.speaker_embeddings is not None and len(speaker_labels) > 1:
emb = result.speaker_embeddings
norms = np.linalg.norm(emb, axis=1, keepdims=True)
norms[norms == 0] = 1
sim = (emb / norms) @ (emb / norms).T
for i in range(len(speaker_labels)):
for j in range(i + 1, len(speaker_labels)):
if sim[i][j] >= 0.6:
target = merge_map.get(speaker_labels[i], speaker_labels[i])
merge_map[speaker_labels[j]] = target
print(f" Merging {speaker_labels[j]} -> {target} (sim: {sim[i][j]:.3f})")
merged = []
speakers_seen = set()
for turn, _, spk in diar.itertracks(yield_label=True):
actual_spk = merge_map.get(spk, spk)
speaker_id = int(actual_spk.split("_")[-1]) + 1
speakers_seen.add(speaker_id)
merged.append({"start": turn.start, "end": turn.end, "speakers": [speaker_id]})
num_speakers = len(speakers_seen)
print(f"Speakers: {num_speakers} | Segments: {len(merged)}\n")
# Transcription (Parakeet)
print("Running Parakeet TDT v3 (with timestamps)...")
import onnx_asr
model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", providers=["CPUExecutionProvider"]).with_timestamps()
audio_int16 = (audio * 32767).astype(np.int16)
output = model.recognize(audio_int16)
del model
gc.collect()
tokens = output.tokens if hasattr(output, "tokens") else []
timestamps = output.timestamps if hasattr(output, "timestamps") else []
# Reconstruct full words from subword tokens
words = [] # list of {"text": str, "start": float, "end": float}
current_word = ""
current_start = 0.0
current_end = 0.0
for tok, ts in zip(tokens, timestamps):
if tok.startswith(" ") or tok.startswith("\n"):
if current_word.strip():
words.append({"text": current_word, "start": current_start, "end": current_end})
current_word = tok
current_start = ts
current_end = ts
else:
if not current_word:
current_start = ts
current_word += tok
current_end = ts
if current_word.strip():
words.append({"text": current_word, "start": current_start, "end": current_end})
# Align each word to speaker with greatest temporal overlap
def best_speaker(word_start, word_end):
best = None
max_overlap = 0
for seg in merged:
ov_start = max(word_start, seg["start"])
ov_end = min(word_end, seg["end"])
if ov_start < ov_end:
overlap = ov_end - ov_start
if overlap > max_overlap:
max_overlap = overlap
best = " & ".join(f"SPEAKER {s}" for s in seg["speakers"])
return best
# Assign speaker to each word, then merge consecutive same-speaker
labeled = []
for w in words:
spk = best_speaker(w["start"], w["end"] + 0.05)
if spk is None:
spk = labeled[-1][0] if labeled else "UNKNOWN"
labeled.append((spk, w["start"], w["text"]))
# Merge consecutive same-speaker words
print("=" * 60)
print(f"{num_speakers} speakers detected:\n")
if labeled:
current_spk = labeled[0][0]
current_start = labeled[0][1]
current_text = labeled[0][2]
for spk, ts, txt in labeled[1:]:
if spk == current_spk:
current_text += txt
else:
chunk = current_text.strip()
if chunk:
m, s = divmod(int(current_start), 60)
print(f"{current_spk} [{m:02d}:{s:02d}]: {chunk}")
current_spk = spk
current_start = ts
current_text = txt
chunk = current_text.strip()
if chunk:
m, s = divmod(int(current_start), 60)
print(f"{current_spk} [{m:02d}:{s:02d}]: {chunk}")
else:
text = output.text if hasattr(output, "text") else str(output)
print(text)
print("=" * 60)
sys.exit(0)
if not _CLI_MODE:
import gradio as gr
import asyncio
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from starlette.requests import Request
from starlette.responses import Response
from whisperlivekit import TranscriptionEngine, AudioProcessor
# -- WhisperLiveKit engine (loaded at startup, ~3-6GB) -------------------
logger.info("Loading TranscriptionEngine (large-v3-turbo model, CPU)...")
transcription_engine = TranscriptionEngine(
model_size="large-v3-turbo",
vac=True,
min_chunk_size=1.0,
lan="auto",
direct_english_translation=False,
)
logger.info("TranscriptionEngine ready.")
# -- Inline CSS --------------------------------------------------------------
RECORDER_CSS = r"""
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
:root {
--bg: #0f0f0f;
--surface: #1a1a2e;
--surface2: #16213e;
--accent: #e94560;
--accent2: #0f3460;
--text: #eee;
--text-dim: #888;
--success: #4ecca3;
--warning: #f5a623;
--radius: 12px;
--font: 'Segoe UI', system-ui, -apple-system, sans-serif;
}
body {
background: var(--bg);
color: var(--text);
font-family: var(--font);
padding: 12px;
line-height: 1.5;
}
.mode-selector {
display: flex;
gap: 6px;
justify-content: center;
margin-bottom: 8px;
flex-wrap: wrap;
}
.engine-btn {
background: var(--surface);
border: 2px solid transparent;
border-radius: var(--radius);
padding: 8px 14px;
cursor: pointer;
font-size: 0.85rem;
color: var(--text);
transition: all 0.2s;
}
.engine-btn:hover { border-color: var(--accent2); }
.engine-btn.active {
border-color: var(--accent);
background: var(--surface2);
}
.options-row {
display: flex;
gap: 16px;
justify-content: center;
align-items: center;
margin-bottom: 8px;
flex-wrap: wrap;
}
.options-row label {
font-size: 0.85rem;
color: var(--text-dim);
cursor: pointer;
display: flex;
align-items: center;
gap: 6px;
}
.options-row input[type="checkbox"] {
accent-color: var(--accent);
}
.controls {
display: flex;
align-items: center;
justify-content: center;
gap: 16px;
margin-bottom: 10px;
}
#recordButton {
width: 64px;
height: 64px;
border-radius: 50%;
border: 3px solid var(--accent);
background: transparent;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.3s;
flex-shrink: 0;
}
#recordButton .inner {
width: 28px;
height: 28px;
background: var(--accent);
border-radius: 50%;
transition: all 0.3s;
}
#recordButton.recording .inner {
border-radius: 4px;
width: 24px;
height: 24px;
}
#recordButton:hover {
transform: scale(1.05);
}
.upload-btn {
width: 40px;
height: 40px;
border-radius: 50%;
border: 2px solid var(--accent2);
background: transparent;
color: var(--text-dim);
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s;
}
.upload-btn:hover { border-color: var(--accent); color: var(--text); }
.timer {
font-size: 1.2rem;
font-variant-numeric: tabular-nums;
color: var(--text-dim);
min-width: 60px;
}
.timer.recording {
color: var(--accent);
}
#waveCanvas {
width: 200px;
height: 48px;
border-radius: 8px;
background: var(--surface);
}
#status {
text-align: center;
font-size: 0.85rem;
color: var(--text-dim);
margin-bottom: 8px;
min-height: 1.3em;
}
#status.error { color: var(--accent); }
#status.success { color: var(--success); }
.results-grid {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 8px;
margin-top: 8px;
}
/* Panel visibility controlled by JS updateResultsLayout() */
.result-panel {
background: var(--surface);
border-radius: var(--radius);
padding: 12px;
min-height: 120px;
position: relative;
}
.copy-btn {
position: absolute;
top: 6px;
right: 6px;
background: transparent;
border: 1px solid rgba(255,255,255,0.15);
color: var(--text-dim);
cursor: pointer;
padding: 3px 10px;
border-radius: 6px;
font-size: 0.72rem;
z-index: 2;
display: flex;
align-items: center;
gap: 4px;
transition: all 0.15s;
}
.copy-btn:hover { background: rgba(255,255,255,0.1); color: var(--text); border-color: rgba(255,255,255,0.3); }
.copy-btn.copied { color: var(--success); border-color: var(--success); }
.short-hint {
color: var(--text-dim);
font-size: 0.8rem;
margin-top: 8px;
font-style: italic;
}
.result-panel h3 {
font-size: 0.9rem;
margin-bottom: 6px;
display: flex;
align-items: center;
gap: 6px;
flex-wrap: wrap;
padding-right: 50px;
}
.badge {
font-size: 0.7rem;
padding: 2px 8px;
border-radius: 999px;
font-weight: 500;
}
.badge.realtime { background: var(--success); color: #000; }
.badge.browser { background: var(--warning); color: #000; }
.timing {
font-size: 0.8rem;
color: var(--text-dim);
margin-bottom: 8px;
}
.transcript {
font-size: 0.85rem;
line-height: 1.6;
white-space: pre-wrap;
word-break: break-word;
max-height: 220px;
overflow-y: auto;
background: rgba(0,0,0,0.25);
border-radius: 8px;
padding: 10px 12px;
font-family: 'SF Mono', 'Cascadia Code', 'Fira Code', 'Consolas', monospace;
border: 1px solid rgba(255,255,255,0.06);
}
.transcript .buffer {
color: var(--text-dim);
font-style: italic;
}
.transcript .line {
margin-bottom: 4px;
}
.transcript .timestamp {
color: var(--accent2);
font-size: 0.75rem;
margin-right: 6px;
opacity: 0.7;
font-variant-numeric: tabular-nums;
}
.transcript .speaker {
color: var(--success);
font-weight: 600;
font-size: 0.8rem;
margin-right: 6px;
}
.spinner {
display: inline-block;
width: 20px;
height: 20px;
border: 2px solid var(--text-dim);
border-top-color: var(--accent);
border-radius: 50%;
animation: spin 0.8s linear infinite;
margin-right: 8px;
vertical-align: middle;
}
@keyframes spin { to { transform: rotate(360deg); } }
.fer-container {
position: relative;
display: flex;
justify-content: center;
margin: 0 auto 8px;
}
.fer-container.hidden { display: none; }
#webcamVideo {
width: 240px;
height: 180px;
border-radius: var(--radius);
object-fit: cover;
transform: scaleX(-1);
background: #000;
}
.emotion-bars {
position: absolute;
bottom: 8px;
left: 8px;
right: 8px;
display: flex;
flex-direction: column;
gap: 2px;
background: rgba(0,0,0,0.6);
padding: 6px;
border-radius: 6px;
font-size: 0.65rem;
}
.emotion-bar {
display: flex;
align-items: center;
gap: 4px;
}
.emotion-bar .label {
width: 55px;
text-align: right;
flex-shrink: 0;
}
.emotion-bar .bar {
flex: 1;
height: 6px;
background: rgba(255,255,255,0.15);
border-radius: 3px;
overflow: hidden;
}
.emotion-bar .fill {
height: 100%;
border-radius: 3px;
transition: width 0.3s;
background: var(--success);
}
.emotion-bar .pct {
width: 30px;
text-align: right;
font-variant-numeric: tabular-nums;
}
#webgpuWarning {
display: none;
text-align: center;
padding: 10px 16px;
margin-bottom: 12px;
background: rgba(233, 69, 96, 0.15);
border: 1px solid var(--accent);
border-radius: var(--radius);
font-size: 0.85rem;
color: var(--accent);
}
.progress-bar-container {
width: 100%;
background: rgba(255,255,255,0.1);
border-radius: 4px;
margin: 6px 0;
height: 8px;
overflow: hidden;
}
.progress-bar-fill {
height: 100%;
background: var(--success);
border-radius: 4px;
transition: width 0.3s;
width: 0%;
}
@media (max-width: 600px) {
.results-grid { grid-template-columns: 1fr; }
.results-grid .result-panel { display: block !important; }
#waveCanvas { width: 120px; }
}
"""
# -- Inline FER JS -----------------------------------------------------------
FER_JS = r"""
const FER_LABELS = [
"Anger", "Contempt", "Disgust", "Fear",
"Happy", "Neutral", "Sad", "Surprise"
];
const IMAGE_SIZE = 224;
const IMAGENET_MEAN = [0.485, 0.456, 0.406];
const IMAGENET_STD = [0.229, 0.224, 0.225];
let ferSession = null;
let ferCanvas = null;
let ortModule = null;
function softmax(scores) {
let max = -Infinity;
for (let i = 0; i < scores.length; i++) {
if (scores[i] > max) max = scores[i];
}
const exps = new Float32Array(scores.length);
let sum = 0;
for (let i = 0; i < scores.length; i++) {
exps[i] = Math.exp(scores[i] - max);
sum += exps[i];
}
for (let i = 0; i < exps.length; i++) {
exps[i] /= sum;
}
return exps;
}
async function loadFERModel() {
if (ferSession) return true;
try {
const ort = window.ort;
if (!ort) { console.error("[FER] onnxruntime-web not loaded"); return false; }
ortModule = ort;
ort.env.wasm.numThreads = 1;
const response = await fetch("/static/emotion_model_web.onnx");
const modelBuffer = await response.arrayBuffer();
ferSession = await ort.InferenceSession.create(
new Uint8Array(modelBuffer),
{ executionProviders: ["wasm"] }
);
console.log("[FER] Model loaded");
return true;
} catch (err) {
console.error("[FER] Failed to load model:", err);
return false;
}
}
async function classifyEmotion(videoElement) {
if (!ferSession || !ortModule) return null;
try {
if (!ferCanvas) {
ferCanvas = document.createElement("canvas");
ferCanvas.width = IMAGE_SIZE;
ferCanvas.height = IMAGE_SIZE;
}
const ctx = ferCanvas.getContext("2d", { willReadFrequently: true });
if (!ctx) return null;
ctx.drawImage(videoElement, 0, 0, IMAGE_SIZE, IMAGE_SIZE);
const imageData = ctx.getImageData(0, 0, IMAGE_SIZE, IMAGE_SIZE);
const { data } = imageData;
const floatData = new Float32Array(1 * 3 * IMAGE_SIZE * IMAGE_SIZE);
const pixelCount = IMAGE_SIZE * IMAGE_SIZE;
for (let i = 0; i < pixelCount; i++) {
const srcIdx = i * 4;
floatData[i] = (data[srcIdx] / 255 - IMAGENET_MEAN[0]) / IMAGENET_STD[0];
floatData[pixelCount + i] = (data[srcIdx + 1] / 255 - IMAGENET_MEAN[1]) / IMAGENET_STD[1];
floatData[2 * pixelCount + i] = (data[srcIdx + 2] / 255 - IMAGENET_MEAN[2]) / IMAGENET_STD[2];
}
const inputTensor = new ortModule.Tensor("float32", floatData, [1, 3, IMAGE_SIZE, IMAGE_SIZE]);
const inputName = ferSession.inputNames[0];
const results = await ferSession.run({ [inputName]: inputTensor });
const outputName = ferSession.outputNames[0];
const output = results[outputName];
if (!output) return null;
const rawScores = output.data;
const probs = softmax(rawScores);
const scores = {};
let maxIdx = 0;
let maxVal = probs[0];
for (let i = 0; i < probs.length; i++) {
scores[FER_LABELS[i]] = probs[i];
if (probs[i] > maxVal) { maxVal = probs[i]; maxIdx = i; }
}
return { emotion: FER_LABELS[maxIdx], confidence: maxVal, scores: scores };
} catch (err) {
console.error("[FER] Classification error:", err);
return null;
}
}
function releaseFER() {
if (ferSession) { ferSession.release().catch(() => {}); ferSession = null; }
}
"""
# -- Inline Recorder JS -----------------------------------------------------
RECORDER_JS = r"""
// -- State -------------------------------------------------------------------
let activeEngines = new Set(["parakeet"]);
let isRecording = false;
let websocket = null;
let mediaRecorder = null;
let audioChunks = [];
let micRecorder = null;
let micChunks = [];
let screenRecorder = null;
let screenChunks = [];
let mixedStream = null;
let micStream = null;
let displayStream = null;
let audioContext = null;
let analyserNode = null;
let animFrameId = null;
let timerInterval = null;
let recordingStartTime = null;
let ferInterval = null;
let webcamStream = null;
// -- Voxtral Realtime state --------------------------------------------------
let voxtralModel = null;
let voxtralProcessor = null;
let voxtralLoading = false;
let voxtralAudioChunks = [];
let voxtralAudioLength = 0;
// Lazy-concatenate: only rebuild when new chunks arrive
let _voxtralCached = new Float32Array(0);
let _voxtralCachedLen = 0;
function getVoxtralAudio() {
if (voxtralAudioLength === _voxtralCachedLen) return _voxtralCached;
if (voxtralAudioChunks.length === 0) { _voxtralCached = new Float32Array(0); _voxtralCachedLen = 0; return _voxtralCached; }
if (voxtralAudioChunks.length === 1) { _voxtralCached = voxtralAudioChunks[0]; _voxtralCachedLen = voxtralAudioLength; return _voxtralCached; }
const combined = new Float32Array(voxtralAudioLength);
let offset = 0;
for (const chunk of voxtralAudioChunks) { combined.set(chunk, offset); offset += chunk.length; }
voxtralAudioChunks = [combined];
_voxtralCached = combined;
_voxtralCachedLen = voxtralAudioLength;
return combined;
}
let voxtralIsRunning = false;
let voxtralStopRequested = false;
let voxtralAudioContext = null;
let voxtralWorkletNode = null;
let voxtralMicSource = null;
let transformersModule = null;
const VOXTRAL_MODEL_ID = "onnx-community/Voxtral-Mini-4B-Realtime-2602-ONNX";
const SEGMENTATION_MODEL_ID = "onnx-community/pyannote-segmentation-3.0";
let segmentationModel = null;
let segmentationProcessor = null;
// -- DOM refs ----------------------------------------------------------------
const modeSelector = document.getElementById("modeSelector");
const recordButton = document.getElementById("recordButton");
const waveCanvas = document.getElementById("waveCanvas");
const timerEl = document.getElementById("timer");
const statusEl = document.getElementById("status");
const resultsGrid = document.getElementById("resultsGrid");
const whisperPanel = document.getElementById("whisperPanel");
const voxtralPanel = document.getElementById("voxtralPanel");
const whisperTranscript = document.getElementById("whisperTranscript");
const voxtralTranscript = document.getElementById("voxtralTranscript");
const whisperTiming = document.getElementById("whisperTiming");
const voxtralTiming = document.getElementById("voxtralTiming");
const screenAudioToggle = document.getElementById("screenAudioToggle");
const ferToggle = document.getElementById("ferToggle");
const diarizeToggle = document.getElementById("diarizeToggle");
const ferContainer = document.getElementById("ferContainer");
const webcamVideo = document.getElementById("webcamVideo");
const emotionBarsEl = document.getElementById("emotionBars");
const webgpuWarning = document.getElementById("webgpuWarning");
const parakeetPanel = document.getElementById("parakeetPanel");
const parakeetTranscript = document.getElementById("parakeetTranscript");
const parakeetTiming = document.getElementById("parakeetTiming");
const nemotronPanel = document.getElementById("nemotronPanel");
const nemotronTranscript = document.getElementById("nemotronTranscript");
const nemotronTiming = document.getElementById("nemotronTiming");
// -- WebGPU check ------------------------------------------------------------
async function checkWebGPU() {
if (!navigator.gpu) {
webgpuWarning.style.display = "block";
webgpuWarning.textContent = "WebGPU is not supported in this browser. Voxtral Realtime requires WebGPU (Chrome 113+, Edge 113+).";
return false;
}
try {
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
webgpuWarning.style.display = "block";
webgpuWarning.textContent = "WebGPU adapter not available. Check your GPU drivers.";
return false;
}
return true;
} catch (e) {
webgpuWarning.style.display = "block";
webgpuWarning.textContent = "WebGPU check failed: " + e.message;
return false;
}
}
checkWebGPU();
// -- Engine toggle selector ---------------------------------------------------
modeSelector.querySelectorAll(".engine-btn").forEach((btn) => {
btn.addEventListener("click", () => {
if (isRecording) return;
const engine = btn.dataset.engine;
if (activeEngines.has(engine)) {
if (activeEngines.size > 1) {
activeEngines.delete(engine);
btn.classList.remove("active");
}
} else {
activeEngines.add(engine);
btn.classList.add("active");
}
updateResultsLayout();
});
});
function updateResultsLayout() {
const panelMap = {
whisper: whisperPanel,
voxtral: voxtralPanel,
parakeet: parakeetPanel,
nemotron: nemotronPanel,
};
Object.entries(panelMap).forEach(([key, panel]) => {
panel.style.display = activeEngines.has(key) ? '' : 'none';
});
const count = activeEngines.size;
resultsGrid.style.gridTemplateColumns = count <= 1 ? '1fr' : '1fr 1fr';
}
updateResultsLayout();
// -- FER toggle --------------------------------------------------------------
ferToggle.addEventListener("change", async () => {
if (ferToggle.checked) {
ferContainer.classList.remove("hidden");
await startWebcam();
await loadFERModel();
startFERLoop();
} else {
ferContainer.classList.add("hidden");
stopFERLoop();
stopWebcam();
}
});
async function startWebcam() {
try {
webcamStream = await navigator.mediaDevices.getUserMedia({ video: true });
webcamVideo.srcObject = webcamStream;
} catch (err) {
console.error("[FER] Webcam error:", err);
setStatus("Webcam access denied", "error");
}
}
function stopWebcam() {
if (webcamStream) {
webcamStream.getTracks().forEach((t) => t.stop());
webcamStream = null;
webcamVideo.srcObject = null;
}
}
function startFERLoop() {
if (ferInterval) return;
ferInterval = setInterval(async () => {
if (!webcamVideo.srcObject) return;
const result = await classifyEmotion(webcamVideo);
if (result) renderEmotionBars(result.scores);
}, 500);
}
function stopFERLoop() {
if (ferInterval) { clearInterval(ferInterval); ferInterval = null; }
}
function renderEmotionBars(scores) {
const labels = Object.keys(scores);
let html = "";
for (const label of labels) {
const pct = (scores[label] * 100).toFixed(0);
html += `<div class="emotion-bar">
<span class="label">${label}</span>
<div class="bar"><div class="fill" style="width:${pct}%"></div></div>
<span class="pct">${pct}%</span>
</div>`;
}
emotionBarsEl.innerHTML = html;
}
// -- Voxtral model loading ---------------------------------------------------
async function loadVoxtralModel() {
if (voxtralModel && voxtralProcessor) return true;
if (voxtralLoading) return false;
voxtralLoading = true;
voxtralTranscript.innerHTML = '<span class="spinner"></span> Loading Voxtral Realtime model (WebGPU)... This downloads ~2GB on first use.';
try {
if (!transformersModule) {
voxtralTranscript.innerHTML = '<span class="spinner"></span> Loading transformers.js library...';
transformersModule = await import("https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.0.0-next.7");
}
const { VoxtralRealtimeForConditionalGeneration, VoxtralRealtimeProcessor } = transformersModule;
voxtralTranscript.innerHTML = '<span class="spinner"></span> Downloading & loading Voxtral model (q4f16, WebGPU)...<div class="progress-bar-container"><div class="progress-bar-fill" id="voxtralProgress"></div></div><div id="voxtralProgressText" style="font-size:0.75rem;color:var(--text-dim);margin-top:4px;"></div>';
const progressCallback = (progress) => {
const bar = document.getElementById("voxtralProgress");
const txt = document.getElementById("voxtralProgressText");
if (bar && progress.progress !== undefined) {
bar.style.width = progress.progress.toFixed(1) + "%";
}
if (txt && progress.file) {
const status = progress.status || "";
const pct = progress.progress !== undefined ? ` (${progress.progress.toFixed(1)}%)` : "";
txt.textContent = `${status} ${progress.file}${pct}`;
}
};
voxtralProcessor = await VoxtralRealtimeProcessor.from_pretrained(VOXTRAL_MODEL_ID, {
progress_callback: progressCallback,
});
voxtralModel = await VoxtralRealtimeForConditionalGeneration.from_pretrained(VOXTRAL_MODEL_ID, {
dtype: {
audio_encoder: "q4f16",
embed_tokens: "q4f16",
decoder_model_merged: "q4f16",
},
device: "webgpu",
progress_callback: progressCallback,
});
// Load speaker segmentation model for browser-side diarization
if (!segmentationModel) {
voxtralTranscript.innerHTML = '<span class="spinner"></span> Loading speaker segmentation model...';
const { AutoProcessor, AutoModelForAudioFrameClassification } = transformersModule;
segmentationProcessor = await AutoProcessor.from_pretrained(SEGMENTATION_MODEL_ID, { progress_callback: progressCallback });
segmentationModel = await AutoModelForAudioFrameClassification.from_pretrained(SEGMENTATION_MODEL_ID, { device: "wasm", dtype: "fp32", progress_callback: progressCallback });
}
voxtralTranscript.innerHTML = '<span style="color:var(--success)">Models loaded. Ready to transcribe.</span>';
voxtralLoading = false;
return true;
} catch (err) {
console.error("[Voxtral] Model loading error:", err);
voxtralTranscript.innerHTML = `<span style="color:var(--accent)">Failed to load model: ${escapeHtml(err.message)}</span>`;
voxtralLoading = false;
return false;
}
}
// -- Voxtral audio capture via AudioWorklet ----------------------------------
// Dual-track buffers for Voxtral (when Speaker detection OFF)
let voxtralMicChunks = [];
let voxtralMicLength = 0;
let voxtralScreenChunks = [];
let voxtralScreenLength = 0;
let voxtralDualTrack = false;
function getVoxtralMicAudio() {
if (voxtralMicChunks.length === 0) return new Float32Array(0);
if (voxtralMicChunks.length === 1) return voxtralMicChunks[0];
const c = new Float32Array(voxtralMicLength);
let o = 0;
for (const ch of voxtralMicChunks) { c.set(ch, o); o += ch.length; }
voxtralMicChunks = [c];
return c;
}
function getVoxtralScreenAudio() {
if (voxtralScreenChunks.length === 0) return new Float32Array(0);
if (voxtralScreenChunks.length === 1) return voxtralScreenChunks[0];
const c = new Float32Array(voxtralScreenLength);
let o = 0;
for (const ch of voxtralScreenChunks) { c.set(ch, o); o += ch.length; }
voxtralScreenChunks = [c];
return c;
}
async function startVoxtralRecording(stream, micOnlyStream, screenOnlyStream) {
voxtralAudioChunks = [];
voxtralAudioLength = 0;
voxtralMicChunks = [];
voxtralMicLength = 0;
voxtralScreenChunks = [];
voxtralScreenLength = 0;
voxtralStopRequested = false;
voxtralIsRunning = true;
voxtralDualTrack = !!(micOnlyStream && screenOnlyStream);
voxtralAudioContext = new AudioContext({ sampleRate: 16000 });
const workletCode = `class CaptureProcessor extends AudioWorkletProcessor {
process(inputs) {
const input = inputs[0];
if (input.length > 0 && input[0].length > 0) {
this.port.postMessage(input[0]);
}
return true;
}
}
registerProcessor("capture-processor", CaptureProcessor);`;
const blob = new Blob([workletCode], { type: "application/javascript" });
const url = URL.createObjectURL(blob);
await voxtralAudioContext.audioWorklet.addModule(url);
URL.revokeObjectURL(url);
// Main mixed stream capture (for transcription)
voxtralMicSource = voxtralAudioContext.createMediaStreamSource(stream);
voxtralWorkletNode = new AudioWorkletNode(voxtralAudioContext, "capture-processor");
voxtralWorkletNode.port.onmessage = (event) => {
if (voxtralStopRequested) return;
const newData = new Float32Array(event.data);
if (newData.length === 0) return;
voxtralAudioChunks.push(newData);
voxtralAudioLength += newData.length;
};
voxtralMicSource.connect(voxtralWorkletNode);
const silentGain = voxtralAudioContext.createGain();
silentGain.gain.value = 0;
voxtralWorkletNode.connect(silentGain);
silentGain.connect(voxtralAudioContext.destination);
// Dual-track: separate mic and screen captures
if (voxtralDualTrack) {
// Mic-only worklet
const micSrc = voxtralAudioContext.createMediaStreamSource(micOnlyStream);
const micWork = new AudioWorkletNode(voxtralAudioContext, "capture-processor");
micWork.port.onmessage = (event) => {
if (voxtralStopRequested) return;
const d = new Float32Array(event.data);
if (d.length > 0) { voxtralMicChunks.push(d); voxtralMicLength += d.length; }
};
micSrc.connect(micWork);
micWork.connect(silentGain);
// Screen-only worklet
const scrSrc = voxtralAudioContext.createMediaStreamSource(screenOnlyStream);
const scrWork = new AudioWorkletNode(voxtralAudioContext, "capture-processor");
scrWork.port.onmessage = (event) => {
if (voxtralStopRequested) return;
const d = new Float32Array(event.data);
if (d.length > 0) { voxtralScreenChunks.push(d); voxtralScreenLength += d.length; }
};
scrSrc.connect(scrWork);
scrWork.connect(silentGain);
}
// Start the transcription loop (uses mixed stream for real-time)
runVoxtralTranscription();
}
function stopVoxtralRecording() {
voxtralStopRequested = true;
if (voxtralWorkletNode) {
voxtralWorkletNode.disconnect();
voxtralWorkletNode = null;
}
if (voxtralMicSource) {
voxtralMicSource.disconnect();
voxtralMicSource = null;
}
if (voxtralAudioContext && voxtralAudioContext.state !== "closed") {
voxtralAudioContext.close().catch(() => {});
voxtralAudioContext = null;
}
}
// -- Voxtral streaming transcription loop ------------------------------------
async function runVoxtralTranscription() {
if (!voxtralModel || !voxtralProcessor) {
console.error("[Voxtral] Model or processor not loaded");
return;
}
const { BaseStreamer } = transformersModule;
const numSamplesFirst = voxtralProcessor.num_samples_first_audio_chunk;
const numSamplesPerChunk = voxtralProcessor.num_samples_per_audio_chunk;
const { hop_length, n_fft } = voxtralProcessor.feature_extractor.config;
const winHalf = Math.floor(n_fft / 2);
const samplesPerTok = voxtralProcessor.audio_length_per_tok * hop_length;
const voxtralStartTime = Date.now();
let fullText = "";
// Streamer matching reference VoxtralProvider.tsx pattern
const tokenizer = voxtralProcessor.tokenizer;
const specialIds = new Set(tokenizer.all_special_ids.map(BigInt));
let tokenCache = [];
let printLen = 0;
let isPrompt = true;
function flushDecodedText() {
if (tokenCache.length === 0) return;
const text = tokenizer.decode(tokenCache, { skip_special_tokens: true });
const printableText = text.slice(printLen);
printLen = text.length;
if (printableText.length > 0) {
fullText += printableText;
voxtralTranscript.innerHTML = `<div class="line">${escapeHtml(fullText)}</div><span class="buffer">streaming...</span>`;
voxtralTranscript.scrollTop = voxtralTranscript.scrollHeight;
}
}
const streamer = new (class extends BaseStreamer {
put(value) {
if (voxtralStopRequested) return;
if (isPrompt) { isPrompt = false; return; }
const tokens = value[0];
if (tokens.length === 1 && specialIds.has(tokens[0])) return;
tokenCache = tokenCache.concat(tokens);
flushDecodedText();
}
end() {
if (voxtralStopRequested) {
tokenCache = []; printLen = 0; isPrompt = true;
return;
}
flushDecodedText();
tokenCache = []; printLen = 0; isPrompt = true;
}
})();
voxtralTranscript.innerHTML = '<span class="buffer">Waiting for audio...</span>';
// Wait until we have enough audio for the first chunk
while (voxtralAudioLength < numSamplesFirst && !voxtralStopRequested) {
await new Promise((r) => setTimeout(r, 100));
}
if (voxtralStopRequested) {
voxtralIsRunning = false;
return;
}
// Process first chunk to get input_ids and first input_features
const voxtralAudioBuffer = getVoxtralAudio();
const firstAudio = voxtralAudioBuffer.subarray(0, numSamplesFirst);
const firstChunkInputs = await voxtralProcessor(firstAudio, {
is_streaming: true,
is_first_audio_chunk: true,
});
// Async generator yields input_features ONLY (not full processor output)
async function* inputFeaturesGenerator() {
yield firstChunkInputs.input_features;
let melFrameIdx = voxtralProcessor.num_mel_frames_first_audio_chunk;
let startIdx = melFrameIdx * hop_length - winHalf;
while (!voxtralStopRequested) {
const endNeeded = startIdx + numSamplesPerChunk;
while (voxtralAudioLength < endNeeded && !voxtralStopRequested) {
await new Promise((r) => setTimeout(r, 50));
}
if (voxtralStopRequested) break;
// Batch extra available audio (matching reference pattern)
const availableSamples = voxtralAudioLength;
let batchEndSample = endNeeded;
while (batchEndSample + samplesPerTok <= availableSamples) {
batchEndSample += samplesPerTok;
}
const chunkAudio = getVoxtralAudio().slice(startIdx, batchEndSample);
const chunkInputs = await voxtralProcessor(chunkAudio, {
is_streaming: true,
is_first_audio_chunk: false,
});
yield chunkInputs.input_features;
melFrameIdx += chunkInputs.input_features.dims[2];
startIdx = melFrameIdx * hop_length - winHalf;
}
}
try {
voxtralTranscript.innerHTML = '<span class="buffer">Transcribing...</span>';
// Pass input_ids and input_features separately (matching reference)
await voxtralModel.generate({
input_ids: firstChunkInputs.input_ids,
input_features: inputFeaturesGenerator(),
max_new_tokens: 4096,
streamer: streamer,
});
const elapsed = ((Date.now() - voxtralStartTime) / 1000).toFixed(1);
voxtralTiming.textContent = `Processing time: ${elapsed}s (real-time, browser)`;
if (fullText.trim()) {
voxtralTranscript.innerHTML = `<div class="line">${escapeHtml(fullText)}</div>`;
// Browser-only diarization (Xenova's method, pyannote segmentation ONNX)
// Voxtral is fully standalone - no server calls, max 3 speakers
// Runs on full audio at end, not chunked
if (segmentationModel && segmentationProcessor) {
try {
voxtralTranscript.innerHTML += '<div class="buffer">Analyzing speakers (browser)...</div>';
const audio16k = getVoxtralAudio();
const inputs = await segmentationProcessor(audio16k);
const { logits } = await segmentationModel(inputs);
const diarSegs = segmentationProcessor.post_process_speaker_diarization(logits, audio16k.length)[0];
const speakerSet = new Set();
const labeled = [];
for (const seg of diarSegs) {
const label = segmentationModel.config.id2label[seg.id];
if (label === 'NO_SPEAKER') continue;
speakerSet.add(label);
labeled.push({start: seg.start, end: seg.end, label});
}
if (speakerSet.size >= 2 && labeled.length > 0) {
const merged = [labeled[0]];
for (let i = 1; i < labeled.length; i++) {
const prev = merged[merged.length - 1];
if (labeled[i].label === prev.label && labeled[i].start - prev.end < 0.5) { prev.end = labeled[i].end; } else { merged.push({...labeled[i]}); }
}
let diarText = speakerSet.size + ' speakers detected (browser):\n';
for (const seg of merged) { diarText += '\n[' + fmtTime(seg.start) + ' - ' + fmtTime(seg.end) + '] ' + seg.label; }
voxtralTranscript.textContent = diarText + '\n\n' + fullText;
}
} catch (diarErr) {
console.warn("[Voxtral Diarization]", diarErr);
}
}
} else {
voxtralTranscript.innerHTML = '<span class="buffer">No speech detected.</span>';
}
} catch (err) {
console.error("[Voxtral] Transcription error:", err);
voxtralTranscript.innerHTML = `<span style="color:var(--accent)">Transcription error: ${escapeHtml(err.message)}</span>`;
}
// Energy-based source attribution: compare mic vs screen energy per time window
if (voxtralDualTrack && voxtralMicLength > 0 && voxtralScreenLength > 0 && fullText.trim()) {
try {
const micAudio = getVoxtralMicAudio();
const screenAudio = getVoxtralScreenAudio();
const sr = 16000;
const windowSize = Math.floor(sr * 0.5); // 0.5s windows
// Compute RMS energy per window
function rms(buf, start, len) {
let sum = 0;
const end = Math.min(start + len, buf.length);
for (let i = start; i < end; i++) sum += buf[i] * buf[i];
return Math.sqrt(sum / (end - start || 1));
}
// Build source timeline
const segments = [];
const maxLen = Math.max(micAudio.length, screenAudio.length);
for (let i = 0; i < maxLen; i += windowSize) {
const micE = i < micAudio.length ? rms(micAudio, i, windowSize) : 0;
const scrE = i < screenAudio.length ? rms(screenAudio, i, windowSize) : 0;
const t = i / sr;
if (micE < 0.005 && scrE < 0.005) continue; // silence
const src = micE >= scrE ? 'YOU' : 'SCREEN';
if (segments.length > 0 && segments[segments.length - 1].src === src) {
segments[segments.length - 1].end = t + 0.5;
} else {
segments.push({src, start: t, end: t + 0.5});
}
}
if (segments.length > 1) {
// Split transcript proportionally by segment duration
const totalDur = segments.reduce((s, seg) => s + (seg.end - seg.start), 0);
const words = fullText.trim().split(/\s+/);
const totalWords = words.length;
let output = '', wordIdx = 0;
for (const seg of segments) {
const dur = seg.end - seg.start;
const nWords = Math.max(1, Math.round(totalWords * dur / totalDur));
const chunk = words.slice(wordIdx, wordIdx + nWords).join(' ');
wordIdx += nWords;
if (!chunk) continue;
const m = Math.floor(seg.start / 60), s = Math.floor(seg.start % 60);
output += seg.src + ' [' + String(m).padStart(2,'0') + ':' + String(s).padStart(2,'0') + ']: ' + chunk + '\n';
}
if (wordIdx < totalWords) output += words.slice(wordIdx).join(' ');
voxtralTranscript.textContent = output.trim();
}
} catch (energyErr) {
console.warn("[Voxtral Energy]", energyErr);
}
}
voxtralIsRunning = false;
}
// -- Record button -----------------------------------------------------------
recordButton.addEventListener("click", () => {
if (isRecording) { stopRecording(); } else { startRecording(); }
});
// -- Start recording ---------------------------------------------------------
async function startRecording() {
whisperTranscript.innerHTML = "";
voxtralTranscript.innerHTML = "";
whisperTiming.textContent = "";
voxtralTiming.textContent = "";
parakeetTranscript.innerHTML = "";
parakeetTiming.textContent = "";
nemotronTranscript.innerHTML = "";
nemotronTiming.textContent = "";
setStatus("Starting...");
window._parakeetData = null;
window._diarSegments = null;
window._diarNumSpeakers = 0;
// For voxtral, check WebGPU and load model first
if (activeEngines.has("voxtral")) {
const gpuOk = await checkWebGPU();
if (!gpuOk) {
setStatus("WebGPU not available. Cannot use Voxtral Realtime.", "error");
return;
}
const loaded = await loadVoxtralModel();
if (!loaded) {
setStatus("Failed to load Voxtral model.", "error");
return;
}
}
try {
micStream = await navigator.mediaDevices.getUserMedia({ audio: true });
if (screenAudioToggle.checked) {
try {
displayStream = await navigator.mediaDevices.getDisplayMedia({ video: true, audio: true });
displayStream.getVideoTracks().forEach((t) => t.stop());
} catch (err) {
console.warn("[Recorder] Screen audio not available:", err);
setStatus("Screen audio denied - using mic only", "error");
displayStream = null;
}
}
audioContext = new AudioContext();
const dest = audioContext.createMediaStreamDestination();
const micSource = audioContext.createMediaStreamSource(micStream);
micSource.connect(dest);
if (displayStream && displayStream.getAudioTracks().length > 0) {
const displaySource = audioContext.createMediaStreamSource(displayStream);
displaySource.connect(dest);
}
mixedStream = dest.stream;
analyserNode = audioContext.createAnalyser();
analyserNode.fftSize = 256;
micSource.connect(analyserNode);
const mimeType = MediaRecorder.isTypeSupported("audio/webm;codecs=opus")
? "audio/webm;codecs=opus" : "audio/webm";
audioChunks = [];
micChunks = [];
screenChunks = [];
mediaRecorder = new MediaRecorder(mixedStream, { mimeType });
mediaRecorder.ondataavailable = (e) => {
if (e.data.size > 0) {
audioChunks.push(e.data);
if (activeEngines.has("whisper") && websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(e.data);
}
}
};
// Separate mic/screen recorders for routing-based speaker separation
if (!diarizeToggle.checked && displayStream && displayStream.getAudioTracks().length > 0) {
micRecorder = new MediaRecorder(micStream, { mimeType });
micRecorder.ondataavailable = (e) => { if (e.data.size > 0) micChunks.push(e.data); };
const screenDest = audioContext.createMediaStreamDestination();
const screenSrc = audioContext.createMediaStreamSource(displayStream);
screenSrc.connect(screenDest);
screenRecorder = new MediaRecorder(screenDest.stream, { mimeType });
screenRecorder.ondataavailable = (e) => { if (e.data.size > 0) screenChunks.push(e.data); };
micRecorder.start(250);
screenRecorder.start(250);
}
mediaRecorder.onstop = () => { onRecordingStopped(); };
if (activeEngines.has("whisper")) {
await connectWebSocket();
}
// Start Voxtral recording with its own AudioContext at 16kHz
if (activeEngines.has("voxtral")) {
if (await checkWebGPU()) {
const loaded = voxtralModel ? true : await loadVoxtralModel();
const dualVoxtral = !diarizeToggle.checked && displayStream && displayStream.getAudioTracks().length > 0;
if (loaded) await startVoxtralRecording(mixedStream || micStream, dualVoxtral ? micStream : null, dualVoxtral ? displayStream : null);
}
}
mediaRecorder.start(250);
isRecording = true;
recordButton.classList.add("recording");
recordingStartTime = Date.now();
startTimer();
startWaveform();
setStatus("Recording...");
} catch (err) {
console.error("[Recorder] Start error:", err);
setStatus("Failed to start: " + err.message, "error");
cleanupStreams();
}
}
// -- Stop recording ----------------------------------------------------------
function stopRecording() {
if (!isRecording) return;
isRecording = false;
recordButton.classList.remove("recording");
stopTimer();
stopWaveform();
if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(new Blob([]));
}
// Stop Voxtral recording
if (activeEngines.has("voxtral")) {
stopVoxtralRecording();
}
if (mediaRecorder && mediaRecorder.state !== "inactive") {
mediaRecorder.stop();
}
if (micRecorder && micRecorder.state !== "inactive") { micRecorder.stop(); }
if (screenRecorder && screenRecorder.state !== "inactive") { screenRecorder.stop(); }
setStatus("Processing...");
}
// -- After recording stops ---------------------------------------------------
async function onRecordingStopped() {
// Let WhisperLiveKit finish in background (don't block batch engines)
const whisperDone = (websocket && websocket.readyState === WebSocket.OPEN) ? new Promise((resolve) => {
const timeout = setTimeout(() => { resolve(); }, 300000);
const origHandler = websocket.onmessage;
websocket.onmessage = (event) => {
if (origHandler) origHandler(event);
try {
const data = JSON.parse(event.data);
if (data.type === "ready_to_stop") { clearTimeout(timeout); resolve(); }
} catch(e) {}
};
}).then(() => {
if (websocket && websocket.readyState === WebSocket.OPEN) { websocket.close(); }
websocket = null;
}) : Promise.resolve();
// Batch transcription for Parakeet/Nemotron
const baseUrl = (window.location.origin !== 'null' && window.location.host) ? '' : window.parent.location.origin;
// Routing-based separation: when Speaker detection OFF + screen audio captured
const useRouting = !diarizeToggle.checked && micChunks.length > 0 && screenChunks.length > 0;
console.log('[Routing] diarize:', diarizeToggle.checked, 'micChunks:', micChunks.length, 'screenChunks:', screenChunks.length, 'useRouting:', useRouting);
const batchEngines = [];
if (activeEngines.has('parakeet')) batchEngines.push({endpoint: '/parakeet-transcribe', el: parakeetTranscript, tim: parakeetTiming});
if (activeEngines.has('nemotron')) batchEngines.push({endpoint: '/nemotron-transcribe', el: nemotronTranscript, tim: nemotronTiming});
if (batchEngines.length > 0 && useRouting) {
// Dual-track: transcribe mic and screen separately in parallel
const micBlob = new Blob(micChunks, { type: 'audio/webm' });
const screenBlob = new Blob(screenChunks, { type: 'audio/webm' });
const promises = batchEngines.map(async ({endpoint, el, tim}) => {
el.innerHTML = '<span class="spinner"></span> Transcribing mic + screen separately...';
const t0 = Date.now();
try {
const [micResp, screenResp] = await Promise.all([
fetch(baseUrl + endpoint, { method: 'POST', body: micBlob }),
fetch(baseUrl + endpoint, { method: 'POST', body: screenBlob })
]);
const micData = await micResp.json();
const screenData = await screenResp.json();
const elapsed = ((Date.now() - t0) / 1000).toFixed(1);
tim.textContent = 'Processing time: ' + elapsed + 's (dual-track, server CPU)';
// Merge both tracks by timestamps, interleaved
const micTokens = micData.tokens || [];
const micTimestamps = micData.timestamps || [];
const screenTokens = screenData.tokens || [];
const screenTimestamps = screenData.timestamps || [];
// Build word arrays with source label
function buildWords(tokens, timestamps, label) {
const words = [];
let curWord = '', curStart = 0;
for (let i = 0; i < tokens.length; i++) {
const tok = tokens[i], ts = timestamps[i];
if (tok.startsWith(' ') || tok.startsWith('\n')) {
if (curWord.trim()) words.push({text: curWord.trim(), start: curStart, label});
curWord = tok; curStart = ts;
} else {
if (!curWord) curStart = ts;
curWord += tok;
}
}
if (curWord.trim()) words.push({text: curWord.trim(), start: curStart, label});
return words;
}
const allWords = [
...buildWords(micTokens, micTimestamps, 'YOU'),
...buildWords(screenTokens, screenTimestamps, 'SCREEN')
].sort((a, b) => a.start - b.start);
// Merge consecutive same-label words into segments
let output = '';
if (allWords.length > 0) {
let cur = {label: allWords[0].label, start: allWords[0].start, text: allWords[0].text};
for (let i = 1; i < allWords.length; i++) {
if (allWords[i].label === cur.label) { cur.text += ' ' + allWords[i].text; }
else {
const m = Math.floor(cur.start / 60), s = Math.floor(cur.start % 60);
output += cur.label + ' [' + String(m).padStart(2,'0') + ':' + String(s).padStart(2,'0') + ']: ' + cur.text + '\n';
cur = {label: allWords[i].label, start: allWords[i].start, text: allWords[i].text};
}
}
const m = Math.floor(cur.start / 60), s = Math.floor(cur.start % 60);
output += cur.label + ' [' + String(m).padStart(2,'0') + ':' + String(s).padStart(2,'0') + ']: ' + cur.text;
}
el.textContent = output.trim() || (micData.text || '') + '\n' + (screenData.text || '') || 'No speech detected.';
if (endpoint.includes('parakeet') && micData.tokens) window._parakeetData = micData;
} catch (err) {
el.innerHTML = '<span style="color:var(--accent)">Error: ' + escapeHtml(err.message) + '</span>';
}
});
await Promise.all(promises);
} else if (batchEngines.length > 0) {
// Normal: single mixed audio
const blob = new Blob(audioChunks, { type: 'audio/webm' });
const promises = batchEngines.map(async ({endpoint, el, tim}) => {
el.innerHTML = '<span class="spinner"></span> Transcribing...';
const t0 = Date.now();
try {
const resp = await fetch(baseUrl + endpoint, { method: 'POST', body: blob });
const data = await resp.json();
const elapsed = ((Date.now() - t0) / 1000).toFixed(1);
tim.textContent = 'Processing time: ' + elapsed + 's (server CPU)';
el.innerHTML = '<div class="line">' + escapeHtml(data.text) + '</div>';
if (endpoint.includes('parakeet') && data.tokens) window._parakeetData = data;
} catch (err) {
el.innerHTML = '<span style="color:var(--accent)">Error: ' + escapeHtml(err.message) + '</span>';
}
});
await Promise.all(promises);
}
// Run diarization on recorded audio (if speaker detection enabled)
if (audioChunks.length > 0 && diarizeToggle.checked) {
const blob = new Blob(audioChunks, { type: 'audio/webm' });
try {
setStatus("Analyzing speakers...");
const diarUrl = baseUrl + '/diarize';
console.log("[Diarization] Posting to:", diarUrl, "blob size:", blob.size);
const resp = await fetch(diarUrl, { method: 'POST', body: blob });
if (!resp.ok) {
console.warn("[Diarization] Server error:", resp.status, await resp.text().catch(() => ''));
} else {
const data = await resp.json();
console.log("[Diarization] Result:", data.num_speakers, "speakers,", (data.segments||[]).length, "segments");
if (data.segments && data.segments.length > 0) {
window._diarSegments = data.segments;
window._diarNumSpeakers = data.num_speakers || 0;
applyDiarization();
}
}
} catch (err) {
console.warn("[Diarization] Error:", err);
}
}
// Offer audio download (before cleanup clears chunks)
if (audioChunks.length > 0) {
const audioBlob = new Blob(audioChunks, { type: 'audio/webm' });
const url = URL.createObjectURL(audioBlob);
const dl = document.getElementById('audioDownload');
if (dl) {
if (dl.href) URL.revokeObjectURL(dl.href);
dl.href = url;
dl.download = 'recording_' + new Date().toISOString().slice(0,19).replace(/:/g,'-') + '.webm';
dl.style.display = 'inline-block';
}
}
// Wait for WhisperLiveKit to finish (runs in parallel with batch engines)
await whisperDone;
cleanupStreams();
setStatus("Done.", "success");
// Short recording hint for WhisperLiveKit
if (activeEngines.has("whisper")) {
setTimeout(() => {
if (whisperTranscript.textContent.trim() === "") {
whisperTranscript.innerHTML = '<div class="short-hint">Tip: Record for 20+ seconds for best results with large models on CPU</div>';
}
}, 3000);
}
}
// -- WebSocket (WhisperLiveKit) ----------------------------------------------
function connectWebSocket() {
return new Promise((resolve, reject) => {
// Use parent window's host (iframe from document.write has about:blank origin)
const host = window.location.host || window.parent.location.host;
const proto = (window.location.protocol === "https:" || window.parent.location.protocol === "https:") ? "wss:" : "ws:";
const wsUrl = proto + "//" + host + "/asr";
websocket = new WebSocket(wsUrl);
let whisperStartTime = Date.now();
websocket.onopen = () => { console.log("[WS] Connected"); whisperStartTime = Date.now(); resolve(); };
websocket.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (data.type === "config") return;
if (data.type === "ready_to_stop") {
const elapsed = ((Date.now() - whisperStartTime) / 1000).toFixed(1);
whisperTiming.textContent = `Processing time: ${elapsed}s (real-time)`;
return;
}
renderWhisperResults(data);
} catch (err) { console.warn("[WS] Parse error:", err); }
};
websocket.onerror = (err) => { console.error("[WS] Error:", err); setStatus("WebSocket connection failed", "error"); reject(err); };
websocket.onclose = () => { console.log("[WS] Closed"); };
});
}
function fmtTime(s) {
if (s == null || isNaN(s) || s < 0) return "";
const m = Math.floor(s / 60), sec = Math.floor(s % 60);
return String(m).padStart(2,"0") + ":" + String(sec).padStart(2,"0");
}
function applyDiarization() {
const segs = window._diarSegments;
if (!segs || segs.length === 0) return;
const numSpeakers = window._diarNumSpeakers || 0;
const panels = [
{el: whisperTranscript, active: activeEngines.has('whisper'), data: null},
// Voxtral excluded: uses browser-only diarization (Xenova method), no server
{el: parakeetTranscript, active: activeEngines.has('parakeet'), data: window._parakeetData},
{el: nemotronTranscript, active: activeEngines.has('nemotron'), data: null},
];
for (const p of panels) {
if (!p.active || !p.el.textContent.trim()) continue;
// If we have token timestamps (Parakeet), use word-level alignment
if (p.data && p.data.tokens && p.data.timestamps && p.data.tokens.length > 0) {
// Reconstruct full words from subword tokens
const words = [];
let curWord = '', curStart = 0, curEnd = 0;
for (let i = 0; i < p.data.tokens.length; i++) {
const tok = p.data.tokens[i];
const ts = p.data.timestamps[i];
if (tok.startsWith(' ') || tok.startsWith('\n')) {
if (curWord.trim()) words.push({text: curWord, start: curStart, end: curEnd});
curWord = tok; curStart = ts; curEnd = ts;
} else {
if (!curWord) curStart = ts;
curWord += tok; curEnd = ts;
}
}
if (curWord.trim()) words.push({text: curWord, start: curStart, end: curEnd});
// Assign speaker by greatest temporal overlap
function bestSpeaker(ws, we) {
let best = null, maxOv = 0;
for (const seg of segs) {
const ovS = Math.max(ws, seg.start), ovE = Math.min(we, seg.end);
if (ovS < ovE && ovE - ovS > maxOv) { maxOv = ovE - ovS; best = seg.speakers.map(s => 'Speaker ' + s).join(' & '); }
}
return best;
}
// Merge consecutive same-speaker words
let merged = '';
let cSpk = null, cStart = 0, cText = '';
for (const w of words) {
const spk = bestSpeaker(w.start, w.end + 0.05) || cSpk || 'Unknown';
if (spk === cSpk) { cText += w.text; } else { if (cText.trim() && cSpk) merged += '\n[' + fmtTime(cStart) + '] ' + cSpk + ': ' + cText.trim(); cSpk = spk; cStart = w.start; cText = w.text; }
}
if (cText.trim() && cSpk) merged += '\n[' + fmtTime(cStart) + '] ' + cSpk + ': ' + cText.trim();
p.el.textContent = numSpeakers + ' speakers detected:\n' + merged.trim();
} else {
// Fallback: proportional split for engines without timestamps
const rawText = p.el.textContent.trim();
const words = rawText.split(/\s+/);
const totalWords = words.length;
const totalDur = segs.reduce((s, seg) => s + (seg.end - seg.start), 0);
if (totalDur <= 0 || totalWords === 0) continue;
let merged = '';
let wordIdx = 0;
let lastSpeaker = '';
for (const seg of segs) {
const dur = seg.end - seg.start;
const nWords = Math.max(1, Math.round(totalWords * dur / totalDur));
const chunk = words.slice(wordIdx, wordIdx + nWords).join(' ');
wordIdx += nWords;
if (!chunk) continue;
const speakers = seg.speakers.map(s => 'Speaker ' + s).join(' & ');
const start = fmtTime(seg.start);
if (speakers !== lastSpeaker) {
merged += '\n[' + start + '] ' + speakers + ': ' + chunk;
lastSpeaker = speakers;
} else {
merged += ' ' + chunk;
}
}
if (wordIdx < totalWords) merged += ' ' + words.slice(wordIdx).join(' ');
p.el.textContent = numSpeakers + ' speakers detected:\n' + merged.trim();
}
}
}
function renderWhisperResults(data) {
if (!data.lines && !data.buffer_transcription) return;
let html = "";
if (data.lines) {
for (const line of data.lines) {
if (!line.text && !line.translation) continue;
const tsFmt = fmtTime(line.start);
const ts = tsFmt ? `<span class="timestamp">[${tsFmt}]</span>` : "";
// speaker tag only if real diarization is active (multiple speakers detected)
const speakerTag = (line.speaker > 0 && data.lines.some(l => l.speaker !== line.speaker)) ? `<span class="speaker">Speaker ${line.speaker}</span>` : "";
const text = line.text || "";
html += `<div class="line">${ts}${speakerTag}${escapeHtml(text)}</div>`;
}
}
if (data.buffer_transcription) {
html += `<span class="buffer">${escapeHtml(data.buffer_transcription)}</span>`;
}
if (data.buffer_diarization) {
html += `<span class="buffer"> ${escapeHtml(data.buffer_diarization)}</span>`;
}
if (html) {
whisperTranscript.innerHTML = html;
whisperTranscript.scrollTop = whisperTranscript.scrollHeight;
}
}
// -- Timer -------------------------------------------------------------------
function startTimer() {
timerEl.classList.add("recording");
timerInterval = setInterval(() => {
const elapsed = Math.floor((Date.now() - recordingStartTime) / 1000);
const mins = String(Math.floor(elapsed / 60)).padStart(2, "0");
const secs = String(elapsed % 60).padStart(2, "0");
timerEl.textContent = `${mins}:${secs}`;
}, 500);
}
function stopTimer() {
timerEl.classList.remove("recording");
if (timerInterval) { clearInterval(timerInterval); timerInterval = null; }
}
// -- Waveform ----------------------------------------------------------------
function startWaveform() {
const ctx = waveCanvas.getContext("2d");
const bufferLength = analyserNode.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
function draw() {
animFrameId = requestAnimationFrame(draw);
analyserNode.getByteTimeDomainData(dataArray);
ctx.fillStyle = getComputedStyle(document.documentElement).getPropertyValue("--surface").trim();
ctx.fillRect(0, 0, waveCanvas.width, waveCanvas.height);
ctx.lineWidth = 2;
ctx.strokeStyle = getComputedStyle(document.documentElement).getPropertyValue("--accent").trim();
ctx.beginPath();
const sliceWidth = waveCanvas.width / bufferLength;
let x = 0;
for (let i = 0; i < bufferLength; i++) {
const v = dataArray[i] / 128.0;
const y = (v * waveCanvas.height) / 2;
if (i === 0) ctx.moveTo(x, y);
else ctx.lineTo(x, y);
x += sliceWidth;
}
ctx.lineTo(waveCanvas.width, waveCanvas.height / 2);
ctx.stroke();
}
draw();
}
function stopWaveform() {
if (animFrameId) { cancelAnimationFrame(animFrameId); animFrameId = null; }
}
// -- Cleanup -----------------------------------------------------------------
function cleanupStreams() {
if (micStream) { micStream.getTracks().forEach((t) => t.stop()); micStream = null; }
if (displayStream) { displayStream.getTracks().forEach((t) => t.stop()); displayStream = null; }
if (audioContext && audioContext.state !== "closed") { audioContext.close().catch(() => {}); audioContext = null; }
mixedStream = null;
analyserNode = null;
audioChunks = [];
micChunks = [];
screenChunks = [];
micRecorder = null;
screenRecorder = null;
voxtralAudioChunks = [];
voxtralAudioLength = 0;
voxtralMicChunks = [];
voxtralMicLength = 0;
voxtralScreenChunks = [];
voxtralScreenLength = 0;
_voxtralCached = new Float32Array(0);
_voxtralCachedLen = 0;
}
// -- Helpers -----------------------------------------------------------------
function setStatus(text, type = "") {
statusEl.textContent = text;
statusEl.className = type;
}
function escapeHtml(str) {
const div = document.createElement("div");
div.textContent = str;
return div.innerHTML;
}
// -- File upload --------------------------------------------------------------
document.getElementById('fileUpload').addEventListener('change', async (e) => {
const file = e.target.files[0];
if (!file) return;
// Clear all panels
whisperTranscript.innerHTML = '<span class="buffer">File upload - streaming engines not available</span>';
voxtralTranscript.innerHTML = '<span class="buffer">File upload - browser engine not available</span>';
parakeetTranscript.innerHTML = '';
parakeetTiming.textContent = '';
nemotronTranscript.innerHTML = '';
nemotronTiming.textContent = '';
setStatus('Processing uploaded file: ' + file.name);
const baseUrl = (window.location.origin !== 'null' && window.location.host) ? '' : window.parent.location.origin;
const blob = file;
// Run batch engines in parallel
const batchEngines = [];
if (activeEngines.has('parakeet')) batchEngines.push({endpoint: '/parakeet-transcribe', el: parakeetTranscript, tim: parakeetTiming});
if (activeEngines.has('nemotron')) batchEngines.push({endpoint: '/nemotron-transcribe', el: nemotronTranscript, tim: nemotronTiming});
if (batchEngines.length > 0) {
const promises = batchEngines.map(async ({endpoint, el, tim}) => {
el.innerHTML = '<span class="spinner"></span> Transcribing...';
const t0 = Date.now();
try {
const resp = await fetch(baseUrl + endpoint, { method: 'POST', body: blob });
const data = await resp.json();
const elapsed = ((Date.now() - t0) / 1000).toFixed(1);
tim.textContent = 'Processing time: ' + elapsed + 's (server CPU)';
el.innerHTML = '<div class="line">' + escapeHtml(data.text) + '</div>';
if (endpoint.includes('parakeet') && data.tokens) window._parakeetData = data;
} catch (err) {
el.innerHTML = '<span style="color:var(--accent)">Error: ' + escapeHtml(err.message) + '</span>';
}
});
await Promise.all(promises);
}
// Run diarization (if speaker detection enabled)
if (diarizeToggle.checked) try {
setStatus('Analyzing speakers...');
console.log('[Upload Diarization] Posting to:', baseUrl + '/diarize', 'size:', blob.size);
const resp = await fetch(baseUrl + '/diarize', { method: 'POST', body: blob });
const text = await resp.text();
console.log('[Upload Diarization] Response:', resp.status, text.substring(0, 200));
if (resp.ok) {
const data = JSON.parse(text);
if (data.error) {
console.warn('[Upload Diarization] Server error:', data.error);
} else if (data.segments && data.segments.length > 0) {
window._diarSegments = data.segments;
window._diarNumSpeakers = data.num_speakers || 0;
applyDiarization();
}
}
} catch (err) {
console.warn('[Upload Diarization] Error:', err);
}
setStatus('Done.', 'success');
e.target.value = ''; // Reset file input
});
// -- Copy buttons -------------------------------------------------------------
document.querySelectorAll(".copy-btn").forEach((btn) => {
btn.addEventListener("click", () => {
const panel = btn.closest(".result-panel");
const transcript = panel.querySelector(".transcript");
const text = transcript ? transcript.textContent.trim() : "";
if (!text) return;
navigator.clipboard.writeText(text).then(() => {
const origHTML = btn.innerHTML;
btn.innerHTML = '<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M20 6L9 17l-5-5"/></svg> Copied!';
btn.classList.add("copied");
setTimeout(() => { btn.innerHTML = origHTML; btn.classList.remove("copied"); }, 1500);
}).catch(() => {});
});
});
"""
# -- Inline HTML (with embedded CSS + JS) ------------------------------------
RECORDER_HTML = f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Transcription Comparison</title>
<style>{RECORDER_CSS}</style>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.1/dist/ort.min.js"></script>
</head>
<body>
<div id="webgpuWarning"></div>
<h2 style="text-align:center;margin-bottom:8px;font-size:1.1rem;">Transcription Comparison <span style="color:var(--text-dim);font-weight:normal;font-size:0.85rem;">- For your meeting notes!</span></h2>
<div class="mode-selector" id="modeSelector">
<button class="engine-btn" data-engine="whisper">WhisperLiveKit</button>
<button class="engine-btn" data-engine="voxtral">Voxtral-Mini-4B</button>
<button class="engine-btn active" data-engine="parakeet">Parakeet TDT v3</button>
<button class="engine-btn" data-engine="nemotron">Nemotron (EN)</button>
</div>
<div class="options-row">
<label>
<input type="checkbox" id="screenAudioToggle" checked />
Screen/system audio (Chrome only)
</label>
<label>
<input type="checkbox" id="ferToggle" />
Webcam emotion detection
</label>
<label>
<input type="checkbox" id="diarizeToggle" checked />
Speaker detection
</label>
</div>
<div class="fer-container hidden" id="ferContainer">
<video id="webcamVideo" autoplay muted playsinline></video>
<div class="emotion-bars" id="emotionBars"></div>
</div>
<div class="controls">
<button id="recordButton">
<div class="inner"></div>
</button>
<label class="upload-btn" title="Upload audio file">
<input type="file" id="fileUpload" accept="audio/*,.wav,.mp3,.webm,.ogg,.flac" style="display:none" />
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 15v4a2 2 0 01-2 2H5a2 2 0 01-2-2v-4"/><polyline points="17 8 12 3 7 8"/><line x1="12" y1="3" x2="12" y2="15"/></svg>
</label>
<canvas id="waveCanvas" width="200" height="48"></canvas>
<div class="timer" id="timer">00:00</div>
</div>
<p id="status"></p>
<a id="audioDownload" style="display:none;text-align:center;color:var(--success);font-size:0.8rem;margin-bottom:8px;cursor:pointer;">Download recorded audio</a>
<div class="results-grid" id="resultsGrid">
<div class="result-panel" id="whisperPanel">
<button class="copy-btn" title="Copy transcript"><svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 01-2-2V4a2 2 0 012-2h9a2 2 0 012 2v1"/></svg> Copy</button>
<h3>WhisperLiveKit &middot; Whisper large-v3-turbo <span class="badge realtime">Real-time</span></h3>
<div class="timing" id="whisperTiming"></div>
<div class="transcript" id="whisperTranscript"></div>
</div>
<div class="result-panel" id="voxtralPanel">
<button class="copy-btn" title="Copy transcript"><svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 01-2-2V4a2 2 0 012-2h9a2 2 0 012 2v1"/></svg> Copy</button>
<h3>Voxtral-Mini-4B-Realtime-2602 <span class="badge browser">WebGPU ONNX</span></h3>
<div class="timing" id="voxtralTiming"></div>
<div class="transcript" id="voxtralTranscript"></div>
</div>
<div class="result-panel" id="parakeetPanel">
<button class="copy-btn" title="Copy transcript"><svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 01-2-2V4a2 2 0 012-2h9a2 2 0 012 2v1"/></svg> Copy</button>
<h3>Parakeet TDT v3 &middot; 25 languages <span class="badge realtime">CPU ONNX</span></h3>
<div class="timing" id="parakeetTiming"></div>
<div class="transcript" id="parakeetTranscript"></div>
</div>
<div class="result-panel" id="nemotronPanel">
<button class="copy-btn" title="Copy transcript"><svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 01-2-2V4a2 2 0 012-2h9a2 2 0 012 2v1"/></svg> Copy</button>
<h3>Nemotron Streaming &middot; English only <span class="badge realtime">CPU ONNX int8</span></h3>
<div class="timing" id="nemotronTiming"></div>
<div class="transcript" id="nemotronTranscript"></div>
</div>
</div>
<script>{FER_JS}</script>
<script>{RECORDER_JS}</script>
</body>
</html>"""
# Base64-encode the recorder HTML so we can embed it in JS without any server route
_RECORDER_HTML_B64 = base64.b64encode(RECORDER_HTML.encode("utf-8")).decode("ascii")
# Inject loader via <img onerror> trick — Gradio 6 strips <script> tags from gr.HTML()
# but preserves inline event handlers on elements like <img>.
EMBED_HTML = (
'<div id="recorder-container" style="width:100%; height:100vh;">'
'<p style="text-align:center; padding:40px; color:#888;">Loading comparison interface...</p>'
'</div>'
'<img src="" onerror="'
"(function(){"
"function inj(){"
"var c=document.getElementById(\'recorder-container\');"
"if(!c)return false;"
"if(c.querySelector(\'iframe\'))return true;"
"try{"
"var h=atob(\'" + _RECORDER_HTML_B64 + "\');"
"var f=document.createElement(\'iframe\');"
"f.style.width=\'100%\';"
"f.style.height=\'100%\';"
"f.style.border=\'none\';"
"f.setAttribute(\'allow\',\'microphone; display-capture; camera\');"
"c.innerHTML=\'\';"
"c.appendChild(f);"
"f.contentDocument.open();"
"f.contentDocument.write(h);"
"f.contentDocument.close();"
"return true;"
"}catch(e){c.innerHTML=\'<p style=color:red>Load error: \'+e+\'</p>\';return true;}"
"}"
"var iv=setInterval(function(){if(inj())clearInterval(iv);},300);"
"setTimeout(function(){clearInterval(iv);},20000);"
"})();"
'" style="display:none" />'
)
with gr.Blocks() as demo:
gr.HTML(EMBED_HTML)
# -- FastAPI app (owns all custom routes, Gradio mounted on top) ---------------
import fastapi as _fa
import uvicorn
app = _fa.FastAPI()
# Static files (serves emotion_model_web.onnx)
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
if os.path.isdir(static_dir):
app.mount("/static", StaticFiles(directory=static_dir), name="static")
# Permissions middleware for display-capture in iframe
@app.middleware("http")
async def add_permissions_policy(request: Request, call_next):
response: Response = await call_next(request)
response.headers["Permissions-Policy"] = "display-capture=*, microphone=*, camera=*"
return response
# -- WebSocket /asr (WhisperLiveKit real-time) --------------------------------
@app.websocket("/asr")
async def websocket_endpoint(websocket: WebSocket):
audio_processor = AudioProcessor(transcription_engine=transcription_engine)
await websocket.accept()
logger.info("WebSocket connection opened.")
try:
await websocket.send_json({"type": "config", "useAudioWorklet": False})
except Exception as e:
logger.warning(f"Failed to send config: {e}")
results_generator = await audio_processor.create_tasks()
async def send_results():
try:
async for response in results_generator:
await websocket.send_json(response.to_dict())
await websocket.send_json({"type": "ready_to_stop"})
except WebSocketDisconnect:
logger.info("Client disconnected during results.")
except Exception as e:
logger.warning(f"Results error: {e}")
results_task = asyncio.create_task(send_results())
try:
while True:
message = await websocket.receive_bytes()
await audio_processor.process_audio(message)
except WebSocketDisconnect:
logger.info("Client disconnected.")
except Exception as e:
logger.warning(f"WebSocket error: {e}")
finally:
if not results_task.done():
results_task.cancel()
try:
await results_task
except asyncio.CancelledError:
pass
await audio_processor.cleanup()
logger.info("WebSocket cleaned up.")
# -- Lazy-loaded sherpa-onnx recognizers (cached after first use) --------------
_parakeet_recognizer = None
_nemotron_recognizer = None
def _get_parakeet():
global _parakeet_recognizer
if _parakeet_recognizer is None:
import onnx_asr
logger.info("Loading Parakeet TDT v3 via onnx-asr...")
_parakeet_recognizer = onnx_asr.load_model(
"nemo-parakeet-tdt-0.6b-v3", providers=["CPUExecutionProvider"]
).with_timestamps()
logger.info("Parakeet TDT v3 ready.")
return _parakeet_recognizer
def _get_nemotron():
global _nemotron_recognizer
if _nemotron_recognizer is None:
import sherpa_onnx
from huggingface_hub import hf_hub_download
repo = "csukuangfj/sherpa-onnx-nemotron-speech-streaming-en-0.6b-int8-2026-01-14"
logger.info("Downloading Nemotron Streaming int8 model...")
encoder = hf_hub_download(repo, "encoder.int8.onnx")
decoder = hf_hub_download(repo, "decoder.int8.onnx")
joiner = hf_hub_download(repo, "joiner.int8.onnx")
tokens = hf_hub_download(repo, "tokens.txt")
logger.info("Loading Nemotron Streaming int8 recognizer...")
_nemotron_recognizer = sherpa_onnx.OnlineRecognizer.from_transducer(
encoder=encoder, decoder=decoder, joiner=joiner, tokens=tokens,
num_threads=2, sample_rate=16000,
)
logger.info("Nemotron Streaming ready.")
return _nemotron_recognizer
def _decode_webm_to_float32(body: bytes):
"""Decode webm/opus audio to 16kHz mono float32 numpy array."""
import numpy as np
import tempfile, os, librosa
with tempfile.NamedTemporaryFile(suffix=".webm", delete=False) as f:
f.write(body)
tmp_path = f.name
try:
audio_data, _ = librosa.load(tmp_path, sr=16000, mono=True)
finally:
os.unlink(tmp_path)
return audio_data.astype(np.float32)
# -- Parakeet TDT v3 batch endpoint (sherpa-onnx int8, cached) ----------------
@app.post("/parakeet-transcribe")
async def parakeet_transcribe(request: Request):
body = await request.body()
audio_data = _decode_webm_to_float32(body)
import numpy as np
model = _get_parakeet()
audio_int16 = (audio_data * 32767).astype(np.int16)
output = model.recognize(audio_int16)
text = output.text if hasattr(output, 'text') else str(output)
tokens = list(output.tokens) if hasattr(output, 'tokens') else []
timestamps = [round(float(t), 2) for t in output.timestamps] if hasattr(output, 'timestamps') else []
return {"text": text, "tokens": tokens, "timestamps": timestamps, "engine": "parakeet"}
# -- Nemotron Speech Streaming batch endpoint (sherpa-onnx int8, cached) -------
@app.post("/nemotron-transcribe")
async def nemotron_transcribe(request: Request):
import numpy as np
body = await request.body()
audio_data = _decode_webm_to_float32(body)
recognizer = _get_nemotron()
stream = recognizer.create_stream()
stream.accept_waveform(16000, audio_data)
tail = np.zeros(int(16000 * 0.5), dtype=np.float32)
stream.accept_waveform(16000, tail)
stream.input_finished()
while recognizer.is_ready(stream):
recognizer.decode_stream(stream)
text = recognizer.get_result(stream)
try:
tokens = list(stream.result.tokens) if hasattr(stream.result, 'tokens') else []
timestamps = [round(float(t), 2) for t in stream.result.timestamps] if hasattr(stream.result, 'timestamps') else []
except Exception:
tokens = []
timestamps = []
return {"text": text, "tokens": tokens, "timestamps": timestamps, "engine": "nemotron"}
# -- Diarization: pyannote speaker-diarization-3.1 pipeline --------------------
_diarize_pipeline = None
def _setup_pyannote_cache():
"""Pre-populate HF cache with bundled model weights."""
import shutil
app_dir = os.path.dirname(os.path.abspath(__file__))
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub")
models = {
"models--pyannote--segmentation-3.0": {
"snapshot": "e66f3d3b9eb0873085418a7b813d3b369bf160bb",
"files": {"pytorch_model.bin": os.path.join(app_dir, "models", "segmentation-3.0", "pytorch_model.bin")},
},
"models--pyannote--wespeaker-voxceleb-resnet34-LM": {
"snapshot": "837717ddb9ff5507820346191109dc79c958d614",
"files": {"pytorch_model.bin": os.path.join(app_dir, "models", "wespeaker-voxceleb-resnet34-LM", "pytorch_model.bin")},
},
"models--pyannote--speaker-diarization-community-1": {
"snapshot": "3533c8cf8e369892e6b79ff1bf80f7b0286a54ee",
"files": {
"plda/plda.npz": os.path.join(app_dir, "models", "speaker-diarization-community-1", "plda", "plda.npz"),
"plda/xvec_transform.npz": os.path.join(app_dir, "models", "speaker-diarization-community-1", "plda", "xvec_transform.npz"),
},
},
}
for model_id, info in models.items():
snap_dir = os.path.join(cache_dir, model_id, "snapshots", info["snapshot"])
refs_dir = os.path.join(cache_dir, model_id, "refs")
os.makedirs(snap_dir, exist_ok=True)
os.makedirs(refs_dir, exist_ok=True)
refs_main = os.path.join(refs_dir, "main")
if not os.path.exists(refs_main):
with open(refs_main, "w") as f:
f.write(info["snapshot"])
for fname, src_path in info["files"].items():
dst = os.path.join(snap_dir, fname)
os.makedirs(os.path.dirname(dst), exist_ok=True)
if not os.path.exists(dst) and os.path.exists(src_path):
shutil.copy2(src_path, dst)
logger.info(f"Cached {model_id}/{fname}")
def _get_diarize_pipeline():
global _diarize_pipeline
if _diarize_pipeline is None:
from pyannote.audio import Pipeline
_setup_pyannote_cache()
models_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
logger.info("Loading pyannote speaker-diarization-3.1 pipeline...")
old_offline = os.environ.get("HF_HUB_OFFLINE")
os.environ["HF_HUB_OFFLINE"] = "1"
try:
_diarize_pipeline = Pipeline.from_pretrained(os.path.join(models_dir, "speaker-diarization-3.1"))
finally:
if old_offline is None:
os.environ.pop("HF_HUB_OFFLINE", None)
else:
os.environ["HF_HUB_OFFLINE"] = old_offline
logger.info("Diarization pipeline ready.")
return _diarize_pipeline
@app.post("/diarize")
async def diarize(request: Request):
import torch
try:
body = await request.body()
audio = _decode_webm_to_float32(body)
logger.info(f"Diarization: audio length={len(audio)/16000:.1f}s")
except Exception as e:
return {"segments": [], "num_speakers": 0, "error": str(e)}
try:
pipeline = _get_diarize_pipeline()
waveform = torch.tensor(audio).unsqueeze(0)
result = pipeline({"waveform": waveform, "sample_rate": 16000})
diar = result.speaker_diarization
# Post-processing: merge similar speakers (numpy cosine, no sklearn)
import numpy as np
speaker_labels = sorted(diar.labels())
merge_map = {}
if hasattr(result, "speaker_embeddings") and result.speaker_embeddings is not None and len(speaker_labels) > 1:
emb = result.speaker_embeddings
norms = np.linalg.norm(emb, axis=1, keepdims=True)
norms[norms == 0] = 1
sim = (emb / norms) @ (emb / norms).T
for i in range(len(speaker_labels)):
for j in range(i + 1, len(speaker_labels)):
if sim[i][j] >= 0.6:
target = merge_map.get(speaker_labels[i], speaker_labels[i])
merge_map[speaker_labels[j]] = target
logger.info(f"Merging {speaker_labels[j]} -> {target} (sim: {sim[i][j]:.3f})")
segments = []
speakers_set = set()
for turn, _, spk in diar.itertracks(yield_label=True):
actual_spk = merge_map.get(spk, spk)
speakers_set.add(actual_spk)
speaker_id = int(actual_spk.split("_")[-1]) + 1
segments.append({"start": round(turn.start, 2), "end": round(turn.end, 2), "speakers": [speaker_id]})
num_speakers = len(speakers_set)
logger.info(f"Diarization done: {num_speakers} speakers, {len(segments)} segments")
return {"segments": segments, "num_speakers": num_speakers}
except Exception as e:
logger.error(f"Diarization failed: {traceback.format_exc()}")
return {"segments": [], "num_speakers": 0, "error": str(e)}
# -- Server mode startup ------------------------------------------------------
app = gr.mount_gradio_app(app, demo, path="/", ssr_mode=False)
uvicorn.run(app, host="0.0.0.0", port=7860)