Spaces:
Running
Running
File size: 5,449 Bytes
0423728 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 60e39a0 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 32b1fc8 dadefb7 32b1fc8 dadefb7 32b1fc8 dadefb7 32b1fc8 dadefb7 0423728 dadefb7 0423728 dadefb7 0423728 dadefb7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전).
모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2).
ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상.
환경변수:
- HF_MODEL_ID: 모델 ID (필수)
- HF_TOKEN: 비공개 리포 접근용 토큰 (선택)
- USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true)
- ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU)
"""
import os
import time
from typing import List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip()
HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true"
ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider")
model = None
tokenizer = None
is_onnx_mode = False
app = FastAPI(title="Reranker API (ONNX)")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
class RerankerRequest(BaseModel):
query: str
passages: List[str]
class RerankerResponse(BaseModel):
scores: List[float]
latency_ms: Optional[float] = None
mode: Optional[str] = None
error: Optional[str] = None
def load_onnx_model():
"""ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환."""
global model, tokenizer, is_onnx_mode
try:
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer
print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = ORTModelForSequenceClassification.from_pretrained(
MODEL_ID,
export=True, # PyTorch → ONNX 자동 변환
provider=ONNX_PROVIDER,
token=HF_TOKEN,
)
is_onnx_mode = True
print("[ONNX] 로드 완료.")
except Exception as e:
print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}")
load_crossencoder_model()
def load_crossencoder_model():
"""기존 CrossEncoder 모델 로드 (fallback)."""
global model, is_onnx_mode
from sentence_transformers import CrossEncoder
print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}")
model = CrossEncoder(MODEL_ID, token=HF_TOKEN)
is_onnx_mode = False
print("[CrossEncoder] 로드 완료.")
@app.on_event("startup")
def startup_load_model():
global model
if not MODEL_ID:
raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요")
if USE_ONNX:
load_onnx_model()
else:
load_crossencoder_model()
def predict_onnx(query: str, passages: List[str]) -> List[float]:
"""ONNX 모델로 batch predict."""
import torch
pairs = [[query, p] for p in passages]
inputs = tokenizer(
pairs,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt",
)
outputs = model(**inputs)
# logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score)
logits = outputs.logits
if logits.shape[-1] == 1:
scores = logits.squeeze(-1)
else:
# 2개 label이면 softmax 후 positive class 확률
scores = torch.softmax(logits, dim=-1)[:, 1]
return scores.tolist()
def predict_crossencoder(query: str, passages: List[str]) -> List[float]:
"""CrossEncoder 모델로 predict."""
pairs = [(query, p) for p in passages]
scores = model.predict(pairs)
if hasattr(scores, "tolist"):
scores = scores.tolist()
return [float(s) for s in scores]
@app.get("/")
def root():
"""브라우저/App 탭 접속 시 안내."""
return {
"message": "CertWeb Reranker API (ONNX Optimized)",
"model": MODEL_ID,
"mode": "onnx" if is_onnx_mode else "crossencoder",
"usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}",
"response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}",
}
@app.post("/", response_model=RerankerResponse)
def score(request: RerankerRequest):
query = request.query
passages = request.passages
if not query or not passages or model is None:
return RerankerResponse(scores=[], error="Empty query/passages or model not loaded")
try:
start = time.perf_counter()
if is_onnx_mode:
scores = predict_onnx(query, passages)
else:
scores = predict_crossencoder(query, passages)
latency_ms = (time.perf_counter() - start) * 1000
return RerankerResponse(
scores=scores,
latency_ms=round(latency_ms, 2),
mode="onnx" if is_onnx_mode else "crossencoder",
)
except Exception as e:
return RerankerResponse(scores=[], error=str(e))
@app.get("/health")
def health():
"""Health check endpoint."""
return {
"status": "ok",
"model_loaded": model is not None,
"mode": "onnx" if is_onnx_mode else "crossencoder",
}
|