File size: 5,259 Bytes
c8e875f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Vector storage and retrieval implementation.
"""
import uuid
from typing import List, Any

from langchain_chroma                  import Chroma
from langchain.storage                 import InMemoryStore
from langchain.schema.document         import Document
from langchain_huggingface             import HuggingFaceEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

from src.config import EMBEDDING_MODEL, DEVICE, COLLECTION_NAME


class VectorStore:
    """Vector storage and retrieval implementation."""
    
    def __init__(self, collection_name: str = COLLECTION_NAME, embedding_model: str = EMBEDDING_MODEL):
        """
        Initialize the vector store.
        
        Args:
            collection_name (str): Name of the vector store collection
            embedding_model (str): Name of the embedding model to use
        """
        self.embedding_function = self._create_embedding_function(embedding_model)
        self.vector_store       = self._create_vector_store(collection_name)
        self.doc_store          = InMemoryStore()
        self.id_key             = 'doc_id'
        self.retriever          = self._create_retriever()


    def _create_embedding_function(self, model_name: str) -> HuggingFaceEmbeddings:
        """
        Create an embedding function.
        
        Args:
            model_name (str): Name of the embedding model
            
        Returns:
            HuggingFaceEmbeddings: The embedding function
        """
        return HuggingFaceEmbeddings(
            model_name    = model_name,
            model_kwargs  = {'device': DEVICE},
            encode_kwargs = {'normalize_embeddings': True} # Change this if use an already normalized model
        )
    
    
    def _create_vector_store(self, collection_name: str) -> Chroma:
        """
        Create a vector store.
        
        Args:
            collection_name (str): Name of the vector store collection
            
        Returns:
            Chroma: The vector store
        """
        return Chroma(
            collection_name    = collection_name,
            embedding_function = self.embedding_function,
        )
    
    
    def _create_retriever(self) -> MultiVectorRetriever:
        """
        Create a multi-vector retriever.
        
        Returns:
            MultiVectorRetriever: The retriever
        """
        return MultiVectorRetriever(
            vectorstore = self.vector_store,
            docstore    = self.doc_store,
            id_key      = self.id_key,
        )
    
    
    def add_to_retriever(self, data: List[Any], data_summaries: List[str]) -> None:
        """
        Add data and summaries to the retriever.
        
        Args:
            data           (List[Any]): List of data elements
            data_summaries (List[str]): List of data summaries
        """
        if not data:
            return

        if len(data) != len(data_summaries):
            raise ValueError(f"Length mismatch: {len(data)} data but {len(data_summaries)} summaries")
    
        ids = [str(uuid.uuid4()) for _ in range(len(data))]
        
        summaries = [
            Document(
                page_content = f"passage: {summary}", # Change this to suit with model requirements if use a different model
                metadata     = {self.id_key: i}
            )
            for i, summary in zip(ids, data_summaries)
        ]
        
        self.retriever.vectorstore.add_documents(summaries)
        self.retriever.docstore.mset(list(zip(ids, data)))
        
        
    def add_contents(self, 
                     texts : List[Any], text_summaries : List[str],
                     tables: List[Any], table_summaries: List[str],
                     images: List[Any], image_summaries: List[str]) -> None:
        """
        Add all content types and their summaries to the retriever.
        
        Args:
            texts           (List[Any]): List of text elements
            text_summaries  (List[str]): List of text summaries
            tables          (List[Any]): List of table elements
            table_summaries (List[str]): List of table summaries
            images          (List[Any]): List of image elements
            image_summaries (List[str]): List of image summaries
        """
        self.add_to_retriever(texts , text_summaries)
        self.add_to_retriever(tables, table_summaries)
        self.add_to_retriever(images, image_summaries)
        
        
    def reset(self) -> None:
        """Reset the vector store and document store."""
        try:
            self.vector_store.reset_collection()
        except Exception as e:
            raise RuntimeError(f"Failed to reset vector store: {e}")
        
        # self.vector_store = self._create_vector_store(COLLECTION_NAME)
        self.doc_store    = InMemoryStore()
        self.retriever    = self._create_retriever()
        
        
    def retrieve(self, query: str) -> List[Any]:
        """
        Retrieve relevant documents for a query.
        
        Args:
            query (str): The query string
            
        Returns:
            List[Any]: List of retrieved documents
        """
        return self.retriever.invoke(query)