File size: 7,910 Bytes
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f06d5d
916dea4
 
6f06d5d
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f06d5d
916dea4
 
6f06d5d
 
 
 
916dea4
 
 
6f06d5d
916dea4
6f06d5d
 
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f06d5d
 
 
 
 
 
 
916dea4
 
 
 
 
6f06d5d
916dea4
 
 
 
 
 
 
 
 
 
6f06d5d
 
 
 
 
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f06d5d
 
 
 
 
 
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f06d5d
916dea4
 
 
 
 
 
 
 
6f06d5d
916dea4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""LangChain vector store extensions for document-qa.

Extends ChromaDB with support for returning similarity scores **and**
raw embedding vectors alongside retrieved documents.  This enables
the Streamlit frontend to compute relevance gradients and the
``question_coefficient`` analysis mode.

"""

from typing import Any, Optional, List, Dict, Tuple, ClassVar, Collection

from langchain.schema import Document
from langchain_community.vectorstores.chroma import Chroma, DEFAULT_K
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever


class AdvancedVectorStoreRetriever(VectorStoreRetriever):
    """Retriever that can enrich documents with similarity scores and embeddings.

    Extends LangChain's ``VectorStoreRetriever`` with a
    ``"similarity_with_embeddings"`` search type.  When used, each
    returned document's ``metadata`` dict gains ``__similarity`` (float)
    and ``__embeddings`` (list[float]) keys.
    """

    allowed_search_types: ClassVar[Collection[str]] = (
        "similarity",
        "similarity_score_threshold",
        "mmr",
        "similarity_with_embeddings",
    )

    def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> List[Document]:
        """Fetch relevant documents for the configured search type.

        Supports all standard search types plus
        ``"similarity_with_embeddings"`` which attaches score and
        embedding vector metadata to each document.

        Args:
            query: The search query string.
            run_manager: LangChain callback manager.

        Returns:
            list[Document]: Retrieved documents, optionally enriched
            with similarity scores and embeddings.
        """

        if self.search_type == "similarity_with_embeddings":
            docs_scores_and_embeddings = self.vectorstore.advanced_similarity_search(query, **self.search_kwargs)

            for doc, score, embeddings in docs_scores_and_embeddings:
                if "__embeddings" not in doc.metadata.keys():
                    doc.metadata["__embeddings"] = embeddings
                if "__similarity" not in doc.metadata.keys():
                    doc.metadata["__similarity"] = score

            docs = [doc for doc, _, _ in docs_scores_and_embeddings]
        elif self.search_type == "similarity_score_threshold":
            docs_and_similarities = self.vectorstore.similarity_search_with_relevance_scores(query, **self.search_kwargs)
            for doc, similarity in docs_and_similarities:
                if "__similarity" not in doc.metadata.keys():
                    doc.metadata["__similarity"] = similarity

            docs = [doc for doc, _ in docs_and_similarities]
        else:
            docs = super()._get_relevant_documents(query, run_manager=run_manager)

        return docs


class AdvancedVectorStore(VectorStore):
    """
    Extension of LangChain's VectorStore that returns a custom retriever
    supporting advanced search features.
    """

    def as_retriever(self, **kwargs: Any) -> AdvancedVectorStoreRetriever:
        """Create a retriever supporting ``similarity_with_embeddings``.

        Accepts the same keyword arguments as the base ``as_retriever``.
        """
        tags = kwargs.pop("tags", None) or []
        tags.extend(self._get_retriever_tags())
        return AdvancedVectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)


class ChromaAdvancedRetrieval(Chroma, AdvancedVectorStore):
    """Chroma vector store with support for embeddings + similarity scores.

    Extends the standard LangChain ``Chroma`` store with
    `advanced_similarity_search` which returns ``(Document, score,
    embedding)`` triples.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @xor_args(("query_texts", "query_embeddings"))
    def __query_collection(
        self,
        query_texts: Optional[List[str]] = None,
        query_embeddings: Optional[List[List[float]]] = None,
        n_results: int = 4,
        where: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Document]:
        """Query the chroma collection."""
        try:
            import chromadb  # noqa: F401
        except ImportError:
            raise ValueError("Could not import chromadb python package. Please install it with `pip install chromadb`.")
        return self._collection.query(
            query_texts=query_texts,
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document,
            **kwargs,
        )

    def advanced_similarity_search(
        self,
        query: str,
        k: int = DEFAULT_K,
        filter: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float, List[float]]]:
        """Return documents, similarity scores, and embeddings for *query*.

        Args:
            query: The search query.
            k: Number of results to return.
            filter: Optional Chroma metadata filter.

        Returns:
            list[tuple[Document, float, list[float]]]: Triples of
            (document, distance, embedding_vector).
        """
        docs_scores_and_embeddings = self.similarity_search_with_scores_and_embeddings(query, k, filter=filter)
        return docs_scores_and_embeddings

    def similarity_search_with_scores_and_embeddings(
        self,
        query: str,
        k: int = DEFAULT_K,
        filter: Optional[Dict[str, str]] = None,
        where_document: Optional[Dict[str, str]] = None,
        **kwargs: Any,
    ) -> List[Tuple[Document, float, List[float]]]:
        """Low-level search returning docs with scores and embeddings.

        Queries the Chroma collection requesting ``distances`` and
        ``embeddings`` in addition to the usual documents and metadata.

        Args:
            query: The search query.
            k: Number of results.
            filter: Optional metadata filter.
            where_document: Optional document-content filter.

        Returns:
            list[tuple[Document, float, list[float]]]: Triples of
            (document, distance, embedding_vector).
        """

        if self._embedding_function is None:
            results = self.__query_collection(
                query_texts=[query],
                n_results=k,
                where=filter,
                where_document=where_document,
                include=["metadatas", "documents", "embeddings", "distances"],
            )
        else:
            query_embedding = self._embedding_function.embed_query(query)
            results = self.__query_collection(
                query_embeddings=[query_embedding],
                n_results=k,
                where=filter,
                where_document=where_document,
                include=["metadatas", "documents", "embeddings", "distances"],
            )

        return _results_to_docs_scores_and_embeddings(results)


def _results_to_docs_scores_and_embeddings(results: Any) -> List[Tuple[Document, float, List[float]]]:
    """Unpack raw Chroma query results into ``(Document, score, embedding)`` tuples.

    Args:
        results: Dict returned by ``Collection.query()`` with
            ``include=['documents', 'metadatas', 'distances', 'embeddings']``.

    Returns:
        list[tuple[Document, float, list[float]]]: One tuple per result.
    """
    return [
        (Document(page_content=result[0], metadata=result[1] or {}), result[2], result[3])
        for result in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0],
            results["embeddings"][0],
        )
    ]