from typing import Literal, Union from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_mongodb.retrievers import ( MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever, ) from loguru import logger from second_brain_online.config import settings from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model from .splitters import get_splitter # Add these type definitions at the top of the file RetrieverType = Literal["contextual", "parent", "contextual_reranked", "parent_reranked"] RetrieverModel = Union[ MongoDBAtlasHybridSearchRetriever, MongoDBAtlasParentDocumentRetriever, "RerankingRetriever" ] def get_retriever( embedding_model_id: str, embedding_model_type: EmbeddingModelType = "huggingface", retriever_type: RetrieverType = "contextual", k: int = 3, device: str = "cpu", enable_reranking: bool = False, rerank_model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2", stage1_limit: int = 50, final_k: int = 10, ) -> RetrieverModel: logger.info( f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' " f"with {k} top results" ) embedding_model = get_embedding_model( embedding_model_id, embedding_model_type, device ) # Determine base retriever type base_retriever_type = retriever_type if retriever_type in ["contextual_reranked", "parent_reranked"]: base_retriever_type = retriever_type.replace("_reranked", "") enable_reranking = True else: enable_reranking = enable_reranking # Create base retriever if base_retriever_type == "contextual": base_retriever = get_hybrid_search_retriever(embedding_model, k) elif base_retriever_type == "parent": base_retriever = get_parent_document_retriever(embedding_model, k) else: raise ValueError(f"Invalid retriever type: {retriever_type}") # Wrap with re-ranking if enabled if enable_reranking: from second_brain_offline.application.rag.reranker import RerankingRetriever logger.info(f"Enabling re-ranking with model: {rerank_model_name}") logger.info(f"Stage 1 limit: {stage1_limit}, Final k: {final_k}") return RerankingRetriever( base_retriever=base_retriever, rerank_model_name=rerank_model_name, stage1_limit=stage1_limit, final_k=final_k ) return base_retriever def get_hybrid_search_retriever( embedding_model: EmbeddingsModel, k: int ) -> MongoDBAtlasHybridSearchRetriever: vectorstore = MongoDBAtlasVectorSearch.from_connection_string( connection_string=settings.MONGODB_URI, embedding=embedding_model, namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}", text_key="chunk", embedding_key="embedding", relevance_score_fn="dotProduct", ) retriever = MongoDBAtlasHybridSearchRetriever( vectorstore=vectorstore, search_index_name="chunk_text_search", top_k=k, vector_penalty=50, fulltext_penalty=50, ) return retriever def get_parent_document_retriever( embedding_model: EmbeddingsModel, k: int = 3 ) -> MongoDBAtlasParentDocumentRetriever: retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string( connection_string=settings.MONGODB_URI, embedding_model=embedding_model, child_splitter=get_splitter(200), parent_splitter=get_splitter(800), database_name=settings.MONGODB_DATABASE_NAME, collection_name=settings.MONGODB_COLLECTION_NAME, text_key="chunk", search_kwargs={"k": k}, ) return retriever