| | import { z } from "zod"; |
| | import type { EmbeddingEndpoint } from "../embeddingEndpoints"; |
| | import type { Tensor, FeatureExtractionPipeline } from "@huggingface/transformers"; |
| | import { pipeline } from "@huggingface/transformers"; |
| |
|
| | export const embeddingEndpointTransformersJSParametersSchema = z.object({ |
| | weight: z.number().int().positive().default(1), |
| | model: z.any(), |
| | type: z.literal("transformersjs"), |
| | }); |
| |
|
| | |
| | class TransformersJSModelsSingleton { |
| | static instances: Array<[string, Promise<FeatureExtractionPipeline>]> = []; |
| |
|
| | static async getInstance(modelName: string): Promise<FeatureExtractionPipeline> { |
| | const modelPipelineInstance = this.instances.find(([name]) => name === modelName); |
| |
|
| | if (modelPipelineInstance) { |
| | const [, modelPipeline] = modelPipelineInstance; |
| | |
| | await (await modelPipeline).dispose(); |
| | this.instances = this.instances.filter(([name]) => name !== modelName); |
| | } |
| | const newModelPipeline = pipeline("feature-extraction", modelName); |
| | this.instances.push([modelName, newModelPipeline]); |
| |
|
| | return newModelPipeline; |
| | } |
| | } |
| |
|
| | export async function calculateEmbedding(modelName: string, inputs: string[]) { |
| | const extractor = await TransformersJSModelsSingleton.getInstance(modelName); |
| | const output: Tensor = await extractor(inputs, { pooling: "mean", normalize: true }); |
| |
|
| | return output.tolist(); |
| | } |
| |
|
| | export function embeddingEndpointTransformersJS( |
| | input: z.input<typeof embeddingEndpointTransformersJSParametersSchema> |
| | ): EmbeddingEndpoint { |
| | const { model } = embeddingEndpointTransformersJSParametersSchema.parse(input); |
| |
|
| | return async ({ inputs }) => { |
| | return calculateEmbedding(model.name, inputs); |
| | }; |
| | } |
| |
|