Spaces:
Running
Running
| """ | |
| MediGuard AI β OpenSearch Client | |
| Production search-engine wrapper supporting BM25, vector (KNN), and | |
| hybrid search with Reciprocal Rank Fusion (RRF). | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from functools import lru_cache | |
| from typing import Any | |
| from src.exceptions import SearchError, SearchQueryError | |
| from src.settings import get_settings | |
| logger = logging.getLogger(__name__) | |
| # Guard import β opensearch-py is optional when running tests locally | |
| try: | |
| from opensearchpy import NotFoundError as OSNotFoundError | |
| from opensearchpy import OpenSearch, RequestError | |
| except ImportError: # pragma: no cover | |
| OpenSearch = None # type: ignore[assignment,misc] | |
| class OpenSearchClient: | |
| """Thin wrapper around *opensearch-py* with medical-domain helpers.""" | |
| def __init__(self, client: OpenSearch, index_name: str): | |
| self._client = client | |
| self.index_name = index_name | |
| # ββ Health βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def health(self) -> dict[str, Any]: | |
| return self._client.cluster.health() | |
| def ping(self) -> bool: | |
| try: | |
| return self._client.ping() | |
| except Exception: | |
| return False | |
| # ββ Index management βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ensure_index(self, mapping: dict[str, Any]) -> None: | |
| """Create the index if it doesn't already exist.""" | |
| if not self._client.indices.exists(index=self.index_name): | |
| self._client.indices.create(index=self.index_name, body=mapping) | |
| logger.info("Created OpenSearch index '%s'", self.index_name) | |
| else: | |
| logger.info("OpenSearch index '%s' already exists", self.index_name) | |
| def delete_index(self) -> None: | |
| if self._client.indices.exists(index=self.index_name): | |
| self._client.indices.delete(index=self.index_name) | |
| def doc_count(self) -> int: | |
| try: | |
| resp = self._client.count(index=self.index_name) | |
| return resp["count"] | |
| except Exception: | |
| return 0 | |
| # ββ Indexing βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index_document(self, doc_id: str, body: dict[str, Any]) -> None: | |
| self._client.index(index=self.index_name, id=doc_id, body=body) | |
| def bulk_index(self, documents: list[dict[str, Any]]) -> int: | |
| """Bulk-index a list of dicts, each must have an ``_id`` key.""" | |
| if not documents: | |
| return 0 | |
| actions: list[dict[str, Any]] = [] | |
| for doc in documents: | |
| doc_id = doc.pop("_id", None) | |
| actions.append({"index": {"_index": self.index_name, "_id": doc_id}}) | |
| actions.append(doc) | |
| resp = self._client.bulk(body=actions, refresh="wait_for") | |
| indexed = sum(1 for item in resp.get("items", []) if item.get("index", {}).get("status") in (200, 201)) | |
| logger.info("Bulk-indexed %d / %d documents", indexed, len(documents)) | |
| return indexed | |
| # ββ BM25 search ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_bm25( | |
| self, | |
| query_text: str, | |
| *, | |
| top_k: int = 10, | |
| filters: dict[str, Any] | None = None, | |
| ) -> list[dict[str, Any]]: | |
| body: dict[str, Any] = { | |
| "size": top_k, | |
| "query": { | |
| "bool": { | |
| "must": [ | |
| { | |
| "multi_match": { | |
| "query": query_text, | |
| "fields": [ | |
| "chunk_text^3", | |
| "title^2", | |
| "section_title^1.5", | |
| "abstract^1", | |
| ], | |
| "type": "best_fields", | |
| } | |
| } | |
| ] | |
| } | |
| }, | |
| } | |
| if filters: | |
| body["query"]["bool"]["filter"] = self._build_filters(filters) | |
| return self._execute_search(body) | |
| # ββ Vector (KNN) search ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_vector( | |
| self, | |
| query_vector: list[float], | |
| *, | |
| top_k: int = 10, | |
| filters: dict[str, Any] | None = None, | |
| ) -> list[dict[str, Any]]: | |
| body: dict[str, Any] = { | |
| "size": top_k, | |
| "query": { | |
| "knn": { | |
| "embedding": { | |
| "vector": query_vector, | |
| "k": top_k, | |
| } | |
| } | |
| }, | |
| } | |
| return self._execute_search(body) | |
| # ββ Hybrid search (RRF) βββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_hybrid( | |
| self, | |
| query_text: str, | |
| query_vector: list[float], | |
| *, | |
| top_k: int = 10, | |
| filters: dict[str, Any] | None = None, | |
| bm25_weight: float = 0.4, | |
| vector_weight: float = 0.6, | |
| ) -> list[dict[str, Any]]: | |
| """Reciprocal Rank Fusion of BM25 + KNN results.""" | |
| bm25_results = self.search_bm25(query_text, top_k=top_k, filters=filters) | |
| vector_results = self.search_vector(query_vector, top_k=top_k, filters=filters) | |
| return self._rrf_fuse(bm25_results, vector_results, top_k=top_k) | |
| # ββ Internal helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _execute_search(self, body: dict[str, Any]) -> list[dict[str, Any]]: | |
| try: | |
| resp = self._client.search(index=self.index_name, body=body) | |
| except Exception as exc: | |
| raise SearchQueryError(str(exc)) from exc | |
| hits = resp.get("hits", {}).get("hits", []) | |
| return [ | |
| { | |
| "_id": h["_id"], | |
| "_score": h.get("_score", 0.0), | |
| "_source": h.get("_source", {}), | |
| } | |
| for h in hits | |
| ] | |
| def _build_filters(filters: dict[str, Any]) -> list[dict[str, Any]]: | |
| clauses: list[dict[str, Any]] = [] | |
| for key, value in filters.items(): | |
| if isinstance(value, list): | |
| clauses.append({"terms": {key: value}}) | |
| else: | |
| clauses.append({"term": {key: value}}) | |
| return clauses | |
| def _rrf_fuse( | |
| results_a: list[dict[str, Any]], | |
| results_b: list[dict[str, Any]], | |
| *, | |
| k: int = 60, | |
| top_k: int = 10, | |
| ) -> list[dict[str, Any]]: | |
| """Simple Reciprocal Rank Fusion.""" | |
| scores: dict[str, float] = {} | |
| docs: dict[str, dict[str, Any]] = {} | |
| for rank, doc in enumerate(results_a, 1): | |
| doc_id = doc["_id"] | |
| scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) | |
| docs[doc_id] = doc | |
| for rank, doc in enumerate(results_b, 1): | |
| doc_id = doc["_id"] | |
| scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank) | |
| docs[doc_id] = doc | |
| ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k] | |
| return [{**docs[doc_id], "_score": score} for doc_id, score in ranked] | |
| # ββ Factory ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_opensearch_client() -> OpenSearchClient: | |
| if OpenSearch is None: | |
| raise SearchError("opensearch-py is not installed") | |
| settings = get_settings() | |
| os_settings = settings.opensearch | |
| client = OpenSearch( | |
| hosts=[os_settings.host], | |
| http_auth=(os_settings.username, os_settings.password) if os_settings.username else None, | |
| verify_certs=os_settings.verify_certs, | |
| timeout=os_settings.timeout, | |
| ) | |
| return OpenSearchClient(client, os_settings.index_name) | |