File size: 1,136 Bytes
2aa7bf4 cff5d3d 2aa7bf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# rag/models/reranker.py
import os
from typing import List, Dict
import onnxruntime as ort
from fastapi import Request
THRESHOLD = float(os.getenv("RERANK_THRESHOLD", "0.3"))
def rerank(request: Request, query: str, contexts: List[Dict]) -> List[Dict]:
"""
request.app.state.reranker_sess : ONNX Runtime InferenceSession
request.app.state.reranker_tokenizer : ν ν¬λμ΄μ
contexts: [{"id": ..., "text": ...}, ...]
"""
if not contexts:
return []
tokenizer = request.app.state.reranker_tokenizer
sess: ort.InferenceSession = request.app.state.reranker_sess
pairs = [(query, ctx["text"]) for ctx in contexts]
inputs = tokenizer(pairs, return_tensors="np", padding=True, truncation=True, max_length=256)
ort_inputs = {k: v for k, v in inputs.items()}
scores = sess.run(None, ort_inputs)[0] # [batch, 1] νν
scores = scores.squeeze(-1)
for ctx, sc in zip(contexts, scores):
ctx["score"] = float(sc)
reranked = sorted(contexts, key=lambda x: x["score"], reverse=True)
reranked = [c for c in reranked if c["score"] >= THRESHOLD]
return reranked
|