| | import { config } from "$lib/server/config"; |
| |
|
| | import { z } from "zod"; |
| | import { sum } from "$lib/utils/sum"; |
| | import { |
| | embeddingEndpoints, |
| | embeddingEndpointSchema, |
| | type EmbeddingEndpoint, |
| | } from "$lib/server/embeddingEndpoints/embeddingEndpoints"; |
| | import { embeddingEndpointTransformersJS } from "$lib/server/embeddingEndpoints/transformersjs/embeddingEndpoints"; |
| |
|
| | import JSON5 from "json5"; |
| |
|
| | const modelConfig = z.object({ |
| | |
| | id: z.string().optional(), |
| | |
| | name: z.string().min(1), |
| | displayName: z.string().min(1).optional(), |
| | description: z.string().min(1).optional(), |
| | websiteUrl: z.string().url().optional(), |
| | modelUrl: z.string().url().optional(), |
| | endpoints: z.array(embeddingEndpointSchema).nonempty(), |
| | chunkCharLength: z.number().positive(), |
| | maxBatchSize: z.number().positive().optional(), |
| | preQuery: z.string().default(""), |
| | prePassage: z.string().default(""), |
| | }); |
| |
|
| | |
| | const rawEmbeddingModelJSON = |
| | config.TEXT_EMBEDDING_MODELS || |
| | `[ |
| | { |
| | "name": "Xenova/gte-small", |
| | "chunkCharLength": 512, |
| | "endpoints": [ |
| | { "type": "transformersjs" } |
| | ] |
| | } |
| | ]`; |
| |
|
| | const embeddingModelsRaw = z.array(modelConfig).parse(JSON5.parse(rawEmbeddingModelJSON)); |
| |
|
| | const processEmbeddingModel = async (m: z.infer<typeof modelConfig>) => ({ |
| | ...m, |
| | id: m.id || m.name, |
| | }); |
| |
|
| | const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({ |
| | ...m, |
| | getEndpoint: async (): Promise<EmbeddingEndpoint> => { |
| | if (!m.endpoints) { |
| | return embeddingEndpointTransformersJS({ |
| | type: "transformersjs", |
| | weight: 1, |
| | model: m, |
| | }); |
| | } |
| |
|
| | const totalWeight = sum(m.endpoints.map((e) => e.weight)); |
| |
|
| | let random = Math.random() * totalWeight; |
| |
|
| | for (const endpoint of m.endpoints) { |
| | if (random < endpoint.weight) { |
| | const args = { ...endpoint, model: m }; |
| |
|
| | switch (args.type) { |
| | case "tei": |
| | return embeddingEndpoints.tei(args); |
| | case "transformersjs": |
| | return embeddingEndpoints.transformersjs(args); |
| | case "openai": |
| | return embeddingEndpoints.openai(args); |
| | case "hfapi": |
| | return embeddingEndpoints.hfapi(args); |
| | default: |
| | throw new Error(`Unknown endpoint type: ${args}`); |
| | } |
| | } |
| |
|
| | random -= endpoint.weight; |
| | } |
| |
|
| | throw new Error(`Failed to select embedding endpoint`); |
| | }, |
| | }); |
| |
|
| | export const embeddingModels = await Promise.all( |
| | embeddingModelsRaw.map((e) => processEmbeddingModel(e).then(addEndpoint)) |
| | ); |
| |
|
| | export const defaultEmbeddingModel = embeddingModels[0]; |
| |
|
| | const validateEmbeddingModel = (_models: EmbeddingBackendModel[], key: "id" | "name") => { |
| | return z.enum([_models[0][key], ..._models.slice(1).map((m) => m[key])]); |
| | }; |
| |
|
| | export const validateEmbeddingModelById = (_models: EmbeddingBackendModel[]) => { |
| | return validateEmbeddingModel(_models, "id"); |
| | }; |
| |
|
| | export const validateEmbeddingModelByName = (_models: EmbeddingBackendModel[]) => { |
| | return validateEmbeddingModel(_models, "name"); |
| | }; |
| |
|
| | export type EmbeddingBackendModel = typeof defaultEmbeddingModel; |
| |
|