Spaces:
Sleeping
Sleeping
File size: 3,888 Bytes
1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 696f787 fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 9659593 fd5543a 1e732dd fd5543a 9659593 fd5543a 1e732dd fd5543a 1e732dd fd5543a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """
MediGuard AI — Retrieve Node
Performs document retrieval using the best available backend:
1. Generic retriever (FAISS, OpenSearch wrapper, etc.)
2. OpenSearch hybrid search (BM25 + KNN)
3. BM25 keyword fallback
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def retrieve_node(state: dict, *, context: Any) -> dict:
"""Retrieve documents using the best available backend.
Priority:
1. context.retriever (generic BaseRetriever — works with FAISS & OpenSearch)
2. context.opensearch_client + context.embedding_service (hybrid search)
3. BM25 keyword fallback
4. Empty list
"""
query = state.get("rewritten_query") or state.get("query", "")
cache_key = f"retrieve:{query}"
if context.tracer:
context.tracer.trace(name="retrieve_node", metadata={"query": query})
# 1. Try cache
if context.cache:
cached = context.cache.get(cache_key)
if cached is not None:
logger.info("Cache HIT for query: %s…", query[:50])
attempts = state.get("retrieval_attempts", 0) + 1
return {"retrieved_documents": cached, "retrieval_attempts": attempts}
documents: list = []
# 2. Generic retriever (FAISS, OpenSearch wrapper, etc.)
if getattr(context, "retriever", None) is not None:
try:
results = context.retriever.retrieve(query, top_k=8)
documents = [
{
"content": getattr(r, "content", ""),
"metadata": getattr(r, "metadata", {}),
"score": getattr(r, "score", 0.0),
}
for r in results
]
backend = getattr(context.retriever, "backend_name", "unknown")
logger.info("Retrieved %d docs via %s", len(documents), backend)
except Exception as exc:
logger.warning("Retriever failed (%s), trying OpenSearch fallback…", exc)
# 3. OpenSearch hybrid fallback
if not documents and context.opensearch_client and context.embedding_service:
try:
embedding = context.embedding_service.embed_query(query)
raw_hits = context.opensearch_client.search_hybrid(
query_text=query,
query_vector=embedding,
top_k=8,
)
documents = [
{
"content": h.get("_source", {}).get("chunk_text", ""),
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
"score": h.get("_score", 0.0),
}
for h in raw_hits
]
logger.info("Retrieved %d docs via OpenSearch hybrid", len(documents))
except Exception as exc:
logger.error("OpenSearch retrieval failed: %s", exc)
# 4. Optional BM25 fallback if still nothing
if not documents and context.opensearch_client:
try:
raw_hits = context.opensearch_client.search_bm25(query_text=query, top_k=8)
documents = [
{
"content": h.get("_source", {}).get("chunk_text", ""),
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
"score": h.get("_score", 0.0),
}
for h in raw_hits
]
logger.info("Retrieved %d docs via BM25 fallback", len(documents))
except Exception as exc:
logger.error("BM25 fallback also failed: %s", exc)
# 5. Store in cache (5 min TTL)
if context.cache and documents:
context.cache.set(cache_key, documents, ttl=300)
attempts = state.get("retrieval_attempts", 0) + 1
return {"retrieved_documents": documents, "retrieval_attempts": attempts}
|