File size: 4,970 Bytes
570b60c
 
 
 
 
5b68ef9
570b60c
 
 
 
 
 
 
 
 
9d1b8d4
 
570b60c
 
9d1b8d4
 
 
 
 
 
 
570b60c
 
9d1b8d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b68ef9
570b60c
9d1b8d4
570b60c
 
9d1b8d4
 
 
 
 
 
 
570b60c
 
5b68ef9
570b60c
 
 
 
5b68ef9
570b60c
 
5b68ef9
570b60c
 
 
 
 
 
5b68ef9
570b60c
 
 
9d1b8d4
 
570b60c
 
5b68ef9
570b60c
9d1b8d4
 
 
570b60c
9d1b8d4
570b60c
 
5b68ef9
 
9d1b8d4
 
 
570b60c
9d1b8d4
 
570b60c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from langchain_community.vectorstores import FAISS
# from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
from huggingface_hub import snapshot_download
import logging
log = logging.getLogger(__name__)

from termcolor import cprint


class VectorStore:
    def __init__(self, 
                 embeddings_model: str, 
                 vs_local_path: str = None, 
                 vs_hf_path: str = None, 

                 # Retrieval parameters
                 number_of_contexts: int = 2,
                 embedding_score_threshold: float = None, 
                 
                 # Context formatting parameters
                 context_fmt: str = "Context document {num_document}:\n{document_content}",
                 join_str: str = "\n\n",
                 header_context_str: str = "",
                 footer_context_str: str = "",
                 no_context_str: str = "Answer 'no relevant context found'.",
                 ):
        
        """Initializes the VectorStore with the given parameters and loads the vectorstore from the specified path.

        Arguments:
        ----------
        embeddings_model : str 
            The name of the HuggingFace embeddings model to use.
        vs_local_path : str, optional
            Local path to the vectorstore. Defaults to None.    
        vs_hf_path : str, optional
            HuggingFace Hub path to the vectorstore. Defaults to None.
        number_of_contexts : int, optional
            Number of top similar contexts to retrieve. Defaults to 2.
        embedding_score_threshold : float, optional
            Minimum similarity score threshold for retrieved documents. Defaults to None.
        context_fmt : str, optional
            Template to format each retrieved document. 
            Use only {document_content} or both {num_document} and {document_content} placeholders. 
            Defaults to "Context document {num_document}:\n{document_content}".
        join_str : str, optional
            String to join multiple retrieved documents. Defaults to "\n\n".
        no_context_str : str, optional
            String to return if no documents are retrieved. Defaults to "No relevant context found.".
        header_context_str : str, optional
            String to prepend to the final context. 
            Defaults to "The following is the context to help you answer the question (sorted from most to least relevant):\n\n".
        footer_context_str : str, optional
            String to append to the final context. 
            Defaults to "\n\nAnswer based only on the above context.".
        """
        
        log.info("Loading vectorstore...")

        # Retrieval parameters
        self.number_of_contexts = number_of_contexts
        self.embedding_score_threshold = embedding_score_threshold
        
        # Context formatting parameters
        self.context_fmt = context_fmt
        self.join_str = join_str
        self.header_context_str = header_context_str
        self.footer_context_str = footer_context_str
        self.no_context_str = no_context_str

        embeddings = HuggingFaceEmbeddings(model_name=embeddings_model)
        log.info(f"Loaded embeddings model: {embeddings_model}")

        if vs_hf_path:
            hf_vectorstore = snapshot_download(repo_id=vs_hf_path)
            self.vdb = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
            log.info(f"Loaded vectorstore from {vs_hf_path}")
        else:
            self.vdb = FAISS.load_local(vs_local_path, embeddings, allow_dangerous_deserialization=True)
            log.info(f"Loaded vectorstore from {vs_local_path}")

    
    def get_context(self, query,):

        # Retrieve documents
        results = self.vdb.similarity_search_with_relevance_scores(query=query, k=self.number_of_contexts, score_threshold=self.embedding_score_threshold)
        log.info(f"Retrieved {len(results)} documents from the vectorstore.")

        # Return formatted context
        return self._beautiful_context(results)
    
    
    def _beautiful_context(self, docs):
        
        log.info(f"Formatting {len(docs)} contexts...")

        # If no documents are retrieved, return the no_context_str
        if not docs:
            return self.no_context_str
        
        contexts = []
        for i, doc in enumerate(docs):
            
            log.info(f"Document {i+1} (score: {doc[1]:.4f}): {repr(doc[0].page_content[:100])}...")
            
            # Format each context document using the provided template
            context = self.context_fmt.format(num_document=i + 1, document_content=doc[0].page_content)
            contexts.append(context)

        # Join all contexts into a single string and add header and footer
        context = self.header_context_str + self.join_str.join(contexts) + self.footer_context_str
        

        return context