File size: 6,218 Bytes
2a8faae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from concurrent.futures import ThreadPoolExecutor

from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from .config import logger
from .tracing import traceable

# Global variables for lazy loading
_vector_store = None
_company_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False

def _ensure_initialized():
    """Initialize retrievers on first use (lazy loading for faster startup)"""
    global _vector_store, _company_chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
    
    if _initialized:
        return
        
    logger.info("πŸ”„ Initializing retrievers (first time use)...")
    
    # Process any new data and update vector store and chunks cache
    try:
        logger.info("πŸ”„ Processing new data and updating vector store if needed...")
        _vector_store = utils.process_new_data_and_update_vector_store()
        if _vector_store is None:
            # Fall back to load existing if processing found no new files
            _vector_store = utils.load_company_vector_store()
        if _vector_store is None:
            # As a last resort, create from whatever is already in cache (if any)
            logger.info("ℹ️ No vector store found; attempting creation from cached chunks...")
            cached_chunks = utils.load_chunks() or []
            if cached_chunks:
                _vector_store = utils.create_company_vector_store(cached_chunks)
                logger.info("βœ… Vector store created from cached chunks")
            else:
                logger.warning("⚠️ No data available to build a vector store. Retrievers may not function until data is provided.")
    except Exception as e:
        logger.error(f"Error preparing vector store: {str(e)}")
        raise

    # Load merged chunks for BM25 (includes previous + new)
    try:
        logger.info("πŸ“¦ Loading chunks cache for BM25 retriever...")
        _company_chunks = utils.load_chunks() or []
        if not _company_chunks:
            logger.warning("⚠️ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
    except Exception as e:
        logger.error(f"Error loading chunks: {str(e)}")
        raise

    # Create vector retriever
    logger.info("πŸ” Creating vector retriever...")
    _vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None

    # Create BM25 retriever
    logger.info("πŸ“ Creating BM25 retriever...")
    _bm25_retriever = BM25Retriever.from_documents(_company_chunks) if _company_chunks else None
    if _bm25_retriever:
        _bm25_retriever.k = 5

    # Create hybrid retriever
    logger.info("πŸ”„ Creating hybrid retriever...")
    if _vector_retriever and _bm25_retriever:
        _hybrid_retriever = EnsembleRetriever(
            retrievers=[_bm25_retriever, _vector_retriever],
            weights=[0.2, 0.8]
        )
    elif _vector_retriever:
        logger.warning("ℹ️ BM25 retriever unavailable; using vector retriever only.")
        _hybrid_retriever = _vector_retriever
    elif _bm25_retriever:
        _hybrid_retriever = _bm25_retriever
    else:
        raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")

    _initialized = True
    logger.info("βœ… Retrievers initialized successfully.")


def initialize_eagerly():
    """Force initialization of retrievers for background loading"""
    _ensure_initialized()


def is_initialized() -> bool:
    """Check if retrievers are already initialized"""
    return _initialized


# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)


def _match_provider(doc, provider: str) -> bool:
    if not provider:
        return True
    prov = str(doc.metadata.get("provider", "")).strip().lower()
    return prov == provider.strip().lower()


@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = 5):
    """Search FAISS vector store with optional provider metadata filter."""
    _ensure_initialized()
    if not _vector_store:
        return []
    try:
        if provider:
            docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
        else:
            docs = _vector_store.similarity_search(query, k=k)
        # Ensure provider post-filter in case backend filter is lenient
        if provider:
            docs = [d for d in docs if _match_provider(d, provider)]
        return docs
    except Exception as e:
        logger.error(f"Vector search failed: {e}")
        return []


@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = 5):
    """Search BM25 using the global retriever and optionally filter by provider."""
    _ensure_initialized()
    try:
        if not _bm25_retriever:
            return []
        _bm25_retriever.k = max(1, k)
        docs = _bm25_retriever.get_relevant_documents(query) or []
        if provider:
            docs = [d for d in docs if _match_provider(d, provider)]
        return docs[:k]
    except Exception as e:
        logger.error(f"BM25 search failed: {e}")
        return []


def hybrid_search(query: str, provider: str | None = None, k_vector: int = 5, k_bm25: int = 5):
    """Combine vector and BM25 results (provider-filtered if provided)."""
    _ensure_initialized()  # Ensure retrievers are initialized before parallel execution
    f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
    f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)

    v_docs = f_vector.result()
    b_docs = f_bm25.result()
    # Merge uniquely by (source, page_number, snippet)
    seen = set()
    merged = []
    for d in v_docs + b_docs:
        key = (d.metadata.get("source"), d.metadata.get("page_number"), d.page_content[:100])
        if key not in seen:
            seen.add(key)
            merged.append(d)
    return merged