""" MediGuard AI — OpenSearch Client Production search-engine wrapper supporting BM25, vector (KNN), and hybrid search with Reciprocal Rank Fusion (RRF). """ from __future__ import annotations import logging from functools import lru_cache from typing import Any from src.exceptions import SearchError, SearchQueryError from src.settings import get_settings logger = logging.getLogger(__name__) # Guard import — opensearch-py is optional when running tests locally try: from opensearchpy import NotFoundError as OSNotFoundError from opensearchpy import OpenSearch, RequestError except ImportError: # pragma: no cover OpenSearch = None # type: ignore[assignment,misc] class OpenSearchClient: """Thin wrapper around *opensearch-py* with medical-domain helpers.""" def __init__(self, client: OpenSearch, index_name: str): self._client = client self.index_name = index_name # ── Health ─────────────────────────────────────────────────────────── def health(self) -> dict[str, Any]: return self._client.cluster.health() def ping(self) -> bool: try: return self._client.ping() except Exception: return False # ── Index management ───────────────────────────────────────────────── def ensure_index(self, mapping: dict[str, Any]) -> None: """Create the index if it doesn't already exist.""" if not self._client.indices.exists(index=self.index_name): self._client.indices.create(index=self.index_name, body=mapping) logger.info("Created OpenSearch index '%s'", self.index_name) else: logger.info("OpenSearch index '%s' already exists", self.index_name) def delete_index(self) -> None: if self._client.indices.exists(index=self.index_name): self._client.indices.delete(index=self.index_name) def doc_count(self) -> int: try: resp = self._client.count(index=self.index_name) return resp["count"] except Exception: return 0 # ── Indexing ───────────────────────────────────────────────────────── def index_document(self, doc_id: str, body: dict[str, Any]) -> None: self._client.index(index=self.index_name, id=doc_id, body=body) def bulk_index(self, documents: list[dict[str, Any]]) -> int: """Bulk-index a list of dicts, each must have an ``_id`` key.""" if not documents: return 0 actions: list[dict[str, Any]] = [] for doc in documents: doc_id = doc.pop("_id", None) actions.append({"index": {"_index": self.index_name, "_id": doc_id}}) actions.append(doc) resp = self._client.bulk(body=actions, refresh="wait_for") indexed = sum(1 for item in resp.get("items", []) if item.get("index", {}).get("status") in (200, 201)) logger.info("Bulk-indexed %d / %d documents", indexed, len(documents)) return indexed # ── BM25 search ────────────────────────────────────────────────────── def search_bm25( self, query_text: str, *, top_k: int = 10, filters: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: body: dict[str, Any] = { "size": top_k, "query": { "bool": { "must": [ { "multi_match": { "query": query_text, "fields": [ "chunk_text^3", "title^2", "section_title^1.5", "abstract^1", ], "type": "best_fields", } } ] } }, } if filters: body["query"]["bool"]["filter"] = self._build_filters(filters) return self._execute_search(body) # ── Vector (KNN) search ────────────────────────────────────────────── def search_vector( self, query_vector: list[float], *, top_k: int = 10, filters: dict[str, Any] | None = None, ) -> list[dict[str, Any]]: body: dict[str, Any] = { "size": top_k, "query": { "knn": { "embedding": { "vector": query_vector, "k": top_k, } } }, } return self._execute_search(body) # ── Hybrid search (RRF) ───────────────────────────────────────────── def search_hybrid( self, query_text: str, query_vector: list[float], *, top_k: int = 10, filters: dict[str, Any] | None = None, bm25_weight: float = 0.4, vector_weight: float = 0.6, ) -> list[dict[str, Any]]: """Reciprocal Rank Fusion of BM25 + KNN results.""" bm25_results = self.search_bm25(query_text, top_k=top_k, filters=filters) vector_results = self.search_vector(query_vector, top_k=top_k, filters=filters) return self._rrf_fuse(bm25_results, vector_results, top_k=top_k) # ── Internal helpers ───────────────────────────────────────────────── def _execute_search(self, body: dict[str, Any]) -> list[dict[str, Any]]: try: resp = self._client.search(index=self.index_name, body=body) except Exception as exc: raise SearchQueryError(str(exc)) from exc hits = resp.get("hits", {}).get("hits", []) return [ { "_id": h["_id"], "_score": h.get("_score", 0.0), "_source": h.get("_source", {}), } for h in hits ] @staticmethod def _build_filters(filters: dict[str, Any]) -> list[dict[str, Any]]: clauses: list[dict[str, Any]] = [] for key, value in filters.items(): if isinstance(value, list): clauses.append({"terms": {key: value}}) else: clauses.append({"term": {key: value}}) return clauses @staticmethod def _rrf_fuse( results_a: list[dict[str, Any]], results_b: list[dict[str, Any]], *, k: int = 60, top_k: int = 10, ) -> list[dict[str, Any]]: """Simple Reciprocal Rank Fusion.""" scores: dict[str, float] = {} docs: dict[str, dict[str, Any]] = {} for rank, doc in enumerate(results_a, 1): doc_id = doc["_id"] scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) docs[doc_id] = doc for rank, doc in enumerate(results_b, 1): doc_id = doc["_id"] scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) docs[doc_id] = doc ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] return [{**docs[doc_id], "_score": score} for doc_id, score in ranked] # ── Factory ────────────────────────────────────────────────────────────────── @lru_cache(maxsize=1) def make_opensearch_client() -> OpenSearchClient: if OpenSearch is None: raise SearchError("opensearch-py is not installed") settings = get_settings() os_settings = settings.opensearch client = OpenSearch( hosts=[os_settings.host], http_auth=(os_settings.username, os_settings.password) if os_settings.username else None, verify_certs=os_settings.verify_certs, timeout=os_settings.timeout, ) return OpenSearchClient(client, os_settings.index_name)