Spaces:
Running
Running
File size: 6,218 Bytes
2a8faae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from concurrent.futures import ThreadPoolExecutor
from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from .config import logger
from .tracing import traceable
# Global variables for lazy loading
_vector_store = None
_company_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False
def _ensure_initialized():
"""Initialize retrievers on first use (lazy loading for faster startup)"""
global _vector_store, _company_chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
if _initialized:
return
logger.info("π Initializing retrievers (first time use)...")
# Process any new data and update vector store and chunks cache
try:
logger.info("π Processing new data and updating vector store if needed...")
_vector_store = utils.process_new_data_and_update_vector_store()
if _vector_store is None:
# Fall back to load existing if processing found no new files
_vector_store = utils.load_company_vector_store()
if _vector_store is None:
# As a last resort, create from whatever is already in cache (if any)
logger.info("βΉοΈ No vector store found; attempting creation from cached chunks...")
cached_chunks = utils.load_chunks() or []
if cached_chunks:
_vector_store = utils.create_company_vector_store(cached_chunks)
logger.info("β
Vector store created from cached chunks")
else:
logger.warning("β οΈ No data available to build a vector store. Retrievers may not function until data is provided.")
except Exception as e:
logger.error(f"Error preparing vector store: {str(e)}")
raise
# Load merged chunks for BM25 (includes previous + new)
try:
logger.info("π¦ Loading chunks cache for BM25 retriever...")
_company_chunks = utils.load_chunks() or []
if not _company_chunks:
logger.warning("β οΈ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
except Exception as e:
logger.error(f"Error loading chunks: {str(e)}")
raise
# Create vector retriever
logger.info("π Creating vector retriever...")
_vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None
# Create BM25 retriever
logger.info("π Creating BM25 retriever...")
_bm25_retriever = BM25Retriever.from_documents(_company_chunks) if _company_chunks else None
if _bm25_retriever:
_bm25_retriever.k = 5
# Create hybrid retriever
logger.info("π Creating hybrid retriever...")
if _vector_retriever and _bm25_retriever:
_hybrid_retriever = EnsembleRetriever(
retrievers=[_bm25_retriever, _vector_retriever],
weights=[0.2, 0.8]
)
elif _vector_retriever:
logger.warning("βΉοΈ BM25 retriever unavailable; using vector retriever only.")
_hybrid_retriever = _vector_retriever
elif _bm25_retriever:
_hybrid_retriever = _bm25_retriever
else:
raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")
_initialized = True
logger.info("β
Retrievers initialized successfully.")
def initialize_eagerly():
"""Force initialization of retrievers for background loading"""
_ensure_initialized()
def is_initialized() -> bool:
"""Check if retrievers are already initialized"""
return _initialized
# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)
def _match_provider(doc, provider: str) -> bool:
if not provider:
return True
prov = str(doc.metadata.get("provider", "")).strip().lower()
return prov == provider.strip().lower()
@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = 5):
"""Search FAISS vector store with optional provider metadata filter."""
_ensure_initialized()
if not _vector_store:
return []
try:
if provider:
docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
else:
docs = _vector_store.similarity_search(query, k=k)
# Ensure provider post-filter in case backend filter is lenient
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs
except Exception as e:
logger.error(f"Vector search failed: {e}")
return []
@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = 5):
"""Search BM25 using the global retriever and optionally filter by provider."""
_ensure_initialized()
try:
if not _bm25_retriever:
return []
_bm25_retriever.k = max(1, k)
docs = _bm25_retriever.get_relevant_documents(query) or []
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs[:k]
except Exception as e:
logger.error(f"BM25 search failed: {e}")
return []
def hybrid_search(query: str, provider: str | None = None, k_vector: int = 5, k_bm25: int = 5):
"""Combine vector and BM25 results (provider-filtered if provided)."""
_ensure_initialized() # Ensure retrievers are initialized before parallel execution
f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector)
f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25)
v_docs = f_vector.result()
b_docs = f_bm25.result()
# Merge uniquely by (source, page_number, snippet)
seen = set()
merged = []
for d in v_docs + b_docs:
key = (d.metadata.get("source"), d.metadata.get("page_number"), d.page_content[:100])
if key not in seen:
seen.add(key)
merged.append(d)
return merged
|