Spaces:
Sleeping
Sleeping
Enhanced conversation analysis UI with customer details and migrated to Keshav MongoDB
8223f74
unverified
| from typing import Literal, Union | |
| from langchain_mongodb import MongoDBAtlasVectorSearch | |
| from langchain_mongodb.retrievers import ( | |
| MongoDBAtlasHybridSearchRetriever, | |
| MongoDBAtlasParentDocumentRetriever, | |
| ) | |
| from loguru import logger | |
| from second_brain_online.config import settings | |
| from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model | |
| from .splitters import get_splitter | |
| # Add these type definitions at the top of the file | |
| RetrieverType = Literal["contextual", "parent", "contextual_reranked", "parent_reranked"] | |
| RetrieverModel = Union[ | |
| MongoDBAtlasHybridSearchRetriever, | |
| MongoDBAtlasParentDocumentRetriever, | |
| "RerankingRetriever" | |
| ] | |
| def get_retriever( | |
| embedding_model_id: str, | |
| embedding_model_type: EmbeddingModelType = "huggingface", | |
| retriever_type: RetrieverType = "contextual", | |
| k: int = 3, | |
| device: str = "cpu", | |
| enable_reranking: bool = False, | |
| rerank_model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2", | |
| stage1_limit: int = 50, | |
| final_k: int = 10, | |
| ) -> RetrieverModel: | |
| logger.info( | |
| f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' " | |
| f"with {k} top results" | |
| ) | |
| embedding_model = get_embedding_model( | |
| embedding_model_id, embedding_model_type, device | |
| ) | |
| # Determine base retriever type | |
| base_retriever_type = retriever_type | |
| if retriever_type in ["contextual_reranked", "parent_reranked"]: | |
| base_retriever_type = retriever_type.replace("_reranked", "") | |
| enable_reranking = True | |
| else: | |
| enable_reranking = enable_reranking | |
| # Create base retriever | |
| if base_retriever_type == "contextual": | |
| base_retriever = get_hybrid_search_retriever(embedding_model, k) | |
| elif base_retriever_type == "parent": | |
| base_retriever = get_parent_document_retriever(embedding_model, k) | |
| else: | |
| raise ValueError(f"Invalid retriever type: {retriever_type}") | |
| # Wrap with re-ranking if enabled | |
| if enable_reranking: | |
| from second_brain_offline.application.rag.reranker import RerankingRetriever | |
| logger.info(f"Enabling re-ranking with model: {rerank_model_name}") | |
| logger.info(f"Stage 1 limit: {stage1_limit}, Final k: {final_k}") | |
| return RerankingRetriever( | |
| base_retriever=base_retriever, | |
| rerank_model_name=rerank_model_name, | |
| stage1_limit=stage1_limit, | |
| final_k=final_k | |
| ) | |
| return base_retriever | |
| def get_hybrid_search_retriever( | |
| embedding_model: EmbeddingsModel, k: int | |
| ) -> MongoDBAtlasHybridSearchRetriever: | |
| vectorstore = MongoDBAtlasVectorSearch.from_connection_string( | |
| connection_string=settings.MONGODB_URI, | |
| embedding=embedding_model, | |
| namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}", | |
| text_key="chunk", | |
| embedding_key="embedding", | |
| relevance_score_fn="dotProduct", | |
| ) | |
| retriever = MongoDBAtlasHybridSearchRetriever( | |
| vectorstore=vectorstore, | |
| search_index_name="chunk_text_search", | |
| top_k=k, | |
| vector_penalty=50, | |
| fulltext_penalty=50, | |
| ) | |
| return retriever | |
| def get_parent_document_retriever( | |
| embedding_model: EmbeddingsModel, k: int = 3 | |
| ) -> MongoDBAtlasParentDocumentRetriever: | |
| retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string( | |
| connection_string=settings.MONGODB_URI, | |
| embedding_model=embedding_model, | |
| child_splitter=get_splitter(200), | |
| parent_splitter=get_splitter(800), | |
| database_name=settings.MONGODB_DATABASE_NAME, | |
| collection_name=settings.MONGODB_COLLECTION_NAME, | |
| text_key="chunk", | |
| search_kwargs={"k": k}, | |
| ) | |
| return retriever | |