File size: 2,183 Bytes
336f4a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Any, Optional

import weave
from pinecone_text.sparse import SpladeEncoder


class SparseEmbeddingsSplade(weave.Model):
    """Generate sparse embeddings for documents and queries using the FastEmbedSparse model.

    Attributes:
        model_kwargs (Optional[dict[str, Any]]): Additional configuration parameters for the model.
        sparse_embedding_model (SpladeEncoder): The FastEmbedSparse model initialized with the specified parameters.
    """

    model_kwargs: Optional[dict[str, Any]]
    sparse_embedding_model: Optional[SpladeEncoder] = None

    def __init__(
        self,
        model_kwargs: Optional[dict[str, Any]] = None,
    ):
        """Initialize the SparseEmbeddings class with the specified model and configurations.

        Args:
            model_kwargs (Optional[dict[str, Any]]): Additional model configuration parameters for initialization.
        """
        super().__init__(model_kwargs=model_kwargs)

        self.model_kwargs = model_kwargs if model_kwargs is not None else {}

        # Initialize the sparse embedding model with specified parameters
        self.sparse_embedding_model = SpladeEncoder(**self.model_kwargs)

    @weave.op()
    def embed_texts(self, texts: list[str]) -> list[dict[str, float]]:
        """Embed a list of texts and return their sparse embeddings.

        Args:
            texts (list[str]): A list of document texts to embed.

        Returns:
            list[dict[str, float]]: A list of sparse embedding dictionaries for each document text.
                                    Each dictionary maps terms to their corresponding weights.
        """
        return self.sparse_embedding_model.encode_documents(texts)

    @weave.op()
    def embed_query(self, text: str) -> dict[str, float]:
        """Embed a single query text and return its sparse embedding.

        Args:
            text (str): The query text to embed.

        Returns:
            dict[str, float]: A sparse embedding dictionary for the query text, where keys are terms
                                and values are term weights.
        """
        return self.sparse_embedding_model.encode_queries([text])