gsaon commited on
Commit
9e600a5
·
verified ·
1 Parent(s): dbc1132

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +88 -3
  2. app.js +90 -500
  3. index.html +3 -2
README.md CHANGED
@@ -1,8 +1,93 @@
1
  ---
2
  title: Granite Speech WebGPU
3
- emoji: 🗣
4
- colorFrom: blue
5
- colorTo: indigo
6
  sdk: static
 
7
  pinned: false
8
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Granite Speech WebGPU
3
+ emoji: 🎙
4
+ colorFrom: green
5
+ colorTo: gray
6
  sdk: static
7
+ app_file: index.html
8
  pinned: false
9
  ---
10
+
11
+ # Granite Speech WebGPU
12
+
13
+ Browser-based speech recognition and translation using IBM Granite Speech 4.0 1B with [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU acceleration.
14
+
15
+ **Your audio and transcription never leave your device.**
16
+
17
+ ## Features
18
+
19
+ - **Speech-to-Text**: Transcribe audio in multiple languages
20
+ - **Translation**: Translate speech to English, French, German, Spanish, Portuguese, or Japanese
21
+ - **Voice Activity Detection**: Silero VAD for automatic speech segmentation
22
+ - **Punctuation & Capitalization**: Automatic post-processing (auto-detected language via tinyld)
23
+ - **Audio Input**: Record from microphone or upload/drag-and-drop audio files
24
+ - **Real-time Sync**: Transcript appears synchronized with audio playback
25
+ - **Streaming Output**: Partial results displayed as tokens are generated
26
+ - **Fully Client-Side**: All processing happens in your browser using WebGPU
27
+
28
+ ## Browser Requirements
29
+
30
+ - **Chrome 113+** or **Edge 113+** (required for WebGPU)
31
+ - Firefox and Safari do not yet have stable WebGPU support
32
+
33
+ ## Quick Start
34
+
35
+ ```bash
36
+ git clone git@github.ibm.com:gsaon/granite-speech-webgpu.git
37
+ cd granite-speech-webgpu
38
+ python3 -m http.server 8080
39
+ ```
40
+
41
+ Open http://localhost:8080. Models (~1.4 GB) are downloaded automatically from Hugging Face on first load and cached by the browser.
42
+
43
+ For non-localhost access, use the HTTPS server:
44
+
45
+ ```bash
46
+ python3 serve.py
47
+ ```
48
+
49
+ ## Architecture
50
+
51
+ The app uses [Transformers.js v4](https://huggingface.co/docs/transformers.js) to run the full inference pipeline in ~30 lines:
52
+
53
+ 1. `AutoProcessor` handles audio preprocessing (mel spectrogram, frame stacking, normalization)
54
+ 2. `GraniteSpeechForConditionalGeneration` manages encoder, embeddings, and autoregressive decoding with KV-cache
55
+ 3. `TextStreamer` provides streaming token output
56
+
57
+ ### Models
58
+
59
+ | Component | Source | Size | Purpose |
60
+ |-----------|--------|------|---------|
61
+ | Granite Speech (q4f16) | [onnx-community/granite-4.0-1b-speech-ONNX](https://huggingface.co/onnx-community/granite-4.0-1b-speech-ONNX) | ~1.4 GB | Speech recognition & translation |
62
+ | Silero VAD | Local | 2.1 MB | Voice activity detection |
63
+ | Punctuation (EN) | [1-800-BAD-CODE](https://huggingface.co/1-800-BAD-CODE/punctuation_fullstop_truecase_english) | ~200 MB | English punctuation & capitalization |
64
+
65
+ ### Dependencies (loaded from CDN)
66
+
67
+ - **Transformers.js 4.0.0-next.7**: Model loading, processing, and inference
68
+ - **ONNX Runtime Web 1.24.3**: VAD and punctuation models (WASM)
69
+ - **tinyld**: Language detection for automatic punctuation
70
+
71
+ ## Project Structure
72
+
73
+ ```
74
+ granite-speech-webgpu/
75
+ ├── index.html # Main HTML page
76
+ ├── app.js # Main app (Transformers.js v4 inference + UI)
77
+ ├── vad.js # Silero VAD integration (ONNX/WASM)
78
+ ├── punctuator.js # Punctuation models (ONNX/WASM)
79
+ ├── style.css # Styling
80
+ ├── pcs_vocab.json # Punctuator vocabulary
81
+ ├── silero_vad.onnx # VAD model
82
+ ├── punct_cap_seg_en.onnx # English punctuator model
83
+ └── serve.py # HTTPS development server
84
+ ```
85
+
86
+ ## Acknowledgments
87
+
88
+ - [IBM Granite Speech](https://huggingface.co/ibm-granite/granite-4.0-1b-speech)
89
+ - [Transformers.js](https://huggingface.co/docs/transformers.js)
90
+ - [ONNX Community](https://huggingface.co/onnx-community)
91
+ - [Silero VAD](https://github.com/snakers4/silero-vad)
92
+ - [Punctuation Model](https://huggingface.co/1-800-BAD-CODE/punctuation_fullstop_truecase_english)
93
+ - [tinyld](https://github.com/komodojp/tinyld)
app.js CHANGED
@@ -1,76 +1,39 @@
1
  /**
2
  * Granite Speech WebGPU Demo
3
- * Uses ONNX Runtime Web for in-browser speech recognition
4
  */
5
 
6
- import { PreTrainedTokenizer } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.4.2';
 
 
 
 
7
  import { detect } from 'https://cdn.jsdelivr.net/npm/tinyld/+esm';
8
 
9
- // Check if ONNX Runtime is loaded
10
- if (typeof ort === 'undefined') {
11
- console.error('ONNX Runtime Web not loaded! Check if the script tag is correct.');
12
- alert('Failed to load ONNX Runtime. Please refresh the page.');
13
- } else {
14
- // Configure WASM paths to CDN
15
- ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.24.3/dist/';
16
 
17
- // WASM settings - enable multi-threading for encoder performance
18
- ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
19
- ort.env.wasm.simd = true;
20
-
21
- // WebGPU settings
22
- ort.env.webgpu = ort.env.webgpu || {};
23
- }
24
-
25
- // Model paths
26
- // Granite Speech ONNX models hosted on HF Hub
27
- const HF_MODEL_BASE = 'https://huggingface.co/ibm-granite/granite-4.0-1b-speech/resolve/main/onnx';
28
- const ENCODER_PATH = `${HF_MODEL_BASE}/audio_encoder_q4f32.onnx`;
29
- const EMBED_PATH = `${HF_MODEL_BASE}/embed_tokens_q4f16.onnx`;
30
- const DECODER_PATH = `${HF_MODEL_BASE}/decoder_model_merged_q4f16.onnx`;
31
-
32
- // Audio config from preprocessor_config.json
33
  const SAMPLE_RATE = 16000;
34
- const N_MELS = 80;
35
- const N_FFT = 512;
36
- const HOP_LENGTH = 160;
37
- const WIN_LENGTH = 400;
38
-
39
- // Model config
40
- const HIDDEN_SIZE = 2048;
41
- const VOCAB_SIZE = 100353;
42
- const BOS_TOKEN = 100257;
43
- const EOS_TOKEN = 100257;
44
- const PAD_TOKEN = 100256;
45
  const MAX_NEW_TOKENS = 256;
46
- // Note: embedding_multiplier (12) is likely already applied in the model weights
47
-
48
- // Prompt templates
49
- const PROMPT_PREFIX = 'USER: ';
50
- const PROMPTS = {
51
- 'transcribe': 'Transcribe the speech to text\n ASSISTANT:',
52
- 'translate_en': 'Translate the speech to English\n ASSISTANT:',
53
- 'translate_fr': 'Translate the speech to French\n ASSISTANT:',
54
- 'translate_de': 'Translate the speech to German\n ASSISTANT:',
55
- 'translate_es': 'Translate the speech to Spanish\n ASSISTANT:',
56
- 'translate_pt': 'Translate the speech to Portuguese\n ASSISTANT:',
57
- 'translate_ja': 'Translate the speech to Japanese\n ASSISTANT:',
58
  };
59
 
60
  // State
61
- let encoderSession = null;
62
- let embedSession = null;
63
- let decoderSession = null;
64
- let tokenizer = null;
65
  let isModelLoading = false;
66
  let currentAudioData = null;
67
 
68
- // Pre-computed prompt embeddings (populated at init)
69
- const promptEmbeddings = {
70
- prefix: null, // "USER: "
71
- // suffix embeddings keyed by prompt name
72
- };
73
-
74
  // DOM Elements
75
  const statusDot = document.getElementById('statusDot');
76
  const statusText = document.getElementById('statusText');
@@ -145,130 +108,50 @@ async function checkWebGPU() {
145
  }
146
  }
147
 
148
- // Load tokenizer using transformers.js
149
- async function loadTokenizer() {
150
- const [tokenizerJson, tokenizerConfig] = await Promise.all([
151
- fetch('https://huggingface.co/ibm-granite/granite-4.0-1b-speech/resolve/main/tokenizer.json').then(r => r.json()),
152
- fetch('https://huggingface.co/ibm-granite/granite-4.0-1b-speech/resolve/main/tokenizer_config.json').then(r => r.json())
153
- ]);
154
- return new PreTrainedTokenizer(tokenizerJson, tokenizerConfig);
155
- }
156
-
157
- // Get embeddings for token IDs (returns Float32Array)
158
- async function getEmbeddings(tokenIds) {
159
- const idsTensor = new ort.Tensor('int64', BigInt64Array.from(tokenIds.map(BigInt)), [1, tokenIds.length]);
160
- const output = await embedSession.run({ input_ids: idsTensor });
161
- return {
162
- data: new Float32Array(output.inputs_embeds.data),
163
- seqLen: output.inputs_embeds.dims[1]
164
- };
165
- }
166
-
167
- // Pre-compute embeddings for all prompts
168
- async function precomputePromptEmbeddings() {
169
- // Prefix embedding
170
- const prefixTokens = tokenizer.encode(PROMPT_PREFIX, { add_special_tokens: false });
171
- promptEmbeddings.prefix = await getEmbeddings(prefixTokens);
172
-
173
- // Suffix embeddings for each prompt
174
- for (const [key, text] of Object.entries(PROMPTS)) {
175
- const tokens = tokenizer.encode(text, { add_special_tokens: false });
176
- promptEmbeddings[key] = await getEmbeddings(tokens);
177
- }
178
-
179
- console.log('Pre-computed embeddings for', Object.keys(promptEmbeddings).length, 'prompts');
180
- }
181
-
182
- // Session options - WebGPU only (no WASM fallback)
183
- const sessionOptions = {
184
- executionProviders: ['webgpu'],
185
- enableMemPattern: false,
186
- enableCpuMemArena: false,
187
- graphOptimizationLevel: 'basic',
188
- };
189
-
190
- // Force garbage collection pause
191
- async function gcPause() {
192
- // Give browser time to garbage collect
193
- await new Promise(resolve => setTimeout(resolve, 100));
194
- }
195
-
196
- // Load ONNX model with external data support
197
- async function loadModelWithExternalData(modelPath, options) {
198
- // Check if external data file exists
199
- const dataPath = modelPath.replace('.onnx', '.onnx_data');
200
-
201
- const modelResponse = await fetch(modelPath);
202
- const modelBuffer = await modelResponse.arrayBuffer();
203
-
204
- const dataResponse = await fetch(dataPath);
205
- if (!dataResponse.ok) {
206
- // No external data, load model directly
207
- return await ort.InferenceSession.create(modelBuffer, options);
208
- }
209
-
210
- const dataBuffer = await dataResponse.arrayBuffer();
211
-
212
- // Extract filename from path for external data reference
213
- const dataFileName = dataPath.split('/').pop();
214
-
215
-
216
- // Create session with external data
217
- const sessionOptionsWithData = {
218
- ...options,
219
- externalData: [
220
- {
221
- path: dataFileName,
222
- data: dataBuffer,
223
- }
224
- ]
225
- };
226
-
227
- return await ort.InferenceSession.create(modelBuffer, sessionOptionsWithData);
228
- }
229
-
230
- // Initialize ONNX Runtime and load models
231
  async function initModels() {
232
  if (isModelLoading) return;
233
  isModelLoading = true;
234
 
235
- setStatus('loading', 'Loading models...');
236
- showProgress(true);
237
 
238
  try {
239
- const hasWebGPU = await checkWebGPU();
240
-
241
- updateProgress(10, 'Initializing ONNX Runtime...');
242
-
243
- // Load tokenizer
244
- updateProgress(15, 'Loading tokenizer...');
245
- tokenizer = await loadTokenizer();
246
-
247
- // Load models one at a time with GC pauses between
248
- // Use loadModelWithExternalData to handle .onnx_data files
249
-
250
- // Load encoder model (q4f32 with WebGPU)
251
- updateProgress(20, 'Loading encoder model...');
252
- encoderSession = await loadModelWithExternalData(ENCODER_PATH, sessionOptions);
253
-
254
- await gcPause();
255
-
256
- // Load embed tokens model
257
- updateProgress(40, 'Loading embed tokens model...');
258
- embedSession = await loadModelWithExternalData(EMBED_PATH, sessionOptions);
259
-
260
- // Pre-compute prompt embeddings
261
- updateProgress(50, 'Pre-computing prompt embeddings...');
262
- await precomputePromptEmbeddings();
263
-
264
- await gcPause();
265
-
266
- // Load decoder model
267
- updateProgress(60, 'Loading decoder model...');
268
- decoderSession = await loadModelWithExternalData(DECODER_PATH, sessionOptions);
 
 
 
 
269
 
270
- updateProgress(100, 'Models loaded!');
271
- showProgress(false);
272
  setStatus('ready', 'Ready - Record or upload audio');
273
  enableControls(true);
274
 
@@ -277,7 +160,7 @@ async function initModels() {
277
  console.error('Error stack:', error?.stack);
278
  const errorMsg = error?.message || error?.toString() || 'Unknown error';
279
  setStatus('error', `Error: ${errorMsg}`);
280
- showProgress(false);
281
  isModelLoading = false;
282
  }
283
  }
@@ -287,335 +170,42 @@ function enableControls(enabled) {
287
  audioFile.disabled = !enabled;
288
  }
289
 
290
- // Mel spectrogram computation
291
- // Uses custom implementation matching torchaudio
292
- function computeMelSpectrogram(audioData) {
293
- // Pad signal with reflection (center=True, pad_mode='reflect')
294
- const padLength = Math.floor(N_FFT / 2);
295
- const paddedLength = audioData.length + 2 * padLength;
296
- const paddedAudio = new Float32Array(paddedLength);
297
-
298
- // Reflect padding at start: for position -i, use position i (not i-1)
299
- // numpy reflect: for index -1, reflects to index 1
300
- for (let i = 0; i < padLength; i++) {
301
- // Position -(i+1) reflects to position (i+1)
302
- const srcIdx = Math.min(i + 1, audioData.length - 1);
303
- paddedAudio[padLength - 1 - i] = audioData[srcIdx];
304
- }
305
- // Copy original audio
306
- for (let i = 0; i < audioData.length; i++) {
307
- paddedAudio[padLength + i] = audioData[i];
308
- }
309
- // Reflect padding at end
310
- for (let i = 0; i < padLength; i++) {
311
- const srcIdx = Math.max(0, audioData.length - 2 - i);
312
- paddedAudio[padLength + audioData.length + i] = audioData[srcIdx];
313
- }
314
-
315
- // Calculate number of frames
316
- const numFrames = Math.floor((paddedLength - N_FFT) / HOP_LENGTH) + 1;
317
-
318
- // Create mel filterbank (torchaudio HTK style)
319
- const melFilterbank = createMelFilterbank(N_FFT, N_MELS, SAMPLE_RATE);
320
-
321
- // Hann window (periodic=True like torchaudio)
322
- const window = new Float32Array(WIN_LENGTH);
323
- for (let i = 0; i < WIN_LENGTH; i++) {
324
- window[i] = 0.5 * (1 - Math.cos(2 * Math.PI * i / WIN_LENGTH));
325
- }
326
-
327
- const melSpec = new Float32Array(numFrames * N_MELS);
328
-
329
- // torch.stft center-pads the window when win_length < n_fft
330
- // Window is placed at indices padLeft to padLeft+win_length
331
- const padLeft = Math.floor((N_FFT - WIN_LENGTH) / 2); // = 56
332
-
333
- for (let frame = 0; frame < numFrames; frame++) {
334
- const start = frame * HOP_LENGTH;
335
-
336
- // Apply center-padded window (matching torch.stft behavior)
337
- // Read n_fft samples, apply window centered in the middle
338
- const windowed = new Float32Array(N_FFT); // initialized to zeros
339
- for (let i = 0; i < WIN_LENGTH; i++) {
340
- windowed[padLeft + i] = paddedAudio[start + padLeft + i] * window[i];
341
- }
342
-
343
- // Compute power spectrum
344
- const powerSpec = computePowerSpectrum(windowed);
345
-
346
- // Apply mel filterbank and log10
347
- for (let m = 0; m < N_MELS; m++) {
348
- let sum = 0;
349
- for (let k = 0; k < N_FFT / 2 + 1; k++) {
350
- sum += powerSpec[k] * melFilterbank[m * (N_FFT / 2 + 1) + k];
351
- }
352
- melSpec[frame * N_MELS + m] = Math.log10(Math.max(sum, 1e-10));
353
- }
354
- }
355
-
356
- return { data: melSpec, numFrames, numMels: N_MELS };
357
- }
358
-
359
- // Create mel filterbank (torchaudio HTK style)
360
- function createMelFilterbank(nfft, nMels, sampleRate) {
361
- const numBins = nfft / 2 + 1;
362
- const filterbank = new Float32Array(nMels * numBins);
363
-
364
- // HTK mel scale
365
- const hzToMel = (hz) => 2595 * Math.log10(1 + hz / 700);
366
- const melToHz = (mel) => 700 * (Math.pow(10, mel / 2595) - 1);
367
-
368
- const fMin = 0;
369
- const fMax = sampleRate / 2;
370
- const melMin = hzToMel(fMin);
371
- const melMax = hzToMel(fMax);
372
-
373
- // Create mel-spaced frequency points (n_mels + 2 points)
374
- const fPts = new Float32Array(nMels + 2);
375
- for (let i = 0; i < nMels + 2; i++) {
376
- fPts[i] = melToHz(melMin + (melMax - melMin) * i / (nMels + 1));
377
- }
378
-
379
- // Create frequency array for each FFT bin
380
- const allFreqs = new Float32Array(numBins);
381
- for (let i = 0; i < numBins; i++) {
382
- allFreqs[i] = i * sampleRate / nfft;
383
- }
384
-
385
- // Compute frequency differences
386
- const fDiff = new Float32Array(nMels + 1);
387
- for (let i = 0; i < nMels + 1; i++) {
388
- fDiff[i] = fPts[i + 1] - fPts[i];
389
- }
390
-
391
- // Create triangular filters using slopes (torchaudio style)
392
- for (let m = 0; m < nMels; m++) {
393
- for (let k = 0; k < numBins; k++) {
394
- const freq = allFreqs[k];
395
- const lowSlope = (freq - fPts[m]) / fDiff[m];
396
- const upSlope = (fPts[m + 2] - freq) / fDiff[m + 1];
397
- filterbank[m * numBins + k] = Math.max(0, Math.min(lowSlope, upSlope));
398
- }
399
- }
400
-
401
- return filterbank;
402
- }
403
-
404
- // Compute power spectrum using radix-2 FFT
405
- function computePowerSpectrum(signal) {
406
- const n = signal.length;
407
-
408
- // Use radix-2 FFT for power of 2 lengths
409
- if ((n & (n - 1)) === 0) {
410
- return computePowerSpectrumFFT(signal);
411
- }
412
-
413
- // Fallback to DFT for non-power-of-2
414
- const spectrum = new Float32Array(n / 2 + 1);
415
- for (let k = 0; k <= n / 2; k++) {
416
- let real = 0, imag = 0;
417
- for (let t = 0; t < n; t++) {
418
- const angle = -2 * Math.PI * k * t / n;
419
- real += signal[t] * Math.cos(angle);
420
- imag += signal[t] * Math.sin(angle);
421
- }
422
- spectrum[k] = real * real + imag * imag;
423
- }
424
- return spectrum;
425
- }
426
-
427
- // Radix-2 FFT for power spectrum
428
- function computePowerSpectrumFFT(signal) {
429
- const n = signal.length;
430
-
431
- // Bit-reversal permutation
432
- const real = new Float32Array(n);
433
- const imag = new Float32Array(n);
434
-
435
- for (let i = 0; i < n; i++) {
436
- let j = 0;
437
- let x = i;
438
- for (let k = 0; k < Math.log2(n); k++) {
439
- j = (j << 1) | (x & 1);
440
- x >>= 1;
441
- }
442
- real[j] = signal[i];
443
- }
444
-
445
- // Cooley-Tukey FFT
446
- for (let size = 2; size <= n; size *= 2) {
447
- const halfSize = size / 2;
448
- const step = Math.PI / halfSize;
449
-
450
- for (let i = 0; i < n; i += size) {
451
- for (let j = 0; j < halfSize; j++) {
452
- const angle = -j * step;
453
- const cos = Math.cos(angle);
454
- const sin = Math.sin(angle);
455
-
456
- const idx1 = i + j;
457
- const idx2 = i + j + halfSize;
458
-
459
- const tReal = cos * real[idx2] - sin * imag[idx2];
460
- const tImag = sin * real[idx2] + cos * imag[idx2];
461
-
462
- real[idx2] = real[idx1] - tReal;
463
- imag[idx2] = imag[idx1] - tImag;
464
- real[idx1] = real[idx1] + tReal;
465
- imag[idx1] = imag[idx1] + tImag;
466
- }
467
- }
468
- }
469
-
470
- // Compute power spectrum (first half + DC and Nyquist)
471
- const spectrum = new Float32Array(n / 2 + 1);
472
- for (let k = 0; k <= n / 2; k++) {
473
- spectrum[k] = real[k] * real[k] + imag[k] * imag[k];
474
- }
475
-
476
- return spectrum;
477
- }
478
-
479
- // Prepare audio features for encoder
480
- function prepareAudioFeatures(audioData) {
481
- const melSpec = computeMelSpectrogram(audioData);
482
-
483
- // Apply Granite Speech normalization:
484
- // 1. Already have log10 mel from computeMelSpectrogram
485
- // 2. Normalize: max(logmel, max - 8) / 4 + 1
486
- const logmel = melSpec.data;
487
- let maxVal = -Infinity;
488
- for (let i = 0; i < logmel.length; i++) {
489
- if (logmel[i] > maxVal) maxVal = logmel[i];
490
- }
491
-
492
- const normalized = new Float32Array(logmel.length);
493
- for (let i = 0; i < logmel.length; i++) {
494
- normalized[i] = Math.max(logmel[i], maxVal - 8) / 4 + 1;
495
- }
496
-
497
- // Remove last frame if odd
498
- let numFrames = melSpec.numFrames;
499
- if (numFrames % 2 === 1) {
500
- numFrames -= 1;
501
- }
502
-
503
- // Stack 2 consecutive frames -> 160 features (80 mels * 2)
504
- const stackedFrames = numFrames / 2;
505
- const features = new Float32Array(stackedFrames * 160);
506
-
507
- for (let t = 0; t < stackedFrames; t++) {
508
- // First frame of pair
509
- for (let m = 0; m < N_MELS; m++) {
510
- features[t * 160 + m] = normalized[(t * 2) * N_MELS + m];
511
- }
512
- // Second frame of pair
513
- for (let m = 0; m < N_MELS; m++) {
514
- features[t * 160 + N_MELS + m] = normalized[(t * 2 + 1) * N_MELS + m];
515
- }
516
- }
517
-
518
- return { data: features, shape: [1, stackedFrames, 160] };
519
- }
520
-
521
  // Transcribe a single audio segment and return the text
522
  async function transcribeSegment(audioSegment, onPartialResult) {
523
- // Prepare audio features
524
- const audioFeatures = prepareAudioFeatures(audioSegment);
525
-
526
- // Run encoder
527
- const encoderInput = new ort.Tensor('float32', audioFeatures.data, audioFeatures.shape);
528
- const encoderOutput = await encoderSession.run({ input_features: encoderInput });
529
- const audioEmbeddings = encoderOutput.audio_features;
530
-
531
- // Get pre-computed prompt embeddings
532
- const prefixEmbed = promptEmbeddings.prefix;
533
- const suffixEmbed = promptEmbeddings[promptSelect.value] || promptEmbeddings['transcribe'];
534
-
535
- // Concatenate embeddings using TypedArray.set()
536
- const prefixSeqLen = prefixEmbed.seqLen;
537
- const audioSeqLen = audioEmbeddings.dims[1];
538
- const suffixSeqLen = suffixEmbed.seqLen;
539
- const totalSeqLen = prefixSeqLen + audioSeqLen + suffixSeqLen;
540
-
541
- const combinedEmbeds = new Float32Array(totalSeqLen * HIDDEN_SIZE);
542
- combinedEmbeds.set(prefixEmbed.data, 0);
543
- combinedEmbeds.set(new Float32Array(audioEmbeddings.data), prefixSeqLen * HIDDEN_SIZE);
544
- combinedEmbeds.set(suffixEmbed.data, (prefixSeqLen + audioSeqLen) * HIDDEN_SIZE);
545
-
546
- // Autoregressive generation
547
- let generatedTokens = [];
548
- let currentEmbeds = combinedEmbeds;
549
- let currentSeqLen = totalSeqLen;
550
- let pastKeyValues = null;
551
- const numLayers = 40;
552
- let totalSeqLenSoFar = totalSeqLen;
553
-
554
- for (let step = 0; step < MAX_NEW_TOKENS; step++) {
555
- const attentionMask = new BigInt64Array(totalSeqLenSoFar).fill(1n);
556
-
557
- const embedsTensor = new ort.Tensor('float32', currentEmbeds, [1, currentSeqLen, HIDDEN_SIZE]);
558
- const maskTensor = new ort.Tensor('int64', attentionMask, [1, totalSeqLenSoFar]);
559
-
560
- const decoderInputs = {
561
- inputs_embeds: embedsTensor,
562
- attention_mask: maskTensor,
563
- };
564
-
565
- if (pastKeyValues) {
566
- for (let i = 0; i < numLayers; i++) {
567
- decoderInputs[`past_key_values.${i}.key`] = pastKeyValues[`present.${i}.key`];
568
- decoderInputs[`past_key_values.${i}.value`] = pastKeyValues[`present.${i}.value`];
569
- }
570
- } else {
571
- const emptyPast = new Uint16Array(0);
572
- for (let i = 0; i < numLayers; i++) {
573
- decoderInputs[`past_key_values.${i}.key`] = new ort.Tensor('float16', emptyPast, [1, 4, 0, 128]);
574
- decoderInputs[`past_key_values.${i}.value`] = new ort.Tensor('float16', emptyPast, [1, 4, 0, 128]);
575
- }
576
- }
577
-
578
- const decoderOutput = await decoderSession.run(decoderInputs);
579
- pastKeyValues = decoderOutput;
580
-
581
- const logitsFloat32 = Float32Array.from(decoderOutput.logits.data);
582
-
583
- // Extract logits for last position and find argmax
584
- const logitOffset = (currentSeqLen - 1) * VOCAB_SIZE;
585
- const lastLogits = logitsFloat32.subarray(logitOffset, logitOffset + VOCAB_SIZE);
586
-
587
- let nextToken = 0, maxVal = lastLogits[0];
588
- for (let i = 1; i < VOCAB_SIZE; i++) {
589
- if (lastLogits[i] > maxVal) { maxVal = lastLogits[i]; nextToken = i; }
590
- }
591
 
592
- // Avoid EOS on first token - take second best
593
- if (step === 0 && nextToken === EOS_TOKEN) {
594
- nextToken = 0; maxVal = -Infinity;
595
- for (let i = 0; i < VOCAB_SIZE; i++) {
596
- if (i !== EOS_TOKEN && lastLogits[i] > maxVal) { maxVal = lastLogits[i]; nextToken = i; }
 
 
 
 
 
 
 
597
  }
598
- }
599
-
600
- if (nextToken === EOS_TOKEN) {
601
- break;
602
- }
603
-
604
- generatedTokens.push(nextToken);
605
-
606
- // Callback for streaming updates
607
- if (onPartialResult) {
608
- onPartialResult(tokenizer.decode(generatedTokens));
609
- }
610
 
611
- const nextTokenTensor = new ort.Tensor('int64', BigInt64Array.from([BigInt(nextToken)]), [1, 1]);
612
- const nextEmbedOutput = await embedSession.run({ input_ids: nextTokenTensor });
613
- currentEmbeds = new Float32Array(nextEmbedOutput.inputs_embeds.data);
614
- currentSeqLen = 1;
615
- totalSeqLenSoFar += 1;
616
- }
617
 
618
- return tokenizer.decode(generatedTokens);
619
  }
620
 
621
  // Wait until audio playback reaches a specific time
@@ -634,7 +224,7 @@ function waitForPlaybackTime(targetTime) {
634
 
635
  // Run inference with segmentation and audio sync
636
  async function transcribe() {
637
- if (!encoderSession || !embedSession || !decoderSession || !currentAudioData) {
638
  setStatus('error', 'Model or audio not ready');
639
  return;
640
  }
 
1
  /**
2
  * Granite Speech WebGPU Demo
3
+ * Uses Transformers.js v4 for in-browser speech recognition
4
  */
5
 
6
+ import {
7
+ AutoProcessor,
8
+ GraniteSpeechForConditionalGeneration,
9
+ TextStreamer,
10
+ } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@4.0.0-next.7';
11
  import { detect } from 'https://cdn.jsdelivr.net/npm/tinyld/+esm';
12
 
13
+ // Model
14
+ const MODEL_ID = 'onnx-community/granite-4.0-1b-speech-ONNX';
 
 
 
 
 
15
 
16
+ // Audio config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  const SAMPLE_RATE = 16000;
 
 
 
 
 
 
 
 
 
 
 
18
  const MAX_NEW_TOKENS = 256;
19
+
20
+ // Task prompts — <|audio|> is expanded by the processor's chat template
21
+ const TASK_PROMPTS = {
22
+ 'transcribe': '<|audio|>Transcribe the speech to text',
23
+ 'translate_en': '<|audio|>Translate the speech to English',
24
+ 'translate_fr': '<|audio|>Translate the speech to French',
25
+ 'translate_de': '<|audio|>Translate the speech to German',
26
+ 'translate_es': '<|audio|>Translate the speech to Spanish',
27
+ 'translate_pt': '<|audio|>Translate the speech to Portuguese',
28
+ 'translate_ja': '<|audio|>Translate the speech to Japanese',
 
 
29
  };
30
 
31
  // State
32
+ let model = null;
33
+ let processor = null;
 
 
34
  let isModelLoading = false;
35
  let currentAudioData = null;
36
 
 
 
 
 
 
 
37
  // DOM Elements
38
  const statusDot = document.getElementById('statusDot');
39
  const statusText = document.getElementById('statusText');
 
108
  }
109
  }
110
 
111
+ // Initialize models using Transformers.js v4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  async function initModels() {
113
  if (isModelLoading) return;
114
  isModelLoading = true;
115
 
116
+ setStatus('loading', 'Loading processor...');
 
117
 
118
  try {
119
+ await checkWebGPU();
120
+
121
+ processor = await AutoProcessor.from_pretrained(MODEL_ID);
122
+
123
+ setStatus('loading', 'Downloading models...');
124
+ progressFill.style.width = '0%';
125
+ let lastProgressUpdate = 0;
126
+ const fileProgress = {};
127
+ model = await GraniteSpeechForConditionalGeneration.from_pretrained(MODEL_ID, {
128
+ dtype: {
129
+ audio_encoder: 'q4f16',
130
+ embed_tokens: 'q4f16',
131
+ decoder_model_merged: 'q4f16',
132
+ },
133
+ device: 'webgpu',
134
+ progress_callback: (progress) => {
135
+ if (progress.status === 'progress' && progress.total) {
136
+ fileProgress[progress.file] = { loaded: progress.loaded, total: progress.total };
137
+ const now = performance.now();
138
+ if (now - lastProgressUpdate < 100) return;
139
+ lastProgressUpdate = now;
140
+ let totalLoaded = 0, totalSize = 0;
141
+ for (const f of Object.values(fileProgress)) {
142
+ totalLoaded += f.loaded;
143
+ totalSize += f.total;
144
+ }
145
+ const pct = totalSize > 0 ? (totalLoaded / totalSize) * 100 : 0;
146
+ progressFill.style.width = `${pct}%`;
147
+ const mb = (totalLoaded / 1e6).toFixed(0);
148
+ const totalMb = (totalSize / 1e6).toFixed(0);
149
+ setStatus('loading', `Downloading models... ${mb} / ${totalMb} MB`);
150
+ }
151
+ },
152
+ });
153
 
154
+ progressFill.style.width = '0%';
 
155
  setStatus('ready', 'Ready - Record or upload audio');
156
  enableControls(true);
157
 
 
160
  console.error('Error stack:', error?.stack);
161
  const errorMsg = error?.message || error?.toString() || 'Unknown error';
162
  setStatus('error', `Error: ${errorMsg}`);
163
+ progressFill.style.width = '0%';
164
  isModelLoading = false;
165
  }
166
  }
 
170
  audioFile.disabled = !enabled;
171
  }
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  // Transcribe a single audio segment and return the text
174
  async function transcribeSegment(audioSegment, onPartialResult) {
175
+ // Build prompt using chat template
176
+ const taskKey = promptSelect.value;
177
+ const content = TASK_PROMPTS[taskKey] || TASK_PROMPTS['transcribe'];
178
+ const messages = [{ role: 'user', content }];
179
+
180
+ const text = processor.tokenizer.apply_chat_template(messages, {
181
+ add_generation_prompt: true,
182
+ tokenize: false,
183
+ });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
+ // Process text + audio into model inputs
186
+ const inputs = await processor(text, audioSegment, { sampling_rate: SAMPLE_RATE });
187
+
188
+ // Streaming via TextStreamer
189
+ let accumulated = '';
190
+ const streamer = new TextStreamer(processor.tokenizer, {
191
+ skip_prompt: true,
192
+ skip_special_tokens: true,
193
+ callback_function: (chunk) => {
194
+ accumulated += chunk;
195
+ if (onPartialResult) {
196
+ onPartialResult(accumulated);
197
  }
198
+ },
199
+ });
 
 
 
 
 
 
 
 
 
 
200
 
201
+ // Generate
202
+ await model.generate({
203
+ ...inputs,
204
+ max_new_tokens: MAX_NEW_TOKENS,
205
+ streamer,
206
+ });
207
 
208
+ return accumulated;
209
  }
210
 
211
  // Wait until audio playback reaches a specific time
 
224
 
225
  // Run inference with segmentation and audio sync
226
  async function transcribe() {
227
+ if (!model || !processor || !currentAudioData) {
228
  setStatus('error', 'Model or audio not ready');
229
  return;
230
  }
index.html CHANGED
@@ -120,16 +120,17 @@
120
  Made with
121
  <a href="https://huggingface.co/ibm-granite/granite-4.0-1b-speech" target="_blank">Granite Speech 4.0 1B</a>
122
  and
123
- <a href="https://onnxruntime.ai/docs/tutorials/web/" target="_blank">ONNX Runtime Web</a>
124
  <br>
125
  <span class="privacy-note">Your audio and transcription never leave your device</span>
126
  </div>
127
  <div class="gpu-info" id="gpuInfo"></div>
128
  </div>
129
 
 
130
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.24.3/dist/ort.all.min.js"></script>
131
  <script src="vad.js?v=1"></script>
132
  <script src="punctuator.js?v=3"></script>
133
- <script type="module" src="app.js?v=53"></script>
134
  </body>
135
  </html>
 
120
  Made with
121
  <a href="https://huggingface.co/ibm-granite/granite-4.0-1b-speech" target="_blank">Granite Speech 4.0 1B</a>
122
  and
123
+ <a href="https://huggingface.co/docs/transformers.js" target="_blank">Transformers.js</a>
124
  <br>
125
  <span class="privacy-note">Your audio and transcription never leave your device</span>
126
  </div>
127
  <div class="gpu-info" id="gpuInfo"></div>
128
  </div>
129
 
130
+ <!-- ORT global is retained for VAD (vad.js) and punctuation (punctuator.js) which use WASM -->
131
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.24.3/dist/ort.all.min.js"></script>
132
  <script src="vad.js?v=1"></script>
133
  <script src="punctuator.js?v=3"></script>
134
+ <script type="module" src="app.js?v=54"></script>
135
  </body>
136
  </html>