""" Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전). 모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2). ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상. 환경변수: - HF_MODEL_ID: 모델 ID (필수) - HF_TOKEN: 비공개 리포 접근용 토큰 (선택) - USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true) - ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU) """ import os import time from typing import List, Optional from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip() HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true" ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider") model = None tokenizer = None is_onnx_mode = False app = FastAPI(title="Reranker API (ONNX)") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) class RerankerRequest(BaseModel): query: str passages: List[str] class RerankerResponse(BaseModel): scores: List[float] latency_ms: Optional[float] = None mode: Optional[str] = None error: Optional[str] = None def load_onnx_model(): """ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환.""" global model, tokenizer, is_onnx_mode try: from optimum.onnxruntime import ORTModelForSequenceClassification from transformers import AutoTokenizer print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) model = ORTModelForSequenceClassification.from_pretrained( MODEL_ID, export=True, # PyTorch → ONNX 자동 변환 provider=ONNX_PROVIDER, token=HF_TOKEN, ) is_onnx_mode = True print("[ONNX] 로드 완료.") except Exception as e: print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}") load_crossencoder_model() def load_crossencoder_model(): """기존 CrossEncoder 모델 로드 (fallback).""" global model, is_onnx_mode from sentence_transformers import CrossEncoder print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}") model = CrossEncoder(MODEL_ID, token=HF_TOKEN) is_onnx_mode = False print("[CrossEncoder] 로드 완료.") @app.on_event("startup") def startup_load_model(): global model if not MODEL_ID: raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요") if USE_ONNX: load_onnx_model() else: load_crossencoder_model() def predict_onnx(query: str, passages: List[str]) -> List[float]: """ONNX 모델로 batch predict.""" import torch pairs = [[query, p] for p in passages] inputs = tokenizer( pairs, padding=True, truncation=True, max_length=512, return_tensors="pt", ) outputs = model(**inputs) # logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score) logits = outputs.logits if logits.shape[-1] == 1: scores = logits.squeeze(-1) else: # 2개 label이면 softmax 후 positive class 확률 scores = torch.softmax(logits, dim=-1)[:, 1] return scores.tolist() def predict_crossencoder(query: str, passages: List[str]) -> List[float]: """CrossEncoder 모델로 predict.""" pairs = [(query, p) for p in passages] scores = model.predict(pairs) if hasattr(scores, "tolist"): scores = scores.tolist() return [float(s) for s in scores] @app.get("/") def root(): """브라우저/App 탭 접속 시 안내.""" return { "message": "CertWeb Reranker API (ONNX Optimized)", "model": MODEL_ID, "mode": "onnx" if is_onnx_mode else "crossencoder", "usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}", "response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}", } @app.post("/", response_model=RerankerResponse) def score(request: RerankerRequest): query = request.query passages = request.passages if not query or not passages or model is None: return RerankerResponse(scores=[], error="Empty query/passages or model not loaded") try: start = time.perf_counter() if is_onnx_mode: scores = predict_onnx(query, passages) else: scores = predict_crossencoder(query, passages) latency_ms = (time.perf_counter() - start) * 1000 return RerankerResponse( scores=scores, latency_ms=round(latency_ms, 2), mode="onnx" if is_onnx_mode else "crossencoder", ) except Exception as e: return RerankerResponse(scores=[], error=str(e)) @app.get("/health") def health(): """Health check endpoint.""" return { "status": "ok", "model_loaded": model is not None, "mode": "onnx" if is_onnx_mode else "crossencoder", }