Spaces:
Running
Running
| from typing import List, Optional | |
| from langchain.schema import Document | |
| from langchain_community.vectorstores import FAISS | |
| from langchain.chains.query_constructor.base import AttributeInfo | |
| from langchain.retrievers.self_query.base import SelfQueryRetriever | |
| from project.utils.model_loader import ModelLoader | |
| from project.utils.config_loader import load_config | |
| from project.logger.logging import get_logger | |
| logger = get_logger(__name__) | |
| class DocumentRetriever: | |
| def __init__(self, config_path: str = None): | |
| self.config = load_config(config_path) | |
| self.model_loader = ModelLoader(config_path) | |
| self.embeddings = self.model_loader.load_embeddings() | |
| self.llm = self.model_loader.load_llm() | |
| self.vectorstore = None | |
| self.retriever = None | |
| logger.info("DocumentRetriever initialized") | |
| def create_vectorstore(self, documents: List[Document]) -> FAISS: | |
| self.vectorstore = FAISS.from_documents(documents, self.embeddings) | |
| logger.info(f"Vector store created with {len(documents)} documents") | |
| return self.vectorstore | |
| def setup_self_query_retriever( | |
| self, | |
| document_content_description: str = "Research papers and technical documents", | |
| metadata_field_info: Optional[List[AttributeInfo]] = None | |
| ): | |
| if self.vectorstore is None: | |
| raise ValueError("Vector store not initialized. Call create_vectorstore first.") | |
| if metadata_field_info is None: | |
| metadata_field_info = [ | |
| AttributeInfo( | |
| name="source", | |
| description="The source file or document name", | |
| type="string" | |
| ), | |
| AttributeInfo( | |
| name="page", | |
| description="The page number in the document", | |
| type="integer" | |
| ) | |
| ] | |
| retriever_config = self.config.get('retriever', {}) | |
| self.retriever = SelfQueryRetriever.from_llm( | |
| llm=self.llm, | |
| vectorstore=self.vectorstore, | |
| document_contents=document_content_description, | |
| metadata_field_info=metadata_field_info, | |
| search_kwargs={ | |
| 'k': retriever_config.get('top_k', 3) | |
| }, | |
| enable_limit=True | |
| ) | |
| logger.info("Self-query retriever configured") | |
| return self.retriever | |
| def retrieve(self, query: str) -> List[Document]: | |
| if self.retriever is None: | |
| raise ValueError("Retriever not initialized. Call setup_self_query_retriever first.") | |
| documents = self.retriever.invoke(query) | |
| logger.info(f"Retrieved {len(documents)} documents for query") | |
| return documents | |
| def get_base_retriever(self): | |
| if self.vectorstore is None: | |
| raise ValueError("Vector store not initialized.") | |
| retriever_config = self.config.get('retriever', {}) | |
| search_type = retriever_config.get('search_type', 'similarity') | |
| top_k = retriever_config.get('top_k', 3) | |
| if search_type == 'mmr': | |
| self.retriever = self.vectorstore.as_retriever( | |
| search_type='mmr', | |
| search_kwargs={'k': top_k, 'fetch_k': top_k * 2} | |
| ) | |
| else: | |
| self.retriever = self.vectorstore.as_retriever( | |
| search_type='similarity', | |
| search_kwargs={'k': top_k} | |
| ) | |
| logger.info(f"Base retriever configured with {search_type} search") | |
| return self.retriever | |