| import { dot } from "@xenova/transformers"; |
| import type { EmbeddingBackendModel } from "$lib/server/embeddingModels"; |
| import type { Embedding } from "$lib/server/embeddingEndpoints/embeddingEndpoints"; |
|
|
| |
| function innerProduct(embeddingA: Embedding, embeddingB: Embedding) { |
| return 1.0 - dot(embeddingA, embeddingB); |
| } |
|
|
| export async function findSimilarSentences( |
| embeddingModel: EmbeddingBackendModel, |
| query: string, |
| sentences: string[], |
| { topK = 5 }: { topK: number } |
| ): Promise<Embedding> { |
| const inputs = [ |
| `${embeddingModel.preQuery}${query}`, |
| ...sentences.map((sentence) => `${embeddingModel.prePassage}${sentence}`), |
| ]; |
|
|
| const embeddingEndpoint = await embeddingModel.getEndpoint(); |
| const output = await embeddingEndpoint({ inputs }); |
|
|
| const queryEmbedding: Embedding = output[0]; |
| const sentencesEmbeddings: Embedding[] = output.slice(1); |
|
|
| const distancesFromQuery: { distance: number; index: number }[] = [...sentencesEmbeddings].map( |
| (sentenceEmbedding: Embedding, index: number) => { |
| return { |
| distance: innerProduct(queryEmbedding, sentenceEmbedding), |
| index, |
| }; |
| } |
| ); |
|
|
| distancesFromQuery.sort((a, b) => { |
| return a.distance - b.distance; |
| }); |
|
|
| |
| return distancesFromQuery.slice(0, topK).map((item) => item.index); |
| } |
|
|