File size: 4,088 Bytes
3ca1d38
 
 
 
 
 
 
 
 
 
 
 
 
696f787
3ca1d38
 
 
 
 
 
 
696f787
3ca1d38
 
696f787
3ca1d38
 
696f787
3ca1d38
 
696f787
 
3ca1d38
696f787
3ca1d38
 
 
 
 
 
 
 
9659593
3ca1d38
 
 
 
9659593
3ca1d38
 
 
 
696f787
3ca1d38
 
 
 
 
 
696f787
 
3ca1d38
 
9659593
3ca1d38
 
 
 
9659593
3ca1d38
 
 
 
696f787
3ca1d38
 
 
 
9659593
3ca1d38
 
 
 
696f787
3ca1d38
 
 
 
9659593
3ca1d38
 
 
 
696f787
3ca1d38
 
 
 
 
696f787
 
3ca1d38
 
9659593
3ca1d38
 
 
 
9659593
3ca1d38
 
 
 
 
696f787
3ca1d38
 
 
696f787
3ca1d38
 
696f787
3ca1d38
 
696f787
3ca1d38
 
9659593
3ca1d38
 
 
 
 
 
 
9659593
3ca1d38
 
 
 
 
696f787
3ca1d38
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
MediGuard AI — Retriever Interface

Abstract base class defining the common interface for all retriever backends:
- FAISS (local dev and HuggingFace Spaces)
- OpenSearch (production with BM25 + KNN hybrid)
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any

logger = logging.getLogger(__name__)


@dataclass
class RetrievalResult:
    """Unified result format for retrieval operations."""

    doc_id: str
    """Unique identifier for the document chunk."""

    content: str
    """The actual text content of the chunk."""

    score: float
    """Relevance score (higher is better, normalized 0-1 where possible)."""

    metadata: dict[str, Any] = field(default_factory=dict)
    """Arbitrary metadata (source_file, page, section, etc.)."""

    def __repr__(self) -> str:
        preview = self.content[:80].replace("\n", " ") + "..." if len(self.content) > 80 else self.content
        return f"RetrievalResult(score={self.score:.3f}, content='{preview}')"


class BaseRetriever(ABC):
    """
    Abstract base class for retrieval backends.

    Implementations must provide:
    - retrieve(): Semantic/hybrid search
    - health(): Health check
    - doc_count(): Number of indexed documents

    Optionally:
    - retrieve_bm25(): Keyword-only search
    - retrieve_hybrid(): Combined BM25 + vector search
    """

    @abstractmethod
    def retrieve(
        self,
        query: str,
        *,
        top_k: int = 5,
        filters: dict[str, Any] | None = None,
    ) -> list[RetrievalResult]:
        """
        Retrieve relevant documents for a query.

        Args:
            query: Natural language query
            top_k: Maximum number of results
            filters: Optional metadata filters (e.g., {"source_file": "guidelines.pdf"})

        Returns:
            List of RetrievalResult objects, ordered by relevance (highest first)
        """
        ...

    @abstractmethod
    def health(self) -> bool:
        """
        Check if the retriever is healthy and ready.

        Returns:
            True if operational, False otherwise
        """
        ...

    @abstractmethod
    def doc_count(self) -> int:
        """
        Return the number of indexed document chunks.

        Returns:
            Total document count, or 0 if unavailable
        """
        ...

    def retrieve_bm25(
        self,
        query: str,
        *,
        top_k: int = 5,
        filters: dict[str, Any] | None = None,
    ) -> list[RetrievalResult]:
        """
        BM25 keyword search (optional, falls back to retrieve()).

        Args:
            query: Natural language query
            top_k: Maximum results
            filters: Optional filters

        Returns:
            List of RetrievalResult objects
        """
        logger.warning("%s does not support BM25, falling back to retrieve()", type(self).__name__)
        return self.retrieve(query, top_k=top_k, filters=filters)

    def retrieve_hybrid(
        self,
        query: str,
        embedding: list[float] | None = None,
        *,
        top_k: int = 5,
        filters: dict[str, Any] | None = None,
        bm25_weight: float = 0.4,
        vector_weight: float = 0.6,
    ) -> list[RetrievalResult]:
        """
        Hybrid search combining BM25 and vector search (optional).

        Args:
            query: Natural language query
            embedding: Pre-computed embedding (optional)
            top_k: Maximum results
            filters: Optional filters
            bm25_weight: Weight for BM25 component
            vector_weight: Weight for vector component

        Returns:
            List of RetrievalResult objects
        """
        logger.warning("%s does not support hybrid search, falling back to retrieve()", type(self).__name__)
        return self.retrieve(query, top_k=top_k, filters=filters)

    @property
    def backend_name(self) -> str:
        """Human-readable backend name for logging."""
        return type(self).__name__