File size: 3,190 Bytes
7d0fa43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
Cross-encoder reranker.
Reranks FAISS retrieval results by true query-document relevance.

WHY cross-encoder over bi-encoder (MiniLM)?
MiniLM embeds query and document independently — fast but approximate.
Cross-encoder sees query+document together — slower but much more accurate.
Used post-retrieval on top-15 candidates to select best top-5.

WHY ms-marco-MiniLM-L-6-v2?
Trained on MS-MARCO passage ranking — transfers well to legal QA.
Small enough to load on HF Spaces free tier (~80MB).
Fast enough for reranking 15 candidates in ~200ms on CPU.

Interview answer:
"I added a cross-encoder reranker post-retrieval to boost precision@5
by focusing on true relevance rather than embedding similarity alone.
Legal domain papers show 8-15% precision lift from reranking."
"""

import logging
from typing import List, Dict

logger = logging.getLogger(__name__)

_reranker = None
_reranker_loaded = False


def load_reranker():
    """
    Load cross-encoder once at startup.
    Fails gracefully — retrieval works without reranker.
    Call from api/main.py after other models load.
    """
    global _reranker, _reranker_loaded

    try:
        from sentence_transformers import CrossEncoder
        logger.info("Loading cross-encoder reranker...")
        _reranker = CrossEncoder(
            "cross-encoder/ms-marco-MiniLM-L-6-v2",
            max_length=512
        )
        _reranker_loaded = True
        logger.info("Cross-encoder reranker ready")
    except Exception as e:
        logger.warning(f"Reranker load failed: {e}. Retrieval will use FAISS scores only.")
        _reranker_loaded = False


def rerank(query: str, chunks: List[Dict], top_k: int = 5) -> List[Dict]:
    """
    Rerank chunks by cross-encoder relevance score.

    Args:
        query: user query string
        chunks: list of retrieved chunks from FAISS
        top_k: number of top chunks to return after reranking

    Returns:
        top_k chunks sorted by reranker score descending.
        If reranker not loaded, returns original chunks[:top_k].
    """
    if not _reranker_loaded or _reranker is None:
        return chunks[:top_k]

    if not chunks:
        return []

    try:
        # Build query-document pairs
        pairs = []
        for chunk in chunks:
            text = (
                chunk.get("expanded_context") or
                chunk.get("chunk_text") or
                chunk.get("text", "")
            )[:512]
            pairs.append([query, text])

        # Score all pairs
        scores = _reranker.predict(pairs, batch_size=16)

        # Attach scores and sort
        for chunk, score in zip(chunks, scores):
            chunk["reranker_score"] = float(score)

        reranked = sorted(chunks, key=lambda x: x.get("reranker_score", 0), reverse=True)
        
        logger.info(
            f"Reranked {len(chunks)} chunks → top {top_k}. "
            f"Top score: {reranked[0].get('reranker_score', 0):.3f}"
        )
        
        return reranked[:top_k]

    except Exception as e:
        logger.warning(f"Reranking failed: {e}. Using FAISS order.")
        return chunks[:top_k]


def is_loaded() -> bool:
    return _reranker_loaded