Spaces:
Running
Running
File size: 4,210 Bytes
6912ad8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from abc import ABC, abstractmethod
from langchain_core.runnables.config import run_in_executor
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from collections.abc import Sequence
class SparseVector(BaseModel, extra="forbid"):
"""Sparse vector structure."""
indices: list[int] = Field(..., description="indices must be unique")
values: list[float] = Field(
..., description="values and indices must be the same length"
)
class SparseEmbeddings(ABC):
"""An interface for sparse embedding models to use with Qdrant."""
@abstractmethod
def embed_documents(self, texts: list[str]) -> list[SparseVector]:
"""Embed search docs."""
@abstractmethod
def embed_query(self, text: str) -> SparseVector:
"""Embed query text."""
async def aembed_documents(self, texts: list[str]) -> list[SparseVector]:
"""Asynchronous Embed search docs."""
return await run_in_executor(None, self.embed_documents, texts)
async def aembed_query(self, text: str) -> SparseVector:
"""Asynchronous Embed query text."""
return await run_in_executor(None, self.embed_query, text)
class FastEmbedSparse(SparseEmbeddings):
"""An interface for sparse embedding models to use with Qdrant."""
def __init__(
self,
model_name: str = "Qdrant/bm25",
batch_size: int = 256,
cache_dir: str | None = None,
threads: int | None = None,
providers: Sequence[Any] | None = None,
parallel: int | None = None,
**kwargs: Any,
) -> None:
"""Sparse encoder implementation using FastEmbed.
Uses [FastEmbed](https://qdrant.github.io/fastembed/) for sparse text
embeddings.
For a list of available models, see [the Qdrant docs](https://qdrant.github.io/fastembed/examples/Supported_Models/).
Args:
model_name (str): The name of the model to use.
batch_size (int): Batch size for encoding.
cache_dir (str, optional): The path to the model cache directory.\
Can also be set using the\
`FASTEMBED_CACHE_PATH` env variable.
threads (int, optional): The number of threads onnxruntime session can use.
providers (Sequence[Any], optional): List of ONNX execution providers.\
parallel (int, optional): If `>1`, data-parallel encoding will be used, r\
Recommended for encoding of large datasets.\
If `0`, use all available cores.\
If `None`, don't use data-parallel processing,\
use default onnxruntime threading instead.\
kwargs: Additional options to pass to `fastembed.SparseTextEmbedding`
Raises:
ValueError: If the `model_name` is not supported in `SparseTextEmbedding`.
"""
try:
from fastembed import ( # type: ignore[import-not-found] # noqa: PLC0415
SparseTextEmbedding,
)
except ImportError as err:
msg = (
"The 'fastembed' package is not installed. "
"Please install it with "
"`pip install fastembed` or `pip install fastembed-gpu`."
)
raise ValueError(msg) from err
self._batch_size = batch_size
self._parallel = parallel
self._model = SparseTextEmbedding(
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
providers=providers,
**kwargs,
)
def embed_documents(self, texts: list[str]) -> list[SparseVector]:
results = self._model.embed(
texts, batch_size=self._batch_size, parallel=self._parallel
)
return [
SparseVector(indices=result.indices.tolist(), values=result.values.tolist())
for result in results
]
def embed_query(self, text: str) -> SparseVector:
result = next(self._model.embed(text))
return SparseVector(
indices=result.indices.tolist(), values=result.values.tolist()
) |