File size: 3,799 Bytes
b27eb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8223f74
b27eb78
8223f74
 
 
b27eb78
 
 
 
 
 
 
 
 
8223f74
 
 
 
b27eb78
 
 
 
 
 
 
 
 
 
8223f74
 
 
 
 
 
 
 
 
 
 
 
 
b27eb78
 
 
8223f74
 
 
 
 
 
 
 
 
 
 
 
 
 
b27eb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Literal, Union

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers import (
    MongoDBAtlasHybridSearchRetriever,
    MongoDBAtlasParentDocumentRetriever,
)
from loguru import logger

from second_brain_online.config import settings

from .embeddings import EmbeddingModelType, EmbeddingsModel, get_embedding_model
from .splitters import get_splitter

# Add these type definitions at the top of the file
RetrieverType = Literal["contextual", "parent", "contextual_reranked", "parent_reranked"]
RetrieverModel = Union[
    MongoDBAtlasHybridSearchRetriever, 
    MongoDBAtlasParentDocumentRetriever,
    "RerankingRetriever"
]


def get_retriever(
    embedding_model_id: str,
    embedding_model_type: EmbeddingModelType = "huggingface",
    retriever_type: RetrieverType = "contextual",
    k: int = 3,
    device: str = "cpu",
    enable_reranking: bool = False,
    rerank_model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2",
    stage1_limit: int = 50,
    final_k: int = 10,
) -> RetrieverModel:
    logger.info(
        f"Getting '{retriever_type}' retriever for '{embedding_model_type}' - '{embedding_model_id}' on '{device}' "
        f"with {k} top results"
    )

    embedding_model = get_embedding_model(
        embedding_model_id, embedding_model_type, device
    )

    # Determine base retriever type
    base_retriever_type = retriever_type
    if retriever_type in ["contextual_reranked", "parent_reranked"]:
        base_retriever_type = retriever_type.replace("_reranked", "")
        enable_reranking = True
    else:
        enable_reranking = enable_reranking

    # Create base retriever
    if base_retriever_type == "contextual":
        base_retriever = get_hybrid_search_retriever(embedding_model, k)
    elif base_retriever_type == "parent":
        base_retriever = get_parent_document_retriever(embedding_model, k)
    else:
        raise ValueError(f"Invalid retriever type: {retriever_type}")

    # Wrap with re-ranking if enabled
    if enable_reranking:
        from second_brain_offline.application.rag.reranker import RerankingRetriever
        logger.info(f"Enabling re-ranking with model: {rerank_model_name}")
        logger.info(f"Stage 1 limit: {stage1_limit}, Final k: {final_k}")
        return RerankingRetriever(
            base_retriever=base_retriever,
            rerank_model_name=rerank_model_name,
            stage1_limit=stage1_limit,
            final_k=final_k
        )
    
    return base_retriever


def get_hybrid_search_retriever(
    embedding_model: EmbeddingsModel, k: int
) -> MongoDBAtlasHybridSearchRetriever:
    vectorstore = MongoDBAtlasVectorSearch.from_connection_string(
        connection_string=settings.MONGODB_URI,
        embedding=embedding_model,
        namespace=f"{settings.MONGODB_DATABASE_NAME}.{settings.MONGODB_COLLECTION_NAME}",
        text_key="chunk",
        embedding_key="embedding",
        relevance_score_fn="dotProduct",
    )

    retriever = MongoDBAtlasHybridSearchRetriever(
        vectorstore=vectorstore,
        search_index_name="chunk_text_search",
        top_k=k,
        vector_penalty=50,
        fulltext_penalty=50,
    )

    return retriever


def get_parent_document_retriever(
    embedding_model: EmbeddingsModel, k: int = 3
) -> MongoDBAtlasParentDocumentRetriever:
    retriever = MongoDBAtlasParentDocumentRetriever.from_connection_string(
        connection_string=settings.MONGODB_URI,
        embedding_model=embedding_model,
        child_splitter=get_splitter(200),
        parent_splitter=get_splitter(800),
        database_name=settings.MONGODB_DATABASE_NAME,
        collection_name=settings.MONGODB_COLLECTION_NAME,
        text_key="chunk",
        search_kwargs={"k": k},
    )

    return retriever