Spaces:
Sleeping
Sleeping
| """ | |
| Reranker API Service | |
| Pack FlagReranker into an HTTP API service, supporting multi-GPU load balancing. | |
| """ | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Optional | |
| import argparse | |
| # Suppress transformers warnings | |
| os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error') | |
| try: | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| HAS_FASTAPI = True | |
| except ImportError: | |
| HAS_FASTAPI = False | |
| print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn") | |
| try: | |
| from FlagEmbedding import FlagReranker | |
| HAS_FLAGEMBEDDING = True | |
| except ImportError: | |
| HAS_FLAGEMBEDDING = False | |
| print("Warning: FlagEmbedding not installed. Install with: pip install FlagEmbedding") | |
| # Request/Response models | |
| class RerankRequest(BaseModel): | |
| query: str | |
| paragraphs: List[str] | |
| batch_size: int = 100 | |
| class RerankResponse(BaseModel): | |
| scores: List[float] | |
| success: bool | |
| message: Optional[str] = None | |
| # Global reranker instance | |
| _reranker: Optional[Any] = None | |
| def create_app(model_path: str, use_fp16: bool = True, device: Optional[str] = None): | |
| """Create FastAPI app with reranker""" | |
| global _reranker | |
| app = FastAPI(title="Reranker API Service", version="1.0.0") | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def load_reranker(): | |
| """Load reranker model on startup""" | |
| global _reranker | |
| if not HAS_FLAGEMBEDDING: | |
| raise RuntimeError("FlagEmbedding not installed") | |
| print(f"Loading reranker model: {model_path}") | |
| print(f"Using FP16: {use_fp16}") | |
| if device: | |
| print(f"Using device: {device}") | |
| try: | |
| _reranker = FlagReranker( | |
| model_path, | |
| use_fp16=use_fp16, | |
| ) | |
| if device: | |
| # Note: FlagReranker may not support explicit device setting | |
| # This is a placeholder for future support | |
| pass | |
| print("Reranker model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading reranker: {e}") | |
| raise | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "model_loaded": _reranker is not None | |
| } | |
| async def rerank(request: RerankRequest): | |
| """Rerank paragraphs given a query""" | |
| global _reranker | |
| if _reranker is None: | |
| raise HTTPException(status_code=503, detail="Reranker not loaded") | |
| if not request.paragraphs: | |
| return RerankResponse( | |
| scores=[], | |
| success=True, | |
| message="No paragraphs to rerank" | |
| ) | |
| try: | |
| # Prepare sentence pairs: [[query, paragraph], ...] | |
| sentence_pairs = [[request.query, p] for p in request.paragraphs] | |
| # Compute scores | |
| scores = _reranker.compute_score( | |
| sentence_pairs, | |
| batch_size=request.batch_size | |
| ) | |
| # Handle score format (can be float or list) | |
| if isinstance(scores, float): | |
| scores = [scores] | |
| elif not isinstance(scores, list): | |
| scores = list(scores) | |
| return RerankResponse( | |
| scores=scores, | |
| success=True | |
| ) | |
| except Exception as e: | |
| print(f"Error during reranking: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| return app | |
| def main(): | |
| """Main entry point for reranker API service""" | |
| parser = argparse.ArgumentParser(description="Reranker API Service") | |
| parser.add_argument( | |
| "--model_path", | |
| type=str, | |
| required=True, | |
| help="Path to reranker model (e.g., 'OpenScholar/OpenScholar_Reranker')" | |
| ) | |
| parser.add_argument( | |
| "--host", | |
| type=str, | |
| default="0.0.0.0", | |
| help="Host to bind to (default: 0.0.0.0)" | |
| ) | |
| parser.add_argument( | |
| "--port", | |
| type=int, | |
| default=8004, | |
| help="Port to bind to (default: 8004)" | |
| ) | |
| parser.add_argument( | |
| "--use_fp16", | |
| action="store_true", | |
| default=True, | |
| help="Use FP16 precision (default: True)" | |
| ) | |
| parser.add_argument( | |
| "--no_fp16", | |
| dest="use_fp16", | |
| action="store_false", | |
| help="Disable FP16 precision" | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| help="Device to use (e.g., 'cuda:0', 'cuda:1')" | |
| ) | |
| parser.add_argument( | |
| "--workers", | |
| type=int, | |
| default=1, | |
| help="Number of worker processes (default: 1, use 1 for reranker)" | |
| ) | |
| args = parser.parse_args() | |
| if not HAS_FASTAPI: | |
| print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn") | |
| sys.exit(1) | |
| if not HAS_FLAGEMBEDDING: | |
| print("Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding") | |
| sys.exit(1) | |
| # Create app | |
| app = create_app( | |
| model_path=args.model_path, | |
| use_fp16=args.use_fp16, | |
| device=args.device | |
| ) | |
| # Run server | |
| print(f"Starting reranker API service on {args.host}:{args.port}") | |
| print(f"Model: {args.model_path}") | |
| print(f"FP16: {args.use_fp16}") | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| workers=args.workers, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |