Spaces:
Configuration error
Configuration error
File size: 8,285 Bytes
09dc9d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import os
import pickle
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
"""
Load the vectorstore and documents from disk.
Args:
vectorstore_dir: The directory to load the vectorstore from.
embed_model: The embedding model to use.
Returns:
vector_store: The vectorstore.
"""
from langchain_community.vectorstores import FAISS
if not os.path.isdir(vectorstore_dir):
print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
os.makedirs(vectorstore_dir, exist_ok=True)
try:
vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
docs_path = os.path.join(vectorstore_dir, "docs.pkl")
if os.path.exists(docs_path):
with open(docs_path, "rb") as f:
docs = pickle.load(f)
else:
docs = None
print("Warning: docs.pkl not found. BM25 search will not be available.")
print(f"Successfully loaded RAG state from {vectorstore_dir}")
return vector_store, docs
except Exception as e:
print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
return None, None
def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
"""
Save the vectorstore and documents to disk.
Args:
vectorstore_dir: The directory to save the vectorstore to.
vectorstore: The vectorstore to save.
docs: The documents to save.
"""
if vectorstore is None:
raise ValueError("Nothing to save.")
if docs is None:
print("Warning: No documents to save. BM25 search will not be available.")
os.makedirs(vectorstore_dir, exist_ok=True)
vectorstore.save_local(vectorstore_dir)
if docs is not None:
with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
pickle.dump(docs, f)
print(f"Successfully saved RAG state to {vectorstore_dir}")
def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
"""
Load the QA dataset. (jsonl)
Args:
qa_dataset_path: The path to the QA dataset.
Returns:
Tuple: (ids, questions, options, answers)\\
ids: The ids of the questions\\
questions: The questions\\
options: The options for each question\\
answers: The answers for each question.
"""
import json
if not os.path.exists(qa_dataset_path):
raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
with open(qa_dataset_path, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
questions = [item["question"] for item in data]
try:
options = [
(f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
(f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
(f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
(f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
(f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
for item in data]
except KeyError:
options = [" " for item in data]
answers = [item["answer"] for item in data]
uuids = [item["uuid"] for item in data]
return uuids, questions, options, answers
def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
"""
Load the prepared retrieve docs from a file.
Args:
prepared_retrieve_docs_path: The path to the prepared retrieve docs.
Returns:
A list of lists of documents.
"""
return safe_load_langchain_docs(prepared_retrieve_docs_path)
def paralelize(func, max_workers: int = 4, **kwargs) -> List:
"""
Parallelizes a function call over multiple keyword argument iterables.
Args:
func: The function to execute in parallel.
max_workers: The maximum number of threads to use.
**kwargs: Keyword arguments where each value is an iterable (e.g., a list).
All iterables must be of the same length.
The keyword names do not matter, but their order does.
Returns:
A list of the results of the function calls.
"""
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
if not kwargs:
return []
arg_lists = list(kwargs.values())
if len(set(len(lst) for lst in arg_lists)) > 1:
raise ValueError("All iterable arguments must have the same length.")
total_items = len(arg_lists[0])
iterable = zip(*arg_lists)
unpacker_func = lambda args_tuple: func(*args_tuple)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
return results
def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
"""
Converts LangChain Document objects into a serializable list of dictionaries
and saves them to a file using pickle.
Args:
documents (List[List[Document]]): The nested list of LangChain Documents.
filepath (str): The path to the file where the data will be saved.
"""
serializable_data = []
print(f"Preparing to save {len(documents)} lists of documents...")
# Convert each Document object into a dictionary
for doc_list in documents:
serializable_doc_list = []
for doc in doc_list:
serializable_doc_list.append({
"page_content": doc.page_content,
"metadata": doc.metadata,
})
serializable_data.append(serializable_doc_list)
print(f"Conversion complete. Saving to {filepath}...")
try:
# Use 'with' to ensure the file is closed properly, even if errors occur
with open(filepath, "wb") as f:
pickle.dump(serializable_data, f)
print("File saved successfully.")
except Exception as e:
print(f"An error occurred while saving the file: {e}")
def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
"""
Loads data from a pickle file and reconstructs the LangChain Document objects.
Args:
filepath (str): The path to the file to load.
Returns:
List[List[Document]]: The reconstructed nested list of LangChain Documents.
"""
reconstructed_documents = []
print(f"Loading data from {filepath}...")
try:
with open(filepath, "rb") as f:
loaded_data = pickle.load(f)
print("File loaded successfully. Reconstructing Document objects...")
# Reconstruct the Document objects from the dictionaries
for doc_list_data in loaded_data:
reconstructed_doc_list = []
for doc_data in doc_list_data:
reconstructed_doc_list.append(
Document(
page_content=doc_data["page_content"],
metadata=doc_data["metadata"]
)
)
reconstructed_documents.append(reconstructed_doc_list)
print("Document objects reconstructed successfully.")
return reconstructed_documents
except FileNotFoundError:
print(f"Error: The file at {filepath} was not found.")
return []
except EOFError:
print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
print("Please re-run the script that generates this file.")
return []
except Exception as e:
print(f"An unexpected error occurred while loading the file: {e}")
return [] |