pls-rag / models /reranker.py
m97j's picture
Initial codes commit
cff5d3d
raw
history blame
1.14 kB
# rag/models/reranker.py
import os
from typing import List, Dict
import onnxruntime as ort
from fastapi import Request
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
"""
request.app.state.reranker_sess : ONNX Runtime InferenceSession
request.app.state.reranker_tokenizer : ν† ν¬λ‚˜μ΄μ €
contexts: [{"id": ..., "text": ...}, ...]
"""
if not contexts:
return []
tokenizer = request.app.state.reranker_tokenizer
sess: ort.InferenceSession = request.app.state.reranker_sess
pairs = [(query, ctx["text"]) for ctx in contexts]
inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
ort_inputs = {k: v for k, v in inputs.items()}
scores = sess.run(None, ort_inputs)[0] # [batch, 1] ν˜•νƒœ
scores = scores.squeeze(-1)
for ctx, sc in zip(contexts, scores):
ctx["score"] = float(sc)
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
return reranked