| import type { Tensor, Pipeline } from "@xenova/transformers"; |
| import { pipeline, dot } from "@xenova/transformers"; |
|
|
| |
| function innerProduct(tensor1: Tensor, tensor2: Tensor) { |
| return 1.0 - dot(tensor1.data, tensor2.data); |
| } |
|
|
| |
| class PipelineSingleton { |
| static modelId = "Xenova/gte-small"; |
| static instance: Promise<Pipeline> | null = null; |
| static async getInstance() { |
| if (this.instance === null) { |
| this.instance = pipeline("feature-extraction", this.modelId); |
| } |
| return this.instance; |
| } |
| } |
|
|
| |
| export const MAX_SEQ_LEN = 512 as const; |
|
|
| export async function findSimilarSentences( |
| query: string, |
| sentences: string[], |
| { topK = 5 }: { topK: number } |
| ) { |
| const input = [query, ...sentences]; |
|
|
| const extractor = await PipelineSingleton.getInstance(); |
| const output: Tensor = await extractor(input, { pooling: "mean", normalize: true }); |
|
|
| const queryTensor: Tensor = output[0]; |
| const sentencesTensor: Tensor = output.slice([1, input.length - 1]); |
|
|
| const distancesFromQuery: { distance: number; index: number }[] = [...sentencesTensor].map( |
| (sentenceTensor: Tensor, index: number) => { |
| return { |
| distance: innerProduct(queryTensor, sentenceTensor), |
| index: index, |
| }; |
| } |
| ); |
|
|
| distancesFromQuery.sort((a, b) => { |
| return a.distance - b.distance; |
| }); |
|
|
| |
| return distancesFromQuery.slice(0, topK).map((item) => item.index); |
| } |
|
|