File size: 15,758 Bytes
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
 
daf250b
 
a63c61f
 
 
 
 
daf250b
a63c61f
daf250b
 
 
a63c61f
 
 
daf250b
a63c61f
 
 
 
 
 
daf250b
a63c61f
 
 
 
 
 
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
daf250b
 
 
a63c61f
 
 
daf250b
a63c61f
daf250b
 
a63c61f
daf250b
a63c61f
 
daf250b
 
 
 
 
a63c61f
 
daf250b
a63c61f
 
 
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
 
daf250b
a63c61f
daf250b
a63c61f
 
daf250b
53c5af5
e43cd24
 
 
 
 
 
12d3d4d
e43cd24
 
 
 
 
 
 
12d3d4d
daf250b
 
 
 
 
 
 
 
12d3d4d
a63c61f
daf250b
a63c61f
daf250b
 
a63c61f
daf250b
a63c61f
daf250b
a63c61f
daf250b
a63c61f
 
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
 
 
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
daf250b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63c61f
 
 
 
 
 
 
daf250b
a63c61f
 
 
 
 
 
daf250b
 
a63c61f
 
 
daf250b
a63c61f
 
 
 
 
 
 
daf250b
 
 
 
 
a63c61f
 
 
daf250b
a63c61f
daf250b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
"""
Reranker Adapter β€” supports BGE-Reranker-v2-m3 AND Jina-Reranker-v3

Auto-detects which model to load based on RERANKER_MODEL setting:
  - "BAAI/bge-reranker-v2-m3"   β†’ FlagReranker (pointwise cross-encoder)
  - "jinaai/jina-reranker-v3"   β†’ Jina v3 listwise reranker

Jina v3 advantages over BGE for this project:
  - Listwise: sees ALL docs at once β†’ better cross-doc comparison
  - 131K context window β†’ reads full Jina-extracted articles (not just 512 chars)
  - +9.6% better on English news (BEIR 61.94 vs 56.51)
  - Better Arabic ranking (78.69 nDCG)
  - Same size (0.6B), same memory, same cost (free, self-hosted)

Thread-safe lazy loading β€” model loads once on first rerank call.
"""

import logging
import threading
from typing import List, Dict, Any, Optional

from src.core.config import settings
from src.core.ports.reranker_port import RerankerPort

logger = logging.getLogger(__name__)

# ── Patch transformers compatibility issue ────────────────────────────────────
try:
    import transformers.utils.import_utils as _tui
    if not hasattr(_tui, "is_torch_fx_available"):
        _tui.is_torch_fx_available = lambda: False
except Exception:
    pass

# ── Try FlagEmbedding (for BGE) ───────────────────────────────────────────────
try:
    from FlagEmbedding import FlagReranker
    HAS_FLAG_RERANKER = True
except ImportError:
    HAS_FLAG_RERANKER = False

# ── Try sentence-transformers CrossEncoder (BGE fallback) ────────────────────
try:
    from sentence_transformers import CrossEncoder
    HAS_CROSS_ENCODER = True
except ImportError:
    HAS_CROSS_ENCODER = False

# ── Try transformers (for Jina v3) ────────────────────────────────────────────
try:
    import torch
    from transformers import AutoModel
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    logger.warning("transformers/torch not available β€” Jina v3 reranker disabled.")


# ═══════════════════════════════════════════════════════════════════════════════
# JINA V3 RERANKER
# ═══════════════════════════════════════════════════════════════════════════════

class JinaV3Reranker:
    """
    Jina-Reranker-v3 self-hosted reranker.

    Key differences from BGE pointwise:
    - Listwise: processes all docs in one forward pass
    - 131K context window: reads full articles, not just first 512 chars
    - Built on Qwen3-0.6B backbone with causal self-attention
    - State-of-the-art BEIR: 61.94 nDCG@10 (vs BGE's 56.51)

    Scoring: uses sigmoid(logits) for normalized 0-1 scores.
    """

    def __init__(self, model_name: str):
        self.model_name = model_name
        self._model = None
        self._lock = threading.Lock()
        self._load_failed = False
        self._device = "cpu"

    def _load(self):
        if self._model is not None or self._load_failed:
            return
        with self._lock:
            if self._model is not None or self._load_failed:
                return
            if not HAS_TRANSFORMERS:
                logger.error("transformers not installed β€” cannot load Jina v3")
                self._load_failed = True
                return
            try:
                logger.info(f"Loading Jina v3 reranker: {self.model_name}")
                self._device = "cuda" if torch.cuda.is_available() else "cpu"

                # Jina v3 uses AutoModel (NOT AutoModelForSequenceClassification)
                # It has a built-in .rerank() method that returns relevance_score directly
                from transformers import AutoModel
                self._model = AutoModel.from_pretrained(
                    self.model_name,
                    trust_remote_code=True,
                    dtype="auto",
                )
                self._model.eval()
                logger.info(
                    f"βœ… Jina v3 reranker loaded on {self._device} "
                    f"(model={self.model_name})"
                )
            except Exception as e:
                logger.error(f"Failed to load Jina v3 reranker: {e}", exc_info=True)
                self._load_failed = True

    def compute_scores(
        self,
        query: str,
        docs: List[str],
        max_length: int = 1024,
    ) -> List[float]:
        """
        Score all (query, doc) pairs using Jina v3's built-in .rerank() method.

        Returns scores in original doc order (not sorted).
        """
        if not docs:
            return []

        self._load()
        if self._model is None:
            return [0.5] * len(docs)

        try:
            # Jina v3's .rerank() returns list of dicts:
            # [{"document": str, "relevance_score": float, "index": int}, ...]
            # Results are sorted by relevance_score descending β€” we need to
            # restore original order using the "index" field.
            results = self._model.rerank(query, docs)

            # Restore original order
            scores = [0.0] * len(docs)
            for r in results:
                original_idx = r["index"]
                scores[original_idx] = float(r["relevance_score"])

            return scores

        except Exception as e:
            logger.error(f"Jina v3 rerank() failed: {e}")
            return [0.0] * len(docs)

    @property
    def is_loaded(self) -> bool:
        return self._model is not None


