File size: 9,182 Bytes
f8b5d42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
const path = require("path");
const fs = require("fs");

class NativeEmbeddingReranker {
  static #model = null;
  static #tokenizer = null;
  static #transformers = null;
  static #initializationPromise = null;

  // This is a folder that Mintplex Labs hosts for those who cannot capture the HF model download
  // endpoint for various reasons. This endpoint is not guaranteed to be active or maintained
  // and may go offline at any time at Mintplex Labs's discretion.
  #fallbackHost = "https://cdn.anythingllm.com/support/models/";

  constructor() {
    // An alternative model to the mixedbread-ai/mxbai-rerank-xsmall-v1 model (speed on CPU is much slower for this model @ 18docs = 6s)
    // Model Card: https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2 (speed on CPU is much faster @ 18docs = 1.6s)
    this.model = "Xenova/ms-marco-MiniLM-L-6-v2";
    this.cacheDir = path.resolve(
      process.env.STORAGE_DIR
        ? path.resolve(process.env.STORAGE_DIR, `models`)
        : path.resolve(__dirname, `../../../storage/models`)
    );
    this.modelPath = path.resolve(this.cacheDir, ...this.model.split("/"));
    // Make directory when it does not exist in existing installations
    if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir);

    this.modelDownloaded = fs.existsSync(
      path.resolve(this.cacheDir, this.model)
    );
    this.log("Initialized");
  }

  log(text, ...args) {
    console.log(`\x1b[36m[NativeEmbeddingReranker]\x1b[0m ${text}`, ...args);
  }

  /**
   * This function will return the host of the current reranker suite.
   * If the reranker suite is not initialized, it will return the default HF host.
   * @returns {string} The host of the current reranker suite.
   */
  get host() {
    if (!NativeEmbeddingReranker.#transformers) return "https://huggingface.co";
    try {
      return new URL(NativeEmbeddingReranker.#transformers.env.remoteHost).host;
    } catch (e) {
      return this.#fallbackHost;
    }
  }

  /**
   * This function will preload the reranker suite and tokenizer.
   * This is useful for reducing the latency of the first rerank call and pre-downloading the models and such
   * to avoid having to wait for the models to download on the first rerank call.
   */
  async preload() {
    try {
      this.log(`Preloading reranker suite...`);
      await this.initClient();
      this.log(
        `Preloaded reranker suite. Reranking is available as a service now.`
      );
      return;
    } catch (e) {
      console.error(e);
      this.log(
        `Failed to preload reranker suite. Reranking will be available on the first rerank call.`
      );
      return;
    }
  }

  async initClient() {
    if (
      NativeEmbeddingReranker.#transformers &&
      NativeEmbeddingReranker.#model &&
      NativeEmbeddingReranker.#tokenizer
    ) {
      this.log(`Reranker suite already fully initialized - reusing.`);
      return;
    }

    if (NativeEmbeddingReranker.#initializationPromise) {
      this.log(`Waiting for existing initialization to complete...`);
      await NativeEmbeddingReranker.#initializationPromise;
      return;
    }

    NativeEmbeddingReranker.#initializationPromise = (async () => {
      try {
        const { AutoModelForSequenceClassification, AutoTokenizer, env } =
          await import("@xenova/transformers");
        this.log(`Loading reranker suite...`);
        NativeEmbeddingReranker.#transformers = {
          AutoModelForSequenceClassification,
          AutoTokenizer,
          env,
        };
        // Attempt to load the model and tokenizer in this order:
        // 1. From local file system cache
        // 2. Download and cache from remote host (hf.co)
        // 3. Download and cache from fallback host (cdn.anythingllm.com)
        await this.#getPreTrainedModel();
        await this.#getPreTrainedTokenizer();
      } finally {
        NativeEmbeddingReranker.#initializationPromise = null;
      }
    })();

    await NativeEmbeddingReranker.#initializationPromise;
  }

  /**
   * This function will load the model from the local file system cache, or download and cache it from the remote host.
   * If the model is not found in the local file system cache, it will download and cache it from the remote host.
   * If the model is not found in the remote host, it will download and cache it from the fallback host.
   * @returns {Promise<any>} The loaded model.
   */
  async #getPreTrainedModel() {
    if (NativeEmbeddingReranker.#model) {
      this.log(`Loading model from singleton...`);
      return NativeEmbeddingReranker.#model;
    }

    try {
      const model =
        await NativeEmbeddingReranker.#transformers.AutoModelForSequenceClassification.from_pretrained(
          this.model,
          {
            progress_callback: (p) => {
              if (!this.modelDownloaded && p.status === "progress") {
                this.log(
                  `[${this.host}] Loading model ${this.model}... ${p?.progress}%`
                );
              }
            },
            cache_dir: this.cacheDir,
          }
        );
      this.log(`Loaded model ${this.model}`);
      NativeEmbeddingReranker.#model = model;
      return model;
    } catch (e) {
      this.log(
        `Failed to load model ${this.model} from ${this.host}.`,
        e.message,
        e.stack
      );
      if (
        NativeEmbeddingReranker.#transformers.env.remoteHost ===
        this.#fallbackHost
      ) {
        this.log(`Failed to load model ${this.model} from fallback host.`);
        throw e;
      }

      this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
      NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
      NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
      return await this.#getPreTrainedModel();
    }
  }

  /**
   * This function will load the tokenizer from the local file system cache, or download and cache it from the remote host.
   * If the tokenizer is not found in the local file system cache, it will download and cache it from the remote host.
   * If the tokenizer is not found in the remote host, it will download and cache it from the fallback host.
   * @returns {Promise<any>} The loaded tokenizer.
   */
  async #getPreTrainedTokenizer() {
    if (NativeEmbeddingReranker.#tokenizer) {
      this.log(`Loading tokenizer from singleton...`);
      return NativeEmbeddingReranker.#tokenizer;
    }

    try {
      const tokenizer =
        await NativeEmbeddingReranker.#transformers.AutoTokenizer.from_pretrained(
          this.model,
          {
            progress_callback: (p) => {
              if (!this.modelDownloaded && p.status === "progress") {
                this.log(
                  `[${this.host}] Loading tokenizer ${this.model}... ${p?.progress}%`
                );
              }
            },
            cache_dir: this.cacheDir,
          }
        );
      this.log(`Loaded tokenizer ${this.model}`);
      NativeEmbeddingReranker.#tokenizer = tokenizer;
      return tokenizer;
    } catch (e) {
      this.log(
        `Failed to load tokenizer ${this.model} from ${this.host}.`,
        e.message,
        e.stack
      );
      if (
        NativeEmbeddingReranker.#transformers.env.remoteHost ===
        this.#fallbackHost
      ) {
        this.log(`Failed to load tokenizer ${this.model} from fallback host.`);
        throw e;
      }

      this.log(`Falling back to fallback host. ${this.#fallbackHost}`);
      NativeEmbeddingReranker.#transformers.env.remoteHost = this.#fallbackHost;
      NativeEmbeddingReranker.#transformers.env.remotePathTemplate = "{model}/";
      return await this.#getPreTrainedTokenizer();
    }
  }

  /**
   * Reranks a list of documents based on the query.
   * @param {string} query - The query to rerank the documents against.
   * @param {{text: string}[]} documents - The list of document text snippets to rerank. Should be output from a vector search.
   * @param {Object} options - The options for the reranking.
   * @param {number} options.topK - The number of top documents to return.
   * @returns {Promise<any[]>} - The reranked list of documents.
   */
  async rerank(query, documents, options = { topK: 4 }) {
    await this.initClient();
    const model = NativeEmbeddingReranker.#model;
    const tokenizer = NativeEmbeddingReranker.#tokenizer;

    const start = Date.now();
    this.log(`Reranking ${documents.length} documents...`);
    const inputs = tokenizer(new Array(documents.length).fill(query), {
      text_pair: documents.map((doc) => doc.text),
      padding: true,
      truncation: true,
    });
    const { logits } = await model(inputs);
    const reranked = logits
      .sigmoid()
      .tolist()
      .map(([score], i) => ({
        rerank_corpus_id: i,
        rerank_score: score,
        ...documents[i],
      }))
      .sort((a, b) => b.rerank_score - a.rerank_score)
      .slice(0, options.topK);

    this.log(
      `Reranking ${documents.length} documents to top ${options.topK} took ${Date.now() - start}ms`
    );
    return reranked;
  }
}

module.exports = {
  NativeEmbeddingReranker,
};