import FileProxyCache from 'https://cdn.jsdelivr.net/gh/jasonmayes/web-ai-model-proxy-cache@main/FileProxyCache.min.js'; import * as LiteRT from 'https://cdn.jsdelivr.net/npm/@litertjs/core@0.2.1/+esm'; import * as LiteRTInterop from 'https://cdn.jsdelivr.net/npm/@litertjs/tfjs-interop@1.0.1/+esm'; import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2'; /** * A class to handle loading and using the EmbeddingGemma model with LiteRT. * Coded by Jason Mayes 2026. */ export class EmbeddingModel { constructor(modelRuntime) { this.model = undefined; this.runtime = modelRuntime; } /** * Loads and compiles the LiteRT model. * @param {string} modelUrl URL to the .tflite model file. * @return {Promise} */ async load(modelUrl) { // LiteRT initialization needs to happen before loading. // In this refactor, we assume LiteRT.loadLiteRt and setWebGpuDevice are handled // in the main script or within this load method if needed. // However, the user asked to keep the logic similar, so we'll just handle // model loading and compilation here. let dataUrl = await FileProxyCache.loadFromURL(modelUrl); if (this.runtime === 'litertjs') { this.model = await LiteRT.loadAndCompile(dataUrl, { accelerator: 'webgpu', }); } else { // Transformers.js model. // Load the feature-extraction pipeline this.model = await pipeline('feature-extraction', dataUrl); } } /** * Generates an embedding for the given tokens. * @param {Array} tokens Array of token IDs. * @param {number} seqLength Expected sequence length for the model. * @return {Promise<{embedding: tf.Tensor, tokens: Array}>} The generated embedding tensor. */ async getEmbeddingLiteRTJS(tokens, seqLength) { if (!this.model) { throw new Error('Model not loaded. Call load() first.'); } if (this.runtime === 'litertjs') { let inputTensor = tf.tensor1d(tokens, 'int32'); // Ensure to fill to expected model token length else trim. if (tokens.length < seqLength) { inputTensor = inputTensor.pad([[0, seqLength - tokens.length]]); } else if (tokens.length > seqLength) { inputTensor = inputTensor.slice([0], [seqLength]); } const EXPANDED_INPUT = inputTensor.expandDims(0); const RESULTS = LiteRTInterop.runWithTfjsTensors(this.model, EXPANDED_INPUT); inputTensor.dispose(); EXPANDED_INPUT.dispose(); return { embedding: RESULTS[0], // Returns batch of 1. tokens: tokens }; } } /** * Generates an embedding for the given tokens. * @param {string|Array} textBatch Text or array of text to embed. * @return {Promise<{embedding: Array}>} The generated embedding. */ async getEmbeddingTransformers(textBatch) { if (this.runtime === 'transformersjs') { const queryResult = await this.model(textBatch, { pooling: 'mean', normalize: true }); const queryVector = Array.from(queryResult.data); return { embedding: queryVector }; } } }