File size: 5,449 Bytes
0423728
dadefb7
 
 
 
 
 
 
 
 
 
0423728
 
dadefb7
 
0423728
 
 
dadefb7
0423728
60e39a0
dadefb7
 
 
 
0423728
dadefb7
 
0423728
dadefb7
0423728
 
 
dadefb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0423728
dadefb7
0423728
 
dadefb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0423728
 
32b1fc8
 
dadefb7
32b1fc8
dadefb7
 
 
32b1fc8
dadefb7
32b1fc8
 
 
dadefb7
 
 
 
 
0423728
dadefb7
 
0423728
dadefb7
 
 
 
 
 
 
 
 
 
 
 
0423728
dadefb7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""

Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전).



모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2).

ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상.



환경변수:

- HF_MODEL_ID: 모델 ID (필수)

- HF_TOKEN: 비공개 리포 접근용 토큰 (선택)

- USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true)

- ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU)

"""
import os
import time
from typing import List, Optional

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel

MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip()
HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true"
ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider")

model = None
tokenizer = None
is_onnx_mode = False

app = FastAPI(title="Reranker API (ONNX)")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])


class RerankerRequest(BaseModel):
    query: str
    passages: List[str]


class RerankerResponse(BaseModel):
    scores: List[float]
    latency_ms: Optional[float] = None
    mode: Optional[str] = None
    error: Optional[str] = None


def load_onnx_model():
    """ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환."""
    global model, tokenizer, is_onnx_mode
    try:
        from optimum.onnxruntime import ORTModelForSequenceClassification
        from transformers import AutoTokenizer
        
        print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})")
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
        model = ORTModelForSequenceClassification.from_pretrained(
            MODEL_ID,
            export=True,  # PyTorch → ONNX 자동 변환
            provider=ONNX_PROVIDER,
            token=HF_TOKEN,
        )
        is_onnx_mode = True
        print("[ONNX] 로드 완료.")
    except Exception as e:
        print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}")
        load_crossencoder_model()


def load_crossencoder_model():
    """기존 CrossEncoder 모델 로드 (fallback)."""
    global model, is_onnx_mode
    from sentence_transformers import CrossEncoder
    print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}")
    model = CrossEncoder(MODEL_ID, token=HF_TOKEN)
    is_onnx_mode = False
    print("[CrossEncoder] 로드 완료.")


@app.on_event("startup")
def startup_load_model():
    global model
    if not MODEL_ID:
        raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요")
    
    if USE_ONNX:
        load_onnx_model()
    else:
        load_crossencoder_model()


def predict_onnx(query: str, passages: List[str]) -> List[float]:
    """ONNX 모델로 batch predict."""
    import torch
    pairs = [[query, p] for p in passages]
    inputs = tokenizer(
        pairs,
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt",
    )
    outputs = model(**inputs)
    # logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score)
    logits = outputs.logits
    if logits.shape[-1] == 1:
        scores = logits.squeeze(-1)
    else:
        # 2개 label이면 softmax 후 positive class 확률
        scores = torch.softmax(logits, dim=-1)[:, 1]
    return scores.tolist()


def predict_crossencoder(query: str, passages: List[str]) -> List[float]:
    """CrossEncoder 모델로 predict."""
    pairs = [(query, p) for p in passages]
    scores = model.predict(pairs)
    if hasattr(scores, "tolist"):
        scores = scores.tolist()
    return [float(s) for s in scores]


@app.get("/")
def root():
    """브라우저/App 탭 접속 시 안내."""
    return {
        "message": "CertWeb Reranker API (ONNX Optimized)",
        "model": MODEL_ID,
        "mode": "onnx" if is_onnx_mode else "crossencoder",
        "usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}",
        "response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}",
    }


@app.post("/", response_model=RerankerResponse)
def score(request: RerankerRequest):
    query = request.query
    passages = request.passages
    
    if not query or not passages or model is None:
        return RerankerResponse(scores=[], error="Empty query/passages or model not loaded")
    
    try:
        start = time.perf_counter()
        if is_onnx_mode:
            scores = predict_onnx(query, passages)
        else:
            scores = predict_crossencoder(query, passages)
        latency_ms = (time.perf_counter() - start) * 1000
        
        return RerankerResponse(
            scores=scores,
            latency_ms=round(latency_ms, 2),
            mode="onnx" if is_onnx_mode else "crossencoder",
        )
    except Exception as e:
        return RerankerResponse(scores=[], error=str(e))


@app.get("/health")
def health():
    """Health check endpoint."""
    return {
        "status": "ok",
        "model_loaded": model is not None,
        "mode": "onnx" if is_onnx_mode else "crossencoder",
    }