Spaces:
Running
Running
Enhance API security and functionality by adding authentication middleware and session management. Updated app.py to include the new auth router and integrated authentication checks for protected endpoints. Modified requirements.txt to include necessary libraries for session handling. Updated .env.example to include authentication credentials. Improved retrieval functions with query expansion for better medical term matching and enriched context in responses.
ddc9c77
| """ | |
| Context Enrichment Module for Medical RAG | |
| This module enriches retrieved documents with surrounding context (adjacent pages) | |
| to provide comprehensive information for expert medical professionals. | |
| """ | |
| from typing import List, Dict, Set, Optional | |
| from langchain.schema import Document | |
| from pathlib import Path | |
| from .config import logger | |
| class ContextEnricher: | |
| """ | |
| Enriches retrieved documents with surrounding pages for richer context. | |
| """ | |
| def __init__(self, cache_size: int = 100): | |
| """ | |
| Initialize context enricher with document cache. | |
| Args: | |
| cache_size: Maximum number of source documents to cache | |
| """ | |
| self._document_cache: Dict[str, List[Document]] = {} | |
| self._cache_size = cache_size | |
| def enrich_documents( | |
| self, | |
| retrieved_docs: List[Document], | |
| pages_before: int = 1, | |
| pages_after: int = 1, | |
| max_enriched_docs: int = 5 | |
| ) -> List[Document]: | |
| """ | |
| Enrich retrieved documents by adding separate context pages. | |
| Args: | |
| retrieved_docs: List of retrieved documents | |
| pages_before: Number of pages to include before each document | |
| pages_after: Number of pages to include after each document | |
| max_enriched_docs: Maximum number of documents to enrich (top results) | |
| Returns: | |
| List with original documents + separate context page documents | |
| """ | |
| if not retrieved_docs: | |
| return [] | |
| result_docs = [] | |
| processed_sources = set() | |
| enriched_count = 0 | |
| # Only enrich top documents to avoid overwhelming context | |
| docs_to_enrich = retrieved_docs[:max_enriched_docs] | |
| for doc in docs_to_enrich: | |
| try: | |
| # Get source information | |
| source = doc.metadata.get('source', 'unknown') | |
| page_num = doc.metadata.get('page_number', 1) | |
| # Skip if already processed this source-page combination | |
| source_page_key = f"{source}_{page_num}" | |
| if source_page_key in processed_sources: | |
| continue | |
| processed_sources.add(source_page_key) | |
| # Get surrounding pages | |
| surrounding_docs = self._get_surrounding_pages( | |
| doc, | |
| pages_before, | |
| pages_after | |
| ) | |
| if surrounding_docs: | |
| # Add separate documents for each page | |
| page_docs = self._create_separate_page_documents( | |
| doc, | |
| surrounding_docs, | |
| pages_before, | |
| pages_after | |
| ) | |
| result_docs.extend(page_docs) | |
| enriched_count += 1 | |
| # Log enrichment details | |
| page_numbers = [int(d.metadata.get('page_number', 0)) for d in page_docs] | |
| logger.debug(f"Enriched {source} page {page_num} with pages: {page_numbers}") | |
| else: | |
| # No surrounding pages found, add original with empty enrichment metadata | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| except Exception as e: | |
| logger.warning(f"Could not enrich document from {doc.metadata.get('source')}: {e}") | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| # Add remaining documents without enrichment | |
| for doc in retrieved_docs[max_enriched_docs:]: | |
| original_with_metadata = self._add_empty_enrichment_metadata(doc) | |
| result_docs.append(original_with_metadata) | |
| logger.info(f"Enriched {enriched_count} documents with surrounding context pages") | |
| return result_docs | |
| def _get_surrounding_pages( | |
| self, | |
| doc: Document, | |
| pages_before: int, | |
| pages_after: int | |
| ) -> List[Document]: | |
| """ | |
| Get surrounding pages for a document. | |
| Args: | |
| doc: Original document | |
| pages_before: Number of pages before | |
| pages_after: Number of pages after | |
| Returns: | |
| List of surrounding documents (including original), deduplicated by page number | |
| """ | |
| source = doc.metadata.get('source', 'unknown') | |
| page_num = doc.metadata.get('page_number', 1) | |
| provider = doc.metadata.get('provider', 'unknown') | |
| disease = doc.metadata.get('disease', 'unknown') | |
| # Try to get full document from cache or load it | |
| full_doc_pages = self._get_full_document(source, provider, disease) | |
| if not full_doc_pages: | |
| return [] | |
| # Find the target page and surrounding pages | |
| target_page = int(page_num) if isinstance(page_num, (int, str)) else 1 | |
| # Use a dict to deduplicate by page number (keep first occurrence) | |
| pages_dict = {} | |
| for page_doc in full_doc_pages: | |
| doc_page_num = page_doc.metadata.get('page_number', 0) | |
| if isinstance(doc_page_num, str): | |
| try: | |
| doc_page_num = int(doc_page_num) | |
| except: | |
| continue | |
| # Include pages within range | |
| if target_page - pages_before <= doc_page_num <= target_page + pages_after: | |
| # Only add if not already present (deduplication) | |
| if doc_page_num not in pages_dict: | |
| pages_dict[doc_page_num] = page_doc | |
| # Return sorted by page number | |
| surrounding = [pages_dict[pn] for pn in sorted(pages_dict.keys())] | |
| return surrounding | |
| def _get_full_document( | |
| self, | |
| source: str, | |
| provider: str, | |
| disease: str | |
| ) -> Optional[List[Document]]: | |
| """ | |
| Get full document pages from chunks cache. | |
| Args: | |
| source: Source filename | |
| provider: Provider name | |
| disease: Disease name | |
| Returns: | |
| List of all pages in the document, or None if not found | |
| """ | |
| cache_key = f"{provider}_{disease}_{source}" | |
| # Check cache | |
| if cache_key in self._document_cache: | |
| return self._document_cache[cache_key] | |
| # Load from chunks cache instead of trying to reload PDFs | |
| try: | |
| from . import utils | |
| # Load all chunks | |
| all_chunks = utils.load_chunks() | |
| if not all_chunks: | |
| logger.debug(f"No chunks available for enrichment") | |
| return None | |
| # Filter chunks for this specific document | |
| doc_pages = [] | |
| for chunk in all_chunks: | |
| chunk_source = chunk.metadata.get('source', '') | |
| chunk_provider = chunk.metadata.get('provider', '') | |
| chunk_disease = chunk.metadata.get('disease', '') | |
| # Match by source, provider, and disease | |
| if (chunk_source == source and | |
| chunk_provider == provider and | |
| chunk_disease == disease): | |
| doc_pages.append(chunk) | |
| if not doc_pages: | |
| logger.debug(f"Could not find chunks for document: {source} (Provider: {provider}, Disease: {disease})") | |
| return None | |
| # Sort by page number | |
| doc_pages.sort(key=lambda d: int(d.metadata.get('page_number', 0))) | |
| # Cache it (with size limit) | |
| if len(self._document_cache) >= self._cache_size: | |
| # Remove oldest entry | |
| self._document_cache.pop(next(iter(self._document_cache))) | |
| self._document_cache[cache_key] = doc_pages | |
| logger.debug(f"Loaded {len(doc_pages)} pages for {source} from chunks cache") | |
| return doc_pages | |
| except Exception as e: | |
| logger.warning(f"Error loading document from chunks cache {source}: {e}") | |
| return None | |
| def _create_separate_page_documents( | |
| self, | |
| original_doc: Document, | |
| surrounding_docs: List[Document], | |
| pages_before: int, | |
| pages_after: int | |
| ) -> List[Document]: | |
| """ | |
| Create separate document objects for original page and context pages. | |
| Args: | |
| original_doc: Original retrieved document | |
| surrounding_docs: List of surrounding documents | |
| pages_before: Number of pages before | |
| pages_after: Number of pages after | |
| Returns: | |
| List of separate documents (context pages + original page + context pages) | |
| """ | |
| # Sort by page number | |
| sorted_docs = sorted( | |
| surrounding_docs, | |
| key=lambda d: int(d.metadata.get('page_number', 0)) | |
| ) | |
| original_page = int(original_doc.metadata.get('page_number', 1)) | |
| result_docs = [] | |
| for doc in sorted_docs: | |
| page_num = int(doc.metadata.get('page_number', 0)) | |
| # Determine if this is a context page or the original page | |
| is_context_page = (page_num != original_page) | |
| # Create document with appropriate metadata | |
| page_doc = Document( | |
| page_content=doc.page_content, | |
| metadata={ | |
| **doc.metadata, | |
| 'context_enrichment': is_context_page, | |
| 'enriched': False, | |
| 'pages_included': [], | |
| 'primary_page': None, | |
| 'context_pages_before': None, | |
| 'context_pages_after': None, | |
| } | |
| ) | |
| result_docs.append(page_doc) | |
| return result_docs | |
| def _add_empty_enrichment_metadata(self, doc: Document) -> Document: | |
| """ | |
| Add empty enrichment metadata fields to a document. | |
| Args: | |
| doc: Original document | |
| Returns: | |
| Document with enrichment metadata fields set to default values | |
| """ | |
| return Document( | |
| page_content=doc.page_content, | |
| metadata={ | |
| **doc.metadata, | |
| 'enriched': False, | |
| 'pages_included': [], | |
| 'primary_page': None, | |
| 'context_pages_before': None, | |
| 'context_pages_after': None, | |
| } | |
| ) | |
| # Global enricher instance | |
| _context_enricher = ContextEnricher(cache_size=100) | |
| def enrich_retrieved_documents( | |
| documents: List[Document], | |
| pages_before: int = 1, | |
| pages_after: int = 1, | |
| max_enriched: int = 5 | |
| ) -> List[Document]: | |
| """ | |
| Convenience function to enrich retrieved documents. | |
| Args: | |
| documents: Retrieved documents | |
| pages_before: Number of pages to include before each document | |
| pages_after: Number of pages to include after each document | |
| max_enriched: Maximum number of documents to enrich | |
| Returns: | |
| Enriched documents with surrounding context | |
| """ | |
| return _context_enricher.enrich_documents( | |
| documents, | |
| pages_before=pages_before, | |
| pages_after=pages_after, | |
| max_enriched_docs=max_enriched | |
| ) | |
| def get_context_enricher() -> ContextEnricher: | |
| """Get the global context enricher instance.""" | |
| return _context_enricher | |