File size: 8,698 Bytes
1e732dd
 
 
 
 
 
 
 
 
 
 
696f787
1e732dd
696f787
1e732dd
 
 
 
 
 
696f787
 
1e732dd
 
 
 
 
 
 
696f787
1e732dd
 
 
 
 
696f787
1e732dd
 
 
 
 
 
 
 
 
 
696f787
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
1e732dd
 
696f787
1e732dd
 
 
696f787
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
fd5543a
1e732dd
 
696f787
 
 
1e732dd
 
 
 
 
 
fd5543a
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
1e732dd
 
696f787
 
 
1e732dd
 
 
 
fd5543a
1e732dd
 
 
 
 
 
 
 
 
 
 
fd5543a
696f787
1e732dd
 
696f787
1e732dd
 
696f787
1e732dd
fd5543a
 
1e732dd
 
 
 
696f787
1e732dd
 
 
9659593
1e732dd
 
 
 
 
fd5543a
1e732dd
 
 
 
 
696f787
 
1e732dd
 
 
 
 
 
 
 
 
696f787
 
1e732dd
 
 
696f787
1e732dd
696f787
 
1e732dd
 
 
 
 
 
 
 
 
9659593
1e732dd
 
 
 
9659593
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
MediGuard AI β€” OpenSearch Client

