|
|
from typing import Any, Optional |
|
|
|
|
|
import weave |
|
|
from langchain_qdrant.fastembed_sparse import FastEmbedSparse |
|
|
|
|
|
|
|
|
class SparseEmbeddings(weave.Model): |
|
|
"""Generate sparse embeddings for documents and queries using the FastEmbedSparse model. |
|
|
|
|
|
Attributes: |
|
|
model_name (str): The name of the sparse embedding model to use. |
|
|
model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model. |
|
|
sparse_embedding_model (FastEmbedSparse): The initialized FastEmbedSparse model with the specified parameters. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "prithvida/Splade_PP_en_v1", |
|
|
model_kwargs: Optional[dict[str, Any]] = None, |
|
|
): |
|
|
"""Initialize the SparseEmbeddings class with the specified model and configurations. |
|
|
|
|
|
Args: |
|
|
model_name (str): The name of the sparse embedding model. Defaults to "prithvida/Splade_PP_en_v1". |
|
|
model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. Defaults to None. |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.model_kwargs = model_kwargs if model_kwargs is not None else {} |
|
|
|
|
|
|
|
|
self.sparse_embedding_model = FastEmbedSparse(model_name=self.model_name, **self.model_kwargs) |
|
|
|
|
|
@weave.op() |
|
|
def embed_texts(self, texts: list[str]) -> list[dict[str, float]]: |
|
|
"""Embed a list of texts and return their sparse embeddings. |
|
|
|
|
|
Args: |
|
|
texts (list[str]): A list of document texts to embed. |
|
|
|
|
|
Returns: |
|
|
list[dict[str, float]]: A list of sparse embedding dictionaries for each document text. |
|
|
Each dictionary maps terms to their corresponding weights. |
|
|
""" |
|
|
return self.sparse_embedding_model.embed_documents(texts) |
|
|
|
|
|
@weave.op() |
|
|
def embed_query(self, text: str) -> dict[str, float]: |
|
|
"""Embed a single query text and return its sparse embedding. |
|
|
|
|
|
Args: |
|
|
text (str): The query text to embed. |
|
|
|
|
|
Returns: |
|
|
dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms |
|
|
and values are term weights. |
|
|
""" |
|
|
return self.sparse_embedding_model.embed_query(text) |
|
|
|