""" MediGuard AI — Retrieve Node Performs document retrieval using the best available backend: 1. Generic retriever (FAISS, OpenSearch wrapper, etc.) 2. OpenSearch hybrid search (BM25 + KNN) 3. BM25 keyword fallback """ from __future__ import annotations import logging from typing import Any logger = logging.getLogger(__name__) def retrieve_node(state: dict, *, context: Any) -> dict: """Retrieve documents using the best available backend. Priority: 1. context.retriever (generic BaseRetriever — works with FAISS & OpenSearch) 2. context.opensearch_client + context.embedding_service (hybrid search) 3. BM25 keyword fallback 4. Empty list """ query = state.get("rewritten_query") or state.get("query", "") cache_key = f"retrieve:{query}" if context.tracer: context.tracer.trace(name="retrieve_node", metadata={"query": query}) # 1. Try cache if context.cache: cached = context.cache.get(cache_key) if cached is not None: logger.info("Cache HIT for query: %s…", query[:50]) attempts = state.get("retrieval_attempts", 0) + 1 return {"retrieved_documents": cached, "retrieval_attempts": attempts} documents: list = [] # 2. Generic retriever (FAISS, OpenSearch wrapper, etc.) if getattr(context, "retriever", None) is not None: try: results = context.retriever.retrieve(query, top_k=8) documents = [ { "content": getattr(r, "content", ""), "metadata": getattr(r, "metadata", {}), "score": getattr(r, "score", 0.0), } for r in results ] backend = getattr(context.retriever, "backend_name", "unknown") logger.info("Retrieved %d docs via %s", len(documents), backend) except Exception as exc: logger.warning("Retriever failed (%s), trying OpenSearch fallback…", exc) # 3. OpenSearch hybrid fallback if not documents and context.opensearch_client and context.embedding_service: try: embedding = context.embedding_service.embed_query(query) raw_hits = context.opensearch_client.search_hybrid( query_text=query, query_vector=embedding, top_k=8, ) documents = [ { "content": h.get("_source", {}).get("chunk_text", ""), "metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"}, "score": h.get("_score", 0.0), } for h in raw_hits ] logger.info("Retrieved %d docs via OpenSearch hybrid", len(documents)) except Exception as exc: logger.error("OpenSearch retrieval failed: %s", exc) # 4. Optional BM25 fallback if still nothing if not documents and context.opensearch_client: try: raw_hits = context.opensearch_client.search_bm25(query_text=query, top_k=8) documents = [ { "content": h.get("_source", {}).get("chunk_text", ""), "metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"}, "score": h.get("_score", 0.0), } for h in raw_hits ] logger.info("Retrieved %d docs via BM25 fallback", len(documents)) except Exception as exc: logger.error("BM25 fallback also failed: %s", exc) # 5. Store in cache (5 min TTL) if context.cache and documents: context.cache.set(cache_key, documents, ttl=300) attempts = state.get("retrieval_attempts", 0) + 1 return {"retrieved_documents": documents, "retrieval_attempts": attempts}