| from langchain_community.vectorstores import Chroma |
| from modules.vectorstore.base import VectorStoreBase |
| import os |
|
|
|
|
| class ChromaVectorStore(VectorStoreBase): |
| def __init__(self, config): |
| self.config = config |
| self._init_vector_db() |
|
|
| def _init_vector_db(self): |
| self.chroma = Chroma() |
|
|
| def create_database(self, document_chunks, embedding_model): |
| self.vectorstore = self.chroma.from_documents( |
| documents=document_chunks, |
| embedding=embedding_model, |
| persist_directory=os.path.join( |
| self.config["vectorstore"]["db_path"], |
| "db_" |
| + self.config["vectorstore"]["db_option"] |
| + "_" |
| + self.config["vectorstore"]["model"], |
| ), |
| ) |
|
|
| def load_database(self, embedding_model): |
| self.vectorstore = Chroma( |
| persist_directory=os.path.join( |
| self.config["vectorstore"]["db_path"], |
| "db_" |
| + self.config["vectorstore"]["db_option"] |
| + "_" |
| + self.config["vectorstore"]["model"], |
| ), |
| embedding_function=embedding_model, |
| ) |
| return self.vectorstore |
|
|
| def as_retriever(self): |
| return self.vectorstore.as_retriever() |
|
|
| def __len__(self): |
| return len(self.vectorstore) |
|
|