Spaces:
Configuration error
Configuration error
| import os | |
| import pickle | |
| from typing import List, Optional | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.schema import Document | |
| def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]: | |
| """ | |
| Load the vectorstore and documents from disk. | |
| Args: | |
| vectorstore_dir: The directory to load the vectorstore from. | |
| embed_model: The embedding model to use. | |
| Returns: | |
| vector_store: The vectorstore. | |
| """ | |
| from langchain_community.vectorstores import FAISS | |
| if not os.path.isdir(vectorstore_dir): | |
| print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.") | |
| os.makedirs(vectorstore_dir, exist_ok=True) | |
| try: | |
| vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True) | |
| docs_path = os.path.join(vectorstore_dir, "docs.pkl") | |
| if os.path.exists(docs_path): | |
| with open(docs_path, "rb") as f: | |
| docs = pickle.load(f) | |
| else: | |
| docs = None | |
| print("Warning: docs.pkl not found. BM25 search will not be available.") | |
| print(f"Successfully loaded RAG state from {vectorstore_dir}") | |
| return vector_store, docs | |
| except Exception as e: | |
| print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}") | |
| return None, None | |
| def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None: | |
| """ | |
| Save the vectorstore and documents to disk. | |
| Args: | |
| vectorstore_dir: The directory to save the vectorstore to. | |
| vectorstore: The vectorstore to save. | |
| docs: The documents to save. | |
| """ | |
| if vectorstore is None: | |
| raise ValueError("Nothing to save.") | |
| if docs is None: | |
| print("Warning: No documents to save. BM25 search will not be available.") | |
| os.makedirs(vectorstore_dir, exist_ok=True) | |
| vectorstore.save_local(vectorstore_dir) | |
| if docs is not None: | |
| with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f: | |
| pickle.dump(docs, f) | |
| print(f"Successfully saved RAG state to {vectorstore_dir}") | |
| def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]: | |
| """ | |
| Load the QA dataset. (jsonl) | |
| Args: | |
| qa_dataset_path: The path to the QA dataset. | |
| Returns: | |
| Tuple: (ids, questions, options, answers)\\ | |
| ids: The ids of the questions\\ | |
| questions: The questions\\ | |
| options: The options for each question\\ | |
| answers: The answers for each question. | |
| """ | |
| import json | |
| if not os.path.exists(qa_dataset_path): | |
| raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}") | |
| with open(qa_dataset_path, "r", encoding="utf-8") as f: | |
| data = [json.loads(line) for line in f] | |
| questions = [item["question"] for item in data] | |
| try: | |
| options = [ | |
| (f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") + | |
| (f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") + | |
| (f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") + | |
| (f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") + | |
| (f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "") | |
| for item in data] | |
| except KeyError: | |
| options = [" " for item in data] | |
| answers = [item["answer"] for item in data] | |
| uuids = [item["uuid"] for item in data] | |
| return uuids, questions, options, answers | |
| def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]: | |
| """ | |
| Load the prepared retrieve docs from a file. | |
| Args: | |
| prepared_retrieve_docs_path: The path to the prepared retrieve docs. | |
| Returns: | |
| A list of lists of documents. | |
| """ | |
| return safe_load_langchain_docs(prepared_retrieve_docs_path) | |
| def paralelize(func, max_workers: int = 4, **kwargs) -> List: | |
| """ | |
| Parallelizes a function call over multiple keyword argument iterables. | |
| Args: | |
| func: The function to execute in parallel. | |
| max_workers: The maximum number of threads to use. | |
| **kwargs: Keyword arguments where each value is an iterable (e.g., a list). | |
| All iterables must be of the same length. | |
| The keyword names do not matter, but their order does. | |
| Returns: | |
| A list of the results of the function calls. | |
| """ | |
| from concurrent.futures import ThreadPoolExecutor | |
| from tqdm import tqdm | |
| if not kwargs: | |
| return [] | |
| arg_lists = list(kwargs.values()) | |
| if len(set(len(lst) for lst in arg_lists)) > 1: | |
| raise ValueError("All iterable arguments must have the same length.") | |
| total_items = len(arg_lists[0]) | |
| iterable = zip(*arg_lists) | |
| unpacker_func = lambda args_tuple: func(*args_tuple) | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items)) | |
| return results | |
| def safe_save_langchain_docs(documents: List[List[Document]], filepath: str): | |
| """ | |
| Converts LangChain Document objects into a serializable list of dictionaries | |
| and saves them to a file using pickle. | |
| Args: | |
| documents (List[List[Document]]): The nested list of LangChain Documents. | |
| filepath (str): The path to the file where the data will be saved. | |
| """ | |
| serializable_data = [] | |
| print(f"Preparing to save {len(documents)} lists of documents...") | |
| # Convert each Document object into a dictionary | |
| for doc_list in documents: | |
| serializable_doc_list = [] | |
| for doc in doc_list: | |
| serializable_doc_list.append({ | |
| "page_content": doc.page_content, | |
| "metadata": doc.metadata, | |
| }) | |
| serializable_data.append(serializable_doc_list) | |
| print(f"Conversion complete. Saving to {filepath}...") | |
| try: | |
| # Use 'with' to ensure the file is closed properly, even if errors occur | |
| with open(filepath, "wb") as f: | |
| pickle.dump(serializable_data, f) | |
| print("File saved successfully.") | |
| except Exception as e: | |
| print(f"An error occurred while saving the file: {e}") | |
| def safe_load_langchain_docs(filepath: str) -> List[List[Document]]: | |
| """ | |
| Loads data from a pickle file and reconstructs the LangChain Document objects. | |
| Args: | |
| filepath (str): The path to the file to load. | |
| Returns: | |
| List[List[Document]]: The reconstructed nested list of LangChain Documents. | |
| """ | |
| reconstructed_documents = [] | |
| print(f"Loading data from {filepath}...") | |
| try: | |
| with open(filepath, "rb") as f: | |
| loaded_data = pickle.load(f) | |
| print("File loaded successfully. Reconstructing Document objects...") | |
| # Reconstruct the Document objects from the dictionaries | |
| for doc_list_data in loaded_data: | |
| reconstructed_doc_list = [] | |
| for doc_data in doc_list_data: | |
| reconstructed_doc_list.append( | |
| Document( | |
| page_content=doc_data["page_content"], | |
| metadata=doc_data["metadata"] | |
| ) | |
| ) | |
| reconstructed_documents.append(reconstructed_doc_list) | |
| print("Document objects reconstructed successfully.") | |
| return reconstructed_documents | |
| except FileNotFoundError: | |
| print(f"Error: The file at {filepath} was not found.") | |
| return [] | |
| except EOFError: | |
| print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).") | |
| print("Please re-run the script that generates this file.") | |
| return [] | |
| except Exception as e: | |
| print(f"An unexpected error occurred while loading the file: {e}") | |
| return [] |