chinmayjha's picture
Enhanced conversation analysis UI with customer details and migrated to Keshav MongoDB
8223f74 unverified
from typing import Literal, Union
from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers import (
MongoDBAtlasHybridSearchRetriever,
MongoDBAtlasParentDocumentRetriever,
)
from loguru import logger
from second_brain_online.config import settings
from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model
from .splitters import get_splitter
# Add these type definitions at the top of the file
RetrieverType = Literal["contextual", "parent", "contextual_reranked", "parent_reranked"]
RetrieverModel = Union[
MongoDBAtlasHybridSearchRetriever,
MongoDBAtlasParentDocumentRetriever,
"RerankingRetriever"
]
def get_retriever(
embedding_model_id: str,
embedding_model_type: EmbeddingModelType = "huggingface",
retriever_type: RetrieverType = "contextual",
k: int = 3,
device: str = "cpu",
enable_reranking: bool = False,
rerank_model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2",
stage1_limit: int = 50,
final_k: int = 10,
) -> RetrieverModel:
logger.info(
f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' "
f"with {k} top results"
)
embedding_model = get_embedding_model(
embedding_model_id, embedding_model_type, device
)
# Determine base retriever type
base_retriever_type = retriever_type
if retriever_type in ["contextual_reranked", "parent_reranked"]:
base_retriever_type = retriever_type.replace("_reranked", "")
enable_reranking = True
else:
enable_reranking = enable_reranking
# Create base retriever
if base_retriever_type == "contextual":
base_retriever = get_hybrid_search_retriever(embedding_model, k)
elif base_retriever_type == "parent":
base_retriever = get_parent_document_retriever(embedding_model, k)
else:
raise ValueError(f"Invalid retriever type: {retriever_type}")
# Wrap with re-ranking if enabled
if enable_reranking:
from second_brain_offline.application.rag.reranker import RerankingRetriever
logger.info(f"Enabling re-ranking with model: {rerank_model_name}")
logger.info(f"Stage 1 limit: {stage1_limit}, Final k: {final_k}")
return RerankingRetriever(
base_retriever=base_retriever,
rerank_model_name=rerank_model_name,
stage1_limit=stage1_limit,
final_k=final_k
)
return base_retriever
def get_hybrid_search_retriever(
embedding_model: EmbeddingsModel, k: int
) -> MongoDBAtlasHybridSearchRetriever:
vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding=embedding_model,
namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}",
text_key="chunk",
embedding_key="embedding",
relevance_score_fn="dotProduct",
)
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vectorstore,
search_index_name="chunk_text_search",
top_k=k,
vector_penalty=50,
fulltext_penalty=50,
)
return retriever
def get_parent_document_retriever(
embedding_model: EmbeddingsModel, k: int = 3
) -> MongoDBAtlasParentDocumentRetriever:
retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
connection_string=settings.MONGODB_URI,
embedding_model=embedding_model,
child_splitter=get_splitter(200),
parent_splitter=get_splitter(800),
database_name=settings.MONGODB_DATABASE_NAME,
collection_name=settings.MONGODB_COLLECTION_NAME,
text_key="chunk",
search_kwargs={"k": k},
)
return retriever