# ═══════════════════════════════════════════════════════════════════════════════
# UNIFIED RERANKER ADAPTER
# ═══════════════════════════════════════════════════════════════════════════════

class BgeRerankerAdapter(RerankerPort):
    """
    Unified reranker adapter β€” auto-selects BGE or Jina v3 based on config.

    RERANKER_MODEL=jinaai/jina-reranker-v3   β†’ Jina v3 (recommended)
    RERANKER_MODEL=BAAI/bge-reranker-v2-m3   β†’ BGE (legacy)

    Both are self-hosted, free, ~0.6B parameters, ~1.2GB on disk.
    """

    # Max content chars to send to reranker
    # Jina v3: 1024 tokens β‰ˆ 4096 chars β€” reads much more than BGE's 512 chars
    MAX_CONTENT_CHARS_JINA = 4096
    MAX_CONTENT_CHARS_BGE  = 512

    def __init__(self):
        self.model_name = settings.RERANKER_MODEL
        self._is_jina_v3 = "jina-reranker-v3" in self.model_name.lower()
        self._lock = threading.Lock()
        self._load_failed = False

        # Check if Jina API reranker is enabled (takes priority over self-hosted)
        self._jina_api = None
        if getattr(settings, 'JINA_RERANKER_ENABLED', False) and getattr(settings, 'JINA_API_KEY', ''):
            try:
                from src.infrastructure.adapters.jina_reranker_adapter import JinaRerankerAPIAdapter
                jina_key = settings.JINA_API_KEY
                if jina_key and jina_key not in ("", "your-jina-api-key-here"):
                    self._jina_api = JinaRerankerAPIAdapter(
                        api_key=jina_key,
                        model=getattr(settings, 'JINA_RERANKER_MODEL', 'jina-reranker-v3'),
                        timeout=getattr(settings, 'JINA_RERANKER_TIMEOUT', 5.0),
                    )
                    logger.info("Reranker configured: Jina API (cloud, fast)")
            except Exception as e:
                logger.warning(f"Jina API reranker init failed: {e}")

        # Jina v3 self-hosted path
        if self._is_jina_v3 and not self._jina_api:
            self._jina = JinaV3Reranker(self.model_name)
            self._bge_model = None
            self._use_flag = False
            logger.info(f"Reranker configured: Jina v3 self-hosted ({self.model_name})")
        elif not self._jina_api:
            # BGE path
            self._jina = None
            self._bge_model = None
            self._use_flag = False
            logger.info(f"Reranker configured: BGE ({self.model_name})")
        else:
            self._jina = None
            self._bge_model = None
            self._use_flag = False

    def _load_bge(self):
        """Lazy-load BGE reranker (thread-safe)."""
        if self._bge_model is not None or self._load_failed:
            return
        with self._lock:
            if self._bge_model is not None or self._load_failed:
                return
            logger.info(f"Loading BGE reranker: {self.model_name}")
            try:
                if HAS_FLAG_RERANKER and "bge-reranker" in self.model_name.lower():
                    # Patch XLMRobertaTokenizer for older transformers versions
                    try:
                        from transformers import XLMRobertaTokenizer, PreTrainedTokenizer
                        for method_name in [
                            "prepare_for_model",
                            "build_inputs_with_special_tokens",
                            "create_token_type_ids_from_sequences",
                            "get_special_tokens_mask",
                            "convert_tokens_to_string",
                        ]:
                            if not hasattr(XLMRobertaTokenizer, method_name):
                                base_method = getattr(PreTrainedTokenizer, method_name, None)
                                if base_method:
                                    setattr(XLMRobertaTokenizer, method_name, base_method)
                    except Exception as patch_err:
                        logger.debug(f"Tokenizer patch skipped: {patch_err}")

                    self._bge_model = FlagReranker(
                        self.model_name,
                        use_fp16=True,
                        normalize=True,
                        trust_remote_code=True,
                    )
                    self._use_flag = True
                    logger.info(f"βœ… BGE loaded via FlagReranker (fp16, multilingual)")

                elif HAS_CROSS_ENCODER:
                    self._bge_model = CrossEncoder(self.model_name)
                    self._use_flag = False
                    logger.info(f"βœ… BGE loaded via CrossEncoder (fallback)")

                else:
                    logger.error("No BGE backend available (FlagEmbedding or sentence-transformers required)")
                    self._load_failed = True

            except Exception as e:
                logger.error(f"Failed to load BGE reranker '{self.model_name}': {e}", exc_info=True)
                self._load_failed = True

    # ── Public interface ──────────────────────────────────────────────────────

    def rerank(
        self,
        query: str,
        docs: List[Dict[str, Any]],
        top_n: int = 5,
    ) -> List[Dict[str, Any]]:
        """
        Rerank documents by relevance to query.

        Priority: Jina API (cloud) > Jina v3 self-hosted > BGE
        Jina v3 path: uses full article content (up to 4096 chars)
        BGE path: uses first 512 chars only

        Returns top_n docs sorted by rerank_score descending.
        """
        if not docs:
            return []

        # Priority: Jina API > Jina v3 self-hosted > BGE
        if self._jina_api and self._jina_api.is_available():
            return self._jina_api.rerank(query, docs, top_n)
        elif self._is_jina_v3 and self._jina:
            return self._rerank_jina(query, docs, top_n)
        else:
            return self._rerank_bge(query, docs, top_n)

    def _rerank_jina(
        self,
        query: str,
        docs: List[Dict[str, Any]],
        top_n: int,
    ) -> List[Dict[str, Any]]:
        """Rerank using Jina v3 β€” reads full article content."""
        # Ensure model is loaded
        self._jina._load()

        if self._jina._load_failed or not self._jina.is_loaded:
            logger.warning("Jina v3 unavailable β€” falling back to vector score ordering")
            return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]

        # Build content list β€” use full content up to 4096 chars
        # This is the key advantage: Jina reads 8x more content than BGE
        valid_docs = []
        doc_texts = []
        for doc in docs:
            content = doc.get("content", "").strip()
            if content:
                doc_texts.append(content[:self.MAX_CONTENT_CHARS_JINA])
                valid_docs.append(doc)

        if not doc_texts:
            return []

        try:
            scores = self._jina.compute_scores(query, doc_texts)

            for i, doc in enumerate(valid_docs):
                doc["rerank_score"] = scores[i]

            valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)

            logger.info(
                f"[Reranker] Jina v3: {len(valid_docs)} docs β†’ top {top_n} "
                f"(max_score={valid_docs[0]['rerank_score']:.3f})"
            )
            return valid_docs[:top_n]

        except Exception as e:
            logger.error(f"Jina v3 reranking failed: {e} β€” falling back to vector score")
            return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]

    def _rerank_bge(
        self,
        query: str,
        docs: List[Dict[str, Any]],
        top_n: int,
    ) -> List[Dict[str, Any]]:
        """Rerank using BGE β€” reads first 512 chars only."""
        if self._bge_model is None:
            self._load_bge()

        if self._bge_model is None:
            logger.warning("BGE unavailable β€” falling back to vector score ordering")
            return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]

        pairs = []
        valid_docs = []
        for doc in docs:
            content = doc.get("content", "").strip()
            if content:
                pairs.append([query, content[:self.MAX_CONTENT_CHARS_BGE]])
                valid_docs.append(doc)

        if not pairs:
            return []

        try:
            if self._use_flag:
                scores = self._bge_model.compute_score(pairs, batch_size=64)
                if isinstance(scores, float):
                    scores = [scores]
            else:
                scores = self._bge_model.predict(pairs)
                if isinstance(scores, float):
                    scores = [scores]

            for i, doc in enumerate(valid_docs):
                doc["rerank_score"] = float(scores[i])

            valid_docs.sort(key=lambda x: x["rerank_score"], reverse=True)

            logger.info(
                f"[Reranker] BGE: {len(valid_docs)} docs β†’ top {top_n} "
                f"(max_score={valid_docs[0]['rerank_score']:.3f})"
            )
            return valid_docs[:top_n]

        except Exception as e:
            logger.error(f"BGE reranking failed: {e} β€” falling back to vector score")
            return sorted(docs, key=lambda x: x.get("score", 0), reverse=True)[:top_n]

    @property
    def model_type(self) -> str:
        return "jina_v3" if self._is_jina_v3 else "bge"