Spaces:
Sleeping
Sleeping
Enhance API security and functionality by adding authentication middleware and session management. Updated app.py to include the new auth router and integrated authentication checks for protected endpoints. Modified requirements.txt to include necessary libraries for session handling. Updated .env.example to include authentication credentials. Improved retrieval functions with query expansion for better medical term matching and enriched context in responses.
ddc9c77
| from concurrent.futures import ThreadPoolExecutor | |
| from typing import List, Optional | |
| from . import utils | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.schema import Document | |
| from .config import logger | |
| from .tracing import traceable | |
| from .query_expansion import expand_medical_query, MultiQueryRetriever | |
| # Global configuration for retrieval parameters | |
| DEFAULT_K_VECTOR = 3 # Number of documents to retrieve from vector search | |
| DEFAULT_K_BM25 = 2 # Number of documents to retrieve from BM25 search | |
| # Global variables for lazy loading | |
| _vector_store = None | |
| _company_chunks = None | |
| _vector_retriever = None | |
| _bm25_retriever = None | |
| _hybrid_retriever = None | |
| _initialized = False | |
| def _ensure_initialized(): | |
| """Initialize retrievers on first use (lazy loading for faster startup)""" | |
| global _vector_store, _company_chunks, _vector_retriever, _bm25_retriever, _hybrid_retriever, _initialized | |
| if _initialized: | |
| return | |
| logger.info("π Initializing retrievers (first time use)...") | |
| # Process any new data and update vector store and chunks cache | |
| try: | |
| logger.info("π Processing new data and updating vector store if needed...") | |
| _vector_store = utils.process_new_data_and_update_vector_store() | |
| if _vector_store is None: | |
| # Fall back to load existing if processing found no new files | |
| _vector_store = utils.load_company_vector_store() | |
| if _vector_store is None: | |
| # As a last resort, create from whatever is already in cache (if any) | |
| logger.info("βΉοΈ No vector store found; attempting creation from cached chunks...") | |
| cached_chunks = utils.load_chunks() or [] | |
| if cached_chunks: | |
| _vector_store = utils.create_company_vector_store(cached_chunks) | |
| logger.info("β Vector store created from cached chunks") | |
| else: | |
| logger.warning("β οΈ No data available to build a vector store. Retrievers may not function until data is provided.") | |
| except Exception as e: | |
| logger.error(f"Error preparing vector store: {str(e)}") | |
| raise | |
| # Load merged chunks for BM25 (includes previous + new) | |
| try: | |
| logger.info("π¦ Loading chunks cache for BM25 retriever...") | |
| _company_chunks = utils.load_chunks() or [] | |
| if not _company_chunks: | |
| logger.warning("β οΈ No chunks available for BM25 retriever. BM25 will be empty until data is processed.") | |
| except Exception as e: | |
| logger.error(f"Error loading chunks: {str(e)}") | |
| raise | |
| # Create vector retriever | |
| logger.info("π Creating vector retriever...") | |
| _vector_retriever = _vector_store.as_retriever(search_kwargs={"k": 5}) if _vector_store else None | |
| # Create BM25 retriever | |
| logger.info("π Creating BM25 retriever...") | |
| _bm25_retriever = BM25Retriever.from_documents(_company_chunks) if _company_chunks else None | |
| if _bm25_retriever: | |
| _bm25_retriever.k = 5 | |
| # Create hybrid retriever | |
| logger.info("π Creating hybrid retriever...") | |
| if _vector_retriever and _bm25_retriever: | |
| _hybrid_retriever = EnsembleRetriever( | |
| retrievers=[_bm25_retriever, _vector_retriever], | |
| weights=[0.2, 0.8] | |
| ) | |
| elif _vector_retriever: | |
| logger.warning("βΉοΈ BM25 retriever unavailable; using vector retriever only.") | |
| _hybrid_retriever = _vector_retriever | |
| elif _bm25_retriever: | |
| _hybrid_retriever = _bm25_retriever | |
| else: | |
| raise RuntimeError("Neither vector or BM25 retrievers could be initialized. Provide data under data/new_data and retry.") | |
| _initialized = True | |
| logger.info("β Retrievers initialized successfully.") | |
| def initialize_eagerly(): | |
| """Force initialization of retrievers for background loading""" | |
| _ensure_initialized() | |
| def is_initialized() -> bool: | |
| """Check if retrievers are already initialized""" | |
| return _initialized | |
| # ----------------------------------------------- | |
| # Provider-aware retrieval helper functions | |
| # ----------------------------------------------- | |
| _retrieval_pool = ThreadPoolExecutor(max_workers=4) | |
| def _get_doc_id(doc: Document) -> str: | |
| """Generate unique identifier for a document.""" | |
| source = doc.metadata.get('source', 'unknown') | |
| page = doc.metadata.get('page_number', 'unknown') | |
| content_hash = hash(doc.page_content[:200]) # Hash first 200 chars | |
| return f"{source}_{page}_{content_hash}" | |
| def _match_provider(doc, provider: str) -> bool: | |
| if not provider: | |
| return True | |
| prov = str(doc.metadata.get("provider", "")).strip().lower() | |
| return prov == provider.strip().lower() | |
| def vector_search(query: str, provider: str | None = None, k: int = None, use_query_expansion: bool = True): | |
| """Search FAISS vector store with optional provider metadata filter and query expansion.""" | |
| _ensure_initialized() | |
| if not _vector_store: | |
| return [] | |
| # Use global default if k is not specified | |
| if k is None: | |
| k = DEFAULT_K_VECTOR | |
| try: | |
| # Use query expansion for better medical term matching | |
| if use_query_expansion: | |
| query_variations = expand_medical_query(query, strategy="adaptive", max_variations=3) | |
| logger.debug(f"Expanded query '{query}' into {len(query_variations)} variations") | |
| # Retrieve with each variation and merge | |
| all_docs = [] | |
| seen_ids = set() | |
| for var_query in query_variations: | |
| if provider: | |
| docs = _vector_store.similarity_search(var_query, k=k, filter={"provider": provider}) | |
| else: | |
| docs = _vector_store.similarity_search(var_query, k=k) | |
| # Deduplicate while preserving order | |
| for doc in docs: | |
| doc_id = _get_doc_id(doc) | |
| if doc_id not in seen_ids: | |
| seen_ids.add(doc_id) | |
| all_docs.append(doc) | |
| docs = all_docs[:k * 2] # Return more results due to expansion | |
| else: | |
| # Standard search without expansion | |
| if provider: | |
| docs = _vector_store.similarity_search(query, k=k, filter={"provider": provider}) | |
| else: | |
| docs = _vector_store.similarity_search(query, k=k) | |
| # Ensure provider post-filter in case backend filter is lenient | |
| if provider: | |
| docs = [d for d in docs if _match_provider(d, provider)] | |
| return docs | |
| except Exception as e: | |
| logger.error(f"Vector search failed: {e}") | |
| return [] | |
| def bm25_search(query: str, provider: str | None = None, k: int = None, use_query_expansion: bool = True): | |
| """Search BM25 using the global retriever with query expansion and optional provider filter.""" | |
| _ensure_initialized() | |
| # Use global default if k is not specified | |
| if k is None: | |
| k = DEFAULT_K_BM25 | |
| try: | |
| if not _bm25_retriever: | |
| return [] | |
| # Use query expansion for better medical term matching | |
| if use_query_expansion: | |
| query_variations = expand_medical_query(query, strategy="adaptive", max_variations=3) | |
| logger.debug(f"BM25: Expanded query '{query}' into {len(query_variations)} variations") | |
| # Retrieve with each variation and merge | |
| all_docs = [] | |
| seen_ids = set() | |
| for var_query in query_variations: | |
| _bm25_retriever.k = max(1, k * 2) | |
| docs = _bm25_retriever.get_relevant_documents(var_query) or [] | |
| # Deduplicate while preserving order | |
| for doc in docs: | |
| doc_id = _get_doc_id(doc) | |
| if doc_id not in seen_ids: | |
| seen_ids.add(doc_id) | |
| all_docs.append(doc) | |
| docs = all_docs[:k * 2] # Return more results due to expansion | |
| else: | |
| # Standard search without expansion | |
| _bm25_retriever.k = max(1, k) | |
| docs = _bm25_retriever.get_relevant_documents(query) or [] | |
| if provider: | |
| docs = [d for d in docs if _match_provider(d, provider)] | |
| return docs[:k * 2 if use_query_expansion else k] | |
| except Exception as e: | |
| logger.error(f"BM25 search failed: {e}") | |
| return [] | |
| def hybrid_search(query: str, provider: str | None = None, k_vector: int = None, k_bm25: int = None, use_query_expansion: bool = True): | |
| """Combine vector and BM25 results with query expansion (provider-filtered if provided).""" | |
| _ensure_initialized() # Ensure retrievers are initialized before parallel execution | |
| # Use global defaults if not specified | |
| if k_vector is None: | |
| k_vector = DEFAULT_K_VECTOR | |
| if k_bm25 is None: | |
| k_bm25 = DEFAULT_K_BM25 | |
| f_vector = _retrieval_pool.submit(vector_search, query, provider, k_vector, use_query_expansion) | |
| f_bm25 = _retrieval_pool.submit(bm25_search, query, provider, k_bm25, use_query_expansion) | |
| v_docs = f_vector.result() | |
| b_docs = f_bm25.result() | |
| # Merge uniquely by document ID | |
| seen = set() | |
| merged = [] | |
| for d in v_docs + b_docs: | |
| doc_id = _get_doc_id(d) | |
| if doc_id not in seen: | |
| seen.add(doc_id) | |
| merged.append(d) | |
| logger.info(f"Hybrid search returned {len(merged)} unique documents (query expansion: {use_query_expansion})") | |
| return merged | |