Spaces:
Sleeping
Sleeping
| """ | |
| Reranker utilities for paper retrieval | |
| Based on OpenScholar's rerank_paragraphs_bge function | |
| Supports two modes: | |
| 1. Direct mode: Use FlagReranker directly (requires global lock for thread-safety) | |
| 2. API mode: Use reranker API service with load balancing (recommended for multi-GPU) | |
| """ | |
| import os | |
| import threading | |
| import time | |
| import requests | |
| from typing import List, Dict, Any, Optional, Tuple, Union | |
| # Suppress transformers progress bars | |
| os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error') | |
| # Global lock for reranker usage (FlagReranker's tokenizer is not thread-safe) | |
| # This prevents "Already borrowed" errors when multiple threads use the same reranker | |
| # NOTE: Not needed when using API mode | |
| _reranker_usage_lock = threading.Lock() | |
| # Try to import endpoint pool for API mode | |
| try: | |
| from .reranker_endpoint_pool import RerankerEndpointPool | |
| HAS_ENDPOINT_POOL = True | |
| except ImportError: | |
| HAS_ENDPOINT_POOL = False | |
| RerankerEndpointPool = None | |
| def rerank_paragraphs_bge( | |
| query: str, | |
| paragraphs: List[Dict[str, Any]], | |
| reranker: Optional[Any] = None, | |
| reranker_endpoint_pool: Optional[Any] = None, | |
| norm_cite: bool = False, | |
| start_index: int = 0, | |
| use_abstract: bool = False, | |
| timeout: float = 30.0, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]: | |
| """ | |
| Rerank paragraphs using BGE reranker (from OpenScholar) | |
| Supports two modes: | |
| 1. Direct mode: Pass FlagReranker instance (uses global lock, thread-safe but serialized) | |
| 2. API mode: Pass RerankerEndpointPool (recommended for multi-GPU, parallel requests) | |
| Args: | |
| query: Search query | |
| paragraphs: List of paragraph/paper dictionaries | |
| reranker: FlagReranker instance (for direct mode, optional if using API mode) | |
| reranker_endpoint_pool: RerankerEndpointPool instance (for API mode, optional if using direct mode) | |
| norm_cite: Whether to normalize citation counts and add to scores | |
| start_index: Starting index for id mapping | |
| use_abstract: Whether to include abstract in reranking text | |
| timeout: Request timeout for API mode (seconds) | |
| Returns: | |
| Tuple of: | |
| - reranked_paragraphs: List of reranked paragraphs | |
| - result_dict: Dictionary mapping original index to score | |
| - id_mapping: Dictionary mapping new index to original index | |
| """ | |
| # Filter out paragraphs without text | |
| paragraphs = [p for p in paragraphs if p.get("text") is not None] | |
| if not paragraphs: | |
| return [], {}, {} | |
| # Build paragraph texts for reranking | |
| if use_abstract: | |
| paragraph_texts = [ | |
| p["title"] + "\n" + p["abstract"] + "\n" + p["text"] | |
| if "title" in p and "abstract" in p and p.get("title") and p.get("abstract") | |
| else p["text"] | |
| for p in paragraphs | |
| ] | |
| else: | |
| paragraph_texts = [ | |
| p["title"] + " " + p["text"] | |
| if "title" in p and p.get("title") is not None | |
| else p["text"] | |
| for p in paragraphs | |
| ] | |
| # Filter out empty or None texts | |
| valid_indices = [] | |
| valid_texts = [] | |
| for i, text in enumerate(paragraph_texts): | |
| if text and isinstance(text, str) and text.strip(): | |
| valid_indices.append(i) | |
| valid_texts.append(text) | |
| # If no valid texts, return empty results | |
| if not valid_texts: | |
| return [], {}, {} | |
| # If some texts were filtered out, update paragraphs list | |
| if len(valid_indices) < len(paragraphs): | |
| paragraphs = [paragraphs[i] for i in valid_indices] | |
| paragraph_texts = valid_texts | |
| # Compute reranking scores | |
| if reranker is None and reranker_endpoint_pool is None: | |
| # If no reranker, return original order | |
| id_mapping = {i: i + start_index for i in range(len(paragraphs))} | |
| result_dict = {i: 0.0 for i in range(len(paragraphs))} | |
| return paragraphs, result_dict, id_mapping | |
| # API mode: Use reranker API service (recommended for multi-GPU) | |
| if reranker_endpoint_pool is not None: | |
| return _rerank_via_api( | |
| query=query, | |
| paragraph_texts=paragraph_texts, | |
| paragraphs=paragraphs, | |
| reranker_endpoint_pool=reranker_endpoint_pool, | |
| norm_cite=norm_cite, | |
| start_index=start_index, | |
| timeout=timeout | |
| ) | |
| # Direct mode: Use FlagReranker directly (requires global lock) | |
| # Suppress transformers warnings and progress bars during computation | |
| original_verbosity = os.environ.get('TRANSFORMERS_VERBOSITY', '') | |
| os.environ['TRANSFORMERS_VERBOSITY'] = 'error' | |
| # Use lock to prevent "Already borrowed" errors from Rust tokenizer | |
| # FlagReranker's tokenizer is not thread-safe, so we need to serialize access | |
| with _reranker_usage_lock: | |
| try: | |
| # Ensure we have at least one valid text before calling compute_score | |
| if not paragraph_texts: | |
| return [], {}, {} | |
| scores = reranker.compute_score([[query, p] for p in paragraph_texts], batch_size=100) | |
| finally: | |
| # Restore original verbosity | |
| if original_verbosity: | |
| os.environ['TRANSFORMERS_VERBOSITY'] = original_verbosity | |
| elif 'TRANSFORMERS_VERBOSITY' in os.environ: | |
| del os.environ['TRANSFORMERS_VERBOSITY'] | |
| # Handle score format (can be float or list) | |
| if isinstance(scores, float): | |
| result_dict = {0: scores} | |
| else: | |
| result_dict = {p_id: score for p_id, score in enumerate(scores)} | |
| # Add normalized citation counts if enabled | |
| if norm_cite: | |
| citation_items = [ | |
| item["citation_counts"] | |
| for item in paragraphs | |
| if "citation_counts" in item and item["citation_counts"] is not None | |
| ] | |
| if len(citation_items) > 0: | |
| max_citations = max(citation_items) | |
| for p_id in result_dict: | |
| if ( | |
| "citation_counts" in paragraphs[p_id] | |
| and paragraphs[p_id]["citation_counts"] is not None | |
| ): | |
| result_dict[p_id] = result_dict[p_id] + ( | |
| paragraphs[p_id]["citation_counts"] / max_citations | |
| ) | |
| # Sort by score | |
| p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True) | |
| # Build reranked list and id mapping | |
| new_orders = [] | |
| id_mapping = {} | |
| for i, (p_id, _) in enumerate(p_ids): | |
| new_orders.append(paragraphs[p_id]) | |
| id_mapping[i] = int(p_id) + start_index | |
| return new_orders, result_dict, id_mapping | |
| def _rerank_via_api( | |
| query: str, | |
| paragraph_texts: List[str], | |
| paragraphs: List[Dict[str, Any]], | |
| reranker_endpoint_pool: Any, | |
| norm_cite: bool = False, | |
| start_index: int = 0, | |
| timeout: float = 30.0, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[int, float], Dict[int, int]]: | |
| """ | |
| Rerank paragraphs via API service (supports load balancing across multiple GPUs) | |
| Args: | |
| query: Search query | |
| paragraph_texts: List of paragraph texts (already formatted) | |
| paragraphs: List of paragraph dictionaries | |
| reranker_endpoint_pool: RerankerEndpointPool instance | |
| norm_cite: Whether to normalize citation counts | |
| start_index: Starting index for id mapping | |
| timeout: Request timeout | |
| Returns: | |
| Tuple of reranked paragraphs, result dict, and id mapping | |
| """ | |
| if not paragraph_texts: | |
| return [], {}, {} | |
| # Get endpoint from pool (round-robin load balancing) | |
| endpoint = reranker_endpoint_pool.get_endpoint() | |
| api_url = f"{endpoint}/rerank" | |
| # Prepare request | |
| request_data = { | |
| "query": query, | |
| "paragraphs": paragraph_texts, | |
| "batch_size": 100 | |
| } | |
| start_time = time.time() | |
| try: | |
| # Make API request | |
| response = requests.post( | |
| api_url, | |
| json=request_data, | |
| timeout=timeout | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| scores = result.get("scores", []) | |
| response_time = time.time() - start_time | |
| # Mark success | |
| reranker_endpoint_pool.mark_success(endpoint, response_time) | |
| except requests.exceptions.RequestException as e: | |
| # Mark error | |
| reranker_endpoint_pool.mark_error(endpoint, str(e)) | |
| raise RuntimeError(f"Reranker API request failed: {e}") | |
| # Handle score format (should be list from API) | |
| if isinstance(scores, float): | |
| result_dict = {0: scores} | |
| else: | |
| result_dict = {p_id: score for p_id, score in enumerate(scores)} | |
| # Add normalized citation counts if enabled | |
| if norm_cite: | |
| citation_items = [ | |
| item["citation_counts"] | |
| for item in paragraphs | |
| if "citation_counts" in item and item["citation_counts"] is not None | |
| ] | |
| if len(citation_items) > 0: | |
| max_citations = max(citation_items) | |
| for p_id in result_dict: | |
| if ( | |
| "citation_counts" in paragraphs[p_id] | |
| and paragraphs[p_id]["citation_counts"] is not None | |
| ): | |
| result_dict[p_id] = result_dict[p_id] + ( | |
| paragraphs[p_id]["citation_counts"] / max_citations | |
| ) | |
| # Sort by score | |
| p_ids = sorted(result_dict.items(), key=lambda x: x[1], reverse=True) | |
| # Build reranked list and id mapping | |
| new_orders = [] | |
| id_mapping = {} | |
| for i, (p_id, _) in enumerate(p_ids): | |
| new_orders.append(paragraphs[p_id]) | |
| id_mapping[i] = int(p_id) + start_index | |
| return new_orders, result_dict, id_mapping | |