Lung-Cancer-AI-Advisor / core /retrievers.py
moazx's picture
Update .env.example with OpenAI and LangSmith configuration, modify app.py to dynamically set the port for deployment, enhance CORS middleware to support additional local development origins, and improve document retrieval settings for more comprehensive context in responses.
0a5dcf9
raw
history blame
9.84 kB
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional
from . import utils
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.schema import Document
from .config import logger
from .tracing import traceable
from .query_expansion import expand_medical_query, MultiQueryRetriever
# Global configuration for retrieval parameters
# Increased for more comprehensive context and complete answers
DEFAULT_K_VECTOR = 10 # Number of documents to retrieve from vector search
DEFAULT_K_BM25 = 5 # Number of documents to retrieve from BM25 search
# Global variables for lazy loading
_vector_store = None
_company_chunks = None
_vector_retriever = None
_bm25_retriever = None
_hybrid_retriever = None
_initialized = False
def _ensure_initialized():
"""Initialize retrievers on first use (lazy loading for faster startup)"""
global _vector_store, _company_chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized
if _initialized:
return
logger.info("πŸ”„ Initializing retrievers (first time use)...")
# Process any new data and update vector store and chunks cache
try:
logger.info("πŸ”„ Processing new data and updating vector store if needed...")
_vector_store = utils.process_new_data_and_update_vector_store()
if _vector_store is None:
# Fall back to load existing if processing found no new files
_vector_store = utils.load_company_vector_store()
if _vector_store is None:
# As a last resort, create from whatever is already in cache (if any)
logger.info("ℹ️ No vector store found; attempting creation from cached chunks...")
cached_chunks = utils.load_chunks() or []
if cached_chunks:
_vector_store = utils.create_company_vector_store(cached_chunks)
logger.info("βœ… Vector store created from cached chunks")
else:
logger.warning("⚠️ No data available to build a vector store. Retrievers may not function until data is provided.")
except Exception as e:
logger.error(f"Error preparing vector store: {str(e)}")
raise
# Load merged chunks for BM25 (includes previous + new)
try:
logger.info("πŸ“¦ Loading chunks cache for BM25 retriever...")
_company_chunks = utils.load_chunks() or []
if not _company_chunks:
logger.warning("⚠️ No chunks available for BM25 retriever. BM25 will be empty until data is processed.")
except Exception as e:
logger.error(f"Error loading chunks: {str(e)}")
raise
# Create vector retriever
logger.info("πŸ” Creating vector retriever...")
_vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None
# Create BM25 retriever
logger.info("πŸ“ Creating BM25 retriever...")
_bm25_retriever = BM25Retriever.from_documents(_company_chunks) if _company_chunks else None
if _bm25_retriever:
_bm25_retriever.k = 5
# Create hybrid retriever
logger.info("πŸ”„ Creating hybrid retriever...")
if _vector_retriever and _bm25_retriever:
_hybrid_retriever = EnsembleRetriever(
retrievers=[_bm25_retriever, _vector_retriever],
weights=[0.2, 0.8]
)
elif _vector_retriever:
logger.warning("ℹ️ BM25 retriever unavailable; using vector retriever only.")
_hybrid_retriever = _vector_retriever
elif _bm25_retriever:
_hybrid_retriever = _bm25_retriever
else:
raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.")
_initialized = True
logger.info("βœ… Retrievers initialized successfully.")
def initialize_eagerly():
"""Force initialization of retrievers for background loading"""
_ensure_initialized()
def is_initialized() -> bool:
"""Check if retrievers are already initialized"""
return _initialized
# -----------------------------------------------
# Provider-aware retrieval helper functions
# -----------------------------------------------
_retrieval_pool = ThreadPoolExecutor(max_workers=4)
def _get_doc_id(doc: Document) -> str:
"""Generate unique identifier for a document."""
source = doc.metadata.get('source', 'unknown')
page = doc.metadata.get('page_number', 'unknown')
content_hash = hash(doc.page_content[:200]) # Hash first 200 chars
return f"{source}_{page}_{content_hash}"
def _match_provider(doc, provider: str) -> bool:
if not provider:
return True
prov = str(doc.metadata.get("provider", "")).strip().lower()
return prov == provider.strip().lower()
@traceable(name="VectorRetriever")
def vector_search(query: str, provider: str | None = None, k: int = None, use_query_expansion: bool = True):
"""Search FAISS vector store with optional provider metadata filter and query expansion."""
_ensure_initialized()
if not _vector_store:
return []
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_VECTOR
try:
# Use query expansion for better medical term matching
if use_query_expansion:
query_variations = expand_medical_query(query, strategy="adaptive", max_variations=3)
logger.debug(f"Expanded query '{query}' into {len(query_variations)} variations")
# Retrieve with each variation and merge
all_docs = []
seen_ids = set()
for var_query in query_variations:
if provider:
docs = _vector_store.similarity_search(var_query, k=k, filter={"provider": provider})
else:
docs = _vector_store.similarity_search(var_query, k=k)
# Deduplicate while preserving order
for doc in docs:
doc_id = _get_doc_id(doc)
if doc_id not in seen_ids:
seen_ids.add(doc_id)
all_docs.append(doc)
docs = all_docs[:k * 2] # Return more results due to expansion
else:
# Standard search without expansion
if provider:
docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider})
else:
docs = _vector_store.similarity_search(query, k=k)
# Ensure provider post-filter in case backend filter is lenient
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs
except Exception as e:
logger.error(f"Vector search failed: {e}")
return []
@traceable(name="BM25Retriever")
def bm25_search(query: str, provider: str | None = None, k: int = None, use_query_expansion: bool = True):
"""Search BM25 using the global retriever with query expansion and optional provider filter."""
_ensure_initialized()
# Use global default if k is not specified
if k is None:
k = DEFAULT_K_BM25
try:
if not _bm25_retriever:
return []
# Use query expansion for better medical term matching
if use_query_expansion:
query_variations = expand_medical_query(query, strategy="adaptive", max_variations=3)
logger.debug(f"BM25: Expanded query '{query}' into {len(query_variations)} variations")
# Retrieve with each variation and merge
all_docs = []
seen_ids = set()
for var_query in query_variations:
_bm25_retriever.k = max(1, k * 2)
docs = _bm25_retriever.get_relevant_documents(var_query) or []
# Deduplicate while preserving order
for doc in docs:
doc_id = _get_doc_id(doc)
if doc_id not in seen_ids:
seen_ids.add(doc_id)
all_docs.append(doc)
docs = all_docs[:k * 2] # Return more results due to expansion
else:
# Standard search without expansion
_bm25_retriever.k = max(1, k)
docs = _bm25_retriever.get_relevant_documents(query) or []
if provider:
docs = [d for d in docs if _match_provider(d, provider)]
return docs[:k * 2 if use_query_expansion else k]
except Exception as e:
logger.error(f"BM25 search failed: {e}")
return []
def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None, use_query_expansion: bool = True):
"""Combine vector and BM25 results with query expansion (provider-filtered if provided)."""
_ensure_initialized() # Ensure retrievers are initialized before parallel execution
# Use global defaults if not specified
if k_vector is None:
k_vector = DEFAULT_K_VECTOR
if k_bm25 is None:
k_bm25 = DEFAULT_K_BM25
f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector, use_query_expansion)
f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25, use_query_expansion)
v_docs = f_vector.result()
b_docs = f_bm25.result()
# Merge uniquely by document ID
seen = set()
merged = []
for d in v_docs + b_docs:
doc_id = _get_doc_id(d)
if doc_id not in seen:
seen.add(doc_id)
merged.append(d)
logger.info(f"Hybrid search returned {len(merged)} unique documents (query expansion: {use_query_expansion})")
return merged