File size: 7,273 Bytes
8c35759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Cross-encoder reranker for document retrieval."""

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple
import logging

from langchain.schema import Document

logger = logging.getLogger(__name__)

# Lazy import to avoid loading model at import time
_cross_encoder = None
_cross_encoder_model_name = None


def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
    """Lazy load the cross-encoder model.

    Args:
        model_name: HuggingFace model identifier

    Returns:
        CrossEncoder instance
    """
    global _cross_encoder, _cross_encoder_model_name

    if _cross_encoder is None or _cross_encoder_model_name != model_name:
        try:
            from sentence_transformers import CrossEncoder
            logger.info(f"Loading cross-encoder model: {model_name}")
            _cross_encoder = CrossEncoder(model_name, max_length=512)
            _cross_encoder_model_name = model_name
        except ImportError:
            logger.warning(
                "sentence-transformers not installed. "
                "Run: pip install sentence-transformers"
            )
            return None
        except Exception as e:
            logger.warning(f"Failed to load cross-encoder: {e}")
            return None

    return _cross_encoder


class FastCrossEncoderReranker:
    """Cross-encoder reranker using sentence-transformers.

    Runs locally and is faster than LLM-based reranking.
    """

    MODEL_OPTIONS = {
        "fast": "cross-encoder/ms-marco-MiniLM-L-6-v2",
        "balanced": "cross-encoder/ms-marco-MiniLM-L-12-v2",
        "tiny": "cross-encoder/ms-marco-TinyBERT-L-2-v2",
    }

    def __init__(
        self,
        model_name: str = "fast",
        max_length: int = 512,
        batch_size: int = 16,
    ) -> None:
        """Initialize cross-encoder reranker.

        Args:
            model_name: One of "fast", "balanced", "tiny", or a HuggingFace model ID
            max_length: Maximum sequence length for encoding
            batch_size: Batch size for scoring (higher = faster but more memory)
        """
        # Resolve model name alias
        self.model_name = self.MODEL_OPTIONS.get(model_name, model_name)
        self.max_length = max_length
        self.batch_size = batch_size
        self._model = None

    def _ensure_model(self) -> bool:
        """Ensure model is loaded.

        Returns:
            True if model is available, False otherwise
        """
        if self._model is None:
            self._model = _get_cross_encoder(self.model_name)
        return self._model is not None

    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: int = 6,
    ) -> List[Document]:
        """Rerank documents by relevance to query.

        Args:
            query: User query
            documents: Documents to rerank
            top_k: Number of top documents to return

        Returns:
            Reranked documents (most relevant first)
        """
        if not documents:
            return []

        if len(documents) <= 1:
            return documents

        if not self._ensure_model():
            logger.warning("Cross-encoder not available, returning original order")
            return documents[:top_k]

        try:
            # Prepare query-document pairs
            pairs = [
                (query, self._get_text(doc)[:self.max_length])
                for doc in documents
            ]

            # Score all pairs (batched for efficiency)
            scores = self._model.predict(
                pairs,
                batch_size=self.batch_size,
                show_progress_bar=False,
            )

            # Sort by score descending
            scored_docs = sorted(
                zip(documents, scores),
                key=lambda x: x[1],
                reverse=True,
            )

            return [doc for doc, _ in scored_docs[:top_k]]

        except Exception as e:
            logger.warning(f"Reranking failed: {e}, returning original order")
            return documents[:top_k]

    def rerank_with_scores(
        self,
        query: str,
        documents: List[Document],
        top_k: int = 6,
    ) -> List[Tuple[Document, float]]:
        """Rerank documents and return with scores.

        Args:
            query: User query
            documents: Documents to rerank
            top_k: Number of top documents to return

        Returns:
            List of (document, score) tuples, sorted by score descending
        """
        if not documents:
            return []

        if len(documents) <= 1:
            return [(doc, 1.0) for doc in documents]

        if not self._ensure_model():
            return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])]

        try:
            pairs = [
                (query, self._get_text(doc)[:self.max_length])
                for doc in documents
            ]

            scores = self._model.predict(
                pairs,
                batch_size=self.batch_size,
                show_progress_bar=False,
            )

            scored_docs = sorted(
                zip(documents, scores),
                key=lambda x: x[1],
                reverse=True,
            )

            return scored_docs[:top_k]

        except Exception as e:
            logger.warning(f"Reranking failed: {e}")
            return [(doc, 1.0 - i * 0.1) for i, doc in enumerate(documents[:top_k])]

    def _get_text(self, doc: Document) -> str:
        """Extract text content from document.

        Args:
            doc: LangChain Document

        Returns:
            Text content
        """
        if hasattr(doc, 'page_content'):
            return doc.page_content
        return str(doc)


class NoOpReranker:
    """No-op reranker that returns documents in original order.

    Use this as a fallback when cross-encoder is not available.
    """

    def rerank(
        self,
        query: str,
        documents: List[Document],
        top_k: int = 6,
    ) -> List[Document]:
        """Return documents without reranking."""
        return documents[:top_k]

    def rerank_with_scores(
        self,
        query: str,
        documents: List[Document],
        top_k: int = 6,
    ) -> List[Tuple[Document, float]]:
        """Return documents with dummy scores."""
        return [(doc, 1.0 - i * 0.05) for i, doc in enumerate(documents[:top_k])]


def get_reranker(
    model_name: str = "fast",
    fallback_to_noop: bool = True,
) -> FastCrossEncoderReranker:
    """Factory function to get a reranker instance.

    Args:
        model_name: Model name or alias
        fallback_to_noop: If True, return NoOpReranker when cross-encoder fails

    Returns:
        Reranker instance
    """
    try:
        reranker = FastCrossEncoderReranker(model_name)
        # Test model loading
        if reranker._ensure_model():
            return reranker
    except Exception as e:
        logger.warning(f"Failed to create cross-encoder reranker: {e}")

    if fallback_to_noop:
        logger.info("Using no-op reranker as fallback")
        return NoOpReranker()

    raise RuntimeError("Cross-encoder reranker not available")