""" Medical Q&A UI - BM25 + Dense Retrieval Models WITH DISK CACHING This version caches the indexes to disk for fast startup (30 seconds vs 5-8 minutes!) """ import gradio as gr from typing import Dict, List from pathlib import Path import pickle import hashlib import json import subprocess import sys from retriever.index_bm25 import BM25Index from retriever.index_dense import DenseIndex from retriever.ingest import load_jsonl from retriever.rrf import rrf from team.interfaces import Candidate # Cache directory CACHE_DIR = Path("cache") CACHE_DIR.mkdir(exist_ok=True) # Ensure embeddings cache directory exists (for Dense index) EMBEDDINGS_CACHE_DIR = Path(".cache/embeddings") EMBEDDINGS_CACHE_DIR.mkdir(parents=True, exist_ok=True) print("=" * 70) print(" Medical Document Retrieval System (CACHED VERSION)") print(" Using BM25 + Dense Embeddings + RRF Fusion") print(" With disk caching for fast startup!") print("=" * 70) def _ensure_corpora_exist(): """Build corpora files if they don't exist""" data_dir = Path("data/corpora") required_files = [ data_dir / "medical_qa.jsonl", data_dir / "miriad_text.jsonl", data_dir / "unidoc_qa.jsonl" ] if all(f.exists() for f in required_files): return # All files exist print("\n" + "=" * 70) print("⚠️ Corpora files not found. Building them now...") print(" This will take 2-3 minutes on first launch.") print("=" * 70 + "\n") try: # Run build_corpora.py subprocess.run( [sys.executable, "adapters/build_corpora.py"], check=True, capture_output=False ) print("\n✓ Corpora files built successfully!\n") except subprocess.CalledProcessError as e: print(f"\n✗ Failed to build corpora: {e}") raise RuntimeError("Could not build corpora files. Please run 'python adapters/build_corpora.py' manually.") def _default_corpora_config() -> Dict[str, dict]: return { "medical_qa": {"path": "data/corpora/medical_qa.jsonl", "text_fields": ["question", "answer", "title"]}, "miriad": {"path": "data/corpora/miriad_text.jsonl", "text_fields": ["question", "answer", "title"]}, "unidoc": {"path": "data/corpora/unidoc_qa.jsonl", "text_fields": ["question", "answer", "title"]}, } def _available(cfg: Dict[str, dict]) -> Dict[str, dict]: return {k: v for k, v in cfg.items() if Path(v["path"]).exists()} def _get_cache_key(corpora_config: Dict[str, dict]) -> str: """Generate a unique cache key based on corpora config""" config_str = json.dumps(corpora_config, sort_keys=True) return hashlib.md5(config_str.encode()).hexdigest() class CachedRetriever: """Retriever with disk caching for BM25 and documents (Dense has its own caching)""" def __init__(self, corpora_config: Dict[str, dict], use_reranker: bool = False): self.corpora_config = corpora_config self.use_reranker = use_reranker self.cache_key = _get_cache_key(corpora_config) # Cache file paths self.bm25_cache = CACHE_DIR / f"bm25_{self.cache_key}.pkl" self.docs_cache = CACHE_DIR / f"docs_{self.cache_key}.pkl" # Note: Dense index uses its own caching in .cache/embeddings/ # Load or build indexes self.docs_all = self._load_or_build_docs() self.bm25 = self._load_or_build_bm25() self.dense = self._load_or_build_dense() def _load_or_build_docs(self) -> List: """Load documents from cache or build from scratch""" if self.docs_cache.exists(): print(f"Loading documents from cache... ({self.docs_cache.name})") try: with open(self.docs_cache, 'rb') as f: docs_all = pickle.load(f) print(f" ✓ Loaded {len(docs_all)} documents from cache") return docs_all except Exception as e: print(f" ✗ Cache load failed: {e}") print(" → Rebuilding documents...") print("Building documents from corpora files...") docs_all = [] for name, cfg in self.corpora_config.items(): print(f" Loading {name}...") docs = load_jsonl(cfg["path"], tuple(cfg.get("text_fields", ("question", "answer")))) docs_all.extend(docs) # Save to cache print(f"Saving documents to cache... ({len(docs_all)} docs)") with open(self.docs_cache, 'wb') as f: pickle.dump(docs_all, f) return docs_all def _load_or_build_bm25(self) -> BM25Index: """Load BM25 index from cache or build from scratch""" if self.bm25_cache.exists(): print(f"Loading BM25 index from cache... ({self.bm25_cache.name})") try: with open(self.bm25_cache, 'rb') as f: bm25_index = pickle.load(f) print(f" ✓ BM25 index loaded from cache") return bm25_index except Exception as e: print(f" ✗ Cache load failed: {e}") print(" → Rebuilding BM25 index...") print("Building BM25 index from scratch...") bm25_index = BM25Index(self.docs_all) # Save to cache print(f"Saving BM25 index to cache...") with open(self.bm25_cache, 'wb') as f: pickle.dump(bm25_index, f) return bm25_index def _load_or_build_dense(self) -> DenseIndex: """Build Dense index (note: Dense index has its own internal caching)""" print("Initializing Dense index (uses internal caching)...") # DenseIndex has its own caching system in .cache/embeddings/ # We don't need to pickle it - just let it build/load from its own cache dense_index = DenseIndex(self.docs_all) print(f" ✓ Dense index ready") return dense_index # Ensure corpora files exist (auto-build if missing) _ensure_corpora_exist() # Initialize cached retriever (fast if cached, slow first time) print("\nInitializing retrieval system...") cfg = _available(_default_corpora_config()) if not cfg: raise RuntimeError("No corpora files found in data/corpora. Build them first.") retriever = CachedRetriever(corpora_config=cfg, use_reranker=False) print("\n✓ Retrieval system ready!") print(f" Total documents indexed: {len(retriever.docs_all):,}") print("=" * 70) def get_candidates_cached(query: str, k_retrieve: int = 50) -> List[Candidate]: """ Returns top-N fused candidates with component scores (bm25, dense, rrf). Uses the cached retriever for fast queries. """ # Get separate result lists (doc, score) bm = retriever.bm25.search(query, k=max(k_retrieve, 100)) de = retriever.dense.search(query, k=max(k_retrieve, 100)) # Maps for score lookup bm_map = {d.id: float(s) for d, s in bm} de_map = {d.id: float(s) for d, s in de} # Fuse and pick candidate set fused = rrf([bm, de], k=max(k_retrieve, 50)) # Compute RRF per candidate using rank positions K = 60 bm_rank = {d.id: i for i, (d, _) in enumerate(bm)} de_rank = {d.id: i for i, (d, _) in enumerate(de)} out: List[Candidate] = [] for doc, _ in fused[:k_retrieve]: rrf_score = 0.0 if doc.id in bm_rank: rrf_score += 1.0 / (K + bm_rank[doc.id] + 1) if doc.id in de_rank: rrf_score += 1.0 / (K + de_rank[doc.id] + 1) out.append(Candidate( id=doc.id, title=doc.title or "", text=doc.text, meta=doc.meta or {}, bm25=bm_map.get(doc.id, 0.0), dense=de_map.get(doc.id, 0.0), rrf=rrf_score, )) # Baseline order: RRF out.sort(key=lambda c: c.rrf, reverse=True) return out def retrieve_documents(query, num_results=5): """Retrieve relevant medical documents using your team's models""" if not query or not query.strip(): return """

