multifuly commited on
Commit
dadefb7
·
verified ·
1 Parent(s): 32b1fc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -26
app.py CHANGED
@@ -1,54 +1,165 @@
1
  """
2
- Hugging Face Spaces에서 돌리는 Reranker API.
3
- 모델은 HF Hub에서 로드 (환경변수 HF_MODEL_ID, 예: username/certweb-reranker).
4
- 무료 CPU Basic(16GB RAM)에서 동작. 슬립 후 첫 요청은 콜스타트 30초~1분 걸릴 있음.
 
 
 
 
 
 
 
5
  """
6
  import os
 
 
7
 
8
  from fastapi import FastAPI
9
  from fastapi.middleware.cors import CORSMiddleware
 
10
 
11
  MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip()
12
- HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None # 비공개 리포일 때만 설정
 
 
 
13
  model = None
 
 
14
 
15
- app = FastAPI(title="Reranker API")
16
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @app.on_event("startup")
20
- def load_model():
21
  global model
22
  if not MODEL_ID:
23
- raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요 (예: username/certweb-reranker)")
24
- from sentence_transformers import CrossEncoder
25
- print(f"모델 로딩 중: {MODEL_ID}")
26
- # 비공개 리포면 HF_TOKEN (Secrets) 설정 필요
27
- model = CrossEncoder(MODEL_ID, token=HF_TOKEN)
28
- print("로드 완료.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  @app.get("/")
32
  def root():
33
- """브라우저/App 탭 접속 시 안내. 실제 점수는 POST / 로 요청."""
34
  return {
35
- "message": "CertWeb Reranker API",
 
 
36
  "usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}",
37
- "response": "{\"scores\": [0.9, 0.2, ...]}",
38
  }
39
 
40
 
41
- @app.post("/")
42
- def score(request: dict):
43
- query = request.get("query") or ""
44
- passages = request.get("passages") or []
 
45
  if not query or not passages or model is None:
46
- return {"scores": []}
 
47
  try:
48
- pairs = [(query, p) for p in passages]
49
- scores = model.predict(pairs)
50
- if hasattr(scores, "tolist"):
51
- scores = scores.tolist()
52
- return {"scores": [float(s) for s in scores]}
 
 
 
 
 
 
 
53
  except Exception as e:
54
- return {"scores": [], "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Hugging Face Spaces에서 돌리는 Reranker API (ONNX 최적화 버전).
3
+
4
+ 모델은 HF Hub에서 (환경변수 HF_MODEL_ID, 예: cross-encoder/ms-marco-MiniLM-L-6-v2).
5
+ ONNX Runtime으로 변환하여 CPU 추론 속도 2-3배 향상.
6
+
7
+ 환경변수:
8
+ - HF_MODEL_ID: 모델 ID (필수)
9
+ - HF_TOKEN: 비공개 리포 접근용 토큰 (선택)
10
+ - USE_ONNX: "true"면 ONNX Runtime 사용 (기본: true)
11
+ - ONNX_PROVIDER: "CPUExecutionProvider" 또는 "CUDAExecutionProvider" (기본: CPU)
12
  """
13
  import os
14
+ import time
15
+ from typing import List, Optional
16
 
17
  from fastapi import FastAPI
18
  from fastapi.middleware.cors import CORSMiddleware
19
+ from pydantic import BaseModel
20
 
21
  MODEL_ID = (os.environ.get("HF_MODEL_ID") or "").strip()
22
+ HF_TOKEN = (os.environ.get("HF_TOKEN") or "").strip() or None
23
+ USE_ONNX = os.environ.get("USE_ONNX", "true").lower() == "true"
24
+ ONNX_PROVIDER = os.environ.get("ONNX_PROVIDER", "CPUExecutionProvider")
25
+
26
  model = None
27
+ tokenizer = None
28
+ is_onnx_mode = False
29
 
30
+ app = FastAPI(title="Reranker API (ONNX)")
31
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
32
 
33
 
34
+ class RerankerRequest(BaseModel):
35
+ query: str
36
+ passages: List[str]
37
+
38
+
39
+ class RerankerResponse(BaseModel):
40
+ scores: List[float]
41
+ latency_ms: Optional[float] = None
42
+ mode: Optional[str] = None
43
+ error: Optional[str] = None
44
+
45
+
46
+ def load_onnx_model():
47
+ """ONNX Runtime으로 모델 로드. 최초 실행 시 자동 변환."""
48
+ global model, tokenizer, is_onnx_mode
49
+ try:
50
+ from optimum.onnxruntime import ORTModelForSequenceClassification
51
+ from transformers import AutoTokenizer
52
+
53
+ print(f"[ONNX] 모델 로딩 중: {MODEL_ID} (provider={ONNX_PROVIDER})")
54
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
55
+ model = ORTModelForSequenceClassification.from_pretrained(
56
+ MODEL_ID,
57
+ export=True, # PyTorch → ONNX 자동 변환
58
+ provider=ONNX_PROVIDER,
59
+ token=HF_TOKEN,
60
+ )
61
+ is_onnx_mode = True
62
+ print("[ONNX] 로드 완료.")
63
+ except Exception as e:
64
+ print(f"[ONNX] 로드 실패, fallback to CrossEncoder: {e}")
65
+ load_crossencoder_model()
66
+
67
+
68
+ def load_crossencoder_model():
69
+ """기존 CrossEncoder 모델 로드 (fallback)."""
70
+ global model, is_onnx_mode
71
+ from sentence_transformers import CrossEncoder
72
+ print(f"[CrossEncoder] 모델 로딩 중: {MODEL_ID}")
73
+ model = CrossEncoder(MODEL_ID, token=HF_TOKEN)
74
+ is_onnx_mode = False
75
+ print("[CrossEncoder] 로드 완료.")
76
+
77
+
78
  @app.on_event("startup")
79
+ def startup_load_model():
80
  global model
81
  if not MODEL_ID:
82
+ raise RuntimeError("Space 설정에서 HF_MODEL_ID 를 지정하세요")
83
+
84
+ if USE_ONNX:
85
+ load_onnx_model()
86
+ else:
87
+ load_crossencoder_model()
88
+
89
+
90
+ def predict_onnx(query: str, passages: List[str]) -> List[float]:
91
+ """ONNX 모델로 batch predict."""
92
+ import torch
93
+ pairs = [[query, p] for p in passages]
94
+ inputs = tokenizer(
95
+ pairs,
96
+ padding=True,
97
+ truncation=True,
98
+ max_length=512,
99
+ return_tensors="pt",
100
+ )
101
+ outputs = model(**inputs)
102
+ # logits shape: (batch_size, num_labels) - 보통 1개 label (relevance score)
103
+ logits = outputs.logits
104
+ if logits.shape[-1] == 1:
105
+ scores = logits.squeeze(-1)
106
+ else:
107
+ # 2개 label이면 softmax 후 positive class 확률
108
+ scores = torch.softmax(logits, dim=-1)[:, 1]
109
+ return scores.tolist()
110
+
111
+
112
+ def predict_crossencoder(query: str, passages: List[str]) -> List[float]:
113
+ """CrossEncoder 모델로 predict."""
114
+ pairs = [(query, p) for p in passages]
115
+ scores = model.predict(pairs)
116
+ if hasattr(scores, "tolist"):
117
+ scores = scores.tolist()
118
+ return [float(s) for s in scores]
119
 
120
 
121
  @app.get("/")
122
  def root():
123
+ """브라우저/App 탭 접속 시 안내."""
124
  return {
125
+ "message": "CertWeb Reranker API (ONNX Optimized)",
126
+ "model": MODEL_ID,
127
+ "mode": "onnx" if is_onnx_mode else "crossencoder",
128
  "usage": "POST / with JSON body: {\"query\": \"질문\", \"passages\": [\"문단1\", \"문단2\", ...]}",
129
+ "response": "{\"scores\": [0.9, 0.2, ...], \"latency_ms\": 123.4, \"mode\": \"onnx\"}",
130
  }
131
 
132
 
133
+ @app.post("/", response_model=RerankerResponse)
134
+ def score(request: RerankerRequest):
135
+ query = request.query
136
+ passages = request.passages
137
+
138
  if not query or not passages or model is None:
139
+ return RerankerResponse(scores=[], error="Empty query/passages or model not loaded")
140
+
141
  try:
142
+ start = time.perf_counter()
143
+ if is_onnx_mode:
144
+ scores = predict_onnx(query, passages)
145
+ else:
146
+ scores = predict_crossencoder(query, passages)
147
+ latency_ms = (time.perf_counter() - start) * 1000
148
+
149
+ return RerankerResponse(
150
+ scores=scores,
151
+ latency_ms=round(latency_ms, 2),
152
+ mode="onnx" if is_onnx_mode else "crossencoder",
153
+ )
154
  except Exception as e:
155
+ return RerankerResponse(scores=[], error=str(e))
156
+
157
+
158
+ @app.get("/health")
159
+ def health():
160
+ """Health check endpoint."""
161
+ return {
162
+ "status": "ok",
163
+ "model_loaded": model is not None,
164
+ "mode": "onnx" if is_onnx_mode else "crossencoder",
165
+ }