Spaces:
Running
Running
File size: 4,088 Bytes
3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 696f787 3ca1d38 9659593 3ca1d38 9659593 3ca1d38 696f787 3ca1d38 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """
MediGuard AI — Retriever Interface
Abstract base class defining the common interface for all retriever backends:
- FAISS (local dev and HuggingFace Spaces)
- OpenSearch (production with BM25 + KNN hybrid)
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
@dataclass
class RetrievalResult:
"""Unified result format for retrieval operations."""
doc_id: str
"""Unique identifier for the document chunk."""
content: str
"""The actual text content of the chunk."""
score: float
"""Relevance score (higher is better, normalized 0-1 where possible)."""
metadata: dict[str, Any] = field(default_factory=dict)
"""Arbitrary metadata (source_file, page, section, etc.)."""
def __repr__(self) -> str:
preview = self.content[:80].replace("\n", " ") + "..." if len(self.content) > 80 else self.content
return f"RetrievalResult(score={self.score:.3f}, content='{preview}')"
class BaseRetriever(ABC):
"""
Abstract base class for retrieval backends.
Implementations must provide:
- retrieve(): Semantic/hybrid search
- health(): Health check
- doc_count(): Number of indexed documents
Optionally:
- retrieve_bm25(): Keyword-only search
- retrieve_hybrid(): Combined BM25 + vector search
"""
@abstractmethod
def retrieve(
self,
query: str,
*,
top_k: int = 5,
filters: dict[str, Any] | None = None,
) -> list[RetrievalResult]:
"""
Retrieve relevant documents for a query.
Args:
query: Natural language query
top_k: Maximum number of results
filters: Optional metadata filters (e.g., {"source_file": "guidelines.pdf"})
Returns:
List of RetrievalResult objects, ordered by relevance (highest first)
"""
...
@abstractmethod
def health(self) -> bool:
"""
Check if the retriever is healthy and ready.
Returns:
True if operational, False otherwise
"""
...
@abstractmethod
def doc_count(self) -> int:
"""
Return the number of indexed document chunks.
Returns:
Total document count, or 0 if unavailable
"""
...
def retrieve_bm25(
self,
query: str,
*,
top_k: int = 5,
filters: dict[str, Any] | None = None,
) -> list[RetrievalResult]:
"""
BM25 keyword search (optional, falls back to retrieve()).
Args:
query: Natural language query
top_k: Maximum results
filters: Optional filters
Returns:
List of RetrievalResult objects
"""
logger.warning("%s does not support BM25, falling back to retrieve()", type(self).__name__)
return self.retrieve(query, top_k=top_k, filters=filters)
def retrieve_hybrid(
self,
query: str,
embedding: list[float] | None = None,
*,
top_k: int = 5,
filters: dict[str, Any] | None = None,
bm25_weight: float = 0.4,
vector_weight: float = 0.6,
) -> list[RetrievalResult]:
"""
Hybrid search combining BM25 and vector search (optional).
Args:
query: Natural language query
embedding: Pre-computed embedding (optional)
top_k: Maximum results
filters: Optional filters
bm25_weight: Weight for BM25 component
vector_weight: Weight for vector component
Returns:
List of RetrievalResult objects
"""
logger.warning("%s does not support hybrid search, falling back to retrieve()", type(self).__name__)
return self.retrieve(query, top_k=top_k, filters=filters)
@property
def backend_name(self) -> str:
"""Human-readable backend name for logging."""
return type(self).__name__
|