LFM2-VL-450M-WebGPU / vl-model.js
shubeydoo's picture
Liquid AI LFM2-VL-450M-WebGPU Demo
accf76b
/**
* LFM2-VL Model Runner for ONNX Runtime Web
*
* Runs VL model inference using three ONNX models:
* 1. embed_tokens.onnx - Text token embeddings
* 2. vision_encoder.onnx - Image embeddings from patches
* 3. decoder_model_merged.onnx - Autoregressive decoder with conv state cache
*/
import * as ort from 'onnxruntime-web';
import { AutoTokenizer, env } from '@huggingface/transformers';
import { processImage, loadImage } from './vl-processor.js';
// Debug logging - set to false for production, toggle via setDebug(true) in console
let DEBUG = false;
export function setDebug(value) { DEBUG = value; console.log(`Debug logging ${value ? 'enabled' : 'disabled'}`); }
const log = (...args) => { if (DEBUG) console.log(...args); };
/**
* Convert float32 to float16 (IEEE 754 half-precision)
* @param {number} float32 - Float32 value
* @returns {number} - Float16 value as uint16
*/
function float32ToFloat16(float32) {
const view = new DataView(new ArrayBuffer(4));
view.setFloat32(0, float32, true);
const f32 = view.getUint32(0, true);
const sign = (f32 >> 31) & 0x1;
const exp = (f32 >> 23) & 0xff;
const frac = f32 & 0x7fffff;
let f16;
if (exp === 0) {
// Zero or denormal
f16 = (sign << 15) | (frac >> 13);
} else if (exp === 0xff) {
// Inf or NaN
f16 = (sign << 15) | 0x7c00 | (frac ? (frac >> 13) : 0);
} else {
// Normalized
const newExp = exp - 127 + 15;
if (newExp >= 31) {
// Overflow to infinity
f16 = (sign << 15) | 0x7c00;
} else if (newExp <= 0) {
// Underflow to zero
f16 = (sign << 15);
} else {
f16 = (sign << 15) | (newExp << 10) | (frac >> 13);
}
}
return f16;
}
/**
* Convert Float32Array to float16 Uint16Array
* @param {Float32Array} float32Array
* @returns {Uint16Array}
*/
function convertToFloat16(float32Array) {
const result = new Uint16Array(float32Array.length);
for (let i = 0; i < float32Array.length; i++) {
result[i] = float32ToFloat16(float32Array[i]);
}
return result;
}
/**
* Convert a float32 tensor to float16 tensor
* @param {ort.Tensor} tensor - Float32 tensor
* @returns {ort.Tensor} - Float16 tensor
*/
function tensorToFloat16(tensor) {
const float16Data = convertToFloat16(tensor.data);
return new ort.Tensor('float16', float16Data, tensor.dims);
}
// Cache configuration
const CACHE_NAME = 'onnx-models-v1';
// Threshold for URL-based ONNX loading (files too large for JS memory)
// Set to 2GB - files larger than this will stream instead of loading into memory
const LARGE_FILE_THRESHOLD = 2 * 1024 * 1024 * 1024; // 2GB
/**
* Fetch with streaming progress tracking
* @param {string} url - URL to fetch
* @param {object} options - Fetch options
* @param {function} onProgress - Progress callback (received, total) => void
* @returns {Promise<Response>} - Response with complete body
*/
async function fetchWithProgress(url, options = {}, onProgress) {
const response = await fetch(url, options);
if (!response.ok) {
throw new Error(`Fetch failed: ${response.status}`);
}
const contentLength = parseInt(response.headers.get('content-length') || '0', 10);
if (!contentLength || !onProgress) {
// No size info or no callback - return as-is
return response;
}
const reader = response.body.getReader();
const chunks = [];
let received = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
received += value.length;
onProgress(received, contentLength);
}
// Combine chunks into single buffer
const buffer = new Uint8Array(received);
let offset = 0;
for (const chunk of chunks) {
buffer.set(chunk, offset);
offset += chunk.length;
}
// Create new Response with fresh Headers for Cache API compatibility
// Using the original headers object from a consumed response can cause issues
return new Response(new Blob([buffer]), {
status: response.status,
headers: new Headers(response.headers),
});
}
/**
* Fetch with caching support using Cache API
* @param {string} url - URL to fetch
* @param {object} options - Fetch options
* @param {function} onProgress - Optional progress callback (received, total) => void
* @returns {Promise<Response>} - Response (from cache or network)
*/
async function fetchWithCache(url, options = {}, onProgress = null) {
// Skip caching for local files
if (!url.startsWith('http://') && !url.startsWith('https://')) {
return fetch(url, options);
}
const fileName = url.split('/').pop();
// 1. Try cache read with validation
try {
const cache = await caches.open(CACHE_NAME);
const cached = await cache.match(url);
if (cached) {
// Validate by reading body - catches corrupted entries from failed cache.put()
try {
const buffer = await cached.clone().arrayBuffer();
log(`[Cache HIT] ${fileName} (${(buffer.byteLength / 1024 / 1024).toFixed(1)} MB)`);
// Return a new Response with the validated buffer
return new Response(buffer, {
status: cached.status,
statusText: cached.statusText,
headers: cached.headers,
});
} catch (bodyError) {
// Corrupted cache entry - delete it and re-fetch
log(`[Cache CORRUPT] ${fileName} - deleting and re-fetching`);
await cache.delete(url);
}
}
} catch (e) {
log(`[Cache ERROR] ${e.message}`);
}
// 2. Fetch from network with progress tracking
log(`[Network] Fetching ${fileName}...`);
const response = await fetchWithProgress(url, options, onProgress);
// 3. Try to cache successful response (fire-and-forget)
if (response.ok) {
tryCacheResponse(url, response.clone());
}
return response;
}
/**
* Try to cache a response (non-blocking, best-effort)
* @param {string} url - URL to cache
* @param {Response} response - Response to cache
*/
async function tryCacheResponse(url, response) {
try {
// Check available space before caching
if (navigator.storage?.estimate) {
const { usage = 0, quota = 0 } = await navigator.storage.estimate();
const available = quota - usage;
const responseSize = parseInt(response.headers.get('content-length') || '0', 10);
// Skip if we don't have space for this file + 100MB buffer
const BUFFER = 100 * 1024 * 1024;
if (responseSize > 0 && available < responseSize + BUFFER) {
log(`[Cache SKIP] Not enough space (need ${((responseSize + BUFFER) / 1e9).toFixed(2)} GB, have ${(available / 1e9).toFixed(2)} GB)`);
return;
}
}
const cache = await caches.open(CACHE_NAME);
await cache.put(url, response);
log(`[Cached] ${url.split('/').pop()}`);
} catch (e) {
// Caching failed, but download succeeded - that's fine
console.warn(`[Cache WRITE ERROR] ${url.split('/').pop()}:`, e.name, e.message, e);
}
}
/**
* Clear the model cache
* @returns {Promise<boolean>} - True if cache was deleted
*/
export async function clearModelCache() {
const deleted = await caches.delete(CACHE_NAME);
log(deleted ? 'Model cache cleared' : 'No cache to clear');
return deleted;
}
/**
* Get cache storage usage info (specifically for model cache)
* @returns {Promise<{used: number, available: number}|null>}
*/
export async function getCacheInfo() {
try {
// Calculate actual size of just the model cache
const cache = await caches.open(CACHE_NAME);
const keys = await cache.keys();
let totalSize = 0;
for (const request of keys) {
const response = await cache.match(request);
if (response) {
// Get the response body as blob to measure size
const blob = await response.clone().blob();
totalSize += blob.size;
}
}
// Get quota info for available space
let available = 0;
if ('storage' in navigator && 'estimate' in navigator.storage) {
const estimate = await navigator.storage.estimate();
available = estimate.quota || 0;
}
return {
used: totalSize,
available: available,
};
} catch (e) {
console.warn('Error getting cache info:', e);
return null;
}
}
/**
* Load tokenizer from model path (local or S3)
* @param {string} modelPath - Path to model directory (local or S3 URL)
* @returns {Promise<{tokenizer: object, specialTokens: object}>} - Tokenizer instance and special token IDs
*/
async function loadTokenizerFromPath(modelPath) {
const isRemote = modelPath.startsWith('http://') || modelPath.startsWith('https://');
log(`Loading tokenizer from ${isRemote ? 'remote' : 'local'}: ${modelPath}`);
const fetchOptions = isRemote ? { mode: 'cors', credentials: 'omit' } : {};
// Fetch tokenizer files (with caching)
const [tokenizerResponse, configResponse] = await Promise.all([
fetchWithCache(`${modelPath}/tokenizer.json`, fetchOptions),
fetchWithCache(`${modelPath}/tokenizer_config.json`, fetchOptions),
]);
if (!tokenizerResponse.ok) {
throw new Error(`Failed to fetch tokenizer.json: ${tokenizerResponse.status}`);
}
if (!configResponse.ok) {
throw new Error(`Failed to fetch tokenizer_config.json: ${configResponse.status}`);
}
const tokenizerJSON = await tokenizerResponse.text();
const configJSON = await configResponse.text();
log('Tokenizer files fetched, creating tokenizer...');
// Parse tokenizer.json to extract special token IDs from added_tokens
const tokenizerData = JSON.parse(tokenizerJSON);
const specialTokens = {};
if (tokenizerData.added_tokens) {
for (const token of tokenizerData.added_tokens) {
specialTokens[token.content] = token.id;
}
log('Found special tokens:', Object.keys(specialTokens).length);
}
// Create a unique fake model ID
const fakeModelId = `tokenizer-${Date.now()}`;
// Cache of files to serve
const fileCache = {
'tokenizer.json': tokenizerJSON,
'tokenizer_config.json': configJSON,
};
// Intercept fetch to serve our cached files
const originalFetch = globalThis.fetch;
globalThis.fetch = async (input, init) => {
const url = typeof input === 'string' ? input : input.url;
// Check if this is a request for our fake model
if (url.includes(fakeModelId)) {
for (const [filename, content] of Object.entries(fileCache)) {
if (url.includes(filename)) {
log(`Serving cached ${filename}`);
return new Response(content, {
status: 200,
headers: { 'Content-Type': 'application/json' },
});
}
}
// Return 404 for other files (like config.json which tokenizer doesn't need)
return new Response('Not found', { status: 404 });
}
return originalFetch(input, init);
};
// Disable local model check
const originalAllowLocal = env.allowLocalModels;
env.allowLocalModels = false;
try {
const tokenizer = await AutoTokenizer.from_pretrained(fakeModelId);
log('Tokenizer created successfully');
return { tokenizer, specialTokens };
} finally {
// Restore original state
globalThis.fetch = originalFetch;
env.allowLocalModels = originalAllowLocal;
}
}
export class VLModel {
constructor() {
this.tokenizer = null;
this.embedTokensSession = null;
this.visionEncoderSession = null;
this.decoderSession = null;
this.config = null;
this.imageTokenId = null;
this.eosTokenId = null;
this.hiddenSize = 1024; // Default for 450M
// Image embedding cache (persists between turns)
this.imageCache = new Map(); // URL -> { embeddings, numTokens }
}
/**
* Clear the image embedding cache (call when starting a new conversation)
*/
clearImageCache() {
this.imageCache.clear();
}
/**
* Load the VL model from a directory
* @param {string} modelPath - Path to model directory (S3 URL)
* @param {object} options - Loading options
* @param {function} options.progressCallback - Progress callback
* @param {string} options.device - Device to use ('webgpu' or 'wasm')
* @param {string} options.quantization - Quantization type ('q4', 'q8', or null for fp32)
*/
async load(modelPath, options = {}) {
const { progressCallback, device = 'webgpu', quantization = null } = options;
const report = (status, progress = 0, file = '') => {
if (progressCallback) {
progressCallback({ status, progress, file });
}
};
// Determine execution provider
const executionProviders = device === 'webgpu'
? ['webgpu', 'wasm']
: ['wasm'];
try {
// Load tokenizer and extract special token IDs
report('loading', 0, 'tokenizer');
const { tokenizer, specialTokens } = await loadTokenizerFromPath(modelPath);
this.tokenizer = tokenizer;
// Load chat template from S3 if not already set in tokenizer
if (!this.tokenizer.chat_template) {
try {
const templateResponse = await fetch(`${modelPath}/chat_template.jinja`, {
mode: 'cors',
credentials: 'omit',
});
if (templateResponse.ok) {
const template = await templateResponse.text();
this.tokenizer.chat_template = template;
log('Loaded chat template from model path');
}
} catch (e) {
console.warn('Could not load chat template:', e);
}
}
// Get special token IDs from parsed tokenizer.json
this.imageTokenId = specialTokens['<image>'] ?? null;
this.imageStartTokenId = specialTokens['<|image_start|>'] ?? null;
this.imageEndTokenId = specialTokens['<|image_end|>'] ?? null;
this.imageSplitTokenId = specialTokens['<|image_split|>'] ?? null;
this.eosTokenId = this.tokenizer.eos_token_id;
log('Image token ID:', this.imageTokenId);
log('Image start token ID:', this.imageStartTokenId);
log('Image end token ID:', this.imageEndTokenId);
log('EOS token ID:', this.eosTokenId);
if (this.imageTokenId === null) {
console.warn('Warning: <image> token not found in tokenizer');
}
// Load config
report('loading', 10, 'config');
const configResponse = await fetch(`${modelPath}/config.json`, {
mode: 'cors',
credentials: 'omit',
});
this.config = await configResponse.json();
// VL models have config in text_config
const textConfig = this.config.text_config || this.config;
this.hiddenSize = textConfig.hidden_size || 1024;
this.numKVHeads = textConfig.num_key_value_heads || 8;
this.headDim = Math.floor(this.hiddenSize / (textConfig.num_attention_heads || 16));
log('Model config:', { hiddenSize: this.hiddenSize, numKVHeads: this.numKVHeads, headDim: this.headDim });
// Get external data files (single file per component for 450M)
const getExternalDataFiles = async (basePath, fileName, fetchOptions) => {
const files = [];
// Get primary file
const primaryUrl = `${basePath}/onnx/${fileName}.onnx_data`;
try {
const headResp = await fetch(primaryUrl, { method: 'HEAD', ...fetchOptions });
if (!headResp.ok) return []; // No external data
files.push({
path: `${fileName}.onnx_data`,
url: primaryUrl,
size: parseInt(headResp.headers.get('content-length') || '0', 10)
});
} catch (e) {
return []; // No external data
}
return files;
};
// Helper to load ONNX model with external data (with caching and progress)
// customProviders allows overriding execution providers for specific sessions
const loadOnnxWithExternalData = async (name, progress, quantSuffix = quantization, customProviders = null) => {
// Build filename with optional quantization suffix
const suffix = quantSuffix ? `_${quantSuffix}` : '';
const fileName = `${name}${suffix}`;
report('loading', progress, `${fileName}.onnx`);
const onnxPath = `${modelPath}/onnx/${fileName}.onnx`;
const fetchOptions = { mode: 'cors', credentials: 'omit' };
log(`Loading ${fileName}...`);
// Progress callback for download progress
const makeProgressCallback = (file) => (received, total) => {
const mb = (received / 1024 / 1024).toFixed(0);
const totalMb = (total / 1024 / 1024).toFixed(0);
report('loading', progress, `${file}: ${mb} / ${totalMb} MB`);
};
// Get external data files (uses size-based format detection)
const dataFiles = await getExternalDataFiles(modelPath, fileName, fetchOptions);
const totalDataSize = dataFiles.reduce((sum, f) => sum + f.size, 0);
log(`Found ${dataFiles.length} external data file(s) for ${fileName}, total: ${(totalDataSize / 1024 / 1024).toFixed(1)} MB`);
// Use custom providers if specified, otherwise use default
const providers = customProviders || executionProviders;
const sessionOptions = {
executionProviders: providers,
};
// Fetch ONNX file (with caching and progress)
const onnxResponse = await fetchWithCache(onnxPath, fetchOptions, makeProgressCallback(`${fileName}.onnx`));
if (!onnxResponse.ok) {
throw new Error(`Failed to fetch ${fileName}.onnx: ${onnxResponse.status}`);
}
const onnxBuffer = await onnxResponse.arrayBuffer();
log(`Loaded ${fileName}.onnx: ${(onnxBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`);
if (dataFiles.length > 0) {
// Load each file individually - use memory for cacheable files, URL for oversized
sessionOptions.externalData = [];
for (const f of dataFiles) {
if (f.size > LARGE_FILE_THRESHOLD) {
// File too large for JS memory - let ONNX Runtime stream it
log(`Large file ${f.path} (${(f.size / 1024 / 1024 / 1024).toFixed(2)} GB), using URL-based loading`);
report('loading', progress, `${fileName} (streaming ${f.path}...)`);
sessionOptions.externalData.push({
path: f.path,
data: f.url,
});
} else {
// File fits in memory - fetch with caching and progress
const dataResponse = await fetchWithCache(f.url, fetchOptions, makeProgressCallback(f.path));
if (!dataResponse.ok) {
throw new Error(`Failed to fetch ${f.path}: ${dataResponse.status}`);
}
const dataBuffer = await dataResponse.arrayBuffer();
log(`Loaded ${f.path}: ${(dataBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`);
sessionOptions.externalData.push({
path: f.path,
data: new Uint8Array(dataBuffer),
});
}
}
report('loading', progress, `${fileName} (initializing)`);
} else {
report('loading', progress, `${fileName} (initializing)`);
}
const session = await ort.InferenceSession.create(new Uint8Array(onnxBuffer), sessionOptions);
log(`Session created for ${fileName}`);
return session;
};
// Parse quantization config (can be string for legacy or object for new format)
const quantConfig = typeof quantization === 'object' ? quantization : {
decoder: quantization,
visionEncoder: quantization,
};
// Load embed_tokens (use fp16 suffix if decoder is fp16, otherwise no suffix)
const embedTokensQuant = quantConfig.decoder || null;
this.embedTokensSession = await loadOnnxWithExternalData('embed_tokens', 20, embedTokensQuant);
// Load vision_encoder (use specified quantization)
const visionEncoderQuant = quantConfig.visionEncoder || null;
this.visionEncoderSession = await loadOnnxWithExternalData('vision_encoder', 40, visionEncoderQuant);
// Load decoder_model_merged (use specified quantization)
const decoderQuant = quantConfig.decoder || null;
this.decoderSession = await loadOnnxWithExternalData('decoder_model_merged', 60, decoderQuant);
report('done', 100, '');
return true;
} catch (error) {
// Better error reporting for ORT errors
let errorMessage = error;
if (typeof error === 'number') {
errorMessage = `ONNX Runtime error code: ${error}. This may indicate a WebGPU memory or compatibility issue.`;
} else if (error instanceof Error) {
errorMessage = error.message;
}
console.error('Failed to load VL model:', errorMessage);
throw new Error(errorMessage);
}
}
/**
* Process images and get embeddings (with caching)
* @param {string[]} imageInputs - Array of image URLs or data URLs
* @returns {Promise<{embeddings: Float32Array, numTokens: number, tokensPerImage: number[]}>}
*/
async getImageEmbeddings(imageInputs) {
const allEmbeddings = [];
const tokensPerImage = [];
let totalTokens = 0;
let cacheHits = 0;
let cacheMisses = 0;
for (const input of imageInputs) {
// Check cache first
if (this.imageCache.has(input)) {
const cached = this.imageCache.get(input);
allEmbeddings.push(cached.embeddings);
tokensPerImage.push(cached.numTokens);
totalTokens += cached.numTokens;
cacheHits++;
continue;
}
// Cache miss - load and process the image
cacheMisses++;
const img = await loadImage(input);
const processed = await processImage(img);
log(`Image processed: ${processed.numTiles} tiles, shape [${processed.shape.join(', ')}]`);
// Create tensors - use shape from processed output
const patchesPerTile = processed.shape[1]; // 1024
const pixelValuesTensor = new ort.Tensor(
'float32',
processed.pixelValues,
processed.shape // [num_tiles, patches_per_tile, 768]
);
const attentionMaskTensor = new ort.Tensor(
'int64',
processed.attentionMask, // BigInt64Array
[processed.numTiles, patchesPerTile] // [num_tiles, patches_per_tile]
);
const spatialShapesTensor = new ort.Tensor(
'int64',
processed.spatialShapes, // BigInt64Array
[processed.numTiles, 2] // [num_tiles, 2]
);
// Run vision_encoder
let outputs = await this.visionEncoderSession.run({
pixel_values: pixelValuesTensor,
pixel_attention_mask: attentionMaskTensor,
spatial_shapes: spatialShapesTensor,
});
// Output shape: [num_image_tokens, hidden_dim] (already flattened)
let embeddings = outputs.image_features;
log('Image embeddings shape:', embeddings.dims);
// Output is 2D: [num_tokens, hidden_dim]
const numTokens = embeddings.dims[0];
// Store in cache (copy the data since tensor might be reused)
const embeddingsCopy = new Float32Array(embeddings.data);
this.imageCache.set(input, { embeddings: embeddingsCopy, numTokens });
tokensPerImage.push(numTokens);
totalTokens += numTokens;
allEmbeddings.push(embeddingsCopy);
}
if (DEBUG && (cacheHits > 0 || cacheMisses > 1)) {
log(`Image embeddings: ${cacheHits} cached, ${cacheMisses} computed, ${totalTokens} total tokens`);
}
// Concatenate all image embeddings
const totalLength = allEmbeddings.reduce((sum, e) => sum + e.length, 0);
const combined = new Float32Array(totalLength);
let offset = 0;
for (const emb of allEmbeddings) {
combined.set(emb, offset);
offset += emb.length;
}
return { embeddings: combined, numTokens: totalTokens, tokensPerImage };
}
/**
* Get text embeddings from token IDs
* @param {number[]} inputIds - Token IDs as regular numbers
* @returns {Promise<ort.Tensor>} - Text embeddings tensor
*/
async getTextEmbeddings(inputIds) {
const inputTensor = new ort.Tensor(
'int64',
new BigInt64Array(inputIds.map(id => BigInt(id))),
[1, inputIds.length]
);
const outputs = await this.embedTokensSession.run({ input_ids: inputTensor });
return outputs.inputs_embeds;
}
/**
* Build combined embeddings by replacing image tokens with image embeddings (1:1)
* Each <image> token position gets replaced with exactly one image embedding.
* The sequence length remains the same.
*
* @param {number[]} inputIds - Token IDs
* @param {ort.Tensor} textEmbeddings - Text embeddings tensor
* @param {Float32Array} imageEmbeddings - Concatenated image embeddings
*/
buildCombinedEmbeddings1to1(inputIds, textEmbeddings, imageEmbeddings) {
const [, seqLen, hiddenDim] = textEmbeddings.dims;
const textEmb = textEmbeddings.data;
const imgEmb = imageEmbeddings;
// Find all image token positions
const imagePositions = [];
for (let i = 0; i < inputIds.length; i++) {
if (inputIds[i] === this.imageTokenId) {
imagePositions.push(i);
}
}
const numImageEmbeddings = imgEmb.length / hiddenDim;
if (imagePositions.length !== numImageEmbeddings) {
console.warn(`Image token mismatch: ${imagePositions.length} <image> tokens vs ${numImageEmbeddings} embeddings`);
}
// Copy text embeddings and replace image token positions
const result = new Float32Array(textEmb);
for (let i = 0; i < Math.min(imagePositions.length, numImageEmbeddings); i++) {
const pos = imagePositions[i];
const embStart = i * hiddenDim;
const dstStart = pos * hiddenDim;
result.set(imgEmb.slice(embStart, embStart + hiddenDim), dstStart);
}
return new ort.Tensor('float32', result, [1, seqLen, hiddenDim]);
}
/**
* Initialize cache for decoder (both conv states and KV cache)
* Uses float16 tensors as required by the 450M ONNX model
*/
initializeCache() {
const cache = {};
for (const name of this.decoderSession.inputNames) {
if (name.startsWith('past_conv')) {
// Conv states: [batch, hidden_size, kernel_size-1]
// Kernel size is 4, so we need 3 states
// Use float16 (Uint16Array) for 450M model compatibility
cache[name] = new ort.Tensor(
'float16',
new Uint16Array(1 * this.hiddenSize * 3),
[1, this.hiddenSize, 3]
);
} else if (name.startsWith('past_key_values')) {
// KV cache: [batch, num_kv_heads, past_seq_len, head_dim]
// Initialize with 0 length sequence
// Use float16 (Uint16Array) for 450M model compatibility
cache[name] = new ort.Tensor(
'float16',
new Uint16Array(0), // Empty cache initially
[1, this.numKVHeads, 0, this.headDim]
);
}
}
return cache;
}
/**
* Update cache from decoder outputs
*/
updateCache(cache, outputs) {
for (const name of Object.keys(outputs)) {
if (name.startsWith('present_conv')) {
// Conv states: present_conv.X -> past_conv.X
const cacheName = name.replace('present_conv', 'past_conv');
if (cacheName in cache) {
cache[cacheName] = outputs[name];
}
} else if (name.startsWith('present.')) {
// KV cache: present.X.key -> past_key_values.X.key
const cacheName = name.replace('present.', 'past_key_values.');
if (cacheName in cache) {
cache[cacheName] = outputs[name];
}
}
}
}
/**
* Generate text given messages with optional images
* @param {Array} messages - Chat messages
* @param {object} options - Generation options
*/
async generate(messages, options = {}) {
const { maxNewTokens = 256, onToken, images = [], messageImageMap = new Map() } = options;
log(`=== VL Generate: ${messages.length} messages, ${images.length} images ===`);
// Process images FIRST to get patch counts
let imageEmbeddings = null;
let tokensPerImage = [];
let totalImageTokens = 0;
if (images.length > 0) {
const result = await this.getImageEmbeddings(images);
imageEmbeddings = result.embeddings;
tokensPerImage = result.tokensPerImage;
totalImageTokens = result.numTokens;
log(`Image tokens: ${totalImageTokens} (per-image: [${tokensPerImage.join(', ')}])`);
}
// Build prompt with <image> tokens placed in EACH message that has images
// This is critical: each user message that sent an image needs its <image> token(s)
let promptMessages = messages;
if (images.length > 0) {
promptMessages = messages.map((msg, idx) => {
// Check if this message has images via messageImageMap
if (msg.role === 'user' && messageImageMap.has(idx)) {
const messageImages = messageImageMap.get(idx);
const imageTokens = messageImages.map(() => '<image>').join('');
return { ...msg, content: imageTokens + msg.content };
}
return msg;
});
}
// Apply chat template
const prompt = this.tokenizer.apply_chat_template(promptMessages, {
add_generation_prompt: true,
tokenize: false,
});
// Tokenize
const encoded = this.tokenizer.encode(prompt);
let inputIds = [...encoded];
// Expand each <image> token to the correct count for that image
// Add boundary tokens if available: <image_start> [tokens] <image_end>
if (images.length > 0) {
const expandedIds = [];
let imageIdx = 0;
for (const id of inputIds) {
if (id === this.imageTokenId && imageIdx < tokensPerImage.length) {
// Add start boundary if available
if (this.imageStartTokenId) {
expandedIds.push(this.imageStartTokenId);
}
// Replace single <image> with N copies
const count = tokensPerImage[imageIdx];
for (let i = 0; i < count; i++) {
expandedIds.push(this.imageTokenId);
}
// Add end boundary if available
if (this.imageEndTokenId) {
expandedIds.push(this.imageEndTokenId);
}
imageIdx++;
} else {
expandedIds.push(id);
}
}
inputIds = expandedIds;
}
// Get text embeddings for expanded sequence
const textEmbeddings = await this.getTextEmbeddings(inputIds);
// Replace image token embeddings with actual image embeddings (1:1)
let inputsEmbeds;
if (images.length > 0) {
inputsEmbeds = this.buildCombinedEmbeddings1to1(inputIds, textEmbeddings, imageEmbeddings);
} else {
inputsEmbeds = textEmbeddings;
}
log(`Input sequence: ${inputsEmbeds.dims[1]} tokens, ${(inputsEmbeds.data.length * 4 / 1024 / 1024).toFixed(1)} MB`);
// Initialize fresh cache for this generation
// (KV cache is used within generation for autoregressive decoding)
const cache = this.initializeCache();
// Generation loop
const seqLen = inputsEmbeds.dims[1];
let curLen = seqLen;
let currentEmbeds = inputsEmbeds;
const generatedTokens = [];
for (let step = 0; step < maxNewTokens; step++) {
// Prepare attention mask
const attentionMask = new ort.Tensor(
'int64',
new BigInt64Array(curLen).fill(1n),
[1, curLen]
);
// Run decoder (LFM2 models don't use position_ids - position is implicit from attention)
const feeds = {
inputs_embeds: currentEmbeds,
attention_mask: attentionMask,
...cache,
};
const outputs = await this.decoderSession.run(feeds);
// Get logits - shape is [batch, seq_len, vocab_size]
const logits = outputs.logits;
const vocabSize = logits.dims[2];
const logitsData = logits.data;
// Get last token logits
const lastLogitStart = (logits.dims[1] - 1) * vocabSize;
const lastLogits = logitsData.slice(lastLogitStart, lastLogitStart + vocabSize);
// Greedy decoding - find max
let maxIdx = 0;
let maxVal = lastLogits[0];
for (let i = 1; i < vocabSize; i++) {
if (lastLogits[i] > maxVal) {
maxVal = lastLogits[i];
maxIdx = i;
}
}
generatedTokens.push(maxIdx);
// Callback with token
if (onToken) {
const tokenText = this.tokenizer.decode([maxIdx]);
const shouldStop = onToken(tokenText, maxIdx);
if (shouldStop) break;
}
// Check for EOS
if (maxIdx === this.eosTokenId) {
break;
}
// Update cache for next token
this.updateCache(cache, outputs);
// Get embedding for next token
const nextEmbeds = await this.getTextEmbeddings([maxIdx]);
currentEmbeds = nextEmbeds;
curLen++;
}
return this.tokenizer.decode(generatedTokens, { skip_special_tokens: true });
}
/**
* Free resources
*/
async dispose() {
this.clearImageCache();
this.tokenizer = null;
// Properly release ONNX sessions to free GPU resources
if (this.embedTokensSession) {
try {
await this.embedTokensSession.release();
} catch (e) {
console.warn('Error releasing embedTokensSession:', e);
}
this.embedTokensSession = null;
}
if (this.visionEncoderSession) {
try {
await this.visionEncoderSession.release();
} catch (e) {
console.warn('Error releasing visionEncoderSession:', e);
}
this.visionEncoderSession = null;
}
if (this.decoderSession) {
try {
await this.decoderSession.release();
} catch (e) {
console.warn('Error releasing decoderSession:', e);
}
this.decoderSession = null;
}
}
}
export default VLModel;