test-ragp / src /rag_pipelines /embeddings /sparse_pinecone_text.py
awinml's picture
Upload 107 files
336f4a9 verified
from typing import Any, Optional
import weave
from pinecone_text.sparse import SpladeEncoder
class SparseEmbeddingsSplade(weave.Model):
"""Generate sparse embeddings for documents and queries using the FastEmbedSparse model.
Attributes:
model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model.
sparse_embedding_model (SpladeEncoder): The FastEmbedSparse model initialized with the specified parameters.
"""
model_kwargs: Optional[dict[str, Any]]
sparse_embedding_model: Optional[SpladeEncoder] = None
def __init__(
self,
model_kwargs: Optional[dict[str, Any]] = None,
):
"""Initialize the SparseEmbeddings class with the specified model and configurations.
Args:
model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization.
"""
super().__init__(model_kwargs=model_kwargs)
self.model_kwargs = model_kwargs if model_kwargs is not None else {}
# Initialize the sparse embedding model with specified parameters
self.sparse_embedding_model = SpladeEncoder(**self.model_kwargs)
@weave.op()
def embed_texts(self, texts: list[str]) -> list[dict[str, float]]:
"""Embed a list of texts and return their sparse embeddings.
Args:
texts (list[str]): A list of document texts to embed.
Returns:
list[dict[str, float]]: A list of sparse embedding dictionaries for each document text.
Each dictionary maps terms to their corresponding weights.
"""
return self.sparse_embedding_model.encode_documents(texts)
@weave.op()
def embed_query(self, text: str) -> dict[str, float]:
"""Embed a single query text and return its sparse embedding.
Args:
text (str): The query text to embed.
Returns:
dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms
and values are term weights.
"""
return self.sparse_embedding_model.encode_queries([text])