|
|
import os
|
|
|
import pickle
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain.schema import Document
|
|
|
|
|
|
|
|
|
def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
|
|
|
"""
|
|
|
Load the vectorstore and documents from disk.
|
|
|
Args:
|
|
|
vectorstore_dir: The directory to load the vectorstore from.
|
|
|
embed_model: The embedding model to use.
|
|
|
Returns:
|
|
|
vector_store: The vectorstore.
|
|
|
"""
|
|
|
from langchain_community.vectorstores import FAISS
|
|
|
|
|
|
if not os.path.isdir(vectorstore_dir):
|
|
|
print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
|
|
|
os.makedirs(vectorstore_dir, exist_ok=True)
|
|
|
|
|
|
try:
|
|
|
vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
|
|
|
|
|
|
docs_path = os.path.join(vectorstore_dir, "docs.pkl")
|
|
|
if os.path.exists(docs_path):
|
|
|
with open(docs_path, "rb") as f:
|
|
|
docs = pickle.load(f)
|
|
|
else:
|
|
|
docs = None
|
|
|
print("Warning: docs.pkl not found. BM25 search will not be available.")
|
|
|
|
|
|
print(f"Successfully loaded RAG state from {vectorstore_dir}")
|
|
|
return vector_store, docs
|
|
|
except Exception as e:
|
|
|
print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
|
|
|
return None, None
|
|
|
|
|
|
def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
|
|
|
"""
|
|
|
Save the vectorstore and documents to disk.
|
|
|
Args:
|
|
|
vectorstore_dir: The directory to save the vectorstore to.
|
|
|
vectorstore: The vectorstore to save.
|
|
|
docs: The documents to save.
|
|
|
"""
|
|
|
if vectorstore is None:
|
|
|
raise ValueError("Nothing to save.")
|
|
|
if docs is None:
|
|
|
print("Warning: No documents to save. BM25 search will not be available.")
|
|
|
|
|
|
os.makedirs(vectorstore_dir, exist_ok=True)
|
|
|
vectorstore.save_local(vectorstore_dir)
|
|
|
|
|
|
if docs is not None:
|
|
|
with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
|
|
|
pickle.dump(docs, f)
|
|
|
|
|
|
print(f"Successfully saved RAG state to {vectorstore_dir}")
|
|
|
|
|
|
def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
|
|
|
"""
|
|
|
Load the QA dataset. (jsonl)
|
|
|
Args:
|
|
|
qa_dataset_path: The path to the QA dataset.
|
|
|
Returns:
|
|
|
Tuple: (ids, questions, options, answers)\\
|
|
|
ids: The ids of the questions\\
|
|
|
questions: The questions\\
|
|
|
options: The options for each question\\
|
|
|
answers: The answers for each question.
|
|
|
"""
|
|
|
import json
|
|
|
if not os.path.exists(qa_dataset_path):
|
|
|
raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
|
|
|
|
|
|
with open(qa_dataset_path, "r", encoding="utf-8") as f:
|
|
|
data = [json.loads(line) for line in f]
|
|
|
questions = [item["question"] for item in data]
|
|
|
try:
|
|
|
options = [
|
|
|
(f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
|
|
|
(f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
|
|
|
(f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
|
|
|
(f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
|
|
|
(f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
|
|
|
for item in data]
|
|
|
except KeyError:
|
|
|
options = [" " for item in data]
|
|
|
answers = [item["answer"] for item in data]
|
|
|
uuids = [item["uuid"] for item in data]
|
|
|
return uuids, questions, options, answers
|
|
|
|
|
|
def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
|
|
|
"""
|
|
|
Load the prepared retrieve docs from a file.
|
|
|
Args:
|
|
|
prepared_retrieve_docs_path: The path to the prepared retrieve docs.
|
|
|
Returns:
|
|
|
A list of lists of documents.
|
|
|
"""
|
|
|
return safe_load_langchain_docs(prepared_retrieve_docs_path)
|
|
|
|
|
|
def paralelize(func, max_workers: int = 4, **kwargs) -> List:
|
|
|
"""
|
|
|
Parallelizes a function call over multiple keyword argument iterables.
|
|
|
|
|
|
Args:
|
|
|
func: The function to execute in parallel.
|
|
|
max_workers: The maximum number of threads to use.
|
|
|
**kwargs: Keyword arguments where each value is an iterable (e.g., a list).
|
|
|
All iterables must be of the same length.
|
|
|
The keyword names do not matter, but their order does.
|
|
|
Returns:
|
|
|
A list of the results of the function calls.
|
|
|
"""
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
if not kwargs:
|
|
|
return []
|
|
|
|
|
|
arg_lists = list(kwargs.values())
|
|
|
if len(set(len(lst) for lst in arg_lists)) > 1:
|
|
|
raise ValueError("All iterable arguments must have the same length.")
|
|
|
|
|
|
total_items = len(arg_lists[0])
|
|
|
iterable = zip(*arg_lists)
|
|
|
unpacker_func = lambda args_tuple: func(*args_tuple)
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
|
|
|
return results
|
|
|
|
|
|
def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
|
|
|
"""
|
|
|
Converts LangChain Document objects into a serializable list of dictionaries
|
|
|
and saves them to a file using pickle.
|
|
|
|
|
|
Args:
|
|
|
documents (List[List[Document]]): The nested list of LangChain Documents.
|
|
|
filepath (str): The path to the file where the data will be saved.
|
|
|
"""
|
|
|
serializable_data = []
|
|
|
print(f"Preparing to save {len(documents)} lists of documents...")
|
|
|
|
|
|
|
|
|
for doc_list in documents:
|
|
|
serializable_doc_list = []
|
|
|
for doc in doc_list:
|
|
|
serializable_doc_list.append({
|
|
|
"page_content": doc.page_content,
|
|
|
"metadata": doc.metadata,
|
|
|
})
|
|
|
serializable_data.append(serializable_doc_list)
|
|
|
|
|
|
print(f"Conversion complete. Saving to {filepath}...")
|
|
|
try:
|
|
|
|
|
|
with open(filepath, "wb") as f:
|
|
|
pickle.dump(serializable_data, f)
|
|
|
print("File saved successfully.")
|
|
|
except Exception as e:
|
|
|
print(f"An error occurred while saving the file: {e}")
|
|
|
|
|
|
def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
|
|
|
"""
|
|
|
Loads data from a pickle file and reconstructs the LangChain Document objects.
|
|
|
|
|
|
Args:
|
|
|
filepath (str): The path to the file to load.
|
|
|
|
|
|
Returns:
|
|
|
List[List[Document]]: The reconstructed nested list of LangChain Documents.
|
|
|
"""
|
|
|
reconstructed_documents = []
|
|
|
|
|
|
print(f"Loading data from {filepath}...")
|
|
|
try:
|
|
|
with open(filepath, "rb") as f:
|
|
|
loaded_data = pickle.load(f)
|
|
|
print("File loaded successfully. Reconstructing Document objects...")
|
|
|
|
|
|
|
|
|
for doc_list_data in loaded_data:
|
|
|
reconstructed_doc_list = []
|
|
|
for doc_data in doc_list_data:
|
|
|
reconstructed_doc_list.append(
|
|
|
Document(
|
|
|
page_content=doc_data["page_content"],
|
|
|
metadata=doc_data["metadata"]
|
|
|
)
|
|
|
)
|
|
|
reconstructed_documents.append(reconstructed_doc_list)
|
|
|
|
|
|
print("Document objects reconstructed successfully.")
|
|
|
return reconstructed_documents
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
print(f"Error: The file at {filepath} was not found.")
|
|
|
return []
|
|
|
except EOFError:
|
|
|
print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
|
|
|
print("Please re-run the script that generates this file.")
|
|
|
return []
|
|
|
except Exception as e:
|
|
|
print(f"An unexpected error occurred while loading the file: {e}")
|
|
|
return [] |