File size: 3,888 Bytes
1e732dd
 
 
fd5543a
 
 
 
1e732dd
 
 
 
 
 
 
 
 
 
 
fd5543a
1e732dd
fd5543a
 
 
 
 
 
 
1e732dd
fd5543a
696f787
 
 
fd5543a
1e732dd
 
 
fd5543a
 
 
1e732dd
fd5543a
1e732dd
fd5543a
 
1e732dd
fd5543a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e732dd
fd5543a
 
1e732dd
fd5543a
 
 
9659593
fd5543a
 
 
 
 
 
 
1e732dd
fd5543a
 
 
 
 
 
 
9659593
fd5543a
 
 
 
 
 
 
1e732dd
fd5543a
 
1e732dd
 
fd5543a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
MediGuard AI — Retrieve Node

Performs document retrieval using the best available backend:
  1. Generic retriever (FAISS, OpenSearch wrapper, etc.)
  2. OpenSearch hybrid search (BM25 + KNN)
  3. BM25 keyword fallback
"""

from __future__ import annotations

import logging
from typing import Any

logger = logging.getLogger(__name__)


def retrieve_node(state: dict, *, context: Any) -> dict:
    """Retrieve documents using the best available backend.

    Priority:
      1. context.retriever (generic BaseRetriever — works with FAISS & OpenSearch)
      2. context.opensearch_client + context.embedding_service (hybrid search)
      3. BM25 keyword fallback
      4. Empty list
    """
    query = state.get("rewritten_query") or state.get("query", "")
    cache_key = f"retrieve:{query}"

    if context.tracer:
        context.tracer.trace(name="retrieve_node", metadata={"query": query})

    # 1. Try cache
    if context.cache:
        cached = context.cache.get(cache_key)
        if cached is not None:
            logger.info("Cache HIT for query: %s…", query[:50])
            attempts = state.get("retrieval_attempts", 0) + 1
            return {"retrieved_documents": cached, "retrieval_attempts": attempts}

    documents: list = []

    # 2. Generic retriever (FAISS, OpenSearch wrapper, etc.)
    if getattr(context, "retriever", None) is not None:
        try:
            results = context.retriever.retrieve(query, top_k=8)
            documents = [
                {
                    "content": getattr(r, "content", ""),
                    "metadata": getattr(r, "metadata", {}),
                    "score": getattr(r, "score", 0.0),
                }
                for r in results
            ]
            backend = getattr(context.retriever, "backend_name", "unknown")
            logger.info("Retrieved %d docs via %s", len(documents), backend)
        except Exception as exc:
            logger.warning("Retriever failed (%s), trying OpenSearch fallback…", exc)

    # 3. OpenSearch hybrid fallback
    if not documents and context.opensearch_client and context.embedding_service:
        try:
            embedding = context.embedding_service.embed_query(query)
            raw_hits = context.opensearch_client.search_hybrid(
                query_text=query,
                query_vector=embedding,
                top_k=8,
            )
            documents = [
                {
                    "content": h.get("_source", {}).get("chunk_text", ""),
                    "metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
                    "score": h.get("_score", 0.0),
                }
                for h in raw_hits
            ]
            logger.info("Retrieved %d docs via OpenSearch hybrid", len(documents))
        except Exception as exc:
            logger.error("OpenSearch retrieval failed: %s", exc)

    # 4. Optional BM25 fallback if still nothing
    if not documents and context.opensearch_client:
        try:
            raw_hits = context.opensearch_client.search_bm25(query_text=query, top_k=8)
            documents = [
                {
                    "content": h.get("_source", {}).get("chunk_text", ""),
                    "metadata": {k: v for k, v in h.get("_source", {}).items() if k != "chunk_text"},
                    "score": h.get("_score", 0.0),
                }
                for h in raw_hits
            ]
            logger.info("Retrieved %d docs via BM25 fallback", len(documents))
        except Exception as exc:
            logger.error("BM25 fallback also failed: %s", exc)

    # 5. Store in cache (5 min TTL)
    if context.cache and documents:
        context.cache.set(cache_key, documents, ttl=300)

    attempts = state.get("retrieval_attempts", 0) + 1
    return {"retrieved_documents": documents, "retrieval_attempts": attempts}