from langchain_qdrant import QdrantVectorStore from langchain_community.retrievers import BM25Retriever from langchain_core.retrievers import BaseRetriever from langchain_core.documents import Document from qdrant_client import QdrantClient from ingestion.embedder import embedder from app.config import config, settings from app.utils.logger import logger from typing import List class HybridRetriever(BaseRetriever): vector_store: QdrantVectorStore = None bm25_retriever: BM25Retriever = None documents: List[Document] = [] k: int = 10 _initialized: bool = False def __init__(self): super().__init__() self.k = config["rag"]["retrieval"]["top_k"] def _initialize_vector_store(self): if not self._initialized: qdrant_config = config["database"]["qdrant"] client = QdrantClient( url=qdrant_config["url"], api_key=settings.qdrant_api_key or None, timeout=60 ) try: self.vector_store = QdrantVectorStore( client=client, collection_name=qdrant_config["collection_name"], embedding=embedder.get_embeddings() ) self._initialized = True logger.info(f"Vector store initialized: {qdrant_config['collection_name']}") except Exception as e: logger.warning(f"Vector store init skipped: {str(e)}") def add_documents(self, documents: List[Document]): self._initialize_vector_store() ids = self.vector_store.add_documents(documents) self.documents.extend(documents) self.bm25_retriever = BM25Retriever.from_documents( self.documents, k=self.k ) logger.info(f"Added {len(documents)} documents (total: {len(self.documents)})") return ids def _get_relevant_documents(self, query: str) -> List[Document]: self._initialize_vector_store() if not self._initialized: logger.warning("Vector store not available") return [] vector_docs = self.vector_store.similarity_search(query, k=self.k) if self.bm25_retriever is None: logger.warning("BM25 not initialized, using vector-only retrieval") return vector_docs bm25_docs = self.bm25_retriever.invoke(query) combined = {} for doc in vector_docs: doc_id = doc.page_content[:100] combined[doc_id] = doc for doc in bm25_docs: doc_id = doc.page_content[:100] if doc_id not in combined: combined[doc_id] = doc results = list(combined.values())[:self.k] logger.info(f"Hybrid search returned {len(results)} documents") return results async def _aget_relevant_documents(self, query: str) -> List[Document]: return self._get_relevant_documents(query) hybrid_retriever = HybridRetriever()