T0X1N's picture
chore: codebase audit and fixes (ruff, mypy, pytest)
9659593
"""
MediGuard AI — Retrieve Node
Performs document retrieval using the best available backend:
1. Generic retriever (FAISS, OpenSearch wrapper, etc.)
2. OpenSearch hybrid search (BM25 + KNN)
3. BM25 keyword fallback
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def retrieve_node(state: dict, *, context: Any) -> dict:
"""Retrieve documents using the best available backend.
Priority:
1. context.retriever (generic BaseRetriever — works with FAISS & OpenSearch)
2. context.opensearch_client + context.embedding_service (hybrid search)
3. BM25 keyword fallback
4. Empty list
"""
query = state.get("rewritten_query") or state.get("query", "")
cache_key = f"retrieve:{query}"
if context.tracer:
context.tracer.trace(name="retrieve_node", metadata={"query": query})
# 1. Try cache
if context.cache:
cached = context.cache.get(cache_key)
if cached is not None:
logger.info("Cache HIT for query: %s…", query[:50])
attempts = state.get("retrieval_attempts", 0) + 1
return {"retrieved_documents": cached, "retrieval_attempts": attempts}
documents: list = []
# 2. Generic retriever (FAISS, OpenSearch wrapper, etc.)
if getattr(context, "retriever", None) is not None:
try:
results = context.retriever.retrieve(query, top_k=8)
documents = [
{
"content": getattr(r, "content", ""),
"metadata": getattr(r, "metadata", {}),
"score": getattr(r, "score", 0.0),
}
for r in results
]
backend = getattr(context.retriever, "backend_name", "unknown")
logger.info("Retrieved %d docs via %s", len(documents), backend)
except Exception as exc:
logger.warning("Retriever failed (%s), trying OpenSearch fallback…", exc)
# 3. OpenSearch hybrid fallback
if not documents and context.opensearch_client and context.embedding_service:
try:
embedding = context.embedding_service.embed_query(query)
raw_hits = context.opensearch_client.search_hybrid(
query_text=query,
query_vector=embedding,
top_k=8,
)
documents = [
{
"content": h.get("_source", {}).get("chunk_text", ""),
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
"score": h.get("_score", 0.0),
}
for h in raw_hits
]
logger.info("Retrieved %d docs via OpenSearch hybrid", len(documents))
except Exception as exc:
logger.error("OpenSearch retrieval failed: %s", exc)
# 4. Optional BM25 fallback if still nothing
if not documents and context.opensearch_client:
try:
raw_hits = context.opensearch_client.search_bm25(query_text=query, top_k=8)
documents = [
{
"content": h.get("_source", {}).get("chunk_text", ""),
"metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
"score": h.get("_score", 0.0),
}
for h in raw_hits
]
logger.info("Retrieved %d docs via BM25 fallback", len(documents))
except Exception as exc:
logger.error("BM25 fallback also failed: %s", exc)
# 5. Store in cache (5 min TTL)
if context.cache and documents:
context.cache.set(cache_key, documents, ttl=300)
attempts = state.get("retrieval_attempts", 0) + 1
return {"retrieved_documents": documents, "retrieval_attempts": attempts}