multifuly's picture
Update app.py
dadefb7 verified
"""
Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전).
모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2).
ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상.
환경변수:
- HF_MODEL_ID: 모델 ID (필수)
- HF_TOKEN: 비공개 리포 접근용 토큰 (선택)
- USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true)
- ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU)
"""
import os
import time
from typing import List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip()
HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true"
ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider")
model = None
tokenizer = None
is_onnx_mode = False
app = FastAPI(title="Reranker API (ONNX)")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
class RerankerRequest(BaseModel):
query: str
passages: List[str]
class RerankerResponse(BaseModel):
scores: List[float]
latency_ms: Optional[float] = None
mode: Optional[str] = None
error: Optional[str] = None
def load_onnx_model():
"""ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환."""
global model, tokenizer, is_onnx_mode
try:
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = ORTModelForSequenceClassification.from_pretrained(
MODEL_ID,
export=True, # PyTorch → ONNX 자동 변환
provider=ONNX_PROVIDER,
token=HF_TOKEN,
)
is_onnx_mode = True
print("[ONNX] 로드 완료.")
except Exception as e:
print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}")
load_crossencoder_model()
def load_crossencoder_model():
"""기존 CrossEncoder 모델 로드 (fallback)."""
global model, is_onnx_mode
from sentence_transformers import CrossEncoder
print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}")
model = CrossEncoder(MODEL_ID, token=HF_TOKEN)
is_onnx_mode = False
print("[CrossEncoder] 로드 완료.")
@app.on_event("startup")
def startup_load_model():
global model
if not MODEL_ID:
raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요")
if USE_ONNX:
load_onnx_model()
else:
load_crossencoder_model()
def predict_onnx(query: str, passages: List[str]) -> List[float]:
"""ONNX 모델로 batch predict."""
import torch
pairs = [[query, p] for p in passages]
inputs = tokenizer(
pairs,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
outputs = model(**inputs)
# logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score)
logits = outputs.logits
if logits.shape[-1] == 1:
scores = logits.squeeze(-1)
else:
# 2개 label이면 softmax 후 positive class 확률
scores = torch.softmax(logits, dim=-1)[:, 1]
return scores.tolist()
def predict_crossencoder(query: str, passages: List[str]) -> List[float]:
"""CrossEncoder 모델로 predict."""
pairs = [(query, p) for p in passages]
scores = model.predict(pairs)
if hasattr(scores, "tolist"):
scores = scores.tolist()
return [float(s) for s in scores]
@app.get("/")
def root():
"""브라우저/App 탭 접속 시 안내."""
return {
"message": "CertWeb Reranker API (ONNX Optimized)",
"model": MODEL_ID,
"mode": "onnx" if is_onnx_mode else "crossencoder",
"usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}",
"response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}",
}
@app.post("/", response_model=RerankerResponse)
def score(request: RerankerRequest):
query = request.query
passages = request.passages
if not query or not passages or model is None:
return RerankerResponse(scores=[], error="Empty query/passages or model not loaded")
try:
start = time.perf_counter()
if is_onnx_mode:
scores = predict_onnx(query, passages)
else:
scores = predict_crossencoder(query, passages)
latency_ms = (time.perf_counter() - start) * 1000
return RerankerResponse(
scores=scores,
latency_ms=round(latency_ms, 2),
mode="onnx" if is_onnx_mode else "crossencoder",
)
except Exception as e:
return RerankerResponse(scores=[], error=str(e))
@app.get("/health")
def health():
"""Health check endpoint."""
return {
"status": "ok",
"model_loaded": model is not None,
"mode": "onnx" if is_onnx_mode else "crossencoder",
}