File size: 2,356 Bytes
336f4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Any, Optional

import weave
from langchain_qdrant.fastembed_sparse import FastEmbedSparse


class SparseEmbeddings(weave.Model):
    """Generate sparse embeddings for documents and queries using the FastEmbedSparse model.

    Attributes:
        model_name (str): The name of the sparse embedding model to use.
        model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model.
        sparse_embedding_model (FastEmbedSparse): The initialized FastEmbedSparse model with the specified parameters.
    """

    def __init__(
        self,
        model_name: str = "prithvida/Splade_PP_en_v1",
        model_kwargs: Optional[dict[str, Any]] = None,
    ):
        """Initialize the SparseEmbeddings class with the specified model and configurations.

        Args:
            model_name (str): The name of the sparse embedding model. Defaults to "prithvida/Splade_PP_en_v1".
            model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization. Defaults to None.
        """
        self.model_name = model_name
        self.model_kwargs = model_kwargs if model_kwargs is not None else {}

        # Initialize the sparse embedding model with specified parameters
        self.sparse_embedding_model = FastEmbedSparse(model_name=self.model_name, **self.model_kwargs)

    @weave.op()
    def embed_texts(self, texts: list[str]) -> list[dict[str, float]]:
        """Embed a list of texts and return their sparse embeddings.

        Args:
            texts (list[str]): A list of document texts to embed.

        Returns:
            list[dict[str, float]]: A list of sparse embedding dictionaries for each document text.
                                    Each dictionary maps terms to their corresponding weights.
        """
        return self.sparse_embedding_model.embed_documents(texts)

    @weave.op()
    def embed_query(self, text: str) -> dict[str, float]:
        """Embed a single query text and return its sparse embedding.

        Args:
            text (str): The query text to embed.

        Returns:
            dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms
                                and values are term weights.
        """
        return self.sparse_embedding_model.embed_query(text)