How to Use

Enter a medical query and we'll find relevant documents using BM25 + Dense retrieval with RRF fusion.

Example: "headache with blurred vision" or "symptoms of diabetes"

""" try: # Use cached retrieval system (fast!) hits = get_candidates_cached(query=query, k_retrieve=num_results) if not hits: return """

No Results Found

Try rephrasing your query or using different medical terms.

""" # Build results HTML result_html = f"""

Found {len(hits)} Relevant Medical Documents

Retrieved using: BM25 + Dense Embeddings + RRF Fusion (CACHED)

""" for i, hit in enumerate(hits, 1): title = hit.title if hit.title and hit.title.strip() else None source = hit.meta.get('source', 'Unknown') if hit.meta else 'Unknown' # Check if we have separate question/answer fields in metadata question = hit.meta.get('question', '') if hit.meta else '' answer = hit.meta.get('answer', '') if hit.meta else '' # If we have separate Q&A, format them nicely if question and answer: content_html = f"""
Question:

{question}

Answer:

{answer[:500] + ("..." if len(answer) > 500 else "")}

""" else: # Fallback to combined text text = hit.text[:500] + ("..." if len(hit.text) > 500 else "") content_html = f'

{text}

' # Display relevance scores bm25_score = hit.bm25 dense_score = hit.dense rrf_score = hit.rrf # Build title HTML only if title exists title_html = f'

{title}

' if title else '' result_html += f"""

Document #{i}

{source}
{title_html} {content_html}
BM25
{bm25_score:.4f}
Dense
{dense_score:.4f}
RRF Fusion
{rrf_score:.4f}
""" return result_html except Exception as e: return f"""

Error

{str(e)}

""" # Create Gradio interface with gr.Blocks(title="Medical Document Retrieval (Cached)") as demo: gr.Markdown(""" # Medical Document Retrieval System (CACHED VERSION) **Models:** - BM25 Index (keyword-based retrieval) - Dense Embeddings (embeddinggemma-300m-medical) - RRF Fusion (combines both approaches) ### Features: - Searches across 10,000+ medical documents - Shows relevance scores from each model component - Returns the most relevant medical information """) with gr.Row(): with gr.Column(): query_input = gr.Textbox( label="Enter your medical query", placeholder="Example: headache with blurred vision", lines=2 ) num_results = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Number of results to retrieve" ) submit_btn = gr.Button("Retrieve Documents", variant="primary", size="lg") output_html = gr.HTML(label="Search Results") submit_btn.click( fn=retrieve_documents, inputs=[query_input, num_results], outputs=output_html ) gr.Examples( examples=[ "headache with blurred vision", "symptoms of diabetes", "chest pain when exercising", "treatment for high blood pressure", "causes of chronic fatigue", ], inputs=query_input, label="Try these example queries:" ) gr.Markdown(""" --- ### Technical Details - **BM25**: Statistical keyword matching (TF-IDF based) - **Dense**: Semantic search using transformer embeddings - **RRF Fusion**: Reciprocal Rank Fusion combines both methods - **Caching**: Indexes saved to disk in `cache/` folder for fast reloading *Note: First launch builds and caches indexes (5-8 min). After that, startup takes only ~30 seconds!* """) print("\nOpening web interface...") print("=" * 70) if __name__ == "__main__": # Auto-detect environment: HuggingFace Spaces vs local import os is_spaces = os.getenv("SPACE_ID") is not None if is_spaces: # HuggingFace Spaces: listen on all interfaces, default port demo.launch(server_name="0.0.0.0", server_port=7860) else: # Local: standard config print(" Local access: http://127.0.0.1:7863") demo.launch(server_name="127.0.0.1", server_port=7863)