| | from modules.vectorstore.faiss import FaissVectorStore |
| | from modules.vectorstore.chroma import ChromaVectorStore |
| | from modules.vectorstore.colbert import ColbertVectorStore |
| | from modules.vectorstore.raptor import RAPTORVectoreStore |
| | from huggingface_hub import snapshot_download |
| | import os |
| | import shutil |
| |
|
| |
|
| | class VectorStore: |
| | def __init__(self, config): |
| | self.config = config |
| | self.vectorstore = None |
| | self.vectorstore_classes = { |
| | "FAISS": FaissVectorStore, |
| | "Chroma": ChromaVectorStore, |
| | "RAGatouille": ColbertVectorStore, |
| | "RAPTOR": RAPTORVectoreStore, |
| | } |
| |
|
| | def _create_database( |
| | self, |
| | document_chunks, |
| | document_names, |
| | documents, |
| | document_metadata, |
| | embedding_model, |
| | ): |
| | db_option = self.config["vectorstore"]["db_option"] |
| | vectorstore_class = self.vectorstore_classes.get(db_option) |
| | if not vectorstore_class: |
| | raise ValueError(f"Invalid db_option: {db_option}") |
| |
|
| | self.vectorstore = vectorstore_class(self.config) |
| |
|
| | if db_option == "RAGatouille": |
| | self.vectorstore.create_database( |
| | documents, document_names, document_metadata |
| | ) |
| | else: |
| | self.vectorstore.create_database(document_chunks, embedding_model) |
| |
|
| | def _load_database(self, embedding_model): |
| | db_option = self.config["vectorstore"]["db_option"] |
| | vectorstore_class = self.vectorstore_classes.get(db_option) |
| | if not vectorstore_class: |
| | raise ValueError(f"Invalid db_option: {db_option}") |
| |
|
| | self.vectorstore = vectorstore_class(self.config) |
| |
|
| | if db_option == "RAGatouille": |
| | return self.vectorstore.load_database() |
| | else: |
| | return self.vectorstore.load_database(embedding_model) |
| |
|
| | def _load_from_HF(self, HF_PATH): |
| | |
| | |
| | snapshot_path = snapshot_download( |
| | repo_id=HF_PATH, |
| | repo_type="dataset", |
| | force_download=True, |
| | ) |
| |
|
| | |
| | target_path = os.path.join( |
| | self.config["vectorstore"]["db_path"], |
| | "db_" + self.config["vectorstore"]["db_option"], |
| | ) |
| |
|
| | |
| | os.makedirs(target_path, exist_ok=True) |
| |
|
| | |
| | |
| | for item in os.listdir(snapshot_path): |
| | s = os.path.join(snapshot_path, item) |
| | d = os.path.join(target_path, item) |
| | if os.path.isdir(s): |
| | shutil.copytree(s, d, dirs_exist_ok=True) |
| | else: |
| | shutil.copy2(s, d) |
| |
|
| | def _as_retriever(self): |
| | return self.vectorstore.as_retriever() |
| |
|
| | def _get_vectorstore(self): |
| | return self.vectorstore |
| |
|
| | def __len__(self): |
| | return self.vectorstore.__len__() |
| |
|