| | from ragatouille import RAGPretrainedModel |
| | from modules.vectorstore.base import VectorStoreBase |
| | from langchain_core.retrievers import BaseRetriever |
| | from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun |
| | from langchain_core.documents import Document |
| | from typing import Any, List |
| | import os |
| | import json |
| |
|
| |
|
| | class RAGatouilleLangChainRetrieverWithScore(BaseRetriever): |
| | model: Any |
| | kwargs: dict = {} |
| |
|
| | def _get_relevant_documents( |
| | self, |
| | query: str, |
| | *, |
| | run_manager: CallbackManagerForRetrieverRun, |
| | ) -> List[Document]: |
| | """Get documents relevant to a query.""" |
| | docs = self.model.search(query, **self.kwargs) |
| | return [ |
| | Document( |
| | page_content=doc["content"], |
| | metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, |
| | ) |
| | for doc in docs |
| | ] |
| |
|
| | async def _aget_relevant_documents( |
| | self, |
| | query: str, |
| | *, |
| | run_manager: CallbackManagerForRetrieverRun, |
| | ) -> List[Document]: |
| | """Get documents relevant to a query.""" |
| | docs = self.model.search(query, **self.kwargs) |
| | return [ |
| | Document( |
| | page_content=doc["content"], |
| | metadata={**doc.get("document_metadata", {}), "score": doc["score"]}, |
| | ) |
| | for doc in docs |
| | ] |
| |
|
| |
|
| | class RAGPretrainedModel(RAGPretrainedModel): |
| | """ |
| | Adding len property to RAGPretrainedModel |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self._document_count = 0 |
| |
|
| | def set_document_count(self, count): |
| | self._document_count = count |
| |
|
| | def __len__(self): |
| | return self._document_count |
| |
|
| | def as_langchain_retriever(self, **kwargs: Any) -> BaseRetriever: |
| | return RAGatouilleLangChainRetrieverWithScore(model=self, kwargs=kwargs) |
| |
|
| |
|
| | class ColbertVectorStore(VectorStoreBase): |
| | def __init__(self, config): |
| | self.config = config |
| | self._init_vector_db() |
| |
|
| | def _init_vector_db(self): |
| | self.colbert = RAGPretrainedModel.from_pretrained( |
| | "colbert-ir/colbertv2.0", |
| | index_root=os.path.join( |
| | self.config["vectorstore"]["db_path"], |
| | "db_" + self.config["vectorstore"]["db_option"], |
| | ), |
| | ) |
| |
|
| | def create_database(self, documents, document_names, document_metadata): |
| | index_path = self.colbert.index( |
| | index_name="new_idx", |
| | collection=documents, |
| | document_ids=document_names, |
| | document_metadatas=document_metadata, |
| | ) |
| | print(f"Index created at {index_path}") |
| | self.colbert.set_document_count(len(document_names)) |
| |
|
| | def load_database(self): |
| | path = os.path.join( |
| | os.getcwd(), |
| | self.config["vectorstore"]["db_path"], |
| | "db_" + self.config["vectorstore"]["db_option"], |
| | ) |
| | self.vectorstore = RAGPretrainedModel.from_index( |
| | f"{path}/colbert/indexes/new_idx" |
| | ) |
| |
|
| | index_metadata = json.load( |
| | open(f"{path}/colbert/indexes/new_idx/0.metadata.json") |
| | ) |
| | num_documents = index_metadata["num_passages"] |
| | self.vectorstore.set_document_count(num_documents) |
| |
|
| | return self.vectorstore |
| |
|
| | def as_retriever(self): |
| | return self.vectorstore.as_retriever() |
| |
|
| | def __len__(self): |
| | return len(self.vectorstore) |
| |
|