File size: 9,852 Bytes
90c099b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
Reranker utilities for paper retrieval
Based on OpenScholar's rerank_paragraphs_bge function

Supports two modes:
1. Direct mode: Use FlagReranker directly (requires global lock for thread-safety)
2. API mode: Use reranker API service with load balancing (recommended for multi-GPU)
"""
import os
import threading
import time
import requests
from typing import List, Dict, Any, Optional, Tuple, Union

# Suppress transformers progress bars
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')

# Global lock for reranker usage (FlagReranker's tokenizer is not thread-safe)
# This prevents "Already borrowed" errors when multiple threads use the same reranker
# NOTE: Not needed when using API mode
_reranker_usage_lock = threading.Lock()

# Try to import endpoint pool for API mode
try:
    from .reranker_endpoint_pool import RerankerEndpointPool
    HAS_ENDPOINT_POOL = True
except ImportError:
    HAS_ENDPOINT_POOL = False
    RerankerEndpointPool = None


def rerank_paragraphs_bge(
    query: str,
    paragraphs: List[Dict[str, Any]],
    reranker: Optional[Any] = None,
    reranker_endpoint_pool: Optional[Any] = None,
    norm_cite: bool = False,
    start_index: int = 0,
    use_abstract: bool = False,
    timeout: float = 30.0,
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
    """
    Rerank paragraphs using BGE reranker (from OpenScholar)
    
    Supports two modes:
    1. Direct mode: Pass FlagReranker instance (uses global lock, thread-safe but serialized)
    2. API mode: Pass RerankerEndpointPool (recommended for multi-GPU, parallel requests)
    
    Args:
        query: Search query
        paragraphs: List of paragraph/paper dictionaries
        reranker: FlagReranker instance (for direct mode, optional if using API mode)
        reranker_endpoint_pool: RerankerEndpointPool instance (for API mode, optional if using direct mode)
        norm_cite: Whether to normalize citation counts and add to scores
        start_index: Starting index for id mapping
        use_abstract: Whether to include abstract in reranking text
        timeout: Request timeout for API mode (seconds)
        
    Returns:
        Tuple of:
        - reranked_paragraphs: List of reranked paragraphs
        - result_dict: Dictionary mapping original index to score
        - id_mapping: Dictionary mapping new index to original index
    """
    # Filter out paragraphs without text
    paragraphs = [p for p in paragraphs if p.get("text") is not None]
    
    if not paragraphs:
        return [], {}, {}
    
    # Build paragraph texts for reranking
    if use_abstract:
        paragraph_texts = [
            p["title"] + "\n" + p["abstract"] + "\n" + p["text"]
            if "title" in p and "abstract" in p and p.get("title") and p.get("abstract")
            else p["text"]
            for p in paragraphs
        ]
    else:
        paragraph_texts = [
            p["title"] + " " + p["text"]
            if "title" in p and p.get("title") is not None
            else p["text"]
            for p in paragraphs
        ]
    
    # Filter out empty or None texts
    valid_indices = []
    valid_texts = []
    for i, text in enumerate(paragraph_texts):
        if text and isinstance(text, str) and text.strip():
            valid_indices.append(i)
            valid_texts.append(text)
    
    # If no valid texts, return empty results
    if not valid_texts:
        return [], {}, {}
    
    # If some texts were filtered out, update paragraphs list
    if len(valid_indices) < len(paragraphs):
        paragraphs = [paragraphs[i] for i in valid_indices]
        paragraph_texts = valid_texts
    
    # Compute reranking scores
    if reranker is None and reranker_endpoint_pool is None:
        # If no reranker, return original order
        id_mapping = {i: i + start_index for i in range(len(paragraphs))}
        result_dict = {i: 0.0 for i in range(len(paragraphs))}
        return paragraphs, result_dict, id_mapping
    
    # API mode: Use reranker API service (recommended for multi-GPU)
    if reranker_endpoint_pool is not None:
        return _rerank_via_api(
            query=query,
            paragraph_texts=paragraph_texts,
            paragraphs=paragraphs,
            reranker_endpoint_pool=reranker_endpoint_pool,
            norm_cite=norm_cite,
            start_index=start_index,
            timeout=timeout
        )
    
    # Direct mode: Use FlagReranker directly (requires global lock)
    # Suppress transformers warnings and progress bars during computation
    original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
    os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
    
    # Use lock to prevent "Already borrowed" errors from Rust tokenizer
    # FlagReranker's tokenizer is not thread-safe, so we need to serialize access
    with _reranker_usage_lock:
        try:
            # Ensure we have at least one valid text before calling compute_score
            if not paragraph_texts:
                return [], {}, {}
            scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100)
        finally:
            # Restore original verbosity
            if original_verbosity:
                os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
            elif 'TRANSFORMERS_VERBOSITY' in os.environ:
                del os.environ['TRANSFORMERS_VERBOSITY']
    
    # Handle score format (can be float or list)
    if isinstance(scores, float):
        result_dict = {0: scores}
    else:
        result_dict = {p_id: score for p_id, score in enumerate(scores)}
    
    # Add normalized citation counts if enabled
    if norm_cite:
        citation_items = [
            item["citation_counts"]
            for item in paragraphs
            if "citation_counts" in item and item["citation_counts"] is not None
        ]
        if len(citation_items) > 0:
            max_citations = max(citation_items)
            for p_id in result_dict:
                if (
                    "citation_counts" in paragraphs[p_id]
                    and paragraphs[p_id]["citation_counts"] is not None
                ):
                    result_dict[p_id] = result_dict[p_id] + (
                        paragraphs[p_id]["citation_counts"] / max_citations
                    )
    
    # Sort by score
    p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
    
    # Build reranked list and id mapping
    new_orders = []
    id_mapping = {}
    for i, (p_id, _) in enumerate(p_ids):
        new_orders.append(paragraphs[p_id])
        id_mapping[i] = int(p_id) + start_index
    
    return new_orders, result_dict, id_mapping


def _rerank_via_api(
    query: str,
    paragraph_texts: List[str],
    paragraphs: List[Dict[str, Any]],
    reranker_endpoint_pool: Any,
    norm_cite: bool = False,
    start_index: int = 0,
    timeout: float = 30.0,
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
    """
    Rerank paragraphs via API service (supports load balancing across multiple GPUs)
    
    Args:
        query: Search query
        paragraph_texts: List of paragraph texts (already formatted)
        paragraphs: List of paragraph dictionaries
        reranker_endpoint_pool: RerankerEndpointPool instance
        norm_cite: Whether to normalize citation counts
        start_index: Starting index for id mapping
        timeout: Request timeout
        
    Returns:
        Tuple of reranked paragraphs, result dict, and id mapping
    """
    if not paragraph_texts:
        return [], {}, {}
    
    # Get endpoint from pool (round-robin load balancing)
    endpoint = reranker_endpoint_pool.get_endpoint()
    api_url = f"{endpoint}/rerank"
    
    # Prepare request
    request_data = {
        "query": query,
        "paragraphs": paragraph_texts,
        "batch_size": 100
    }
    
    start_time = time.time()
    try:
        # Make API request
        response = requests.post(
            api_url,
            json=request_data,
            timeout=timeout
        )
        response.raise_for_status()
        
        result = response.json()
        scores = result.get("scores", [])
        response_time = time.time() - start_time
        
        # Mark success
        reranker_endpoint_pool.mark_success(endpoint, response_time)
        
    except requests.exceptions.RequestException as e:
        # Mark error
        reranker_endpoint_pool.mark_error(endpoint, str(e))
        raise RuntimeError(f"Reranker API request failed: {e}")
    
    # Handle score format (should be list from API)
    if isinstance(scores, float):
        result_dict = {0: scores}
    else:
        result_dict = {p_id: score for p_id, score in enumerate(scores)}
    
    # Add normalized citation counts if enabled
    if norm_cite:
        citation_items = [
            item["citation_counts"]
            for item in paragraphs
            if "citation_counts" in item and item["citation_counts"] is not None
        ]
        if len(citation_items) > 0:
            max_citations = max(citation_items)
            for p_id in result_dict:
                if (
                    "citation_counts" in paragraphs[p_id]
                    and paragraphs[p_id]["citation_counts"] is not None
                ):
                    result_dict[p_id] = result_dict[p_id] + (
                        paragraphs[p_id]["citation_counts"] / max_citations
                    )
    
    # Sort by score
    p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
    
    # Build reranked list and id mapping
    new_orders = []
    id_mapping = {}
    for i, (p_id, _) in enumerate(p_ids):
        new_orders.append(paragraphs[p_id])
        id_mapping[i] = int(p_id) + start_index
    
    return new_orders, result_dict, id_mapping