File size: 3,175 Bytes
e3aec01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import FileProxyCache from 'https://cdn.jsdelivr.net/gh/jasonmayes/web-ai-model-proxy-cache@main/FileProxyCache.min.js';
import * as LiteRT from 'https://cdn.jsdelivr.net/npm/@litertjs/core@0.2.1/+esm';
import * as LiteRTInterop from 'https://cdn.jsdelivr.net/npm/@litertjs/tfjs-interop@1.0.1/+esm';
import { pipeline, env } from 'https://cdn.jsdelivr.net/npm/@xenova/transformers@2.17.2';

/**
 * A class to handle loading and using the EmbeddingGemma model with LiteRT.
 * Coded by Jason Mayes 2026.
 */
export class EmbeddingModel {
  constructor(modelRuntime) {
    this.model = undefined;
    this.runtime = modelRuntime;
  }

  /**
   * Loads and compiles the LiteRT model.
   * @param {string} modelUrl URL to the .tflite model file.
   * @return {Promise<void>}
   */
  async load(modelUrl) {
    // LiteRT initialization needs to happen before loading.
    // In this refactor, we assume LiteRT.loadLiteRt and setWebGpuDevice are handled 
    // in the main script or within this load method if needed.
    // However, the user asked to keep the logic similar, so we'll just handle 
    // model loading and compilation here.
    let dataUrl = await FileProxyCache.loadFromURL(modelUrl);
    
    if (this.runtime === 'litertjs') {
      this.model = await LiteRT.loadAndCompile(dataUrl, {
        accelerator: 'webgpu',
      });
    } else {
      // Transformers.js model.
      // Load the feature-extraction pipeline
      this.model = await pipeline('feature-extraction', dataUrl);
    }
  }

  /**
   * Generates an embedding for the given tokens.
   * @param {Array<number>} tokens Array of token IDs.
   * @param {number} seqLength Expected sequence length for the model.
   * @return {Promise<{embedding: tf.Tensor, tokens: Array<number>}>} The generated embedding tensor.
   */
  async getEmbeddingLiteRTJS(tokens, seqLength) {
    if (!this.model) {
      throw new Error('Model not loaded. Call load() first.');
    }
    
    if (this.runtime === 'litertjs') {
      let inputTensor = tf.tensor1d(tokens, 'int32');
    
      // Ensure to fill to expected model token length else trim.
      if (tokens.length < seqLength) {
        inputTensor = inputTensor.pad([[0, seqLength - tokens.length]]);
      } else if (tokens.length > seqLength) {
        inputTensor = inputTensor.slice([0], [seqLength]);
      }
    
      const EXPANDED_INPUT = inputTensor.expandDims(0);
      const RESULTS = LiteRTInterop.runWithTfjsTensors(this.model, EXPANDED_INPUT);

      inputTensor.dispose();
      EXPANDED_INPUT.dispose();
      return {
        embedding: RESULTS[0], // Returns batch of 1.
        tokens: tokens
      };
    }
  }

  /**
   * Generates an embedding for the given tokens.
   * @param {string|Array<string>} textBatch Text or array of text to embed.
   * @return {Promise<{embedding: Array<number>}>} The generated embedding.
   */
  async getEmbeddingTransformers(textBatch) {
    if (this.runtime === 'transformersjs') {
      const queryResult = await this.model(textBatch, { pooling: 'mean', normalize: true });
      const queryVector = Array.from(queryResult.data);
      return {
        embedding: queryVector
      };
    }
  }
}