MedLLM-Assistant / utils.py
VuvanAn's picture
Upload 47 files
09dc9d3 verified
import os
import pickle
from typing import List, Optional
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
def load_local(vectorstore_dir: str, embed_model: HuggingFaceEmbeddings) -> tuple[Optional[FAISS], Optional[List[Document]]]:
"""
Load the vectorstore and documents from disk.
Args:
vectorstore_dir: The directory to load the vectorstore from.
embed_model: The embedding model to use.
Returns:
vector_store: The vectorstore.
"""
from langchain_community.vectorstores import FAISS
if not os.path.isdir(vectorstore_dir):
print(f"Vectorstore directory not found at {vectorstore_dir}. Creating a new one.")
os.makedirs(vectorstore_dir, exist_ok=True)
try:
vector_store = FAISS.load_local(vectorstore_dir, embed_model, allow_dangerous_deserialization=True)
docs_path = os.path.join(vectorstore_dir, "docs.pkl")
if os.path.exists(docs_path):
with open(docs_path, "rb") as f:
docs = pickle.load(f)
else:
docs = None
print("Warning: docs.pkl not found. BM25 search will not be available.")
print(f"Successfully loaded RAG state from {vectorstore_dir}")
return vector_store, docs
except Exception as e:
print(f"Could not load from {vectorstore_dir}. It might be empty or corrupted. Error: {e}")
return None, None
def save_local(vectorstore_dir: str, vectorstore: FAISS, docs: Optional[List[Document]]) -> None:
"""
Save the vectorstore and documents to disk.
Args:
vectorstore_dir: The directory to save the vectorstore to.
vectorstore: The vectorstore to save.
docs: The documents to save.
"""
if vectorstore is None:
raise ValueError("Nothing to save.")
if docs is None:
print("Warning: No documents to save. BM25 search will not be available.")
os.makedirs(vectorstore_dir, exist_ok=True)
vectorstore.save_local(vectorstore_dir)
if docs is not None:
with open(os.path.join(vectorstore_dir, "docs.pkl"), "wb") as f:
pickle.dump(docs, f)
print(f"Successfully saved RAG state to {vectorstore_dir}")
def load_qa_dataset(qa_dataset_path: str) -> tuple[List[str], List[str], List[str], List[str]]:
"""
Load the QA dataset. (jsonl)
Args:
qa_dataset_path: The path to the QA dataset.
Returns:
Tuple: (ids, questions, options, answers)\\
ids: The ids of the questions\\
questions: The questions\\
options: The options for each question\\
answers: The answers for each question.
"""
import json
if not os.path.exists(qa_dataset_path):
raise FileNotFoundError(f"Error: File not found at {qa_dataset_path}")
with open(qa_dataset_path, "r", encoding="utf-8") as f:
data = [json.loads(line) for line in f]
questions = [item["question"] for item in data]
try:
options = [
(f"A. {item['A']} \n" if item['A'] not in [" ", "", None] else "") +
(f"B. {item['B']} \n" if item['B'] not in [" ", "", None] else "") +
(f"C. {item['C']} \n" if item['C'] not in [" ", "", None] else "") +
(f"D. {item['D']} \n" if item['D'] not in [" ", "", None] else "") +
(f"E. {item['E']} \n" if item['E'] not in [" ", "", None] else "")
for item in data]
except KeyError:
options = [" " for item in data]
answers = [item["answer"] for item in data]
uuids = [item["uuid"] for item in data]
return uuids, questions, options, answers
def load_prepared_retrieve_docs(prepared_retrieve_docs_path: str) -> List[List[Document]]:
"""
Load the prepared retrieve docs from a file.
Args:
prepared_retrieve_docs_path: The path to the prepared retrieve docs.
Returns:
A list of lists of documents.
"""
return safe_load_langchain_docs(prepared_retrieve_docs_path)
def paralelize(func, max_workers: int = 4, **kwargs) -> List:
"""
Parallelizes a function call over multiple keyword argument iterables.
Args:
func: The function to execute in parallel.
max_workers: The maximum number of threads to use.
**kwargs: Keyword arguments where each value is an iterable (e.g., a list).
All iterables must be of the same length.
The keyword names do not matter, but their order does.
Returns:
A list of the results of the function calls.
"""
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
if not kwargs:
return []
arg_lists = list(kwargs.values())
if len(set(len(lst) for lst in arg_lists)) > 1:
raise ValueError("All iterable arguments must have the same length.")
total_items = len(arg_lists[0])
iterable = zip(*arg_lists)
unpacker_func = lambda args_tuple: func(*args_tuple)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(unpacker_func, iterable), total=total_items))
return results
def safe_save_langchain_docs(documents: List[List[Document]], filepath: str):
"""
Converts LangChain Document objects into a serializable list of dictionaries
and saves them to a file using pickle.
Args:
documents (List[List[Document]]): The nested list of LangChain Documents.
filepath (str): The path to the file where the data will be saved.
"""
serializable_data = []
print(f"Preparing to save {len(documents)} lists of documents...")
# Convert each Document object into a dictionary
for doc_list in documents:
serializable_doc_list = []
for doc in doc_list:
serializable_doc_list.append({
"page_content": doc.page_content,
"metadata": doc.metadata,
})
serializable_data.append(serializable_doc_list)
print(f"Conversion complete. Saving to {filepath}...")
try:
# Use 'with' to ensure the file is closed properly, even if errors occur
with open(filepath, "wb") as f:
pickle.dump(serializable_data, f)
print("File saved successfully.")
except Exception as e:
print(f"An error occurred while saving the file: {e}")
def safe_load_langchain_docs(filepath: str) -> List[List[Document]]:
"""
Loads data from a pickle file and reconstructs the LangChain Document objects.
Args:
filepath (str): The path to the file to load.
Returns:
List[List[Document]]: The reconstructed nested list of LangChain Documents.
"""
reconstructed_documents = []
print(f"Loading data from {filepath}...")
try:
with open(filepath, "rb") as f:
loaded_data = pickle.load(f)
print("File loaded successfully. Reconstructing Document objects...")
# Reconstruct the Document objects from the dictionaries
for doc_list_data in loaded_data:
reconstructed_doc_list = []
for doc_data in doc_list_data:
reconstructed_doc_list.append(
Document(
page_content=doc_data["page_content"],
metadata=doc_data["metadata"]
)
)
reconstructed_documents.append(reconstructed_doc_list)
print("Document objects reconstructed successfully.")
return reconstructed_documents
except FileNotFoundError:
print(f"Error: The file at {filepath} was not found.")
return []
except EOFError:
print(f"Error: The file at {filepath} is corrupted or incomplete (EOFError).")
print("Please re-run the script that generates this file.")
return []
except Exception as e:
print(f"An unexpected error occurred while loading the file: {e}")
return []