|
|
|
|
|
import os |
|
|
from typing import List, Dict |
|
|
import onnxruntime as ort |
|
|
from fastapi import Request |
|
|
|
|
|
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3")) |
|
|
|
|
|
def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]: |
|
|
""" |
|
|
request.app.state.reranker_sess : ONNX Runtime InferenceSession |
|
|
request.app.state.reranker_tokenizer : ν ν¬λμ΄μ |
|
|
contexts: [{"id": ..., "text": ...}, ...] |
|
|
""" |
|
|
if not contexts: |
|
|
return [] |
|
|
|
|
|
tokenizer = request.app.state.reranker_tokenizer |
|
|
sess: ort.InferenceSession = request.app.state.reranker_sess |
|
|
|
|
|
pairs = [(query, ctx["text"]) for ctx in contexts] |
|
|
inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256) |
|
|
ort_inputs = {k: v for k, v in inputs.items()} |
|
|
scores = sess.run(None, ort_inputs)[0] |
|
|
scores = scores.squeeze(-1) |
|
|
|
|
|
for ctx, sc in zip(contexts, scores): |
|
|
ctx["score"] = float(sc) |
|
|
|
|
|
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True) |
|
|
reranked = [c for c in reranked if c["score"] >= THRESHOLD] |
|
|
return reranked |
|
|
|