File size: 5,723 Bytes
f866820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""
Reranking module for improving retrieval precision.

Uses cross-encoder models to reorder initial retrieval results
based on query-document relevance.
"""

from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass

# Model cache for lazy loading
_reranker_model = None
_reranker_model_name = None


@dataclass
class RerankResult:
    """Result from reranking operation."""
    chunks: List[Dict[str, Any]]
    model_used: str
    reranked: bool


def _get_cross_encoder(model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
    """
    Lazy load and cache cross-encoder model.

    Args:
        model_name: HuggingFace model name for cross-encoder

    Returns:
        CrossEncoder model instance
    """
    global _reranker_model, _reranker_model_name

    if _reranker_model is None or _reranker_model_name != model_name:
        try:
            from sentence_transformers import CrossEncoder
            _reranker_model = CrossEncoder(model_name)
            _reranker_model_name = model_name
        except ImportError:
            raise ImportError(
                "sentence-transformers not installed. "
                "Install with: pip install sentence-transformers"
            )
        except Exception as e:
            raise RuntimeError(f"Failed to load cross-encoder model: {e}")

    return _reranker_model


def rerank_chunks(
    query: str,
    chunks: List[Dict[str, Any]],
    top_k: int = 5,
    model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
) -> RerankResult:
    """
    Rerank chunks using a cross-encoder model.

    Cross-encoders process query and document together, enabling
    more nuanced relevance scoring than bi-encoders.

    Args:
        query: User query
        chunks: List of chunks to rerank
        top_k: Number of top results to return
        model_name: Cross-encoder model to use

    Returns:
        RerankResult with reranked chunks
    """
    if not chunks:
        return RerankResult(chunks=[], model_used="none", reranked=False)

    if len(chunks) <= 1:
        return RerankResult(chunks=chunks, model_used="none", reranked=False)

    try:
        model = _get_cross_encoder(model_name)

        # Prepare query-document pairs
        pairs = []
        for chunk in chunks:
            text = chunk.get("text", "")
            if not text and "metadata" in chunk:
                text = chunk["metadata"].get("text", "")
            pairs.append((query, text))

        # Get relevance scores
        scores = model.predict(pairs)

        # Combine chunks with scores and sort
        scored_chunks = list(zip(chunks, scores))
        scored_chunks.sort(key=lambda x: x[1], reverse=True)

        # Build result with rerank scores
        results = []
        for chunk, score in scored_chunks[:top_k]:
            result_chunk = chunk.copy()
            result_chunk["rerank_score"] = float(score)
            results.append(result_chunk)

        return RerankResult(
            chunks=results,
            model_used=model_name,
            reranked=True
        )

    except Exception as e:
        # Fallback: return original chunks without reranking
        return RerankResult(
            chunks=chunks[:top_k],
            model_used=f"fallback (error: {str(e)[:50]})",
            reranked=False
        )


def rerank_with_llm(
    query: str,
    chunks: List[Dict[str, Any]],
    top_k: int = 5
) -> RerankResult:
    """
    Rerank chunks using LLM-based scoring (fallback method).

    More expensive but works without additional models.

    Args:
        query: User query
        chunks: List of chunks to rerank
        top_k: Number of top results to return

    Returns:
        RerankResult with reranked chunks
    """
    if not chunks:
        return RerankResult(chunks=[], model_used="none", reranked=False)

    if len(chunks) <= 1:
        return RerankResult(chunks=chunks, model_used="none", reranked=False)

    try:
        from src.llm_providers import call_llm

        # Build scoring prompt
        chunk_texts = []
        for i, chunk in enumerate(chunks):
            text = chunk.get("text", "")[:500]  # Truncate for prompt size
            chunk_texts.append(f"[{i}] {text}")

        prompt = f"""Rate the relevance of each document to the query on a scale of 0-10.
Return ONLY a comma-separated list of scores in order (e.g., "8,3,7,5").

Query: {query}

Documents:
{chr(10).join(chunk_texts)}

Scores:"""

        response = call_llm(prompt=prompt, temperature=0.0, max_tokens=100)
        scores_text = response.get("text", "").strip()

        # Parse scores
        try:
            scores = [float(s.strip()) for s in scores_text.split(",")]
            if len(scores) != len(chunks):
                raise ValueError("Score count mismatch")
        except (ValueError, AttributeError):
            # Fallback to original order
            return RerankResult(
                chunks=chunks[:top_k],
                model_used="llm_parse_failed",
                reranked=False
            )

        # Sort by scores
        scored_chunks = list(zip(chunks, scores))
        scored_chunks.sort(key=lambda x: x[1], reverse=True)

        results = []
        for chunk, score in scored_chunks[:top_k]:
            result_chunk = chunk.copy()
            result_chunk["rerank_score"] = float(score)
            results.append(result_chunk)

        return RerankResult(
            chunks=results,
            model_used="llm",
            reranked=True
        )

    except Exception as e:
        return RerankResult(
            chunks=chunks[:top_k],
            model_used=f"llm_fallback (error: {str(e)[:50]})",
            reranked=False
        )