Spaces:
Sleeping
Sleeping
File size: 6,092 Bytes
90c099b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """
Reranker API Service
Pack FlagReranker into an HTTP API service, supporting multi-GPU load balancing.
"""
import os
import sys
from pathlib import Path
from typing import List, Dict, Any, Optional
import argparse
# Suppress transformers warnings
os.environ.setdefault('TRANSFORMERS_VERBOSITY', 'error')
try:
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
HAS_FASTAPI = True
except ImportError:
HAS_FASTAPI = False
print("Warning: FastAPI not installed. Install with: pip install fastapi uvicorn")
try:
from FlagEmbedding import FlagReranker
HAS_FLAGEMBEDDING = True
except ImportError:
HAS_FLAGEMBEDDING = False
print("Warning: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
# Request/Response models
class RerankRequest(BaseModel):
query: str
paragraphs: List[str]
batch_size: int = 100
class RerankResponse(BaseModel):
scores: List[float]
success: bool
message: Optional[str] = None
# Global reranker instance
_reranker: Optional[Any] = None
def create_app(model_path: str, use_fp16: bool = True, device: Optional[str] = None):
"""Create FastAPI app with reranker"""
global _reranker
app = FastAPI(title="Reranker API Service", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def load_reranker():
"""Load reranker model on startup"""
global _reranker
if not HAS_FLAGEMBEDDING:
raise RuntimeError("FlagEmbedding not installed")
print(f"Loading reranker model: {model_path}")
print(f"Using FP16: {use_fp16}")
if device:
print(f"Using device: {device}")
try:
_reranker = FlagReranker(
model_path,
use_fp16=use_fp16,
)
if device:
# Note: FlagReranker may not support explicit device setting
# This is a placeholder for future support
pass
print("Reranker model loaded successfully")
except Exception as e:
print(f"Error loading reranker: {e}")
raise
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": _reranker is not None
}
@app.post("/rerank", response_model=RerankResponse)
async def rerank(request: RerankRequest):
"""Rerank paragraphs given a query"""
global _reranker
if _reranker is None:
raise HTTPException(status_code=503, detail="Reranker not loaded")
if not request.paragraphs:
return RerankResponse(
scores=[],
success=True,
message="No paragraphs to rerank"
)
try:
# Prepare sentence pairs: [[query, paragraph], ...]
sentence_pairs = [[request.query, p] for p in request.paragraphs]
# Compute scores
scores = _reranker.compute_score(
sentence_pairs,
batch_size=request.batch_size
)
# Handle score format (can be float or list)
if isinstance(scores, float):
scores = [scores]
elif not isinstance(scores, list):
scores = list(scores)
return RerankResponse(
scores=scores,
success=True
)
except Exception as e:
print(f"Error during reranking: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))
return app
def main():
"""Main entry point for reranker API service"""
parser = argparse.ArgumentParser(description="Reranker API Service")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to reranker model (e.g., 'OpenScholar/OpenScholar_Reranker')"
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to (default: 0.0.0.0)"
)
parser.add_argument(
"--port",
type=int,
default=8004,
help="Port to bind to (default: 8004)"
)
parser.add_argument(
"--use_fp16",
action="store_true",
default=True,
help="Use FP16 precision (default: True)"
)
parser.add_argument(
"--no_fp16",
dest="use_fp16",
action="store_false",
help="Disable FP16 precision"
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g., 'cuda:0', 'cuda:1')"
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes (default: 1, use 1 for reranker)"
)
args = parser.parse_args()
if not HAS_FASTAPI:
print("Error: FastAPI not installed. Install with: pip install fastapi uvicorn")
sys.exit(1)
if not HAS_FLAGEMBEDDING:
print("Error: FlagEmbedding not installed. Install with: pip install FlagEmbedding")
sys.exit(1)
# Create app
app = create_app(
model_path=args.model_path,
use_fp16=args.use_fp16,
device=args.device
)
# Run server
print(f"Starting reranker API service on {args.host}:{args.port}")
print(f"Model: {args.model_path}")
print(f"FP16: {args.use_fp16}")
uvicorn.run(
app,
host=args.host,
port=args.port,
workers=args.workers,
log_level="info"
)
if __name__ == "__main__":
main()
|