RAG-Project / project /model /retriever.py
github-actions[bot]
Deploy from GitHub Actions
dfa6a46
from typing import List, Optional
from langchain.schema import Document
from langchain_community.vectorstores import FAISS
from langchain.chains.query_constructor.base import AttributeInfo
from langchain.retrievers.self_query.base import SelfQueryRetriever
from project.utils.model_loader import ModelLoader
from project.utils.config_loader import load_config
from project.logger.logging import get_logger
logger = get_logger(__name__)
class DocumentRetriever:
def __init__(self, config_path: str = None):
self.config = load_config(config_path)
self.model_loader = ModelLoader(config_path)
self.embeddings = self.model_loader.load_embeddings()
self.llm = self.model_loader.load_llm()
self.vectorstore = None
self.retriever = None
logger.info("DocumentRetriever initialized")
def create_vectorstore(self, documents: List[Document]) -> FAISS:
self.vectorstore = FAISS.from_documents(documents, self.embeddings)
logger.info(f"Vector store created with {len(documents)} documents")
return self.vectorstore
def setup_self_query_retriever(
self,
document_content_description: str = "Research papers and technical documents",
metadata_field_info: Optional[List[AttributeInfo]] = None
):
if self.vectorstore is None:
raise ValueError("Vector store not initialized. Call create_vectorstore first.")
if metadata_field_info is None:
metadata_field_info = [
AttributeInfo(
name="source",
description="The source file or document name",
type="string"
),
AttributeInfo(
name="page",
description="The page number in the document",
type="integer"
)
]
retriever_config = self.config.get('retriever', {})
self.retriever = SelfQueryRetriever.from_llm(
llm=self.llm,
vectorstore=self.vectorstore,
document_contents=document_content_description,
metadata_field_info=metadata_field_info,
search_kwargs={
'k': retriever_config.get('top_k', 3)
},
enable_limit=True
)
logger.info("Self-query retriever configured")
return self.retriever
def retrieve(self, query: str) -> List[Document]:
if self.retriever is None:
raise ValueError("Retriever not initialized. Call setup_self_query_retriever first.")
documents = self.retriever.invoke(query)
logger.info(f"Retrieved {len(documents)} documents for query")
return documents
def get_base_retriever(self):
if self.vectorstore is None:
raise ValueError("Vector store not initialized.")
retriever_config = self.config.get('retriever', {})
search_type = retriever_config.get('search_type', 'similarity')
top_k = retriever_config.get('top_k', 3)
if search_type == 'mmr':
self.retriever = self.vectorstore.as_retriever(
search_type='mmr',
search_kwargs={'k': top_k, 'fetch_k': top_k * 2}
)
else:
self.retriever = self.vectorstore.as_retriever(
search_type='similarity',
search_kwargs={'k': top_k}
)
logger.info(f"Base retriever configured with {search_type} search")
return self.retriever