Spaces:
Paused
Paused
| import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'; | |
| import { chunkArray } from '@langchain/core/utils/chunk_array'; | |
| export interface HuggingFaceTransformersEmbeddingsParams | |
| extends EmbeddingsParams { | |
| modelName: string; | |
| model: string; | |
| timeout?: number; | |
| batchSize?: number; | |
| stripNewLines?: boolean; | |
| } | |
| export class HuggingFaceTransformersEmbeddings | |
| extends Embeddings | |
| implements HuggingFaceTransformersEmbeddingsParams | |
| { | |
| modelName = 'Xenova/all-MiniLM-L6-v2'; | |
| model = 'Xenova/all-MiniLM-L6-v2'; | |
| batchSize = 512; | |
| stripNewLines = true; | |
| timeout?: number; | |
| private pipelinePromise: Promise<any>; | |
| constructor(fields?: Partial<HuggingFaceTransformersEmbeddingsParams>) { | |
| super(fields ?? {}); | |
| this.modelName = fields?.model ?? fields?.modelName ?? this.model; | |
| this.model = this.modelName; | |
| this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; | |
| this.timeout = fields?.timeout; | |
| } | |
| async embedDocuments(texts: string[]): Promise<number[][]> { | |
| const batches = chunkArray( | |
| this.stripNewLines ? texts.map((t) => t.replace(/\n/g, ' ')) : texts, | |
| this.batchSize, | |
| ); | |
| const batchRequests = batches.map((batch) => this.runEmbedding(batch)); | |
| const batchResponses = await Promise.all(batchRequests); | |
| const embeddings: number[][] = []; | |
| for (let i = 0; i < batchResponses.length; i += 1) { | |
| const batchResponse = batchResponses[i]; | |
| for (let j = 0; j < batchResponse.length; j += 1) { | |
| embeddings.push(batchResponse[j]); | |
| } | |
| } | |
| return embeddings; | |
| } | |
| async embedQuery(text: string): Promise<number[]> { | |
| const data = await this.runEmbedding([ | |
| this.stripNewLines ? text.replace(/\n/g, ' ') : text, | |
| ]); | |
| return data[0]; | |
| } | |
| private async runEmbedding(texts: string[]) { | |
| const { pipeline } = await import('@xenova/transformers'); | |
| const pipe = await (this.pipelinePromise ??= pipeline( | |
| 'feature-extraction', | |
| this.model, | |
| )); | |
| return this.caller.call(async () => { | |
| const output = await pipe(texts, { pooling: 'mean', normalize: true }); | |
| return output.tolist(); | |
| }); | |
| } | |
| } | |