GradioDemo / shared /utils /reranker.py
eigentom
Initial Update
90c099b
"""
Reranker utilities for paper retrieval
Based on OpenScholar's rerank_paragraphs_bge function
Supports two modes:
1. Direct mode: Use FlagReranker directly (requires global lock for thread-safety)
2. API mode: Use reranker API service with load balancing (recommended for multi-GPU)
"""
import os
import threading
import time
import requests
from typing import List, Dict, Any, Optional, Tuple, Union
# Suppress transformers progress bars
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
# Global lock for reranker usage (FlagReranker's tokenizer is not thread-safe)
# This prevents "Already borrowed" errors when multiple threads use the same reranker
# NOTE: Not needed when using API mode
_reranker_usage_lock = threading.Lock()
# Try to import endpoint pool for API mode
try:
from .reranker_endpoint_pool import RerankerEndpointPool
HAS_ENDPOINT_POOL = True
except ImportError:
HAS_ENDPOINT_POOL = False
RerankerEndpointPool = None
def rerank_paragraphs_bge(
query: str,
paragraphs: List[Dict[str, Any]],
reranker: Optional[Any] = None,
reranker_endpoint_pool: Optional[Any] = None,
norm_cite: bool = False,
start_index: int = 0,
use_abstract: bool = False,
timeout: float = 30.0,
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
"""
Rerank paragraphs using BGE reranker (from OpenScholar)
Supports two modes:
1. Direct mode: Pass FlagReranker instance (uses global lock, thread-safe but serialized)
2. API mode: Pass RerankerEndpointPool (recommended for multi-GPU, parallel requests)
Args:
query: Search query
paragraphs: List of paragraph/paper dictionaries
reranker: FlagReranker instance (for direct mode, optional if using API mode)
reranker_endpoint_pool: RerankerEndpointPool instance (for API mode, optional if using direct mode)
norm_cite: Whether to normalize citation counts and add to scores
start_index: Starting index for id mapping
use_abstract: Whether to include abstract in reranking text
timeout: Request timeout for API mode (seconds)
Returns:
Tuple of:
- reranked_paragraphs: List of reranked paragraphs
- result_dict: Dictionary mapping original index to score
- id_mapping: Dictionary mapping new index to original index
"""
# Filter out paragraphs without text
paragraphs = [p for p in paragraphs if p.get("text") is not None]
if not paragraphs:
return [], {}, {}
# Build paragraph texts for reranking
if use_abstract:
paragraph_texts = [
p["title"] + "\n" + p["abstract"] + "\n" + p["text"]
if "title" in p and "abstract" in p and p.get("title") and p.get("abstract")
else p["text"]
for p in paragraphs
]
else:
paragraph_texts = [
p["title"] + " " + p["text"]
if "title" in p and p.get("title") is not None
else p["text"]
for p in paragraphs
]
# Filter out empty or None texts
valid_indices = []
valid_texts = []
for i, text in enumerate(paragraph_texts):
if text and isinstance(text, str) and text.strip():
valid_indices.append(i)
valid_texts.append(text)
# If no valid texts, return empty results
if not valid_texts:
return [], {}, {}
# If some texts were filtered out, update paragraphs list
if len(valid_indices) < len(paragraphs):
paragraphs = [paragraphs[i] for i in valid_indices]
paragraph_texts = valid_texts
# Compute reranking scores
if reranker is None and reranker_endpoint_pool is None:
# If no reranker, return original order
id_mapping = {i: i + start_index for i in range(len(paragraphs))}
result_dict = {i: 0.0 for i in range(len(paragraphs))}
return paragraphs, result_dict, id_mapping
# API mode: Use reranker API service (recommended for multi-GPU)
if reranker_endpoint_pool is not None:
return _rerank_via_api(
query=query,
paragraph_texts=paragraph_texts,
paragraphs=paragraphs,
reranker_endpoint_pool=reranker_endpoint_pool,
norm_cite=norm_cite,
start_index=start_index,
timeout=timeout
)
# Direct mode: Use FlagReranker directly (requires global lock)
# Suppress transformers warnings and progress bars during computation
original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '')
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
# Use lock to prevent "Already borrowed" errors from Rust tokenizer
# FlagReranker's tokenizer is not thread-safe, so we need to serialize access
with _reranker_usage_lock:
try:
# Ensure we have at least one valid text before calling compute_score
if not paragraph_texts:
return [], {}, {}
scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100)
finally:
# Restore original verbosity
if original_verbosity:
os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity
elif 'TRANSFORMERS_VERBOSITY' in os.environ:
del os.environ['TRANSFORMERS_VERBOSITY']
# Handle score format (can be float or list)
if isinstance(scores, float):
result_dict = {0: scores}
else:
result_dict = {p_id: score for p_id, score in enumerate(scores)}
# Add normalized citation counts if enabled
if norm_cite:
citation_items = [
item["citation_counts"]
for item in paragraphs
if "citation_counts" in item and item["citation_counts"] is not None
]
if len(citation_items) > 0:
max_citations = max(citation_items)
for p_id in result_dict:
if (
"citation_counts" in paragraphs[p_id]
and paragraphs[p_id]["citation_counts"] is not None
):
result_dict[p_id] = result_dict[p_id] + (
paragraphs[p_id]["citation_counts"] / max_citations
)
# Sort by score
p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
# Build reranked list and id mapping
new_orders = []
id_mapping = {}
for i, (p_id, _) in enumerate(p_ids):
new_orders.append(paragraphs[p_id])
id_mapping[i] = int(p_id) + start_index
return new_orders, result_dict, id_mapping
def _rerank_via_api(
query: str,
paragraph_texts: List[str],
paragraphs: List[Dict[str, Any]],
reranker_endpoint_pool: Any,
norm_cite: bool = False,
start_index: int = 0,
timeout: float = 30.0,
) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]:
"""
Rerank paragraphs via API service (supports load balancing across multiple GPUs)
Args:
query: Search query
paragraph_texts: List of paragraph texts (already formatted)
paragraphs: List of paragraph dictionaries
reranker_endpoint_pool: RerankerEndpointPool instance
norm_cite: Whether to normalize citation counts
start_index: Starting index for id mapping
timeout: Request timeout
Returns:
Tuple of reranked paragraphs, result dict, and id mapping
"""
if not paragraph_texts:
return [], {}, {}
# Get endpoint from pool (round-robin load balancing)
endpoint = reranker_endpoint_pool.get_endpoint()
api_url = f"{endpoint}/rerank"
# Prepare request
request_data = {
"query": query,
"paragraphs": paragraph_texts,
"batch_size": 100
}
start_time = time.time()
try:
# Make API request
response = requests.post(
api_url,
json=request_data,
timeout=timeout
)
response.raise_for_status()
result = response.json()
scores = result.get("scores", [])
response_time = time.time() - start_time
# Mark success
reranker_endpoint_pool.mark_success(endpoint, response_time)
except requests.exceptions.RequestException as e:
# Mark error
reranker_endpoint_pool.mark_error(endpoint, str(e))
raise RuntimeError(f"Reranker API request failed: {e}")
# Handle score format (should be list from API)
if isinstance(scores, float):
result_dict = {0: scores}
else:
result_dict = {p_id: score for p_id, score in enumerate(scores)}
# Add normalized citation counts if enabled
if norm_cite:
citation_items = [
item["citation_counts"]
for item in paragraphs
if "citation_counts" in item and item["citation_counts"] is not None
]
if len(citation_items) > 0:
max_citations = max(citation_items)
for p_id in result_dict:
if (
"citation_counts" in paragraphs[p_id]
and paragraphs[p_id]["citation_counts"] is not None
):
result_dict[p_id] = result_dict[p_id] + (
paragraphs[p_id]["citation_counts"] / max_citations
)
# Sort by score
p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True)
# Build reranked list and id mapping
new_orders = []
id_mapping = {}
for i, (p_id, _) in enumerate(p_ids):
new_orders.append(paragraphs[p_id])
id_mapping[i] = int(p_id) + start_index
return new_orders, result_dict, id_mapping