| from ragatouille import RAGPretrainedModel |
| from modules.vectorstore.base import VectorStoreBase |
| from langchain_core.retrievers import BaseRetriever |
| from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun |
| from langchain_core.documents import Document |
| from typing import Any, List |
| import os |
| import json |
|
|
|
|
| class RAGatouilleLangChainRetrieverWithScore(BaseRetriever): |
| model: Any |
| kwargs: dict = {} |
|
|
| def _get_relevant_documents( |
| self, |
| query: str, |
| *, |
| run_manager: CallbackManagerForRetrieverRun, |
| ) -> List[Document]: |
| """Get documents relevant to a query.""" |
| docs = self.model.search(query, **self.kwargs) |
| return [ |
| Document( |
| page_content=doc["content"], |
| metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, |
| ) |
| for doc in docs |
| ] |
|
|
| async def _aget_relevant_documents( |
| self, |
| query: str, |
| *, |
| run_manager: CallbackManagerForRetrieverRun, |
| ) -> List[Document]: |
| """Get documents relevant to a query.""" |
| docs = self.model.search(query, **self.kwargs) |
| return [ |
| Document( |
| page_content=doc["content"], |
| metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, |
| ) |
| for doc in docs |
| ] |
|
|
|
|
| class RAGPretrainedModel(RAGPretrainedModel): |
| """ |
| Adding len property to RAGPretrainedModel |
| """ |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._document_count = 0 |
|
|
| def set_document_count(self, count): |
| self._document_count = count |
|
|
| def __len__(self): |
| return self._document_count |
|
|
| def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever: |
| return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs) |
|
|
|
|
| class ColbertVectorStore(VectorStoreBase): |
| def __init__(self, config): |
| self.config = config |
| self._init_vector_db() |
|
|
| def _init_vector_db(self): |
| self.colbert = RAGPretrainedModel.from_pretrained( |
| "colbert-ir/colbertv2.0", |
| index_root=os.path.join( |
| self.config["vectorstore"]["db_path"], |
| "db_" + self.config["vectorstore"]["db_option"], |
| ), |
| ) |
|
|
| def create_database(self, documents, document_names, document_metadata): |
| index_path = self.colbert.index( |
| index_name="new_idx", |
| collection=documents, |
| document_ids=document_names, |
| document_metadatas=document_metadata, |
| ) |
| print(f"Index created at {index_path}") |
| self.colbert.set_document_count(len(document_names)) |
|
|
| def load_database(self): |
| path = os.path.join( |
| os.getcwd(), |
| self.config["vectorstore"]["db_path"], |
| "db_" + self.config["vectorstore"]["db_option"], |
| ) |
| self.vectorstore = RAGPretrainedModel.from_index( |
| f"{path}/colbert/indexes/new_idx" |
| ) |
|
|
| index_metadata = json.load( |
| open(f"{path}/colbert/indexes/new_idx/0.metadata.json") |
| ) |
| num_documents = index_metadata["num_passages"] |
| self.vectorstore.set_document_count(num_documents) |
|
|
| return self.vectorstore |
|
|
| def as_retriever(self): |
| return self.vectorstore.as_retriever() |
|
|
| def __len__(self): |
| return len(self.vectorstore) |
|
|