GradioDemo / shared /utils /reranker_api_service.py
eigentom
Initial Update
90c099b
"""
Reranker API Service
Pack FlagReranker into an HTTP API service, supporting multi-GPU load balancing.
"""
import os
import sys
from pathlib import Path
from typing import List, Dict, Any, Optional
import argparse
# Suppress transformers warnings
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
try:
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
HAS_FASTAPI = True
except ImportError:
HAS_FASTAPI = False
print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn")
try:
from FlagEmbedding import FlagReranker
HAS_FLAGEMBEDDING = True
except ImportError:
HAS_FLAGEMBEDDING = False
print("Warning: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
# Request/Response models
class RerankRequest(BaseModel):
query: str
paragraphs: List[str]
batch_size: int = 100
class RerankResponse(BaseModel):
scores: List[float]
success: bool
message: Optional[str] = None
# Global reranker instance
_reranker: Optional[Any] = None
def create_app(model_path: str, use_fp16: bool = True, device: Optional[str] = None):
"""Create FastAPI app with reranker"""
global _reranker
app = FastAPI(title="Reranker API Service", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def load_reranker():
"""Load reranker model on startup"""
global _reranker
if not HAS_FLAGEMBEDDING:
raise RuntimeError("FlagEmbedding not installed")
print(f"Loading reranker model: {model_path}")
print(f"Using FP16: {use_fp16}")
if device:
print(f"Using device: {device}")
try:
_reranker = FlagReranker(
model_path,
use_fp16=use_fp16,
)
if device:
# Note: FlagReranker may not support explicit device setting
# This is a placeholder for future support
pass
print("Reranker model loaded successfully")
except Exception as e:
print(f"Error loading reranker: {e}")
raise
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": _reranker is not None
}
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""Rerank paragraphs given a query"""
global _reranker
if _reranker is None:
raise HTTPException(status_code=503, detail="Reranker not loaded")
if not request.paragraphs:
return RerankResponse(
scores=[],
success=True,
message="No paragraphs to rerank"
)
try:
# Prepare sentence pairs: [[query, paragraph], ...]
sentence_pairs = [[request.query, p] for p in request.paragraphs]
# Compute scores
scores = _reranker.compute_score(
sentence_pairs,
batch_size=request.batch_size
)
# Handle score format (can be float or list)
if isinstance(scores, float):
scores = [scores]
elif not isinstance(scores, list):
scores = list(scores)
return RerankResponse(
scores=scores,
success=True
)
except Exception as e:
print(f"Error during reranking: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
return app
def main():
"""Main entry point for reranker API service"""
parser = argparse.ArgumentParser(description="Reranker API Service")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to reranker model (e.g., 'OpenScholar/OpenScholar_Reranker')"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8004,
help="Port to bind to (default: 8004)"
)
parser.add_argument(
"--use_fp16",
action="store_true",
default=True,
help="Use FP16 precision (default: True)"
)
parser.add_argument(
"--no_fp16",
dest="use_fp16",
action="store_false",
help="Disable FP16 precision"
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g., 'cuda:0', 'cuda:1')"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes (default: 1, use 1 for reranker)"
)
args = parser.parse_args()
if not HAS_FASTAPI:
print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn")
sys.exit(1)
if not HAS_FLAGEMBEDDING:
print("Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
sys.exit(1)
# Create app
app = create_app(
model_path=args.model_path,
use_fp16=args.use_fp16,
device=args.device
)
# Run server
print(f"Starting reranker API service on {args.host}:{args.port}")
print(f"Model: {args.model_path}")
print(f"FP16: {args.use_fp16}")
uvicorn.run(
app,
host=args.host,
port=args.port,
workers=args.workers,
log_level="info"
)
if __name__ == "__main__":
main()