Spaces:
Sleeping
Sleeping
| """ | |
| Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전). | |
| 모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2). | |
| ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상. | |
| 환경변수: | |
| - HF_MODEL_ID: 모델 ID (필수) | |
| - HF_TOKEN: 비공개 리포 접근용 토큰 (선택) | |
| - USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true) | |
| - ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU) | |
| """ | |
| import os | |
| import time | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip() | |
| HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None | |
| USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true" | |
| ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider") | |
| model = None | |
| tokenizer = None | |
| is_onnx_mode = False | |
| app = FastAPI(title="Reranker API (ONNX)") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) | |
| class RerankerRequest(BaseModel): | |
| query: str | |
| passages: List[str] | |
| class RerankerResponse(BaseModel): | |
| scores: List[float] | |
| latency_ms: Optional[float] = None | |
| mode: Optional[str] = None | |
| error: Optional[str] = None | |
| def load_onnx_model(): | |
| """ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환.""" | |
| global model, tokenizer, is_onnx_mode | |
| try: | |
| from optimum.onnxruntime import ORTModelForSequenceClassification | |
| from transformers import AutoTokenizer | |
| print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| model = ORTModelForSequenceClassification.from_pretrained( | |
| MODEL_ID, | |
| export=True, # PyTorch → ONNX 자동 변환 | |
| provider=ONNX_PROVIDER, | |
| token=HF_TOKEN, | |
| ) | |
| is_onnx_mode = True | |
| print("[ONNX] 로드 완료.") | |
| except Exception as e: | |
| print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}") | |
| load_crossencoder_model() | |
| def load_crossencoder_model(): | |
| """기존 CrossEncoder 모델 로드 (fallback).""" | |
| global model, is_onnx_mode | |
| from sentence_transformers import CrossEncoder | |
| print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}") | |
| model = CrossEncoder(MODEL_ID, token=HF_TOKEN) | |
| is_onnx_mode = False | |
| print("[CrossEncoder] 로드 완료.") | |
| def startup_load_model(): | |
| global model | |
| if not MODEL_ID: | |
| raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요") | |
| if USE_ONNX: | |
| load_onnx_model() | |
| else: | |
| load_crossencoder_model() | |
| def predict_onnx(query: str, passages: List[str]) -> List[float]: | |
| """ONNX 모델로 batch predict.""" | |
| import torch | |
| pairs = [[query, p] for p in passages] | |
| inputs = tokenizer( | |
| pairs, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt", | |
| ) | |
| outputs = model(**inputs) | |
| # logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score) | |
| logits = outputs.logits | |
| if logits.shape[-1] == 1: | |
| scores = logits.squeeze(-1) | |
| else: | |
| # 2개 label이면 softmax 후 positive class 확률 | |
| scores = torch.softmax(logits, dim=-1)[:, 1] | |
| return scores.tolist() | |
| def predict_crossencoder(query: str, passages: List[str]) -> List[float]: | |
| """CrossEncoder 모델로 predict.""" | |
| pairs = [(query, p) for p in passages] | |
| scores = model.predict(pairs) | |
| if hasattr(scores, "tolist"): | |
| scores = scores.tolist() | |
| return [float(s) for s in scores] | |
| def root(): | |
| """브라우저/App 탭 접속 시 안내.""" | |
| return { | |
| "message": "CertWeb Reranker API (ONNX Optimized)", | |
| "model": MODEL_ID, | |
| "mode": "onnx" if is_onnx_mode else "crossencoder", | |
| "usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}", | |
| "response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}", | |
| } | |
| def score(request: RerankerRequest): | |
| query = request.query | |
| passages = request.passages | |
| if not query or not passages or model is None: | |
| return RerankerResponse(scores=[], error="Empty query/passages or model not loaded") | |
| try: | |
| start = time.perf_counter() | |
| if is_onnx_mode: | |
| scores = predict_onnx(query, passages) | |
| else: | |
| scores = predict_crossencoder(query, passages) | |
| latency_ms = (time.perf_counter() - start) * 1000 | |
| return RerankerResponse( | |
| scores=scores, | |
| latency_ms=round(latency_ms, 2), | |
| mode="onnx" if is_onnx_mode else "crossencoder", | |
| ) | |
| except Exception as e: | |
| return RerankerResponse(scores=[], error=str(e)) | |
| def health(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "ok", | |
| "model_loaded": model is not None, | |
| "mode": "onnx" if is_onnx_mode else "crossencoder", | |
| } | |