File size: 8,285 Bytes
09dc9d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import pickle
from typing import List, Optional

from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document


def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
    """

    Load the vectorstore and documents from disk.

    Args:

        vectorstore_dir: The directory to load the vectorstore from.

        embed_model: The embedding model to use.

    Returns:

        vector_store: The vectorstore.

    """
    from langchain_community.vectorstores import FAISS

    if not os.path.isdir(vectorstore_dir):
        print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
        os.makedirs(vectorstore_dir, exist_ok=True)
        
    try:
        vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
        
        docs_path = os.path.join(vectorstore_dir, "docs.pkl")
        if os.path.exists(docs_path):
            with open(docs_path, "rb") as f:
                docs = pickle.load(f)
        else:
            docs = None 
            print("Warning: docs.pkl not found. BM25 search will not be available.")

        print(f"Successfully loaded RAG state from {vectorstore_dir}")
        return vector_store, docs
    except Exception as e:
        print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
        return None, None

def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
    """

    Save the vectorstore and documents to disk.

    Args:

        vectorstore_dir: The directory to save the vectorstore to.

        vectorstore: The vectorstore to save.

        docs: The documents to save.

    """
    if vectorstore is None:
        raise ValueError("Nothing to save.")
    if docs is None:
        print("Warning: No documents to save. BM25 search will not be available.")
    
    os.makedirs(vectorstore_dir, exist_ok=True)
    vectorstore.save_local(vectorstore_dir)
    
    if docs is not None:
        with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
            pickle.dump(docs, f)

    print(f"Successfully saved RAG state to {vectorstore_dir}")

def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
    """

    Load the QA dataset. (jsonl)

    Args:

        qa_dataset_path: The path to the QA dataset.

    Returns:

        Tuple: (ids, questions, options, answers)\\

        ids: The ids of the questions\\

        questions: The questions\\

        options: The options for each question\\

        answers: The answers for each question.

    """
    import json
    if not os.path.exists(qa_dataset_path):
        raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
    
    with open(qa_dataset_path, "r", encoding="utf-8") as f:
        data = [json.loads(line) for line in f]
    questions = [item["question"] for item in data]
    try:
        options = [
            (f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
            (f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
            (f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
            (f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
            (f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
            for item in data]
    except KeyError:
        options = [" " for item in data]
    answers = [item["answer"] for item in data]
    uuids = [item["uuid"] for item in data]
    return uuids, questions, options, answers

def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
    """

    Load the prepared retrieve docs from a file.

    Args:

        prepared_retrieve_docs_path: The path to the prepared retrieve docs.

    Returns:

        A list of lists of documents.

    """
    return safe_load_langchain_docs(prepared_retrieve_docs_path)

def paralelize(func, max_workers: int = 4, **kwargs) -> List:
    """

    Parallelizes a function call over multiple keyword argument iterables.



    Args:

        func: The function to execute in parallel.

        max_workers: The maximum number of threads to use.

        **kwargs: Keyword arguments where each value is an iterable (e.g., a list).

                  All iterables must be of the same length.

                  The keyword names do not matter, but their order does.

    Returns:

        A list of the results of the function calls.

    """
    from concurrent.futures import ThreadPoolExecutor
    from tqdm import tqdm

    if not kwargs:
        return []

    arg_lists = list(kwargs.values())
    if len(set(len(lst) for lst in arg_lists)) > 1:
        raise ValueError("All iterable arguments must have the same length.")
        
    total_items = len(arg_lists[0])
    iterable = zip(*arg_lists)
    unpacker_func = lambda args_tuple: func(*args_tuple)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
    return results

def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
    """

    Converts LangChain Document objects into a serializable list of dictionaries

    and saves them to a file using pickle.



    Args:

        documents (List[List[Document]]): The nested list of LangChain Documents.

        filepath (str): The path to the file where the data will be saved.

    """
    serializable_data = []
    print(f"Preparing to save {len(documents)} lists of documents...")
    
    # Convert each Document object into a dictionary
    for doc_list in documents:
        serializable_doc_list = []
        for doc in doc_list:
            serializable_doc_list.append({
                "page_content": doc.page_content,
                "metadata": doc.metadata,
            })
        serializable_data.append(serializable_doc_list)

    print(f"Conversion complete. Saving to {filepath}...")
    try:
        # Use 'with' to ensure the file is closed properly, even if errors occur
        with open(filepath, "wb") as f:
            pickle.dump(serializable_data, f)
        print("File saved successfully.")
    except Exception as e:
        print(f"An error occurred while saving the file: {e}")

def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
    """

    Loads data from a pickle file and reconstructs the LangChain Document objects.



    Args:

        filepath (str): The path to the file to load.



    Returns:

        List[List[Document]]: The reconstructed nested list of LangChain Documents.

    """
    reconstructed_documents = []
    
    print(f"Loading data from {filepath}...")
    try:
        with open(filepath, "rb") as f:
            loaded_data = pickle.load(f)
        print("File loaded successfully. Reconstructing Document objects...")

        # Reconstruct the Document objects from the dictionaries
        for doc_list_data in loaded_data:
            reconstructed_doc_list = []
            for doc_data in doc_list_data:
                reconstructed_doc_list.append(
                    Document(
                        page_content=doc_data["page_content"],
                        metadata=doc_data["metadata"]
                    )
                )
            reconstructed_documents.append(reconstructed_doc_list)
        
        print("Document objects reconstructed successfully.")
        return reconstructed_documents

    except FileNotFoundError:
        print(f"Error: The file at {filepath} was not found.")
        return []
    except EOFError:
        print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
        print("Please re-run the script that generates this file.")
        return []
    except Exception as e:
        print(f"An unexpected error occurred while loading the file: {e}")
        return []