Production search-engine wrapper supporting BM25, vector (KNN), and
hybrid search with Reciprocal Rank Fusion (RRF).
"""

from __future__ import annotations

import logging
from functools import lru_cache
from typing import Any

from src.exceptions import SearchError, SearchQueryError
from src.settings import get_settings

logger = logging.getLogger(__name__)

# Guard import β€” opensearch-py is optional when running tests locally
try:
    from opensearchpy import NotFoundError as OSNotFoundError
    from opensearchpy import OpenSearch, RequestError
except ImportError:  # pragma: no cover
    OpenSearch = None  # type: ignore[assignment,misc]


class OpenSearchClient:
    """Thin wrapper around *opensearch-py* with medical-domain helpers."""

    def __init__(self, client: OpenSearch, index_name: str):
        self._client = client
        self.index_name = index_name

    # ── Health ───────────────────────────────────────────────────────────

    def health(self) -> dict[str, Any]:
        return self._client.cluster.health()

    def ping(self) -> bool:
        try:
            return self._client.ping()
        except Exception:
            return False

    # ── Index management ─────────────────────────────────────────────────

    def ensure_index(self, mapping: dict[str, Any]) -> None:
        """Create the index if it doesn't already exist."""
        if not self._client.indices.exists(index=self.index_name):
            self._client.indices.create(index=self.index_name, body=mapping)
            logger.info("Created OpenSearch index '%s'", self.index_name)
        else:
            logger.info("OpenSearch index '%s' already exists", self.index_name)

    def delete_index(self) -> None:
        if self._client.indices.exists(index=self.index_name):
            self._client.indices.delete(index=self.index_name)

    def doc_count(self) -> int:
        try:
            resp = self._client.count(index=self.index_name)
            return resp["count"]
        except Exception:
            return 0

    # ── Indexing ─────────────────────────────────────────────────────────

    def index_document(self, doc_id: str, body: dict[str, Any]) -> None:
        self._client.index(index=self.index_name, id=doc_id, body=body)

    def bulk_index(self, documents: list[dict[str, Any]]) -> int:
        """Bulk-index a list of dicts, each must have an ``_id`` key."""
        if not documents:
            return 0
        actions: list[dict[str, Any]] = []
        for doc in documents:
            doc_id = doc.pop("_id", None)
            actions.append({"index": {"_index": self.index_name, "_id": doc_id}})
            actions.append(doc)
        resp = self._client.bulk(body=actions, refresh="wait_for")
        indexed = sum(1 for item in resp.get("items", []) if item.get("index", {}).get("status") in (200, 201))
        logger.info("Bulk-indexed %d / %d documents", indexed, len(documents))
        return indexed

    # ── BM25 search ──────────────────────────────────────────────────────

    def search_bm25(
        self,
        query_text: str,
        *,
        top_k: int = 10,
        filters: dict[str, Any] | None = None,
    ) -> list[dict[str, Any]]:
        body: dict[str, Any] = {
            "size": top_k,
            "query": {
                "bool": {
                    "must": [
                        {
                            "multi_match": {
                                "query": query_text,
                                "fields": [
                                    "chunk_text^3",
                                    "title^2",
                                    "section_title^1.5",
                                    "abstract^1",
                                ],
                                "type": "best_fields",
                            }
                        }
                    ]
                }
            },
        }
        if filters:
            body["query"]["bool"]["filter"] = self._build_filters(filters)
        return self._execute_search(body)

    # ── Vector (KNN) search ──────────────────────────────────────────────

    def search_vector(
        self,
        query_vector: list[float],
        *,
        top_k: int = 10,
        filters: dict[str, Any] | None = None,
    ) -> list[dict[str, Any]]:
        body: dict[str, Any] = {
            "size": top_k,
            "query": {
                "knn": {
                    "embedding": {
                        "vector": query_vector,
                        "k": top_k,
                    }
                }
            },
        }
        return self._execute_search(body)

    # ── Hybrid search (RRF) ─────────────────────────────────────────────

    def search_hybrid(
        self,
        query_text: str,
        query_vector: list[float],
        *,
        top_k: int = 10,
        filters: dict[str, Any] | None = None,
        bm25_weight: float = 0.4,
        vector_weight: float = 0.6,
    ) -> list[dict[str, Any]]:
        """Reciprocal Rank Fusion of BM25 + KNN results."""
        bm25_results = self.search_bm25(query_text, top_k=top_k, filters=filters)
        vector_results = self.search_vector(query_vector, top_k=top_k, filters=filters)
        return self._rrf_fuse(bm25_results, vector_results, top_k=top_k)

    # ── Internal helpers ─────────────────────────────────────────────────

    def _execute_search(self, body: dict[str, Any]) -> list[dict[str, Any]]:
        try:
            resp = self._client.search(index=self.index_name, body=body)
        except Exception as exc:
            raise SearchQueryError(str(exc)) from exc
        hits = resp.get("hits", {}).get("hits", [])
        return [
            {
                "_id": h["_id"],
                "_score": h.get("_score", 0.0),
                "_source": h.get("_source", {}),
            }
            for h in hits
        ]

    @staticmethod
    def _build_filters(filters: dict[str, Any]) -> list[dict[str, Any]]:
        clauses: list[dict[str, Any]] = []
        for key, value in filters.items():
            if isinstance(value, list):
                clauses.append({"terms": {key: value}})
            else:
                clauses.append({"term": {key: value}})
        return clauses

    @staticmethod
    def _rrf_fuse(
        results_a: list[dict[str, Any]],
        results_b: list[dict[str, Any]],
        *,
        k: int = 60,
        top_k: int = 10,
    ) -> list[dict[str, Any]]:
        """Simple Reciprocal Rank Fusion."""
        scores: dict[str, float] = {}
        docs: dict[str, dict[str, Any]] = {}
        for rank, doc in enumerate(results_a, 1):
            doc_id = doc["_id"]
            scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank)
            docs[doc_id] = doc
        for rank, doc in enumerate(results_b, 1):
            doc_id = doc["_id"]
            scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank)
            docs[doc_id] = doc
        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        return [{**docs[doc_id], "_score": score} for doc_id, score in ranked]


# ── Factory ──────────────────────────────────────────────────────────────────


@lru_cache(maxsize=1)
def make_opensearch_client() -> OpenSearchClient:
    if OpenSearch is None:
        raise SearchError("opensearch-py is not installed")
    settings = get_settings()
    os_settings = settings.opensearch
    client = OpenSearch(
        hosts=[os_settings.host],
        http_auth=(os_settings.username, os_settings.password) if os_settings.username else None,
        verify_certs=os_settings.verify_certs,
        timeout=os_settings.timeout,
    )
    return OpenSearchClient(client, os_settings.index_name)